Files
rlp/src/Rlp/HindleyMilner.hs

225 lines
7.1 KiB
Haskell

{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TemplateHaskell #-}
module Rlp.HindleyMilner
( typeCheckRlpProgR
, solve
, annotate
, TypeError(..)
, runHM'
, HM
)
where
--------------------------------------------------------------------------------
import Control.Lens hiding (Context', Context, (:<), para)
import Control.Lens.Unsound
import Control.Monad.Errorful
import Control.Monad.State
import Control.Monad.Accum
import Control.Monad
import Control.Arrow ((>>>))
import Control.Monad.Writer.Strict
import Data.Text qualified as T
import Data.Function
import Data.Pretty hiding (annotate)
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as H
import Data.HashSet (HashSet)
import Data.HashSet qualified as S
import Data.Maybe (fromMaybe)
import Data.Traversable
import GHC.Generics (Generic(..), Generically(..))
import Data.Functor
import Data.Functor.Foldable
import Data.Fix hiding (cata, para)
import Control.Comonad.Cofree
import Control.Comonad
import Compiler.RLPC
import Compiler.RlpcError
import Rlp.AltSyntax as Rlp
import Core.Syntax qualified as Core
import Core.Syntax (ExprF(..), Lit(..))
import Rlp.HindleyMilner.Types
--------------------------------------------------------------------------------
fixCofree :: (Functor f, Functor g)
=> Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b)
fixCofree = iso sa bt where
sa = foldFix (() :<)
bt (_ :< as) = Fix $ bt <$> as
lookupVar :: PsName -> Context -> HM (Type PsName)
lookupVar n g = case g ^. contextVars . at n of
Just t -> pure t
Nothing -> addFatal $ TyErrUntypedVariable n
gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather e = look >>= (H.lookup e >>> maybe memoise pure)
where
memoise = do
r <- gather' e
add (H.singleton e r)
pure r
gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather' = \case
Finl (LitF (IntL _)) -> pure (IntT, mempty)
Finl (VarF n) -> do
t <- freshTv
let j = mempty & assumptions .~ H.singleton n [t]
pure (t,j)
Finl (AppF f x) -> do
tfx <- freshTv
(tf,jf) <- gather f
(tx,jx) <- gather x
let jtfx = mempty & constraints .~ [Equality tf (tx :-> tfx)]
pure (tfx, jf <> jx <> jtfx)
Finl (LamF bs e) -> do
tbs <- for bs (const freshTv)
(te,je) <- gather e
let cs = concatMap (uncurry . equals $ je ^. assumptions) $ bs `zip` tbs
as = foldr H.delete (je ^. assumptions) bs
j = mempty & constraints .~ cs & assumptions .~ as
t = foldr (:->) te tbs
pure (t,j)
where
equals as b tb = maybe []
(fmap $ Equality tb)
(as ^. at b)
-- Finl (LamF [b] e) -> do
-- tb <- freshTv
-- (te,je) <- gather e
-- let cs = maybe [] (fmap $ Equality tb) (je ^. assumptions . at b)
-- as = je ^. assumptions & at b .~ Nothing
-- j = mempty & constraints .~ cs & assumptions .~ as
-- t = tb :-> te
-- pure (t,j)
unify :: [Constraint] -> HM Context
unify [] = pure mempty
unify (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify $ Equality sx tx : Equality sy ty : cs
-- elim
unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs
unify (Equality (VarT s) (VarT t) : cs) | s == t = unify cs
unify (Equality (VarT s) t : cs)
| occurs s t = addFatal $ TyErrRecursiveType s t
| otherwise = unify cs' <&> contextVars . at s ?~ t
where
cs' = cs & each . constraintTypes %~ subst s t
-- swap
unify (Equality s (VarT t) : cs) = unify (Equality (VarT t) s : cs)
unify (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
annotate :: RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
annotate = sequenceA . fixtend (gather . wrapFix)
infer1 :: RlpExpr PsName -> HM (Type PsName)
infer1 = infer1' mempty
infer1' :: Context -> RlpExpr PsName -> HM (Type PsName)
infer1' g1 e = do
((t,j) :< _) <- annotate e
g2 <- unify (j ^. constraints)
g <- unionContextWithKeyM unifyTypes g1 g2
pure $ ifoldlOf (contextVars . itraversed) subst t g
where
-- intuitively, we'd return mgu(s,t) but the union is left-biased making `s`
-- the user-specified type: prioritise her.
unifyTypes _ s t = unify [Equality s t] $> s
unionContextWithKeyM :: Monad m
=> (PsName -> Type PsName -> Type PsName
-> m (Type PsName))
-> Context -> Context -> m Context
unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b'
where
a' = a ^. contextVars
b' = b ^. contextVars
unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m)
=> (k -> v -> v -> m v) -> HashMap k v -> HashMap k v
-> m (HashMap k v)
unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb
where
f' k x y = join $ liftA2 (f k) x y
ma = fmap (pure @m) a
mb = fmap (pure @m) b
solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
solve = solve' mempty
solve' :: Context -> RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName))
solve' g e = sequenceA $ fixtend (infer1' g . wrapFix) e
occurs :: PsName -> Type PsName -> Bool
occurs n = cata \case
VarTF m | n == m -> True
t -> or t
subst :: PsName -> Type PsName -> Type PsName -> Type PsName
subst n t' = para \case
VarTF m | n == m -> t'
-- shadowing
ForallTF x (pre,post) | x == n -> ForallT x pre
| otherwise -> ForallT x post
t -> embed $ t <&> view _2
prettyHM :: (Out a)
=> Either [TypeError] (a, [Constraint])
-> Either [TypeError] (String, [String])
prettyHM = over (mapped . _1) rout
. over (mapped . _2 . each) rout
fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
fixtend c (Fix f) = c f :< fmap (fixtend c) f
infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
infer = infer' mempty
infer' :: Context -> RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName))
infer' g = sequenceA . fixtend (infer1' g . wrapFix)
buildInitialContext :: Program PsName a -> Context
buildInitialContext =
Context . H.fromList . toListOf (programDecls . each . _TySigD)
typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName)
-> RLPCT m (Program PsName
(Cofree (RlpExprF PsName) (Type PsName)))
typeCheckRlpProgR p = tc p
where
g = buildInitialContext p
tc = liftHM . traverse (solve' g) . etaExpandAll
etaExpandAll = programDecls . each %~ etaExpand
etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b)
etaExpand (FunD n [] e) = FunD n [] e
etaExpand (FunD n as e)
| Right as' <- allVarP as
= FunD n [] (Finl . LamF as' $ e)
where
allVarP = traverse (matching _VarP)
etaExpand a = a
liftHM :: (Monad m) => HM a -> RLPCT m a
liftHM = liftEither . runHM'