i'm honestly rather disappointed in myself for not implementing a comonadic algo J.
cross my heart i'll come back to this and return stronger!
in the mean time, i really need to get this thing into a presentable state...
This commit is contained in:
crumbtoo
2024-03-11 10:36:38 -06:00
parent 35c770c63c
commit cf81b76c1a
4 changed files with 92 additions and 59 deletions

View File

@@ -14,6 +14,7 @@ module Control.Monad.Errorful
where where
---------------------------------------------------------------------------------- ----------------------------------------------------------------------------------
import Control.Monad.State.Strict import Control.Monad.State.Strict
import Control.Monad.Writer
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.Trans import Control.Monad.Trans
import Data.Functor.Identity import Data.Functor.Identity
@@ -88,3 +89,9 @@ instance (Monad m, MonadErrorful e m) => MonadErrorful e (ReaderT r m) where
instance (Monad m, MonadState s m) => MonadState s (ErrorfulT e m) where instance (Monad m, MonadState s m) => MonadState s (ErrorfulT e m) where
state = lift . state state = lift . state
instance (Monoid w, Monad m, MonadWriter w m) => MonadWriter w (ErrorfulT e m) where
tell = lift . tell
listen (ErrorfulT m) = ErrorfulT $ listen m <&> \ ((ma,es),w) ->
((,w) <$> ma, es)
pass (ErrorfulT m) = undefined

View File

@@ -4,6 +4,7 @@ module Data.Pretty
, prettyPrec1 , prettyPrec1
, rpretty , rpretty
, ttext , ttext
, Showing(..)
-- * Pretty-printing lens combinators -- * Pretty-printing lens combinators
, hsepOf, vsepOf, vcatOf, vlinesOf, vsepTerm , hsepOf, vsepOf, vcatOf, vlinesOf, vsepTerm
, vsep , vsep

View File

@@ -15,6 +15,7 @@ import Control.Monad.State
import Control.Monad import Control.Monad
import Control.Monad.Writer.Strict import Control.Monad.Writer.Strict
import Data.Text qualified as T import Data.Text qualified as T
import Data.Pretty
import Data.Hashable import Data.Hashable
import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as H import Data.HashMap.Strict qualified as H
@@ -50,69 +51,72 @@ fixCofree = iso sa bt where
lookupVar :: PsName -> Context -> HM (Type PsName) lookupVar :: PsName -> Context -> HM (Type PsName)
lookupVar n g = case g ^. contextVars . at n of lookupVar n g = case g ^. contextVars . at n of
Just t -> pure t Just t -> pure t
Nothing -> addFatal (TyErrUntypedVariable n) Nothing -> addFatal $ TyErrUntypedVariable n
-- | Instantiate a polytype by replacing the bound type variables with fresh gather :: Context -> RlpExpr PsName -> HM (Type PsName)
-- monotype (free) variables gather g = \case
inst :: Type PsName -> HM (Type PsName)
inst = para \case
ForallTF x (_,t) -> do
m <- t
tv <- freshTv
pure $ subst x tv m
-- discard the recursive results by selected fst
t -> pure . embed . fmap fst $ t
generalise :: Type PsName -> Type PsName Finl (LitF (IntL _)) -> pure IntT
generalise = foldr ForallT <*> toListOf tyVars
tyVars :: Traversal (Type b) (Type b') b b' Finl (AppF f x) -> do
tyVars = traverse tf <- gather g f
tx <- gather g x
tfx <- freshTv
addConstraint $ Equality tf (tx :-> tfx)
pure tfx
polytypeBinds :: Traversal' (Type b) b Finl (VarF n) -> lookupVar n g
polytypeBinds k (ForallT x m) = ForallT <$> k x <*> polytypeBinds k m
polytypeBinds k t = pure t
subst :: PsName -> Type PsName -> Type PsName -> Type PsName Finl (LamF bs e) -> do
subst n t' = para \case tbs <- for bs $ \b -> (b,) <$> freshTv
VarTF m | n == m -> t' te <- gather (supplement tbs g) e
ForallTF m (pre,post) | n == m -> pre pure $ foldrOf (each . _2) (:->) te tbs
| otherwise -> post
t -> embed . fmap snd $ t unify :: Context -> [Constraint] -> HM Context
unify g [] = pure g
unify g (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify g $ Equality sx tx : Equality sy ty : cs
-- elim
unify g (Equality (ConT s) (ConT t) : cs) | s == t = unify g cs
unify g (Equality (VarT s) (VarT t) : cs) | s == t = unify g cs
unify g (Equality (VarT s) t : cs)
| occurs s t = addFatal $ TyErrRecursiveType s t
| otherwise = unify g' cs'
where
g' = supplement [(s,t)] g
cs' = cs & each . constraintTypes %~ subst s t
-- swap
unify g (Equality s (VarT t) : cs) = unify g (Equality (VarT t) s : cs)
unify _ (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
solve :: Context -> RlpExpr PsName -> HM (Type PsName)
solve g e = do
(t,cs) <- listen $ gather g e
g' <- unify g cs
pure $ ifoldrOf (contextVars . itraversed) subst t g'
occurs :: PsName -> Type PsName -> Bool occurs :: PsName -> Type PsName -> Bool
occurs n = cata \case occurs n = cata \case
VarTF m | n == m -> True VarTF m | n == m -> True
t -> or t t -> or t
infer :: Context -> RlpExpr PsName -> HM (Type PsName) subst :: PsName -> Type PsName -> Type PsName -> Type PsName
infer g = \case subst n t' = para \case
VarTF m | n == m -> t'
-- shadowing
ForallTF x (pre,post) | x == n -> ForallT x pre
| otherwise -> ForallT x post
t -> embed $ t <&> view _2
Finl (LitF (IntL _)) -> pure IntT prettyHM :: (Pretty a)
=> Either [TypeError] (a, [Constraint])
{- Var -> Either [TypeError] (String, [String])
- x : τ ∈ Γ prettyHM = over (mapped . _1) rpretty
- τ' = inst τ . over (mapped . _2 . each) rpretty
- -----------
- Γ |- x : τ'
-}
Finl (VarF x) -> do
t <- lookupVar x g
let t' = inst t
t'
Finl (AppF f x) -> do
te <- infer g f
tx <- infer g x
t' <- freshTv
undefined
unify :: Context -> Type PsName -> Type PsName -> Context
unify g = \cases
IntT IntT -> g
(VarT a) b | Just a' <- g ^. contextTyVars . at a -> unify g a' b
b (VarT a) | Just a' <- g ^. contextTyVars . at a -> unify g b a'
s@(VarT a) b | Nothing <- g ^. contextTyVars . at a
| s == b

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TemplateHaskell #-}
module Rlp.HindleyMilner.Types module Rlp.HindleyMilner.Types
where where
@@ -22,9 +23,8 @@ import Compiler.RlpcError
import Rlp.AltSyntax import Rlp.AltSyntax
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
data Context = Context newtype Context = Context
{ _contextVars :: HashMap PsName (Type PsName) { _contextVars :: HashMap PsName (Type PsName)
, _contextTyVars :: HashMap PsName (Type PsName)
} }
data Constraint = Equality (Type PsName) (Type PsName) data Constraint = Equality (Type PsName) (Type PsName)
@@ -38,7 +38,7 @@ data PartialJudgement = PartialJudgement [Constraint]
instance Hashable Constraint instance Hashable Constraint
type HM = ErrorfulT TypeError (State Int) type HM = ErrorfulT TypeError (StateT Int (Writer [Constraint]))
-- | Type error enum. -- | Type error enum.
data TypeError data TypeError
@@ -112,12 +112,33 @@ freshTv = do
modify succ modify succ
pure . VarT $ "$a" <> T.pack (show n) pure . VarT $ "$a" <> T.pack (show n)
runHM' :: HM a -> Either [TypeError] a runHM' :: HM a -> Either [TypeError] (a, [Constraint])
runHM' e = maybe (Left es) Right ma runHM' e = maybe (Left es) (Right . (,cs)) ma
where where
(ma,es) = (`evalState` 0) . runErrorfulT $ e ((ma,es),cs) = runWriter . (`evalStateT` 0) . runErrorfulT $ e
addConstraint :: Constraint -> HM ()
addConstraint = tell . pure
-- makePrisms ''PartialJudgement -- makePrisms ''PartialJudgement
makeLenses ''Context makeLenses ''Context
makePrisms ''Constraint
supplement :: [(PsName, Type PsName)] -> Context -> Context
supplement bs = contextVars %~ (H.fromList bs <>)
demoContext :: Context
demoContext = Context
{ _contextVars =
[ ("+#", IntT :-> IntT :-> IntT)
]
}
constraintTypes :: Traversal' Constraint (Type PsName)
constraintTypes k (Equality s t) = Equality <$> k s <*> k t
instance Pretty Constraint where
pretty (Equality s t) =
hsep [prettyPrec appPrec1 s, "~", prettyPrec appPrec1 t]