whole-program inference

This commit is contained in:
crumbtoo
2024-03-28 06:53:46 -06:00
parent 3ed6fc233f
commit 7e8be474c6
4 changed files with 162 additions and 126 deletions

View File

@@ -102,6 +102,10 @@ instance (Monoid w, Monad m, MonadWriter w m) => MonadWriter w (ErrorfulT e m) w
((,w) <$> ma, es)
pass (ErrorfulT m) = undefined
instance (Monad m, MonadReader r m) => MonadReader r (ErrorfulT e m) where
ask = lift ask
local rr = hoistErrorfulT (local rr)
instance (Monoid w, Monad m, MonadAccum w m)
=> MonadAccum w (ErrorfulT e m) where
accum = lift . accum

View File

@@ -4,8 +4,8 @@ module Rlp.AltSyntax
-- * AST
Program(..), Decl(..), ExprF(..), Pat(..)
, RlpExprF, RlpExpr, Binding(..), Alter(..)
, DataCon(..), Type(..)
, pattern IntT
, DataCon(..), Type(..), Kind
, pattern IntT, pattern TypeT
, Core.Rec(..)
, AnnotatedRlpExpr, TypedRlpExpr
@@ -18,6 +18,8 @@ module Rlp.AltSyntax
, programDecls
, _VarP, _FunB, _VarB
, _TySigD, _FunD
, _LetEF
, Core.applicants1, Core.arrowStops
-- * Functor-related tools
, Fix(..), Cofree(..), Sum(..), pattern Finl, pattern Finr
@@ -60,8 +62,10 @@ type PsName = T.Text
newtype Program b a = Program [Decl b a]
deriving (Show, Functor, Foldable, Traversable)
programDecls :: Lens' (Program b a) [Decl b a]
programDecls = lens (\ (Program ds) -> ds) (const Program)
programDecls :: Iso (Program b a) (Program b' a') [Decl b a] [Decl b' a']
programDecls = iso sa bt where
sa (Program ds) = ds
bt = Program
data Decl b a = FunD b [Pat b] a
| DataD Core.Name [Core.Name] [DataCon b]
@@ -78,11 +82,20 @@ data Type b = VarT Core.Name
| ForallT b (Type b)
deriving (Show, Eq, Generic, Functor, Foldable, Traversable)
instance Core.HasApplicants1 (Type b) (Type b) (Type b) (Type b) where
applicants1 k (AppT f x) = AppT <$> Core.applicants1 k f <*> k x
applicants1 k t = k t
instance (Hashable b) => Hashable (Type b)
pattern IntT :: (IsString b, Eq b) => Type b
pattern IntT = ConT "Int#"
type Kind = Type
pattern TypeT :: (IsString b, Eq b) => Type b
pattern TypeT = ConT "Type"
instance Core.HasArrowSyntax (Type b) (Type b) (Type b) where
_arrowSyntax = prism make unmake where
make (s,t) = FunT `AppT` s `AppT` t
@@ -205,6 +218,7 @@ instance (Out a, Out b) => Out (Program b a) where
instance (Out b) => Out1 (Program b) where
liftOutPrec pr p (Program ds) = vsep $ liftOutPrec pr p <$> ds
makePrisms ''ExprF
makePrisms ''Pat
makePrisms ''Binding
makePrisms ''Decl

View File

