From 7e8be474c651ac76142ea6ddd9c02734d294637e Mon Sep 17 00:00:00 2001 From: crumbtoo Date: Thu, 28 Mar 2024 06:53:46 -0600 Subject: [PATCH] whole-program inference --- src/Control/Monad/Errorful.hs | 4 + src/Rlp/AltSyntax.hs | 22 +++- src/Rlp/HindleyMilner.hs | 200 ++++++++++++++++++++------------- src/Rlp/HindleyMilner/Types.hs | 62 +++------- 4 files changed, 162 insertions(+), 126 deletions(-) diff --git a/src/Control/Monad/Errorful.hs b/src/Control/Monad/Errorful.hs index 70c4a71..9ee363b 100644 --- a/src/Control/Monad/Errorful.hs +++ b/src/Control/Monad/Errorful.hs @@ -102,6 +102,10 @@ instance (Monoid w, Monad m, MonadWriter w m) => MonadWriter w (ErrorfulT e m) w ((,w) <$> ma, es) pass (ErrorfulT m) = undefined +instance (Monad m, MonadReader r m) => MonadReader r (ErrorfulT e m) where + ask = lift ask + local rr = hoistErrorfulT (local rr) + instance (Monoid w, Monad m, MonadAccum w m) => MonadAccum w (ErrorfulT e m) where accum = lift . accum diff --git a/src/Rlp/AltSyntax.hs b/src/Rlp/AltSyntax.hs index 37c3016..60a0f27 100644 --- a/src/Rlp/AltSyntax.hs +++ b/src/Rlp/AltSyntax.hs @@ -4,8 +4,8 @@ module Rlp.AltSyntax -- * AST Program(..), Decl(..), ExprF(..), Pat(..) , RlpExprF, RlpExpr, Binding(..), Alter(..) - , DataCon(..), Type(..) - , pattern IntT + , DataCon(..), Type(..), Kind + , pattern IntT, pattern TypeT , Core.Rec(..) , AnnotatedRlpExpr, TypedRlpExpr @@ -18,6 +18,8 @@ module Rlp.AltSyntax , programDecls , _VarP, _FunB, _VarB , _TySigD, _FunD + , _LetEF + , Core.applicants1, Core.arrowStops -- * Functor-related tools , Fix(..), Cofree(..), Sum(..), pattern Finl, pattern Finr @@ -60,8 +62,10 @@ type PsName = T.Text newtype Program b a = Program [Decl b a] deriving (Show, Functor, Foldable, Traversable) -programDecls :: Lens' (Program b a) [Decl b a] -programDecls = lens (\ (Program ds) -> ds) (const Program) +programDecls :: Iso (Program b a) (Program b' a') [Decl b a] [Decl b' a'] +programDecls = iso sa bt where + sa (Program ds) = ds + bt = Program data Decl b a = FunD b [Pat b] a | DataD Core.Name [Core.Name] [DataCon b] @@ -78,11 +82,20 @@ data Type b = VarT Core.Name | ForallT b (Type b) deriving (Show, Eq, Generic, Functor, Foldable, Traversable) +instance Core.HasApplicants1 (Type b) (Type b) (Type b) (Type b) where + applicants1 k (AppT f x) = AppT <$> Core.applicants1 k f <*> k x + applicants1 k t = k t + instance (Hashable b) => Hashable (Type b) pattern IntT :: (IsString b, Eq b) => Type b pattern IntT = ConT "Int#" +type Kind = Type + +pattern TypeT :: (IsString b, Eq b) => Type b +pattern TypeT = ConT "Type" + instance Core.HasArrowSyntax (Type b) (Type b) (Type b) where _arrowSyntax = prism make unmake where make (s,t) = FunT `AppT` s `AppT` t @@ -205,6 +218,7 @@ instance (Out a, Out b) => Out (Program b a) where instance (Out b) => Out1 (Program b) where liftOutPrec pr p (Program ds) = vsep $ liftOutPrec pr p <$> ds +makePrisms ''ExprF makePrisms ''Pat makePrisms ''Binding makePrisms ''Decl diff --git a/src/Rlp/HindleyMilner.hs b/src/Rlp/HindleyMilner.hs index 42944cd..159acd5 100644 --- a/src/Rlp/HindleyMilner.hs +++ b/src/Rlp/HindleyMilner.hs @@ -13,9 +13,11 @@ module Rlp.HindleyMilner -------------------------------------------------------------------------------- import Control.Lens hiding (Context', Context, (:<), para, uncons) import Control.Lens.Unsound +import Control.Lens.Extras import Control.Monad.Errorful import Control.Monad.State import Control.Monad.Accum +import Control.Monad.Reader import Control.Monad import Control.Monad.Extra import Control.Arrow ((>>>)) @@ -117,8 +119,7 @@ gather' = \case let ks = bs ^.. each . singular _VarB . _1 . singular _VarP (txs,txs',jxs) <- unzip3 <$> gatherBinds bs let jxsa = foldOf (each . assumptions) jxs - jxcs <- fmap concat . for (ks `zip` txs) $ \ (k,t) -> - elimAssumptions' jxsa k t + jxcs <- elimWithBinds (ks `zip` txs) jxsa (te,je) <- gather e -- ... why don't we need the map? (cs,_) <- fmap fold . for (ks `zip` txs') $ \ (k,t) -> @@ -183,6 +184,16 @@ instantiateMap (ForallT x m) = do & mapped . _1 %~ subst x tv instantiateMap t = pure (t, mempty) +saturated :: Type PsName -> HM [Type PsName] +saturated (ConT con `AppT` as) = do + mk <- view $ contextTyCons . at con + case mk of + Nothing -> addFatal $ TyErrUntypedVariable con + Just k | lengthOf arrowStops k - 1 == lengthOf applicants1 as + -> pure (as ^.. applicants1) + | otherwise + -> undefined + unify :: [Constraint] -> HM [(PsName, Type PsName)] unify [] = pure mempty @@ -190,6 +201,11 @@ unify [] = pure mempty unify (Equality (sx :-> sy) (tx :-> ty) : cs) = unify $ Equality sx tx : Equality sy ty : cs +unify (Equality a@(ConT ca `AppT` as) b@(ConT cb `AppT` bs) : cs) + | ca == cb = do + cs' <- liftA2 (zipWith Equality) (saturated a) (saturated b) + unify $ cs' ++ cs + -- elim unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs unify (Equality (VarT s) (VarT t) : cs) | s == t = unify cs @@ -248,18 +264,18 @@ elimAssumptionsG g as & itraverse (elimAssumptions' as) & fmap (H.elems >>> concat) -infer :: Context -> RlpExpr PsName - -> HM (Cofree (RlpExprF PsName) (Type PsName)) -infer g0 e = do +finalJudgement :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement) + -> PartialJudgement +finalJudgement = foldOf (folded . _2) + +infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) +infer e = do + g0 <- ask e' <- annotate e - let (as, concat -> cs) = unzip $ e' ^.. folded . _2 - . lensProduct assumptions constraints - cs' <- do - csa <- concatMapM (elimAssumptionsG g0) as - pure (csa <> cs) - g <- unify cs' - let sub t = ifoldrOf (reversed . assocs) subst t g + let (cs,as) = finalJudgement e' ^. lensProduct constraints assumptions + cs' <- (<>cs) <$> elimAssumptionsG g0 as checkUndefinedVariables e' + sub <- solve cs' pure $ e' & fmap (sub . view _1) & _extract %~ generaliseG g0 where @@ -267,6 +283,11 @@ infer g0 e = do -- the user-specified type: prioritise her. unifyTypes _ s t = unify [Equality s t] $> s +solve :: [Constraint] -> HM (Type PsName -> Type PsName) +solve cs = do + g <- unify cs + pure $ \t -> ifoldrOf (reversed . assocs) subst t g + checkUndefinedVariables :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement) -> HM () @@ -276,33 +297,8 @@ checkUndefinedVariables ((_,j) :< es) as -> doErrs *> checkUndefinedVariables `traverse_` es where doErrs = ifor as \n _ -> addWound $ TyErrUntypedVariable n -infer1 :: Context -> RlpExpr PsName -> HM (Type PsName) -infer1 g = fmap extract . infer g - --- unionContextWithKeyM :: Monad m --- => (PsName -> Type PsName -> Type PsName --- -> m (Type PsName)) --- -> Context -> Context -> m Context --- unionContextWithKeyM f a b = Context <$> unionWithKeyM f a' b' --- where --- a' = a ^. contextVars --- b' = b ^. contextVars - --- unionWithKeyM :: forall m k v. (Eq k, Hashable k, Monad m) --- => (k -> v -> v -> m v) -> HashMap k v -> HashMap k v --- -> m (HashMap k v) --- unionWithKeyM f a b = sequenceA $ H.unionWithKey f' ma mb --- where --- f' k x y = join $ liftA2 (f k) x y --- ma = fmap (pure @m) a --- mb = fmap (pure @m) b - --- solve :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) --- solve = solve' mempty - --- solve' :: Context -> RlpExpr PsName --- -> HM (Cofree (RlpExprF PsName) (Type PsName)) --- solve' g = sequenceA . fixtend (infer1' g . wrapFix) +infer1 :: RlpExpr PsName -> HM (Type PsName) +infer1 = fmap extract . infer occurs :: PsName -> Type PsName -> Bool occurs n = cata \case @@ -326,19 +322,100 @@ fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b fixtend c (Fix f) = c f :< fmap (fixtend c) f buildInitialContext :: Program PsName a -> Context -buildInitialContext = const mempty - -- Context . H.fromList . toListOf (programDecls . each . _TySigD) +buildInitialContext = foldMapOf (programDecls . each) \case + TySigD n t -> contextOfTySig n t + DataD n as cs -> contextOfData n as cs + _ -> mempty + +contextOfTySig :: PsName -> Type PsName -> Context +contextOfTySig = const $ const mempty + +contextOfData :: PsName -> [PsName] -> [DataCon PsName] -> Context +contextOfData n as cs = kindCtx <> consCtx where + kindCtx = mempty & contextTyCons . at n ?~ kind + where kind = foldr (\_ t -> TypeT :-> t) TypeT as + + consCtx = foldMap contextOfCon cs + + contextOfCon (DataCon c as) = + mempty & contextVars . at c ?~ ty + where ty = foralls $ foldr (:->) base as + + base = foldl (\f x -> AppT f (VarT x)) (VarT n) as + + foralls t = foldr ForallT t as typeCheckRlpProgR :: (Monad m) => Program PsName (RlpExpr PsName) -> RLPCT m (Program PsName (TypedRlpExpr PsName)) -typeCheckRlpProgR p = tc p +typeCheckRlpProgR p = liftHM g (inferProg . etaExpandAll $ p) where g = buildInitialContext p - tc = liftHM . traverse (infer g) . etaExpandAll etaExpandAll = programDecls . each %~ etaExpand +inferProg :: Program PsName (RlpExpr PsName) + -> HM (Program PsName (TypedRlpExpr PsName)) +inferProg p = do + g0 <- ask + p' <- annotateProg p + let (cs,as) = foldOf (folded . folded . _2) p' + ^. lensProduct constraints assumptions + cs' <- (<>cs) <$> elimAssumptionsG g0 as + sub <- solve cs' + pure $ p' & programDecls . traversed . _FunD . _3 + %~ ((_extract %~ generaliseG g0) . fmap (sub . view _1)) + +annotateProg :: Program PsName (RlpExpr PsName) + -> HM (Program PsName + (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))) +annotateProg = traverse annotate + +gatherProg :: Program PsName (RlpExpr PsName) + -> HM [Constraint] +gatherProg p = do + -- this should be nearly identical to the rule for `letrec` in gather' + let (ks,xs) = unzip $ funsToSimpleBinds (p ^. programDecls) + (txs,txs',jxs) <- unzip3 <$> gatherBinds (zipWith simpleBind ks xs) + let jxsa = foldOf (each . assumptions) jxs + jxcs <- elimWithBinds (ks `zip` txs) jxsa + -- TODO: any remaining assumptions should be errors at this point + pure jxcs + +elimWithBinds :: [(PsName, Type PsName)] + -> Assumptions + -> HM [Constraint] +elimWithBinds bs jxsa = fmap concat . for bs $ \ (k,t) -> + elimAssumptions' jxsa k t + +simpleBind :: b -> a -> Binding b a +simpleBind k v = VarB (VarP k) v + +funsToSimpleBinds :: [Decl PsName (RlpExpr PsName)] + -> [(PsName, RlpExpr PsName)] +funsToSimpleBinds = mapMaybe \case + d@(FunD n _ _) -> Just (n, etaExpand' d) + _ -> Nothing + +simpleBindsToFuns :: [(PsName, TypedRlpExpr PsName)] + -> [Decl PsName (TypedRlpExpr PsName)] +simpleBindsToFuns = fmap \ (n,e) -> FunD n [] e + +wrapLetrec :: [(PsName, RlpExpr PsName)] -> RlpExpr PsName +wrapLetrec ds = ds & each . _1 %~ VarP + & each %~ review _VarB + & \bs -> Finr $ LetEF Rec bs (Finl . LitF . IntL $ 123) + +unwrapLetrec :: TypedRlpExpr PsName -> [(PsName, TypedRlpExpr PsName)] +unwrapLetrec (_ :< InR (LetEF _ bs _)) + = bs ^.. each . _VarB + & each . _1 %~ view (singular _VarP) + +etaExpand' :: Decl b (RlpExpr b) -> RlpExpr b +etaExpand' (FunD _ [] e) = e +etaExpand' (FunD _ as e) = Finl . LamF as' $ e + where as' = as ^.. each . singular _VarP + etaExpand :: Decl b (RlpExpr b) -> Decl b (RlpExpr b) etaExpand (FunD n [] e) = FunD n [] e etaExpand (FunD n as e) @@ -348,8 +425,8 @@ etaExpand (FunD n as e) allVarP = traverse (matching _VarP) etaExpand a = a -liftHM :: (Monad m) => HM a -> RLPCT m a -liftHM = liftEither . runHM' +liftHM :: (Monad m) => Context -> HM a -> RLPCT m a +liftHM g = liftEither . runHM g freeVariables :: Type PsName -> HashSet PsName freeVariables = cata \case @@ -362,32 +439,18 @@ boundVariables = cata \case ForallTF x m -> S.singleton x <> m vs -> fold vs --- | rename all free variables for aesthetic purposes - -prettyVars' :: Type PsName -> Type PsName -prettyVars' = join prettyVars - freeVariablesLTR :: Type PsName -> [PsName] freeVariablesLTR = nub . cata \case VarTF x -> [x] ForallTF x m -> m \\ [x] vs -> concat vs --- | for some type, compute a substitution which will rename all free variables --- for aesthetic purposes - -prettyVars :: Type PsName -> Type PsName -> Type PsName -prettyVars root = appEndo (foldMap Endo subs) - where - alphabetNames = [ T.pack [c] | c <- ['a'..'z'] ] - names = alphabetNames \\ S.toList (boundVariables root) - subs = zipWith (\k v -> subst k (VarT v)) - (freeVariablesLTR root) - names - renamePrettily' :: Type PsName -> Type PsName renamePrettily' = join renamePrettily +-- | for some type, compute a substitution which will rename all free variables +-- for aesthetic purposes + renamePrettily :: Type PsName -> Type PsName -> Type PsName renamePrettily root = (`evalState` alphabetNames) . (renameFree <=< renameBound) where @@ -408,19 +471,6 @@ renamePrettily root = (`evalState` alphabetNames) . (renameFree <=< renameBound) getName :: State [PsName] PsName getName = state (fromJust . uncons) --- renamePrettily :: Type PsName -> Type PsName --- renamePrettily --- = (`evalState` alphabetNames) --- . (renameFrees <=< renameForalls) --- where --- -- TODO: --- renameFrees root = pure root --- renameForalls = cata \case --- ForallTF x m -> do --- n <- getName --- ForallT n <$> (subst x (VarT n) <$> m) --- t -> embed <$> sequenceA t - alphabetNames :: [PsName] alphabetNames = alphabet ++ concatMap appendAlphabet alphabetNames where alphabet = [ T.pack [c] | c <- ['a'..'z'] ] diff --git a/src/Rlp/HindleyMilner/Types.hs b/src/Rlp/HindleyMilner/Types.hs index 55a2cd7..908f33d 100644 --- a/src/Rlp/HindleyMilner/Types.hs +++ b/src/Rlp/HindleyMilner/Types.hs @@ -16,6 +16,7 @@ import Control.Monad.Accum import Control.Monad.Trans.Accum import Control.Monad.Errorful import Control.Monad.State +import Control.Monad.Reader import Text.Printf import Data.Pretty import Data.Function @@ -29,6 +30,7 @@ import Rlp.AltSyntax data Context = Context { _contextVars :: HashMap PsName (Type PsName) , _contextTyVars :: HashMap PsName (Type PsName) + , _contextTyCons :: HashMap PsName (Kind PsName) } deriving (Show, Generic) deriving (Semigroup, Monoid) @@ -58,7 +60,7 @@ instance Hashable Constraint type Memo = HashMap (RlpExpr PsName) (Type PsName, PartialJudgement) -type HM = ErrorfulT TypeError (StateT Int (Accum Memo)) +type HM = ErrorfulT TypeError (ReaderT Context (StateT Int (Accum Memo))) -- | Type error enum. data TypeError @@ -89,44 +91,6 @@ instance IsRlpcError TypeError where (rout @String t) (rout @String x) ] --- type Memo t = HashMap t (Type PsName, PartialJudgement) - --- newtype HM t a = HM { unHM :: Int -> Memo t -> (a, Int, Memo t) } - --- runHM :: (Hashable t) => HM t a -> (a, Memo t) --- runHM hm = let (a,_,m) = unHM hm 0 mempty in (a,m) - --- instance Functor (HM t) where --- fmap f (HM h) = HM \n m -> h n m & _1 %~ f - --- instance Applicative (HM t) where --- pure a = HM \n m -> (a,n,m) --- HM hf <*> HM ha = HM \n m -> --- let (f',n',m') = hf n m --- (a,n'',m'') = ha n' m' --- in (f' a, n'', m'') - --- instance Monad (HM t) where --- HM ha >>= k = HM \n m -> --- let (a,n',m') = ha n m --- (a',n'',m'') = unHM (k a) n' m' --- in (a',n'', m'') - --- instance Hashable t => MonadWriter (Memo t) (HM t) where --- -- IMPORTAN! (<>) is left-biased for HashMap! append `w` to the RIGHt! --- writer (a,w) = HM \n m -> (a,n,m <> w) --- listen ma = HM \n m -> --- let (a,n',m') = unHM ma n m --- in ((a,m'),n',m') --- pass maww = HM \n m -> --- let ((a,ww),n',m') = unHM maww n m --- in (a,n',ww m') - --- instance MonadState Int (HM t) where --- state f = HM \n m -> --- let (a,n') = f n --- in (a,n',m) - tvNameOfInt :: Int -> PsName tvNameOfInt n = "$a" <> T.pack (show n) @@ -146,13 +110,14 @@ listenFreshTvNames hm = do n' <- get pure (a, [ tvNameOfInt k | k <- [n .. pred n'] ]) -runHM' :: HM a -> Either [TypeError] a -runHM' e = maybe (Left es) Right ma +runHM :: Context -> HM a -> Either [TypeError] a +runHM g e = maybe (Left es) Right ma where - ((ma,es),m) = (`runAccum` mempty) . (`evalStateT` 0) . runErrorfulT $ e + ((ma,es),m) = (`runAccum` mempty) . (`evalStateT` 0) + . (`runReaderT` g) . runErrorfulT $ e --- addConstraint :: Constraint -> HM () --- addConstraint = tell . pure +runHM' :: HM a -> Either [TypeError] a +runHM' = runHM mempty makePrisms ''PartialJudgement makeLenses ''PartialJudgement @@ -164,11 +129,14 @@ supplement :: [(PsName, Type PsName)] -> Context -> Context supplement bs = contextVars %~ (H.fromList bs <>) demoContext :: Context -demoContext = Context - { _contextVars = +demoContext = mempty + & contextVars .~ [ ("+#", IntT :-> IntT :-> IntT) + , ("Nil", ForallT "a" $ ConT "List" `AppT` VarT "a") + ] + & contextTyCons .~ + [ ("List", TypeT :-> TypeT) ] - } constraintTypes :: Traversal' Constraint (Type PsName) constraintTypes k (Equality s t) = Equality <$> k s <*> k t