From de41536e1d52ac1902b556a75fa896d203abd9f0 Mon Sep 17 00:00:00 2001 From: crumbtoo Date: Mon, 11 Mar 2024 10:36:38 -0600 Subject: [PATCH] algW i'm honestly rather disappointed in myself for not implementing a comonadic algo J. cross my heart i'll come back to this and return stronger! in the mean time, i really need to get this thing into a presentable state... --- src/Control/Monad/Errorful.hs | 7 +++ src/Data/Pretty.hs | 1 + src/Rlp/HindleyMilner.hs | 110 +++++++++++++++++---------------- src/Rlp/HindleyMilner/Types.hs | 33 ++++++++-- 4 files changed, 92 insertions(+), 59 deletions(-) diff --git a/src/Control/Monad/Errorful.hs b/src/Control/Monad/Errorful.hs index fc4dae3..175d39b 100644 --- a/src/Control/Monad/Errorful.hs +++ b/src/Control/Monad/Errorful.hs @@ -14,6 +14,7 @@ module Control.Monad.Errorful where ---------------------------------------------------------------------------------- import Control.Monad.State.Strict +import Control.Monad.Writer import Control.Monad.Reader import Control.Monad.Trans import Data.Functor.Identity @@ -88,3 +89,9 @@ instance (Monad m, MonadErrorful e m) => MonadErrorful e (ReaderT r m) where instance (Monad m, MonadState s m) => MonadState s (ErrorfulT e m) where state = lift . state +instance (Monoid w, Monad m, MonadWriter w m) => MonadWriter w (ErrorfulT e m) where + tell = lift . tell + listen (ErrorfulT m) = ErrorfulT $ listen m <&> \ ((ma,es),w) -> + ((,w) <$> ma, es) + pass (ErrorfulT m) = undefined + diff --git a/src/Data/Pretty.hs b/src/Data/Pretty.hs index 820a934..5302677 100644 --- a/src/Data/Pretty.hs +++ b/src/Data/Pretty.hs @@ -4,6 +4,7 @@ module Data.Pretty , prettyPrec1 , rpretty , ttext + , Showing(..) -- * Pretty-printing lens combinators , hsepOf, vsepOf, vcatOf, vlinesOf, vsepTerm , vsep diff --git a/src/Rlp/HindleyMilner.hs b/src/Rlp/HindleyMilner.hs index 616aee4..3212e35 100644 --- a/src/Rlp/HindleyMilner.hs +++ b/src/Rlp/HindleyMilner.hs @@ -15,6 +15,7 @@ import Control.Monad.State import Control.Monad import Control.Monad.Writer.Strict import Data.Text qualified as T +import Data.Pretty import Data.Hashable import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as H @@ -50,69 +51,72 @@ fixCofree = iso sa bt where lookupVar :: PsName -> Context -> HM (Type PsName) lookupVar n g = case g ^. contextVars . at n of Just t -> pure t - Nothing -> addFatal (TyErrUntypedVariable n) + Nothing -> addFatal $ TyErrUntypedVariable n --- | Instantiate a polytype by replacing the bound type variables with fresh --- monotype (free) variables -inst :: Type PsName -> HM (Type PsName) -inst = para \case - ForallTF x (_,t) -> do - m <- t - tv <- freshTv - pure $ subst x tv m - -- discard the recursive results by selected fst - t -> pure . embed . fmap fst $ t +gather :: Context -> RlpExpr PsName -> HM (Type PsName) +gather g = \case -generalise :: Type PsName -> Type PsName -generalise = foldr ForallT <*> toListOf tyVars + Finl (LitF (IntL _)) -> pure IntT -tyVars :: Traversal (Type b) (Type b') b b' -tyVars = traverse + Finl (AppF f x) -> do + tf <- gather g f + tx <- gather g x + tfx <- freshTv + addConstraint $ Equality tf (tx :-> tfx) + pure tfx -polytypeBinds :: Traversal' (Type b) b -polytypeBinds k (ForallT x m) = ForallT <$> k x <*> polytypeBinds k m -polytypeBinds k t = pure t + Finl (VarF n) -> lookupVar n g -subst :: PsName -> Type PsName -> Type PsName -> Type PsName -subst n t' = para \case - VarTF m | n == m -> t' - ForallTF m (pre,post) | n == m -> pre - | otherwise -> post - t -> embed . fmap snd $ t + Finl (LamF bs e) -> do + tbs <- for bs $ \b -> (b,) <$> freshTv + te <- gather (supplement tbs g) e + pure $ foldrOf (each . _2) (:->) te tbs + +unify :: Context -> [Constraint] -> HM Context + +unify g [] = pure g + +unify g (Equality (sx :-> sy) (tx :-> ty) : cs) = + unify g $ Equality sx tx : Equality sy ty : cs + +-- elim +unify g (Equality (ConT s) (ConT t) : cs) | s == t = unify g cs +unify g (Equality (VarT s) (VarT t) : cs) | s == t = unify g cs + +unify g (Equality (VarT s) t : cs) + | occurs s t = addFatal $ TyErrRecursiveType s t + | otherwise = unify g' cs' + where + g' = supplement [(s,t)] g + cs' = cs & each . constraintTypes %~ subst s t + +-- swap +unify g (Equality s (VarT t) : cs) = unify g (Equality (VarT t) s : cs) + +unify _ (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t + +solve :: Context -> RlpExpr PsName -> HM (Type PsName) +solve g e = do + (t,cs) <- listen $ gather g e + g' <- unify g cs + pure $ ifoldrOf (contextVars . itraversed) subst t g' occurs :: PsName -> Type PsName -> Bool occurs n = cata \case VarTF m | n == m -> True t -> or t -infer :: Context -> RlpExpr PsName -> HM (Type PsName) -infer g = \case +subst :: PsName -> Type PsName -> Type PsName -> Type PsName +subst n t' = para \case + VarTF m | n == m -> t' + -- shadowing + ForallTF x (pre,post) | x == n -> ForallT x pre + | otherwise -> ForallT x post + t -> embed $ t <&> view _2 - Finl (LitF (IntL _)) -> pure IntT - - {- Var - - x : τ ∈ Γ - - τ' = inst τ - - ----------- - - Γ |- x : τ' - -} - Finl (VarF x) -> do - t <- lookupVar x g - let t' = inst t - t' - - Finl (AppF f x) -> do - te <- infer g f - tx <- infer g x - t' <- freshTv - undefined - -unify :: Context -> Type PsName -> Type PsName -> Context -unify g = \cases - IntT IntT -> g - (VarT a) b | Just a' <- g ^. contextTyVars . at a -> unify g a' b - b (VarT a) | Just a' <- g ^. contextTyVars . at a -> unify g b a' - - s@(VarT a) b | Nothing <- g ^. contextTyVars . at a - | s == b +prettyHM :: (Pretty a) + => Either [TypeError] (a, [Constraint]) + -> Either [TypeError] (String, [String]) +prettyHM = over (mapped . _1) rpretty + . over (mapped . _2 . each) rpretty diff --git a/src/Rlp/HindleyMilner/Types.hs b/src/Rlp/HindleyMilner/Types.hs index de0dc3a..3271b6d 100644 --- a/src/Rlp/HindleyMilner/Types.hs +++ b/src/Rlp/HindleyMilner/Types.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE TemplateHaskell #-} module Rlp.HindleyMilner.Types where @@ -22,9 +23,8 @@ import Compiler.RlpcError import Rlp.AltSyntax -------------------------------------------------------------------------------- -data Context = Context +newtype Context = Context { _contextVars :: HashMap PsName (Type PsName) - , _contextTyVars :: HashMap PsName (Type PsName) } data Constraint = Equality (Type PsName) (Type PsName) @@ -38,7 +38,7 @@ data PartialJudgement = PartialJudgement [Constraint] instance Hashable Constraint -type HM = ErrorfulT TypeError (State Int) +type HM = ErrorfulT TypeError (StateT Int (Writer [Constraint])) -- | Type error enum. data TypeError @@ -112,12 +112,33 @@ freshTv = do modify succ pure . VarT $ "$a" <> T.pack (show n) -runHM' :: HM a -> Either [TypeError] a -runHM' e = maybe (Left es) Right ma +runHM' :: HM a -> Either [TypeError] (a, [Constraint]) +runHM' e = maybe (Left es) (Right . (,cs)) ma where - (ma,es) = (`evalState` 0) . runErrorfulT $ e + ((ma,es),cs) = runWriter . (`evalStateT` 0) . runErrorfulT $ e + +addConstraint :: Constraint -> HM () +addConstraint = tell . pure -- makePrisms ''PartialJudgement makeLenses ''Context +makePrisms ''Constraint + +supplement :: [(PsName, Type PsName)] -> Context -> Context +supplement bs = contextVars %~ (H.fromList bs <>) + +demoContext :: Context +demoContext = Context + { _contextVars = + [ ("+#", IntT :-> IntT :-> IntT) + ] + } + +constraintTypes :: Traversal' Constraint (Type PsName) +constraintTypes k (Equality s t) = Equality <$> k s <*> k t + +instance Pretty Constraint where + pretty (Equality s t) = + hsep [prettyPrec appPrec1 s, "~", prettyPrec appPrec1 t]