@@ -13,9 +13,11 @@ module Rlp.HindleyMilner
--------------------------------------------------------------------------------
import Control.Lens hiding (Context', Context, (:<), para, uncons)
import Control.Lens.Unsound
import Control.Lens.Extras
import Control.Monad.Errorful
import Control.Monad.State
import Control.Monad.Accum
import Control.Monad.Reader
import Control.Monad
import Control.Monad.Extra
import Control.Arrow ((>>>))
@@ -117,8 +119,7 @@ gather' = \case
let ks = bs ^.. each . singular _VarB . _1 . singular _VarP
(txs,txs',jxs) <- unzip3 <$> gatherBinds bs
let jxsa = foldOf (each . assumptions) jxs
jxcs <- fmap concat . for (ks `zip` txs) $ \ (k,t) ->
elimAssumptions' jxsa k t
jxcs <- elimWithBinds (ks `zip` txs) jxsa
(te,je) <- gather e
-- ... why don't we need the map?
(cs,_) <- fmap fold . for (ks `zip` txs') $ \ (k,t) ->
@@ -183,6 +184,16 @@ instantiateMap (ForallT x m) = do
& mapped . _1 %~ subst x tv
instantiateMap t = pure (t, mempty)
saturated :: Type PsName -> HM [Type PsName]
saturated (ConT con `AppT` as) = do
mk <- view $ contextTyCons . at con
case mk of
Nothing -> addFatal $ TyErrUntypedVariable con
Just k | lengthOf arrowStops k - 1 == lengthOf applicants1 as
-> pure (as ^.. applicants1)
| otherwise
-> undefined
unify :: [Constraint] -> HM [(PsName, Type PsName)]
unify [] = pure mempty
@@ -190,6 +201,11 @@ unify [] = pure mempty
unify (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify $ Equality sx tx : Equality sy ty : cs
unify (Equality a@(ConT ca `AppT` as) b@(ConT cb `AppT` bs) : cs)
| ca == cb = do
cs' <- liftA2 (zipWith Equality) (saturated a) (saturated b)
unify $ cs' ++ cs
-- elim
unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs
unify (Equality (VarT s) (VarT t) : cs) | s == t = unify cs
@@ -248,18 +264,18 @@ elimAssumptionsG g as
& itraverse (elimAssumptions' as)
& fmap (H.elems >>> concat)
infer :: Context -> RlpExpr PsName
-> HM (Cofree (RlpExprF PsName) (Type PsName))
infer g0 e = do
finalJudgement :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)
-> PartialJudgement
finalJudgement = foldOf (folded . _2)
infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
infer e = do
g0 <- ask
e' <- annotate e
let (as, concat -> cs) = unzip $ e' ^.. folded . _2
. lensProduct assumptions constraints
cs' <- do
csa <- concatMapM (elimAssumptionsG g0) as
pure (csa <> cs)
g <- unify cs'
let sub t = ifoldrOf (reversed . assocs) subst t g
let (cs,as) = finalJudgement e' ^. lensProduct constraints assumptions
cs' <- (<>cs) <$> elimAssumptionsG g0 as
checkUndefinedVariables e'
sub <- solve cs'
pure $ e' & fmap (sub . view _1)
& _extract %~ generaliseG g0
where
@@ -267,6 +283,11 @@ infer g0 e = do
-- the user-specified type: prioritise her.
unifyTypes _ s t = unify [Equality s t] $> s
solve :: [Constraint] -> HM (Type PsName -> Type PsName)
solve cs = do
g <- unify cs
pure $ \t -> ifoldrOf (reversed . assocs) subst t g
checkUndefinedVariables
:: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)
-> HM ()
@@ -276,33 +297,8 @@ checkUndefinedVariables ((_,j) :< es)
as -> doErrs *> checkUndefinedVariables `traverse_` es
where doErrs = ifor as \n _ -> addWound $ TyErrUntypedVariable n
infer1 :: Context -> RlpExpr PsName -> HM (Type PsName)
infer1 g = fmap extract . infer g
-- unionContextWithKeyM :: Monad m
-- => (PsName -> Type PsName -> Type PsName
-- -> m (Type PsName))
-- -> Context -> Context -> m Context
-- unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b'
-- where
-- a' = a ^. contextVars
-- b' = b ^. contextVars
-- unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m)
-- => (k -> v -> v -> m v) -> HashMap k v -> HashMap k v
-- -> m (HashMap k v)
-- unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb
-- where
-- f' k x y = join $ liftA2 (f k) x y
-- ma = fmap (pure @m) a
-- mb = fmap (pure @m) b
-- solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
-- solve = solve' mempty
-- solve' :: Context -> RlpExpr PsName
-- -> HM (Cofree (RlpExprF PsName) (Type PsName))
-- solve' g = sequenceA . fixtend (infer1' g . wrapFix)
infer1 :: RlpExpr PsName -> HM (Type PsName)
infer1 = fmap extract . infer
occurs :: PsName -> Type PsName -> Bool
occurs n = cata \case
@@ -326,19 +322,100 @@ fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
fixtend c (Fix f) = c f :< fmap (fixtend c) f
buildInitialContext :: Program PsName a -> Context
buildInitialContext = const mempty
-- Context . H.fromList . toListOf (programDecls . each . _TySigD)
buildInitialContext = foldMapOf (programDecls . each) \case
TySigD n t -> contextOfTySig n t
DataD n as cs -> contextOfData n as cs
_ -> mempty
contextOfTySig :: PsName -> Type PsName -> Context
contextOfTySig = const $ const mempty
contextOfData :: PsName -> [PsName] -> [DataCon PsName] -> Context
contextOfData n as cs = kindCtx <> consCtx where
kindCtx = mempty & contextTyCons . at n ?~ kind
where kind = foldr (\_ t -> TypeT :-> t) TypeT as
consCtx = foldMap contextOfCon cs
contextOfCon (DataCon c as) =
mempty & contextVars . at c ?~ ty
where ty = foralls $ foldr (:->) base as
base = foldl (\f x -> AppT f (VarT x)) (VarT n) as
foralls t = foldr ForallT t as
typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName)
-> RLPCT m (Program PsName
(TypedRlpExpr PsName))
typeCheckRlpProgR p = tc p
typeCheckRlpProgR p = liftHM g (inferProg . etaExpandAll $ p)
where
g = buildInitialContext p
tc = liftHM . traverse (infer g) . etaExpandAll
etaExpandAll = programDecls . each %~ etaExpand
inferProg :: Program PsName (RlpExpr PsName)
-> HM (Program PsName (TypedRlpExpr PsName))
inferProg p = do
g0 <- ask
p' <- annotateProg p
let (cs,as) = foldOf (folded . folded . _2) p'
^. lensProduct constraints assumptions
cs' <- (<>cs) <$> elimAssumptionsG g0 as
sub <- solve cs'
pure $ p' & programDecls . traversed . _FunD . _3
%~ ((_extract %~ generaliseG g0) . fmap (sub . view _1))
annotateProg :: Program PsName (RlpExpr PsName)
-> HM (Program PsName
(Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)))
annotateProg = traverse annotate
gatherProg :: Program PsName (RlpExpr PsName)
-> HM [Constraint]
gatherProg p = do
-- this should be nearly identical to the rule for `letrec` in gather'
let (ks,xs) = unzip $ funsToSimpleBinds (p ^. programDecls)
(txs,txs',jxs) <- unzip3 <$> gatherBinds (zipWith simpleBind ks xs)
let jxsa = foldOf (each . assumptions) jxs
jxcs <- elimWithBinds (ks `zip` txs) jxsa
-- TODO: any remaining assumptions should be errors at this point
pure jxcs
elimWithBinds :: [(PsName, Type PsName)]
-> Assumptions
-> HM [Constraint]
elimWithBinds bs jxsa = fmap concat . for bs $ \ (k,t) ->
elimAssumptions' jxsa k t
simpleBind :: b -> a -> Binding b a
simpleBind k v = VarB (VarP k) v
funsToSimpleBinds :: [Decl PsName (RlpExpr PsName)]
-> [(PsName, RlpExpr PsName)]
funsToSimpleBinds = mapMaybe \case
d@(FunD n _ _) -> Just (n, etaExpand' d)
_ -> Nothing
simpleBindsToFuns :: [(PsName, TypedRlpExpr PsName)]
-> [Decl PsName (TypedRlpExpr PsName)]
simpleBindsToFuns = fmap \ (n,e) -> FunD n [] e
wrapLetrec :: [(PsName, RlpExpr PsName)] -> RlpExpr PsName
wrapLetrec ds = ds & each . _1 %~ VarP
& each %~ review _VarB
& \bs -> Finr $ LetEF Rec bs (Finl . LitF . IntL $ 123)
unwrapLetrec :: TypedRlpExpr PsName -> [(PsName, TypedRlpExpr PsName)]
unwrapLetrec (_ :< InR (LetEF _ bs _))
= bs ^.. each . _VarB
& each . _1 %~ view (singular _VarP)
etaExpand' :: Decl b (RlpExpr b) -> RlpExpr b
etaExpand' (FunD _ [] e) = e
etaExpand' (FunD _ as e) = Finl . LamF as' $ e
where as' = as ^.. each . singular _VarP
etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b)
etaExpand (FunD n [] e) = FunD n [] e
etaExpand (FunD n as e)
@@ -348,8 +425,8 @@ etaExpand (FunD n as e)
allVarP = traverse (matching _VarP)
etaExpand a = a
liftHM :: (Monad m) => HM a -> RLPCT m a
liftHM = liftEither . runHM'
liftHM :: (Monad m) => Context -> HM a -> RLPCT m a
liftHM g = liftEither . runHM g
freeVariables :: Type PsName -> HashSet PsName
freeVariables = cata \case
@@ -362,32 +439,18 @@ boundVariables = cata \case
ForallTF x m -> S.singleton x <> m
vs -> fold vs
-- | rename all free variables for aesthetic purposes
prettyVars' :: Type PsName -> Type PsName
prettyVars' = join prettyVars
freeVariablesLTR :: Type PsName -> [PsName]
freeVariablesLTR = nub . cata \case
VarTF x -> [x]
ForallTF x m -> m \\ [x]
vs -> concat vs
-- | for some type, compute a substitution which will rename all free variables
-- for aesthetic purposes
prettyVars :: Type PsName -> Type PsName -> Type PsName
prettyVars root = appEndo (foldMap Endo subs)
where
alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ]
names = alphabetNames \\ S.toList (boundVariables root)
subs = zipWith (\k v -> subst k (VarT v))
(freeVariablesLTR root)
names
renamePrettily' :: Type PsName -> Type PsName
renamePrettily' = join renamePrettily
-- | for some type, compute a substitution which will rename all free variables
-- for aesthetic purposes
renamePrettily :: Type PsName -> Type PsName -> Type PsName
renamePrettily root = (`evalState` alphabetNames) . (renameFree <=< renameBound)
where
@@ -408,19 +471,6 @@ renamePrettily root = (`evalState` alphabetNames) . (renameFree <=< renameBound)
getName :: State [PsName] PsName
getName = state (fromJust . uncons)
-- renamePrettily :: Type PsName -> Type PsName
-- renamePrettily
-- = (`evalState` alphabetNames)
-- . (renameFrees <=< renameForalls)
-- where
-- -- TODO:
-- renameFrees root = pure root
-- renameForalls = cata \case
-- ForallTF x m -> do
-- n <- getName
-- ForallT n <$> (subst x (VarT n) <$> m)
-- t -> embed <$> sequenceA t
alphabetNames :: [PsName]
alphabetNames = alphabet ++ concatMap appendAlphabet alphabetNames
where alphabet = [ T.pack [c] | c <- ['a'..'z'] ]

