Files
rlp/src/Rlp/HindleyMilner.hs
2024-03-18 10:27:06 -06:00

328 lines
11 KiB
Haskell

{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TemplateHaskell #-}
module Rlp.HindleyMilner
( typeCheckRlpProgR
, annotate
, TypeError(..)
, runHM'
, HM
, prettyVars
, prettyVars'
)
where
--------------------------------------------------------------------------------
import Control.Lens hiding (Context', Context, (:<), para)
import Control.Lens.Unsound
import Control.Monad.Errorful
import Control.Monad.State
import Control.Monad.Accum
import Control.Monad
import Control.Arrow ((>>>))
import Control.Monad.Writer.Strict
import Data.List
import Data.Monoid
import Data.Text qualified as T
import Data.Foldable (fold)
import Data.Function
import Data.Pretty hiding (annotate)
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as H
import Data.HashSet (HashSet)
import Data.HashSet qualified as S
import Data.Maybe (fromMaybe)
import Data.Traversable
import GHC.Generics (Generic(..), Generically(..))
import Debug.Trace
import Data.Functor
import Data.Functor.Foldable hiding (fold)
import Data.Fix hiding (cata, para)
import Control.Comonad.Cofree
import Control.Comonad
import Compiler.RLPC
import Compiler.RlpcError
import Rlp.AltSyntax as Rlp
import Core.Syntax qualified as Core
import Core.Syntax (ExprF(..), Lit(..))
import Rlp.HindleyMilner.Types
--------------------------------------------------------------------------------
fixCofree :: (Functor f, Functor g)
=> Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b)
fixCofree = iso sa bt where
sa = foldFix (() :<)
bt (_ :< as) = Fix $ bt <$> as
lookupVar :: PsName -> Context -> HM (Type PsName)
lookupVar n g = case g ^. contextVars . at n of
Just t -> pure t
Nothing -> addFatal $ TyErrUntypedVariable n
gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather e = look >>= (H.lookup e >>> maybe memoise pure)
where
memoise = do
r <- gather' e
add (H.singleton e r)
pure r
gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather' = \case
Finl (LitF (IntL _)) -> pure (IntT, mempty)
Finl (VarF n) -> do
t <- freshTv
let j = mempty & assumptions .~ H.singleton n [t]
pure (t,j)
Finl (AppF f x) -> do
tfx <- freshTv
(tf,jf) <- gather f
(tx,jx) <- gather x
let jtfx = mempty & constraints .~ [Equality tf (tx :-> tfx)]
pure (tfx, jf <> jx <> jtfx)
Finl (LamF bs e) -> do
tbs <- for bs (const freshTv)
(te,je) <- gather e
let cs = concatMap (uncurry . equals $ je ^. assumptions) $ bs `zip` tbs
as = foldr H.delete (je ^. assumptions) bs
j = mempty & constraints .~ (je ^. constraints <> cs)
& assumptions .~ as
t = foldr (:->) te tbs
pure (t,j)
where
equals as b tb = maybe []
(fmap $ Equality tb)
(as ^. at b)
-- Finl (LamF [b] e) -> do
-- tb <- freshTv
-- (te,je) <- gather e
-- let cs = maybe [] (fmap $ Equality tb) (je ^. assumptions . at b)
-- as = je ^. assumptions & at b .~ Nothing
-- j = mempty & constraints .~ cs & assumptions .~ as
-- t = tb :-> te
-- pure (t,j)
unify :: [Constraint] -> HM Context
unify [] = pure mempty
unify (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify $ Equality sx tx : Equality sy ty : cs
-- elim
unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs
unify (Equality (VarT s) (VarT t) : cs) | s == t = unify cs
unify (Equality (VarT s) t : cs)
| occurs s t = addFatal $ TyErrRecursiveType s t
| otherwise = unify cs' <&> contextVars . at s ?~ t
where
cs' = cs & each . constraintTypes %~ subst s t
-- swap
unify (Equality s (VarT t) : cs) = unify (Equality (VarT t) s : cs)
unify (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
unify' :: [Constraint] -> HM [(PsName, Type PsName)]
unify' [] = pure mempty
unify' (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify' $ Equality sx tx : Equality sy ty : cs
-- elim
unify' (Equality (ConT s) (ConT t) : cs) | s == t = unify' cs
unify' (Equality (VarT s) (VarT t) : cs) | s == t = unify' cs
unify' (Equality (VarT s) t : cs)
| occurs s t = addFatal $ TyErrRecursiveType s t
| otherwise = unify' cs' <&> ((s,t):)
where
cs' = cs & each . constraintTypes %~ subst s t
-- swap
unify' (Equality s (VarT t) : cs) = unify' (Equality (VarT t) s : cs)
unify' (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
annotate :: RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
annotate = sequenceA . fixtend (gather . wrapFix)
-- infer1 :: RlpExpr PsName -> HM (Type PsName)
-- infer1 = infer1' mempty
-- infer1' :: Context -> RlpExpr PsName -> HM (Type PsName)
-- infer1' g1 e = do
-- ((t,j) :< _) <- annotate e
-- g2 <- unify (j ^. constraints)
-- g <- unionContextWithKeyM unifyTypes g1 g2
-- pure $ ifoldrOf (contextVars . itraversed) subst t g
-- where
-- -- intuitively, we'd return mgu(s,t) but the union is left-biased making `s`
-- -- the user-specified type: prioritise her.
-- unifyTypes _ s t = unify [Equality s t] $> s
assocs :: IndexedTraversal k [(k,v)] [(k,v')] v v'
assocs f [] = pure []
assocs f ((k,v):xs) = (\v' xs' -> (k,v') : xs')
<$> indexed f k v <*> assocs f xs
traceSubst k v t = trace ("subst " <> show' k <> " " <> show' v <> " " <> show' t)
$ subst k v t
where show' a = showsPrec 11 a mempty
infer :: Context -> RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName))
infer g1 e = do
e' <- annotate e
g2 <- unify' $ concatOf (folded . _2 . constraints) e'
traceM $ "e': " <> show (view _1 <$> e')
traceM $ "g2: " <> show g2
let sub t = ifoldrOf (reversed . assocs) traceSubst t g2
pure $ sub . view _1 <$> e'
where
-- intuitively, we'd return mgu(s,t) but the union is left-biased making `s`
-- the user-specified type: prioritise her.
unifyTypes _ s t = unify [Equality s t] $> s
e :: Cofree (RlpExprF PsName) (Type PsName)
e = AppT (AppT FunT (VarT "$a2")) (AppT (AppT FunT (VarT "$a3")) (VarT "$a4")) :< InL (LamF ["f","x"] (VarT "$a4" :< InL (AppF (VarT "$a5" :< InL (VarF "f")) (VarT "$a6" :< InL (AppF (VarT "$a5" :< InL (VarF "f")) (VarT "$a1" :< InL (VarF "x")))))))
g = Context
{ _contextVars = H.fromList
[("$a1",VarT "$a6")
,("$a3",VarT "$a4")
,("$a2",AppT (AppT FunT (VarT "$a4")) (VarT "$a4"))
,("$a5",AppT (AppT FunT (VarT "$a1")) (VarT "$a6"))
,("$a6",VarT "$a4")]}
unionContextWithKeyM :: Monad m
=> (PsName -> Type PsName -> Type PsName
-> m (Type PsName))
-> Context -> Context -> m Context
unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b'
where
a' = a ^. contextVars
b' = b ^. contextVars
unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m)
=> (k -> v -> v -> m v) -> HashMap k v -> HashMap k v
-> m (HashMap k v)
unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb
where
f' k x y = join $ liftA2 (f k) x y
ma = fmap (pure @m) a
mb = fmap (pure @m) b
-- solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
-- solve = solve' mempty
-- solve' :: Context -> RlpExpr PsName
-- -> HM (Cofree (RlpExprF PsName) (Type PsName))
-- solve' g = sequenceA . fixtend (infer1' g . wrapFix)
occurs :: PsName -> Type PsName -> Bool
occurs n = cata \case
VarTF m | n == m -> True
t -> or t
subst :: PsName -> Type PsName -> Type PsName -> Type PsName
subst n t' = para \case
VarTF m | n == m -> t'
-- shadowing
ForallTF x (pre,post) | x == n -> ForallT x pre
t -> embed $ t <&> view _2
prettyHM :: (Out a)
=> Either [TypeError] (a, [Constraint])
-> Either [TypeError] (String, [String])
prettyHM = over (mapped . _1) rout
. over (mapped . _2 . each) rout
fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
fixtend c (Fix f) = c f :< fmap (fixtend c) f
-- infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
-- infer = infer' mempty
-- infer' :: Context -> RlpExpr PsName
-- -> HM (Cofree (RlpExprF PsName) (Type PsName))
-- infer' g = sequenceA . fixtend (infer1' g . wrapFix)
buildInitialContext :: Program PsName a -> Context
buildInitialContext =
Context . H.fromList . toListOf (programDecls . each . _TySigD)
typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName)
-> RLPCT m (Program PsName
(Cofree (RlpExprF PsName) (Type PsName)))
typeCheckRlpProgR p = tc p
where
g = buildInitialContext p
tc = liftHM . traverse (infer g) . etaExpandAll
etaExpandAll = programDecls . each %~ etaExpand
etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b)
etaExpand (FunD n [] e) = FunD n [] e
etaExpand (FunD n as e)
| Right as' <- allVarP as
= FunD n [] (Finl . LamF as' $ e)
where
allVarP = traverse (matching _VarP)
etaExpand a = a
liftHM :: (Monad m) => HM a -> RLPCT m a
liftHM = liftEither . runHM'
freeVariables :: Type PsName -> HashSet PsName
freeVariables = cata \case
VarTF x -> S.singleton x
ForallTF x m -> m `S.difference` S.singleton x
vs -> fold vs
boundVariables :: Type PsName -> HashSet PsName
boundVariables = cata \case
ForallTF x m -> S.singleton x <> m
vs -> fold vs
-- | rename all free variables for aesthetic purposes
prettyVars' :: Type PsName -> Type PsName
prettyVars' = join prettyVars
freeVariablesLTR :: Type PsName -> [PsName]
freeVariablesLTR = nub . cata \case
VarTF x -> [x]
ForallTF x m -> m \\ [x]
vs -> concat vs
-- | for some type, compute a substitution which will rename all free variables
-- for aesthetic purposes
prettyVars :: Type PsName -> Type PsName -> Type PsName
prettyVars root = appEndo (foldMap Endo subs)
where
alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ]
names = alphabetNames \\ S.toList (boundVariables root)
subs = zipWith (\k v -> subst k (VarT v))
(freeVariablesLTR root)
names
-- test :: Type PsName -> [(PsName, PsName)]
-- test root = subs
-- where
-- alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ]
-- names = alphabetNames \\ S.toList (boundVariables root)
-- subs = zip (freeVariablesLTR root) names