type-checker and working visualiser
This commit is contained in:
@@ -1,13 +1,15 @@
|
||||
{-# LANGUAGE ParallelListComp #-}
|
||||
{-# LANGUAGE PartialTypeSignatures #-}
|
||||
{-# LANGUAGE OverloadedLists #-}
|
||||
{-# LANGUAGE TemplateHaskell #-}
|
||||
module Rlp.HindleyMilner
|
||||
( typeCheckRlpProgR
|
||||
, solve
|
||||
, annotate
|
||||
, TypeError(..)
|
||||
, runHM'
|
||||
, HM
|
||||
, prettyVars
|
||||
, prettyVars'
|
||||
)
|
||||
where
|
||||
--------------------------------------------------------------------------------
|
||||
@@ -19,7 +21,10 @@ import Control.Monad.Accum
|
||||
import Control.Monad
|
||||
import Control.Arrow ((>>>))
|
||||
import Control.Monad.Writer.Strict
|
||||
import Data.List
|
||||
import Data.Monoid
|
||||
import Data.Text qualified as T
|
||||
import Data.Foldable (fold)
|
||||
import Data.Function
|
||||
import Data.Pretty hiding (annotate)
|
||||
import Data.Hashable
|
||||
@@ -30,9 +35,10 @@ import Data.HashSet qualified as S
|
||||
import Data.Maybe (fromMaybe)
|
||||
import Data.Traversable
|
||||
import GHC.Generics (Generic(..), Generically(..))
|
||||
import Debug.Trace
|
||||
|
||||
import Data.Functor
|
||||
import Data.Functor.Foldable
|
||||
import Data.Functor.Foldable hiding (fold)
|
||||
import Data.Fix hiding (cata, para)
|
||||
import Control.Comonad.Cofree
|
||||
import Control.Comonad
|
||||
@@ -125,24 +131,80 @@ unify (Equality s (VarT t) : cs) = unify (Equality (VarT t) s : cs)
|
||||
|
||||
unify (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
|
||||
|
||||
unify' :: [Constraint] -> HM [(PsName, Type PsName)]
|
||||
|
||||
unify' [] = pure mempty
|
||||
|
||||
unify' (Equality (sx :-> sy) (tx :-> ty) : cs) =
|
||||
unify' $ Equality sx tx : Equality sy ty : cs
|
||||
|
||||
-- elim
|
||||
unify' (Equality (ConT s) (ConT t) : cs) | s == t = unify' cs
|
||||
unify' (Equality (VarT s) (VarT t) : cs) | s == t = unify' cs
|
||||
|
||||
unify' (Equality (VarT s) t : cs)
|
||||
| occurs s t = addFatal $ TyErrRecursiveType s t
|
||||
| otherwise = unify' cs' <&> ((s,t):)
|
||||
where
|
||||
cs' = cs & each . constraintTypes %~ subst s t
|
||||
|
||||
-- swap
|
||||
unify' (Equality s (VarT t) : cs) = unify' (Equality (VarT t) s : cs)
|
||||
|
||||
unify' (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
|
||||
|
||||
annotate :: RlpExpr PsName
|
||||
-> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
|
||||
annotate = sequenceA . fixtend (gather . wrapFix)
|
||||
|
||||
infer1 :: RlpExpr PsName -> HM (Type PsName)
|
||||
infer1 = infer1' mempty
|
||||
-- infer1 :: RlpExpr PsName -> HM (Type PsName)
|
||||
-- infer1 = infer1' mempty
|
||||
|
||||
infer1' :: Context -> RlpExpr PsName -> HM (Type PsName)
|
||||
infer1' g1 e = do
|
||||
((t,j) :< _) <- annotate e
|
||||
g2 <- unify (j ^. constraints)
|
||||
g <- unionContextWithKeyM unifyTypes g1 g2
|
||||
pure $ ifoldrOf (contextVars . itraversed) subst t g
|
||||
-- infer1' :: Context -> RlpExpr PsName -> HM (Type PsName)
|
||||
-- infer1' g1 e = do
|
||||
-- ((t,j) :< _) <- annotate e
|
||||
-- g2 <- unify (j ^. constraints)
|
||||
-- g <- unionContextWithKeyM unifyTypes g1 g2
|
||||
-- pure $ ifoldrOf (contextVars . itraversed) subst t g
|
||||
-- where
|
||||
-- -- intuitively, we'd return mgu(s,t) but the union is left-biased making `s`
|
||||
-- -- the user-specified type: prioritise her.
|
||||
-- unifyTypes _ s t = unify [Equality s t] $> s
|
||||
|
||||
assocs :: IndexedTraversal k [(k,v)] [(k,v')] v v'
|
||||
assocs f [] = pure []
|
||||
assocs f ((k,v):xs) = (\v' xs' -> (k,v') : xs')
|
||||
<$> indexed f k v <*> assocs f xs
|
||||
|
||||
traceSubst k v t = trace ("subst " <> show' k <> " " <> show' v <> " " <> show' t)
|
||||
$ subst k v t
|
||||
where show' a = showsPrec 11 a mempty
|
||||
|
||||
infer :: Context -> RlpExpr PsName
|
||||
-> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
infer g1 e = do
|
||||
e' <- annotate e
|
||||
g2 <- unify' $ concatOf (folded . _2 . constraints) e'
|
||||
traceM $ "e': " <> show (view _1 <$> e')
|
||||
traceM $ "g2: " <> show g2
|
||||
let sub t = ifoldrOf (reversed . assocs) traceSubst t g2
|
||||
pure $ sub . view _1 <$> e'
|
||||
where
|
||||
-- intuitively, we'd return mgu(s,t) but the union is left-biased making `s`
|
||||
-- the user-specified type: prioritise her.
|
||||
unifyTypes _ s t = unify [Equality s t] $> s
|
||||
|
||||
e :: Cofree (RlpExprF PsName) (Type PsName)
|
||||
e = AppT (AppT FunT (VarT "$a2")) (AppT (AppT FunT (VarT "$a3")) (VarT "$a4")) :< InL (LamF ["f","x"] (VarT "$a4" :< InL (AppF (VarT "$a5" :< InL (VarF "f")) (VarT "$a6" :< InL (AppF (VarT "$a5" :< InL (VarF "f")) (VarT "$a1" :< InL (VarF "x")))))))
|
||||
|
||||
g = Context
|
||||
{ _contextVars = H.fromList
|
||||
[("$a1",VarT "$a6")
|
||||
,("$a3",VarT "$a4")
|
||||
,("$a2",AppT (AppT FunT (VarT "$a4")) (VarT "$a4"))
|
||||
,("$a5",AppT (AppT FunT (VarT "$a1")) (VarT "$a6"))
|
||||
,("$a6",VarT "$a4")]}
|
||||
|
||||
unionContextWithKeyM :: Monad m
|
||||
=> (PsName -> Type PsName -> Type PsName
|
||||
-> m (Type PsName))
|
||||
@@ -161,12 +223,12 @@ unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb
|
||||
ma = fmap (pure @m) a
|
||||
mb = fmap (pure @m) b
|
||||
|
||||
solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
solve = solve' mempty
|
||||
-- solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
-- solve = solve' mempty
|
||||
|
||||
solve' :: Context -> RlpExpr PsName
|
||||
-> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
solve' g e = sequenceA $ fixtend (infer1' g . wrapFix) e
|
||||
-- solve' :: Context -> RlpExpr PsName
|
||||
-- -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
-- solve' g = sequenceA . fixtend (infer1' g . wrapFix)
|
||||
|
||||
occurs :: PsName -> Type PsName -> Bool
|
||||
occurs n = cata \case
|
||||
@@ -178,7 +240,6 @@ subst n t' = para \case
|
||||
VarTF m | n == m -> t'
|
||||
-- shadowing
|
||||
ForallTF x (pre,post) | x == n -> ForallT x pre
|
||||
| otherwise -> ForallT x post
|
||||
t -> embed $ t <&> view _2
|
||||
|
||||
prettyHM :: (Out a)
|
||||
@@ -190,12 +251,12 @@ prettyHM = over (mapped . _1) rout
|
||||
fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
|
||||
fixtend c (Fix f) = c f :< fmap (fixtend c) f
|
||||
|
||||
infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
infer = infer' mempty
|
||||
-- infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
-- infer = infer' mempty
|
||||
|
||||
infer' :: Context -> RlpExpr PsName
|
||||
-> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
infer' g = sequenceA . fixtend (infer1' g . wrapFix)
|
||||
-- infer' :: Context -> RlpExpr PsName
|
||||
-- -> HM (Cofree (RlpExprF PsName) (Type PsName))
|
||||
-- infer' g = sequenceA . fixtend (infer1' g . wrapFix)
|
||||
|
||||
buildInitialContext :: Program PsName a -> Context
|
||||
buildInitialContext =
|
||||
@@ -208,7 +269,7 @@ typeCheckRlpProgR :: (Monad m)
|
||||
typeCheckRlpProgR p = tc p
|
||||
where
|
||||
g = buildInitialContext p
|
||||
tc = liftHM . traverse (solve' g) . etaExpandAll
|
||||
tc = liftHM . traverse (infer g) . etaExpandAll
|
||||
etaExpandAll = programDecls . each %~ etaExpand
|
||||
|
||||
etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b)
|
||||
@@ -223,3 +284,44 @@ etaExpand a = a
|
||||
liftHM :: (Monad m) => HM a -> RLPCT m a
|
||||
liftHM = liftEither . runHM'
|
||||
|
||||
freeVariables :: Type PsName -> HashSet PsName
|
||||
freeVariables = cata \case
|
||||
VarTF x -> S.singleton x
|
||||
ForallTF x m -> m `S.difference` S.singleton x
|
||||
vs -> fold vs
|
||||
|
||||
boundVariables :: Type PsName -> HashSet PsName
|
||||
boundVariables = cata \case
|
||||
ForallTF x m -> S.singleton x <> m
|
||||
vs -> fold vs
|
||||
|
||||
-- | rename all free variables for aesthetic purposes
|
||||
|
||||
prettyVars' :: Type PsName -> Type PsName
|
||||
prettyVars' = join prettyVars
|
||||
|
||||
freeVariablesLTR :: Type PsName -> [PsName]
|
||||
freeVariablesLTR = nub . cata \case
|
||||
VarTF x -> [x]
|
||||
ForallTF x m -> m \\ [x]
|
||||
vs -> concat vs
|
||||
|
||||
-- | for some type, compute a substitution which will rename all free variables
|
||||
-- for aesthetic purposes
|
||||
|
||||
prettyVars :: Type PsName -> Type PsName -> Type PsName
|
||||
prettyVars root = appEndo (foldMap Endo subs)
|
||||
where
|
||||
alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ]
|
||||
names = alphabetNames \\ S.toList (boundVariables root)
|
||||
subs = zipWith (\k v -> subst k (VarT v))
|
||||
(freeVariablesLTR root)
|
||||
names
|
||||
|
||||
-- test :: Type PsName -> [(PsName, PsName)]
|
||||
-- test root = subs
|
||||
-- where
|
||||
-- alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ]
|
||||
-- names = alphabetNames \\ S.toList (boundVariables root)
|
||||
-- subs = zip (freeVariablesLTR root) names
|
||||
|
||||
|
||||
Reference in New Issue
Block a user