View File

@@ -16,6 +16,7 @@ import Control.Monad.Accum
import Control.Monad.Trans.Accum
import Control.Monad.Errorful
import Control.Monad.State
import Control.Monad.Reader
import Text.Printf
import Data.Pretty
import Data.Function
@@ -29,6 +30,7 @@ import Rlp.AltSyntax
data Context = Context
{ _contextVars :: HashMap PsName (Type PsName)
, _contextTyVars :: HashMap PsName (Type PsName)
, _contextTyCons :: HashMap PsName (Kind PsName)
}
deriving (Show, Generic)
deriving (Semigroup, Monoid)
@@ -58,7 +60,7 @@ instance Hashable Constraint
type Memo = HashMap (RlpExpr PsName) (Type PsName, PartialJudgement)
type HM = ErrorfulT TypeError (StateT Int (Accum Memo))
type HM = ErrorfulT TypeError (ReaderT Context (StateT Int (Accum Memo)))
-- | Type error enum.
data TypeError
@@ -89,44 +91,6 @@ instance IsRlpcError TypeError where
(rout @String t) (rout @String x)
]
-- type Memo t = HashMap t (Type PsName, PartialJudgement)
-- newtype HM t a = HM { unHM :: Int -> Memo t -> (a, Int, Memo t) }
-- runHM :: (Hashable t) => HM t a -> (a, Memo t)
-- runHM hm = let (a,_,m) = unHM hm 0 mempty in (a,m)
-- instance Functor (HM t) where
-- fmap f (HM h) = HM \n m -> h n m & _1 %~ f
-- instance Applicative (HM t) where
-- pure a = HM \n m -> (a,n,m)
-- HM hf <*> HM ha = HM \n m ->
-- let (f',n',m') = hf n m
-- (a,n'',m'') = ha n' m'
-- in (f' a, n'', m'')
-- instance Monad (HM t) where
-- HM ha >>= k = HM \n m ->
-- let (a,n',m') = ha n m
-- (a',n'',m'') = unHM (k a) n' m'
-- in (a',n'', m'')
-- instance Hashable t => MonadWriter (Memo t) (HM t) where
-- -- IMPORTAN! (<>) is left-biased for HashMap! append `w` to the RIGHt!
-- writer (a,w) = HM \n m -> (a,n,m <> w)
-- listen ma = HM \n m ->
-- let (a,n',m') = unHM ma n m
-- in ((a,m'),n',m')
-- pass maww = HM \n m ->
-- let ((a,ww),n',m') = unHM maww n m
-- in (a,n',ww m')
-- instance MonadState Int (HM t) where
-- state f = HM \n m ->
-- let (a,n') = f n
-- in (a,n',m)
tvNameOfInt :: Int -> PsName
tvNameOfInt n = "$a" <> T.pack (show n)
@@ -146,13 +110,14 @@ listenFreshTvNames hm = do
n' <- get
pure (a, [ tvNameOfInt k | k <- [n .. pred n'] ])
runHM' :: HM a -> Either [TypeError] a
runHM' e = maybe (Left es) Right ma
runHM :: Context -> HM a -> Either [TypeError] a
runHM g e = maybe (Left es) Right ma
where
((ma,es),m) = (`runAccum` mempty) . (`evalStateT` 0) . runErrorfulT $ e
((ma,es),m) = (`runAccum` mempty) . (`evalStateT` 0)
. (`runReaderT` g) . runErrorfulT $ e
-- addConstraint :: Constraint -> HM ()
-- addConstraint = tell . pure
runHM' :: HM a -> Either [TypeError] a
runHM' = runHM mempty
makePrisms ''PartialJudgement
makeLenses ''PartialJudgement
@@ -164,11 +129,14 @@ supplement :: [(PsName, Type PsName)] -> Context -> Context
supplement bs = contextVars %~ (H.fromList bs <>)
demoContext :: Context
demoContext = Context
{ _contextVars =
demoContext = mempty
& contextVars .~
[ ("+#", IntT :-> IntT :-> IntT)
, ("Nil", ForallT "a" $ ConT "List" `AppT` VarT "a")
]
& contextTyCons .~
[ ("List", TypeT :-> TypeT)
]
}
constraintTypes :: Traversal' Constraint (Type PsName)
constraintTypes k (Equality s t) = Equality <$> k s <*> k t