seems to work

This commit is contained in:
crumbtoo
2024-03-13 18:10:29 -06:00
parent e00e0eff3b
commit 8fd75a67d3
5 changed files with 93 additions and 49 deletions

View File

@@ -13,7 +13,9 @@ module Rlp.HindleyMilner
import Control.Lens hiding (Context', Context, (:<), para)
import Control.Monad.Errorful
import Control.Monad.State
import Control.Monad.Accum
import Control.Monad
import Control.Arrow ((>>>))
import Control.Monad.Writer.Strict
import Data.Text qualified as T
import Data.Pretty
@@ -40,10 +42,6 @@ import Core.Syntax (ExprF(..), Lit(..))
import Rlp.HindleyMilner.Types
--------------------------------------------------------------------------------
-- | Synonym for @Errorful [TypeError]@. This means an @HMError@ action may
-- throw any number of fatal or nonfatal errors. Run with @runErrorful@.
type HMError = Errorful TypeError
fixCofree :: (Functor f, Functor g)
=> Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b)
fixCofree = iso sa bt where
@@ -55,53 +53,80 @@ lookupVar n g = case g ^. contextVars . at n of
Just t -> pure t
Nothing -> addFatal $ TyErrUntypedVariable n
gather :: Context -> RlpExpr PsName -> HM (Type PsName)
gather g = \case
gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather e = look >>= (H.lookup e >>> maybe memoise pure)
where
memoise = do
r <- gather' e
add (H.singleton e r)
pure r
Finl (LitF (IntL _)) -> pure IntT
gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather' = \case
Finl (LitF (IntL _)) -> pure (IntT, mempty)
Finl (VarF n) -> do
t <- freshTv
let j = mempty & assumptions .~ H.singleton n [t]
pure (t,j)
Finl (AppF f x) -> do
tf <- gather g f
tx <- gather g x
tfx <- freshTv
addConstraint $ Equality tf (tx :-> tfx)
pure tfx
(tf,jf) <- gather f
(tx,jx) <- gather x
let jtfx = mempty & constraints .~ [Equality tf (tx :-> tfx)]
pure (tfx, jf <> jx <> jtfx)
Finl (VarF n) -> lookupVar n g
Finl (LamF [b] e) -> do
tb <- freshTv
(te,je) <- gather e
let cs = maybe [] (fmap $ Equality tb) (je ^. assumptions . at b)
as = je ^. assumptions & at b .~ Nothing
j = mempty & constraints .~ cs & assumptions .~ as
t = tb :-> te
pure (t,j)
Finl (LamF bs e) -> do
tbs <- for bs $ \b -> (b,) <$> freshTv
te <- gather (supplement tbs g) e
pure $ foldrOf (each . _2) (:->) te tbs
unify :: [Constraint] -> HM Context
unify :: Context -> [Constraint] -> HM Context
unify [] = pure mempty
unify g [] = pure g
unify g (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify g $ Equality sx tx : Equality sy ty : cs
unify (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify $ Equality sx tx : Equality sy ty : cs
-- elim
unify g (Equality (ConT s) (ConT t) : cs) | s == t = unify g cs
unify g (Equality (VarT s) (VarT t) : cs) | s == t = unify g cs
unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs
unify (Equality (VarT s) (VarT t) : cs) | s == t = unify cs
unify g (Equality (VarT s) t : cs)
unify (Equality (VarT s) t : cs)
| occurs s t = addFatal $ TyErrRecursiveType s t
| otherwise = unify g' cs'
| otherwise = unify cs' <&> contextVars . at s ?~ t
where
g' = supplement [(s,t)] g
cs' = cs & each . constraintTypes %~ subst s t
-- swap
unify g (Equality s (VarT t) : cs) = unify g (Equality (VarT t) s : cs)
unify (Equality s (VarT t) : cs) = unify (Equality (VarT t) s : cs)
unify _ (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
unify (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
solve :: Context -> RlpExpr PsName -> HM (Type PsName)
solve g e = do
(t,cs) <- listen $ gather g e
g' <- unify g cs
pure $ ifoldrOf (contextVars . itraversed) subst t g'
annotate :: RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
annotate = sequenceA . fixtend (gather . wrapFix)
solveTree :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)
-> HM (Type PsName)
solveTree e = undefined
infer1 :: RlpExpr PsName -> HM (Type PsName)
infer1 e = do
((t,j) :< _) <- annotate e
g <- unify (j ^. constraints)
pure $ ifoldrOf (contextVars . itraversed) subst t g
solve = undefined
-- solve g e = do
-- (t,j) <- gather e
-- g' <- unify cs
-- pure $ ifoldrOf (contextVars . itraversed) subst t g'
occurs :: PsName -> Type PsName -> Bool
occurs n = cata \case
@@ -122,11 +147,11 @@ prettyHM :: (Pretty a)
prettyHM = over (mapped . _1) rpretty
. over (mapped . _2 . each) rpretty
fixtend :: (f (Fix f) -> b) -> Fix f -> Cofree f b
fixtend = undefined
fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
fixtend c (Fix f) = c f :< fmap (fixtend c) f
infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
infer = _ . fixtend (solve _ . wrapFix)
infer = undefined
typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName)