From 010c719eacd83906377dd4965120e6ec0161390b Mon Sep 17 00:00:00 2001 From: crumbtoo Date: Fri, 15 Mar 2024 13:43:23 -0600 Subject: [PATCH] infer under given context --- src/Rlp/AltSyntax.hs | 2 ++ src/Rlp/HindleyMilner.hs | 67 ++++++++++++++++++++++++++++++++++------ src/Rlp/Lex.x | 2 +- 3 files changed, 60 insertions(+), 11 deletions(-) diff --git a/src/Rlp/AltSyntax.hs b/src/Rlp/AltSyntax.hs index eec5234..81d2025 100644 --- a/src/Rlp/AltSyntax.hs +++ b/src/Rlp/AltSyntax.hs @@ -15,6 +15,7 @@ module Rlp.AltSyntax -- * Optics , programDecls , _VarP, _FunB, _VarB + , _TySigD, _FunD -- * Functor-related tools , Fix(..), Cofree(..), Sum(..), pattern Finl, pattern Finr @@ -186,6 +187,7 @@ instance (Out b) => Out1 (Program b) where makePrisms ''Pat makePrisms ''Binding +makePrisms ''Decl deriving instance (Lift b, Lift a) => Lift (Program b a) deriving instance (Lift b, Lift a) => Lift (Decl b a) diff --git a/src/Rlp/HindleyMilner.hs b/src/Rlp/HindleyMilner.hs index 00403ad..c950ef9 100644 --- a/src/Rlp/HindleyMilner.hs +++ b/src/Rlp/HindleyMilner.hs @@ -20,6 +20,7 @@ import Control.Monad import Control.Arrow ((>>>)) import Control.Monad.Writer.Strict import Data.Text qualified as T +import Data.Function import Data.Pretty hiding (annotate) import Data.Hashable import Data.HashMap.Strict (HashMap) @@ -127,18 +128,44 @@ annotate :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)) annotate = sequenceA . fixtend (gather . wrapFix) -solveTree :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement) - -> HM (Type PsName) -solveTree e = undefined - infer1 :: RlpExpr PsName -> HM (Type PsName) -infer1 e = do +infer1 = infer1' mempty + +infer1' :: Context -> RlpExpr PsName -> HM (Type PsName) +infer1' g1 e = do ((t,j) :< _) <- annotate e - g <- unify (j ^. constraints) - pure $ ifoldrOf (contextVars . itraversed) subst t g + g2 <- unify (j ^. constraints) + g <- unionContextWithKeyM unifyTypes g1 g2 + pure $ ifoldlOf (contextVars . itraversed) subst t g + where + -- intuitively, we'd return mgu(s,t) but the union is left-biased making `s` + -- the user-specified type: prioritise her. + unifyTypes _ s t = unify [Equality s t] $> s + +unionContextWithKeyM :: Monad m + => (PsName -> Type PsName -> Type PsName + -> m (Type PsName)) + -> Context -> Context -> m Context +unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b' + where + a' = a ^. contextVars + b' = b ^. contextVars + +unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m) + => (k -> v -> v -> m v) -> HashMap k v -> HashMap k v + -> m (HashMap k v) +unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb + where + f' k x y = join $ liftA2 (f k) x y + ma = fmap (pure @m) a + mb = fmap (pure @m) b solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) -solve e = sequenceA $ fixtend (infer1 . wrapFix) e +solve = solve' mempty + +solve' :: Context -> RlpExpr PsName + -> HM (Cofree (RlpExprF PsName) (Type PsName)) +solve' g e = sequenceA $ fixtend (infer1' g . wrapFix) e occurs :: PsName -> Type PsName -> Bool occurs n = cata \case @@ -163,13 +190,33 @@ fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b fixtend c (Fix f) = c f :< fmap (fixtend c) f infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) -infer = sequenceA . fixtend (infer1 . wrapFix) +infer = infer' mempty + +infer' :: Context -> RlpExpr PsName + -> HM (Cofree (RlpExprF PsName) (Type PsName)) +infer' g = sequenceA . fixtend (infer1' g . wrapFix) + +buildInitialContext :: Program PsName a -> Context +buildInitialContext = + Context . H.fromList . toListOf (programDecls . each . _TySigD) typeCheckRlpProgR :: (Monad m) => Program PsName (RlpExpr PsName) -> RLPCT m (Program PsName (Cofree (RlpExprF PsName) (Type PsName))) -typeCheckRlpProgR = liftHM . traverse infer +typeCheckRlpProgR p = tc p + where + g = buildInitialContext p + tc = liftHM . traverse (infer' g) . etaExpandAll + etaExpandAll = programDecls . each %~ etaExpand + +etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b) +etaExpand (FunD n as e) + | Right as' <- allVarP as + = FunD n [] (Finl . LamF as' $ e) + where + allVarP = traverse (matching _VarP) +etaExpand a = a liftHM :: (Monad m) => HM a -> RLPCT m a liftHM = liftEither . runHM' diff --git a/src/Rlp/Lex.x b/src/Rlp/Lex.x index c2e316d..7bd4406 100644 --- a/src/Rlp/Lex.x +++ b/src/Rlp/Lex.x @@ -283,7 +283,7 @@ lexStream = fmap extract <$> lexStream' lexStream' :: P [Located RlpToken] lexStream' = lexToken >>= \case t@(Located _ TokenEOF) -> pure [t] - t -> (t:) <$> lexStream' + t -> (t:) <$> lexStream' lexDebug :: (Located RlpToken -> P a) -> P a lexDebug k = do