infer under given context

This commit is contained in:
crumbtoo
2024-03-15 13:43:23 -06:00
parent 932fed8e5c
commit fcd784441a
3 changed files with 60 additions and 11 deletions

View File

@@ -15,6 +15,7 @@ module Rlp.AltSyntax
-- * Optics
, programDecls
, _VarP, _FunB, _VarB
, _TySigD, _FunD
-- * Functor-related tools
, Fix(..), Cofree(..), Sum(..), pattern Finl, pattern Finr
@@ -186,6 +187,7 @@ instance (Out b) => Out1 (Program b) where
makePrisms ''Pat
makePrisms ''Binding
makePrisms ''Decl
deriving instance (Lift b, Lift a) => Lift (Program b a)
deriving instance (Lift b, Lift a) => Lift (Decl b a)

View File

@@ -20,6 +20,7 @@ import Control.Monad
import Control.Arrow ((>>>))
import Control.Monad.Writer.Strict
import Data.Text qualified as T
import Data.Function
import Data.Pretty hiding (annotate)
import Data.Hashable
import Data.HashMap.Strict (HashMap)
@@ -127,18 +128,44 @@ annotate :: RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
annotate = sequenceA . fixtend (gather . wrapFix)
solveTree :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)
-> HM (Type PsName)
solveTree e = undefined
infer1 :: RlpExpr PsName -> HM (Type PsName)
infer1 e = do
infer1 = infer1' mempty
infer1' :: Context -> RlpExpr PsName -> HM (Type PsName)
infer1' g1 e = do
((t,j) :< _) <- annotate e
g <- unify (j ^. constraints)
pure $ ifoldrOf (contextVars . itraversed) subst t g
g2 <- unify (j ^. constraints)
g <- unionContextWithKeyM unifyTypes g1 g2
pure $ ifoldlOf (contextVars . itraversed) subst t g
where
-- intuitively, we'd return mgu(s,t) but the union is left-biased making `s`
-- the user-specified type: prioritise her.
unifyTypes _ s t = unify [Equality s t] $> s
unionContextWithKeyM :: Monad m
=> (PsName -> Type PsName -> Type PsName
-> m (Type PsName))
-> Context -> Context -> m Context
unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b'
where
a' = a ^. contextVars
b' = b ^. contextVars
unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m)
=> (k -> v -> v -> m v) -> HashMap k v -> HashMap k v
-> m (HashMap k v)
unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb
where
f' k x y = join $ liftA2 (f k) x y
ma = fmap (pure @m) a
mb = fmap (pure @m) b
solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
solve e = sequenceA $ fixtend (infer1 . wrapFix) e
solve = solve' mempty
solve' :: Context -> RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName))
solve' g e = sequenceA $ fixtend (infer1' g . wrapFix) e
occurs :: PsName -> Type PsName -> Bool
occurs n = cata \case
@@ -163,13 +190,33 @@ fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
fixtend c (Fix f) = c f :< fmap (fixtend c) f
infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
infer = sequenceA . fixtend (infer1 . wrapFix)
infer = infer' mempty
infer' :: Context -> RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName))
infer' g = sequenceA . fixtend (infer1' g . wrapFix)
buildInitialContext :: Program PsName a -> Context
buildInitialContext =
Context . H.fromList . toListOf (programDecls . each . _TySigD)
typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName)
-> RLPCT m (Program PsName
(Cofree (RlpExprF PsName) (Type PsName)))
typeCheckRlpProgR = liftHM . traverse infer
typeCheckRlpProgR p = tc p
where
g = buildInitialContext p
tc = liftHM . traverse (infer' g) . etaExpandAll
etaExpandAll = programDecls . each %~ etaExpand
etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b)
etaExpand (FunD n as e)
| Right as' <- allVarP as
= FunD n [] (Finl . LamF as' $ e)
where
allVarP = traverse (matching _VarP)
etaExpand a = a
liftHM :: (Monad m) => HM a -> RLPCT m a
liftHM = liftEither . runHM'

View File

@@ -283,7 +283,7 @@ lexStream = fmap extract <$> lexStream'
lexStream' :: P [Located RlpToken]
lexStream' = lexToken >>= \case
t@(Located _ TokenEOF) -> pure [t]
t -> (t:) <$> lexStream'
t -> (t:) <$> lexStream'
lexDebug :: (Located RlpToken -> P a) -> P a
lexDebug k = do