infer under given context
This commit is contained in:
@@ -15,6 +15,7 @@ module Rlp.AltSyntax
|
||||
-- * Optics
|
||||
, programDecls
|
||||
, _VarP, _FunB, _VarB
|
||||
, _TySigD, _FunD
|
||||
|
||||
-- * Functor-related tools
|
||||
, Fix(..), Cofree(..), Sum(..), pattern Finl, pattern Finr
|
||||
@@ -186,6 +187,7 @@ instance (Out b) => Out1 (Program b) where
|
||||
|
||||
makePrisms ''Pat
|
||||
makePrisms ''Binding
|
||||
makePrisms ''Decl
|
||||
|
||||
deriving instance (Lift b, Lift a) => Lift (Program b a)
|
||||
deriving instance (Lift b, Lift a) => Lift (Decl b a)
|
||||
|
||||
@@ -20,6 +20,7 @@ import Control.Monad
|
||||
import Control.Arrow ((>>>))
|
||||
import Control.Monad.Writer.Strict
|
||||
import Data.Text qualified as T
|
||||
import Data.Function
|
||||
import Data.Pretty hiding (annotate)
|
||||
import Data.Hashable
|
||||
import Data.HashMap.Strict (HashMap)
|
||||
@@ -127,18 +128,44 @@ annotate :: RlpExpr PsName
|
||||
-> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
|
||||
annotate = sequenceA . fixtend (gather . wrapFix)
|
||||
|
||||
solveTree :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)
|
||||
-> HM (Type PsName)
|
||||
solveTree e = undefined
|
||||
|
||||
infer1 :: RlpExpr PsName -> HM (Type PsName)
|
||||
infer1 e = do
|
||||
infer1 = infer1' mempty
|
||||
|
||||
infer1' :: Context -> RlpExpr PsName -> HM (Type PsName)
|
||||
infer1' g1 e = do
|
||||
((t,j) :< _) <- annotate e
|
||||
g <- unify (j ^. constraints)
|
||||
pure $ ifoldrOf (contextVars . itraversed) subst t g
|
||||
g2 <- unify (j ^. constraints)
|
||||
g <- unionContextWithKeyM unifyTypes g1 g2
|
||||
pure $ ifoldlOf (contextVars . itraversed) subst t g
|
||||
where
|
||||
-- intuitively, we'd return mgu(s,t) but the union is left-biased making `s`
|
||||
-- the user-specified type: prioritise her.
|
||||
unifyTypes _ s t = unify [Equality s t] $> s
|
||||
|
||||
unionContextWithKeyM :: Monad m
|
||||
=> (PsName -> Type PsName -> Type PsName
|
||||
-> m (Type PsName))
|
||||
-> Context -> Context -> m Context
|
||||
unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b'
|
||||
where
|
||||
a' = a ^. contextVars
|
||||
b' = b ^. contextVars
|
||||
|
||||
unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m)
|
||||
=> (k -> v -> v -> m v) -> HashMap k v -> HashMap k v
|
||||
-> m (HashMap k v)
|
||||
unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb
|
||||
where
|
||||
f' k x y = join $ liftA2 (f k) x y
|
||||
ma = fmap (pure @m) a
|
||||
mb = fmap (pure @m) b
|
||||
|
||||
solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
solve e = sequenceA $ fixtend (infer1 . wrapFix) e
|
||||
solve = solve' mempty
|
||||
|
||||
solve' :: Context -> RlpExpr PsName
|
||||
-> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
solve' g e = sequenceA $ fixtend (infer1' g . wrapFix) e
|
||||
|
||||
occurs :: PsName -> Type PsName -> Bool
|
||||
occurs n = cata \case
|
||||
@@ -163,13 +190,33 @@ fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
|
||||
fixtend c (Fix f) = c f :< fmap (fixtend c) f
|
||||
|
||||
infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
infer = sequenceA . fixtend (infer1 . wrapFix)
|
||||
infer = infer' mempty
|
||||
|
||||
infer' :: Context -> RlpExpr PsName
|
||||
-> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
infer' g = sequenceA . fixtend (infer1' g . wrapFix)
|
||||
|
||||
buildInitialContext :: Program PsName a -> Context
|
||||
buildInitialContext =
|
||||
Context . H.fromList . toListOf (programDecls . each . _TySigD)
|
||||
|
||||
typeCheckRlpProgR :: (Monad m)
|
||||
=> Program PsName (RlpExpr PsName)
|
||||
-> RLPCT m (Program PsName
|
||||
(Cofree (RlpExprF PsName) (Type PsName)))
|
||||
typeCheckRlpProgR = liftHM . traverse infer
|
||||
typeCheckRlpProgR p = tc p
|
||||
where
|
||||
g = buildInitialContext p
|
||||
tc = liftHM . traverse (infer' g) . etaExpandAll
|
||||
etaExpandAll = programDecls . each %~ etaExpand
|
||||
|
||||
etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b)
|
||||
etaExpand (FunD n as e)
|
||||
| Right as' <- allVarP as
|
||||
= FunD n [] (Finl . LamF as' $ e)
|
||||
where
|
||||
allVarP = traverse (matching _VarP)
|
||||
etaExpand a = a
|
||||
|
||||
liftHM :: (Monad m) => HM a -> RLPCT m a
|
||||
liftHM = liftEither . runHM'
|
||||
|
||||
@@ -283,7 +283,7 @@ lexStream = fmap extract <$> lexStream'
|
||||
lexStream' :: P [Located RlpToken]
|
||||
lexStream' = lexToken >>= \case
|
||||
t@(Located _ TokenEOF) -> pure [t]
|
||||
t -> (t:) <$> lexStream'
|
||||
t -> (t:) <$> lexStream'
|
||||
|
||||
lexDebug :: (Located RlpToken -> P a) -> P a
|
||||
lexDebug k = do
|
||||
|
||||
Reference in New Issue
Block a user