diff --git a/rlp.cabal b/rlp.cabal index 34a5c00..5880571 100644 --- a/rlp.cabal +++ b/rlp.cabal @@ -92,6 +92,7 @@ test-suite rlp-test , rlp , QuickCheck , hspec ==2.* + , microlens other-modules: Arith , GMSpec , Core.HindleyMilnerSpec diff --git a/src/Core/HindleyMilner.hs b/src/Core/HindleyMilner.hs index 18bd346..4aa6c77 100644 --- a/src/Core/HindleyMilner.hs +++ b/src/Core/HindleyMilner.hs @@ -4,8 +4,9 @@ Description : Hindley-Milner inference -} {-# LANGUAGE LambdaCase #-} module Core.HindleyMilner - ( infer - , Context' + ( Context' + , infer + , check , TypeError(..) , HMError ) @@ -15,7 +16,8 @@ import Lens.Micro import Lens.Micro.Mtl import Data.Maybe (fromMaybe) import Data.Text qualified as T -import Control.Monad (foldM) +import Data.HashMap.Strict qualified as H +import Control.Monad (foldM, void) import Control.Monad.State import Control.Monad.Utils (mapAccumLM) import Core.Syntax @@ -43,6 +45,26 @@ data TypeError -- | Synonym for @Either TypeError@ type HMError = Either TypeError +-- | Assert that an expression unifies with a given type +-- +-- >>> let e = [coreProg|3|] +-- >>> check [] (TyCon "Bool") e +-- Left (TyErrCouldNotUnify (TyCon "Bool") (TyCon "Int#")) +-- >>> check [] (TyCon "Int#") e +-- Right () + +check :: Context' -> Type -> Expr' -> HMError () +check g t1 e = do + t2 <- infer g e + unify [(t1,t2)] + pure () + +checkProg :: Program' -> HMError () +checkProg p = p ^. programScDefs + & traversalOf k + where + k sc = undefined + -- | Infer the type of an expression under some context. -- -- >>> let g1 = [("id", TyVar "a" :-> TyVar "a")] @@ -55,6 +77,7 @@ type HMError = Either TypeError infer :: Context' -> Expr' -> HMError Type infer g e = do (t,cs) <- gather g e + -- apply all unified constraints foldr (uncurry subst) t <$> unify cs -- | A @Constraint@ between two types describes the requirement that the pair @@ -89,6 +112,7 @@ gather = \g e -> runStateT (go g e) ([],0) <&> \ (t,(cs,_)) -> (t,cs) where Let NonRec bs e -> do g' <- buildLetContext g bs go g' e + -- TODO letrec, lambda, case buildLetContext :: Context' -> [Binding'] -> StateT ([Constraint], Int) HMError Context' @@ -149,8 +173,17 @@ unify = go mempty where | x == y = True occurs _ = False +buildInitialContext :: Program b -> Context b +buildInitialContext p = p ^. programTypeSigs + & H.toList + -- | The expression @subst x t e@ substitutes all occurences of @x@ in @e@ with -- @t@ +-- +-- >>> subst "a" (TyCon "Int") (TyVar "a") +-- TyCon "Int" +-- >>> subst "a" (TyCon "Int") (TyVar "a" :-> TyVar "a") +-- TyCon "Int" :-> TyCon "Int" subst :: Name -> Type -> Type -> Type subst x t (TyVar y) | x == y = t diff --git a/src/Core/Lex.x b/src/Core/Lex.x index 341b51b..9fb9d31 100644 --- a/src/Core/Lex.x +++ b/src/Core/Lex.x @@ -87,6 +87,7 @@ rlp :- "where" { constTok TokenWhere } "Pack" { constTok TokenPack } -- temp + -- TODO: this should be "\" "\\" { constTok TokenLambda } "λ" { constTok TokenLambda } "=" { constTok TokenEquals } diff --git a/src/Core/Syntax.hs b/src/Core/Syntax.hs index ddc3b66..fb9b720 100644 --- a/src/Core/Syntax.hs +++ b/src/Core/Syntax.hs @@ -24,6 +24,7 @@ module Core.Syntax , Module(..) , Program(..) , Program' + , unliftScDef , programScDefs , programTypeSigs , Expr' @@ -37,13 +38,13 @@ module Core.Syntax ---------------------------------------------------------------------------------- import Data.Coerce import Data.Pretty -import GHC.Generics import Data.List (intersperse) import Data.Function ((&)) import Data.String import Data.HashMap.Strict qualified as H import Data.Hashable import Data.Text qualified as T +import Data.Char -- Lift instances for the Core quasiquoters import Language.Haskell.TH.Syntax (Lift) import Lens.Micro.TH (makeLenses) @@ -116,6 +117,9 @@ type Tag = Int data ScDef b = ScDef b [b] (Expr b) deriving (Show, Lift) +unliftScDef :: ScDef b -> Expr b +unliftScDef (ScDef _ as e) = Lam as e + data Module b = Module (Maybe (Name, [Name])) (Program b) deriving (Show, Lift) @@ -138,7 +142,11 @@ instance IsString (Expr b) where fromString = Var . fromString instance IsString Type where - fromString = TyVar . fromString + fromString "" = error "IsString Type string may not be empty" + fromString s + | isUpper c = TyCon . fromString $ s + | otherwise = TyVar . fromString $ s + where (c:_) = s instance (Hashable b) => Semigroup (Program b) where (<>) = undefined diff --git a/tst/Arith.hs b/tst/Arith.hs index 700849b..2c168c4 100644 --- a/tst/Arith.hs +++ b/tst/Arith.hs @@ -6,6 +6,7 @@ module Arith ) where ---------------------------------------------------------------------------------- import Data.Functor.Classes (eq1) +import Lens.Micro import Core.Syntax import GM import Test.QuickCheck @@ -70,7 +71,7 @@ instance Arbitrary ArithExpr where -- coreResult = evalCore (toCore e) toCore :: ArithExpr -> Program' -toCore expr = Program +toCore expr = mempty & programScDefs .~ [ ScDef "id" ["x"] $ Var "x" , ScDef "main" [] $ go expr ] diff --git a/tst/Core/HindleyMilnerSpec.hs b/tst/Core/HindleyMilnerSpec.hs index 74a2468..07940e6 100644 --- a/tst/Core/HindleyMilnerSpec.hs +++ b/tst/Core/HindleyMilnerSpec.hs @@ -6,7 +6,8 @@ module Core.HindleyMilnerSpec ---------------------------------------------------------------------------------- import Core.Syntax import Core.TH (coreExpr) -import Core.HindleyMilner (infer, TypeError(..), HMError) +import Core.HindleyMilner (infer, check, TypeError(..), HMError) +import Data.Either (isLeft) import Test.Hspec ---------------------------------------------------------------------------------- @@ -19,7 +20,7 @@ spec = do it "should not infer `id 3` when `id` is specialised to `a -> a`" $ let g = [ ("id", ("a" :-> "a") :-> "a" :-> "a") ] - in infer g [coreExpr|id 3|] `shouldSatisfy` isUntypedVariableErr + in infer g [coreExpr|id 3|] `shouldSatisfy` isLeft -- TODO: property-based tests for let it "should infer `let x = 3 in id x` :: Int" $ @@ -31,8 +32,8 @@ spec = do let g = [ ("+#", TyInt :-> TyInt :-> TyInt) ] e = [coreExpr|let {x=3;y=2} in (+#) x y|] in infer g e `shouldBe` Right TyInt - -isUntypedVariableErr :: HMError a -> Bool -isUntypedVariableErr (Left (TyErrCouldNotUnify _ _)) = True -isUntypedVariableErr _ = False + + it "should find `3 :: Bool` contradictory" $ + let e = [coreExpr|3|] + in check [] (TyCon "Bool") e `shouldSatisfy` isLeft