From e222dae6ac4598b5a464a69fe886e550f82d7f19 Mon Sep 17 00:00:00 2001 From: crumbtoo Date: Mon, 18 Dec 2023 12:21:53 -0700 Subject: [PATCH] infer nonrec let binds infer nonrec let binds --- rlp.cabal | 1 + src/Control/Monad/Utils.hs | 21 +++++++++++++++++++++ src/Core/HindleyMilner.hs | 15 +++++++++++++-- tst/Core/HindleyMilnerSpec.hs | 11 +++++++++++ 4 files changed, 46 insertions(+), 2 deletions(-) create mode 100644 src/Control/Monad/Utils.hs diff --git a/rlp.cabal b/rlp.cabal index f5b26ab..4803362 100644 --- a/rlp.cabal +++ b/rlp.cabal @@ -34,6 +34,7 @@ library , Core.Lex , Core2Core , Control.Monad.Errorful + , Control.Monad.Utils , RLP.Syntax build-tool-depends: happy:happy, alex:alex diff --git a/src/Control/Monad/Utils.hs b/src/Control/Monad/Utils.hs new file mode 100644 index 0000000..6cc5521 --- /dev/null +++ b/src/Control/Monad/Utils.hs @@ -0,0 +1,21 @@ +module Control.Monad.Utils + ( mapAccumLM + ) + where +---------------------------------------------------------------------------------- +import Data.Tuple (swap) +import Control.Monad.State +---------------------------------------------------------------------------------- + +-- | Monadic variant of @mapAccumL@ + +mapAccumLM :: forall m t s a b. (Monad m, Traversable t) + => (s -> a -> m (s, b)) + -> s + -> t a + -> m (s, t b) +mapAccumLM k s t = swap <$> runStateT (traverse k' t) s + where + k' :: a -> StateT s m b + k' a = StateT $ fmap swap <$> flip k a + diff --git a/src/Core/HindleyMilner.hs b/src/Core/HindleyMilner.hs index b59b7a5..0897ece 100644 --- a/src/Core/HindleyMilner.hs +++ b/src/Core/HindleyMilner.hs @@ -13,10 +13,10 @@ module Core.HindleyMilner ---------------------------------------------------------------------------------- import Lens.Micro import Lens.Micro.Mtl -import Data.Set qualified as S -import Data.Set (Set) import Data.Maybe (fromMaybe) +import Control.Monad (foldM) import Control.Monad.State +import Control.Monad.Utils (mapAccumLM) import Core.Syntax ---------------------------------------------------------------------------------- @@ -85,6 +85,17 @@ gather = \g e -> runStateT (go g e) ([],0) <&> \ (t,(cs,_)) -> (t,cs) where tfx <- uniqueVar addConstraint tf (tx :-> tfx) pure tfx + Let NonRec bs e -> do + g' <- buildLetContext g bs + go g' e + + buildLetContext :: Context' -> [Binding'] + -> StateT ([Constraint], Int) HMError Context' + buildLetContext = foldM k where + k :: Context' -> Binding' -> StateT ([Constraint], Int) HMError Context' + k g (x := y) = do + ty <- go g y + pure ((x,ty) : g) uniqueVar :: StateT ([Constraint], Int) HMError Type uniqueVar = do diff --git a/tst/Core/HindleyMilnerSpec.hs b/tst/Core/HindleyMilnerSpec.hs index c50be15..74a2468 100644 --- a/tst/Core/HindleyMilnerSpec.hs +++ b/tst/Core/HindleyMilnerSpec.hs @@ -21,6 +21,17 @@ spec = do let g = [ ("id", ("a" :-> "a") :-> "a" :-> "a") ] in infer g [coreExpr|id 3|] `shouldSatisfy` isUntypedVariableErr + -- TODO: property-based tests for let + it "should infer `let x = 3 in id x` :: Int" $ + let g = [ ("id", "a" :-> "a") ] + e = [coreExpr|let {x = 3} in id x|] + in infer g e `shouldBe` Right TyInt + + it "should infer `let x = 3; y = 2 in (+#) x y` :: Int" $ + let g = [ ("+#", TyInt :-> TyInt :-> TyInt) ] + e = [coreExpr|let {x=3;y=2} in (+#) x y|] + in infer g e `shouldBe` Right TyInt + isUntypedVariableErr :: HMError a -> Bool isUntypedVariableErr (Left (TyErrCouldNotUnify _ _)) = True isUntypedVariableErr _ = False