From 22b5b477956ec884dc3295bc67c580d2f1e4cfed Mon Sep 17 00:00:00 2001 From: crumbtoo Date: Tue, 23 Jan 2024 20:19:16 -0700 Subject: [PATCH] letrec --- src/Core/HindleyMilner.hs | 38 +++++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/src/Core/HindleyMilner.hs b/src/Core/HindleyMilner.hs index 4cffcca..6d9bfe9 100644 --- a/src/Core/HindleyMilner.hs +++ b/src/Core/HindleyMilner.hs @@ -3,6 +3,7 @@ Module : Core.HindleyMilner Description : Hindley-Milner type system -} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} module Core.HindleyMilner ( Context' , infer @@ -16,12 +17,13 @@ module Core.HindleyMilner ---------------------------------------------------------------------------------- import Lens.Micro import Lens.Micro.Mtl +import Lens.Micro.Platform import Data.Maybe (fromMaybe) import Data.Text qualified as T import Data.HashMap.Strict qualified as H import Data.Foldable (traverse_) import Compiler.RLPC -import Control.Monad (foldM, void) +import Control.Monad (foldM, void, forM) import Control.Monad.Errorful (Errorful, addFatal) import Control.Monad.State import Control.Monad.Utils (mapAccumLM) @@ -152,8 +154,28 @@ gather = \g e -> runStateT (go g e) ([],0) <&> \ (t,(cs,_)) -> (t,cs) where Let NonRec bs e -> do g' <- buildLetContext g bs go g' e + Let Rec bs e -> do + g' <- buildLetrecContext g bs + go g' e + -- TODO letrec, lambda, case + buildLetrecContext :: Context' -> [Binding'] + -> StateT ([Constraint], Int) HMError Context' + buildLetrecContext g bs = do + let f ag (k := _) = do + n <- uniqueVar + pure ((k,n) : ag) + rg <- foldM f g bs + let k ag (k := v) = do + t <- go rg v + pure ((k,t) : ag) + foldM k g bs + + -- | augment a context with the inferred types of each binder. the returned + -- context is linearly accumulated, meaning that the context used to infer each binder + -- will include the inferred types of all previous binder + buildLetContext :: Context' -> [Binding'] -> StateT ([Constraint], Int) HMError Context' buildLetContext = foldM k where @@ -230,3 +252,17 @@ subst x t (TyVar y) | x == y = t subst x t (a :-> b) = subst x t a :-> subst x t b subst _ _ e = e +-------------------------------------------------------------------------------- + +demoContext :: Context' +demoContext = + [ ("fix", (TyVar "a" :-> TyVar "a") :-> TyVar "a") + , ("add", TyInt :-> TyInt :-> TyInt) + ] + +pprintType :: Type -> String +pprintType (s :-> t) = "(" <> pprintType s <> " -> " <> pprintType t <> ")" +pprintType TyFun = "(->)" +pprintType (TyVar x) = x ^. unpacked +pprintType (TyCon t) = t ^. unpacked +