seems to work

This commit is contained in:
crumbtoo
2024-03-13 18:10:29 -06:00
parent e00e0eff3b
commit 8fd75a67d3
5 changed files with 93 additions and 49 deletions

View File

@@ -71,6 +71,7 @@ library
, data-default-class >= 0.1.2 && < 0.2 , data-default-class >= 0.1.2 && < 0.2
, hashable >= 1.4.3 && < 1.5 , hashable >= 1.4.3 && < 1.5
, mtl >= 2.3.1 && < 2.4 , mtl >= 2.3.1 && < 2.4
, transformers
, text >= 2.0.2 && < 2.2 , text >= 2.0.2 && < 2.2
, unordered-containers >= 0.2.20 && < 0.3 , unordered-containers >= 0.2.20 && < 0.3
, recursion-schemes >= 5.2.2 && < 5.3 , recursion-schemes >= 5.2.2 && < 5.3

View File

@@ -16,6 +16,7 @@ module Control.Monad.Errorful
import Control.Monad.State.Strict import Control.Monad.State.Strict
import Control.Monad.Writer import Control.Monad.Writer
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad.Accum
import Control.Monad.Trans import Control.Monad.Trans
import Data.Functor.Identity import Data.Functor.Identity
import Data.Coerce import Data.Coerce
@@ -95,3 +96,7 @@ instance (Monoid w, Monad m, MonadWriter w m) => MonadWriter w (ErrorfulT e m) w
((,w) <$> ma, es) ((,w) <$> ma, es)
pass (ErrorfulT m) = undefined pass (ErrorfulT m) = undefined
instance (Monoid w, Monad m, MonadAccum w m)
=> MonadAccum w (ErrorfulT e m) where
accum = lift . accum

View File

@@ -128,7 +128,6 @@ Expr :: { RlpExpr PsName }
: AppE { $1 } : AppE { $1 }
| LetE { $1 } | LetE { $1 }
| CaseE { $1 } | CaseE { $1 }
| Expr1 { $1 }
| LamE { $1 } | LamE { $1 }
LamE :: { RlpExpr PsName } LamE :: { RlpExpr PsName }
@@ -155,6 +154,7 @@ Expr1 :: { RlpExpr PsName }
| litint { $1 ^. to extract | litint { $1 ^. to extract
. singular _TokenLitInt . singular _TokenLitInt
. to (Finl . Core.LitF . Core.IntL) } . to (Finl . Core.LitF . Core.IntL) }
| '(' Expr ')' { $2 }
AppE :: { RlpExpr PsName } AppE :: { RlpExpr PsName }
: AppE Expr1 { Finl $ Core.AppF $1 $2 } : AppE Expr1 { Finl $ Core.AppF $1 $2 }

View File

@@ -13,7 +13,9 @@ module Rlp.HindleyMilner
import Control.Lens hiding (Context', Context, (:<), para) import Control.Lens hiding (Context', Context, (:<), para)
import Control.Monad.Errorful import Control.Monad.Errorful
import Control.Monad.State import Control.Monad.State
import Control.Monad.Accum
import Control.Monad import Control.Monad
import Control.Arrow ((>>>))
import Control.Monad.Writer.Strict import Control.Monad.Writer.Strict
import Data.Text qualified as T import Data.Text qualified as T
import Data.Pretty import Data.Pretty
@@ -40,10 +42,6 @@ import Core.Syntax (ExprF(..), Lit(..))
import Rlp.HindleyMilner.Types import Rlp.HindleyMilner.Types
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- | Synonym for @Errorful [TypeError]@. This means an @HMError@ action may
-- throw any number of fatal or nonfatal errors. Run with @runErrorful@.
type HMError = Errorful TypeError
fixCofree :: (Functor f, Functor g) fixCofree :: (Functor f, Functor g)
=> Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b) => Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b)
fixCofree = iso sa bt where fixCofree = iso sa bt where
@@ -55,53 +53,80 @@ lookupVar n g = case g ^. contextVars . at n of
Just t -> pure t Just t -> pure t
Nothing -> addFatal $ TyErrUntypedVariable n Nothing -> addFatal $ TyErrUntypedVariable n
gather :: Context -> RlpExpr PsName -> HM (Type PsName) gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather g = \case gather e = look >>= (H.lookup e >>> maybe memoise pure)
where
memoise = do
r <- gather' e
add (H.singleton e r)
pure r
Finl (LitF (IntL _)) -> pure IntT gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather' = \case
Finl (LitF (IntL _)) -> pure (IntT, mempty)
Finl (VarF n) -> do
t <- freshTv
let j = mempty & assumptions .~ H.singleton n [t]
pure (t,j)
Finl (AppF f x) -> do Finl (AppF f x) -> do
tf <- gather g f
tx <- gather g x
tfx <- freshTv tfx <- freshTv
addConstraint $ Equality tf (tx :-> tfx) (tf,jf) <- gather f
pure tfx (tx,jx) <- gather x
let jtfx = mempty & constraints .~ [Equality tf (tx :-> tfx)]
pure (tfx, jf <> jx <> jtfx)
Finl (VarF n) -> lookupVar n g Finl (LamF [b] e) -> do
tb <- freshTv
(te,je) <- gather e
let cs = maybe [] (fmap $ Equality tb) (je ^. assumptions . at b)
as = je ^. assumptions & at b .~ Nothing
j = mempty & constraints .~ cs & assumptions .~ as
t = tb :-> te
pure (t,j)
Finl (LamF bs e) -> do unify :: [Constraint] -> HM Context
tbs <- for bs $ \b -> (b,) <$> freshTv
te <- gather (supplement tbs g) e
pure $ foldrOf (each . _2) (:->) te tbs
unify :: Context -> [Constraint] -> HM Context unify [] = pure mempty
unify g [] = pure g unify (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify $ Equality sx tx : Equality sy ty : cs
unify g (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify g $ Equality sx tx : Equality sy ty : cs
-- elim -- elim
unify g (Equality (ConT s) (ConT t) : cs) | s == t = unify g cs unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs
unify g (Equality (VarT s) (VarT t) : cs) | s == t = unify g cs unify (Equality (VarT s) (VarT t) : cs) | s == t = unify cs
unify g (Equality (VarT s) t : cs) unify (Equality (VarT s) t : cs)
| occurs s t = addFatal $ TyErrRecursiveType s t | occurs s t = addFatal $ TyErrRecursiveType s t
| otherwise = unify g' cs' | otherwise = unify cs' <&> contextVars . at s ?~ t
where where
g' = supplement [(s,t)] g
cs' = cs & each . constraintTypes %~ subst s t cs' = cs & each . constraintTypes %~ subst s t
-- swap -- swap
unify g (Equality s (VarT t) : cs) = unify g (Equality (VarT t) s : cs) unify (Equality s (VarT t) : cs) = unify (Equality (VarT t) s : cs)
unify _ (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t unify (Equality s t : _) = addFatal $ TyErrCouldNotUnify s t
solve :: Context -> RlpExpr PsName -> HM (Type PsName) annotate :: RlpExpr PsName
solve g e = do -> HM (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
(t,cs) <- listen $ gather g e annotate = sequenceA . fixtend (gather . wrapFix)
g' <- unify g cs
pure $ ifoldrOf (contextVars . itraversed) subst t g' solveTree :: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)
-> HM (Type PsName)
solveTree e = undefined
infer1 :: RlpExpr PsName -> HM (Type PsName)
infer1 e = do
((t,j) :< _) <- annotate e
g <- unify (j ^. constraints)
pure $ ifoldrOf (contextVars . itraversed) subst t g
solve = undefined
-- solve g e = do
-- (t,j) <- gather e
-- g' <- unify cs
-- pure $ ifoldrOf (contextVars . itraversed) subst t g'
occurs :: PsName -> Type PsName -> Bool occurs :: PsName -> Type PsName -> Bool
occurs n = cata \case occurs n = cata \case
@@ -122,11 +147,11 @@ prettyHM :: (Pretty a)
prettyHM = over (mapped . _1) rpretty prettyHM = over (mapped . _1) rpretty
. over (mapped . _2 . each) rpretty . over (mapped . _2 . each) rpretty
fixtend :: (f (Fix f) -> b) -> Fix f -> Cofree f b fixtend :: Functor f => (f (Fix f) -> b) -> Fix f -> Cofree f b
fixtend = undefined fixtend c (Fix f) = c f :< fmap (fixtend c) f
infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName)) infer :: RlpExpr PsName -> HM (Cofree (RlpExprF PsName) (Type PsName))
infer = _ . fixtend (solve _ . wrapFix) infer = undefined
typeCheckRlpProgR :: (Monad m) typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName) => Program PsName (RlpExpr PsName)

View File

@@ -12,10 +12,13 @@ import GHC.Generics (Generic(..), Generically(..))
import Data.Kind qualified import Data.Kind qualified
import Data.Text qualified as T import Data.Text qualified as T
import Control.Monad.Writer import Control.Monad.Writer
import Control.Monad.Accum
import Control.Monad.Trans.Accum
import Control.Monad.Errorful import Control.Monad.Errorful
import Control.Monad.State import Control.Monad.State
import Text.Printf import Text.Printf
import Data.Pretty import Data.Pretty
import Data.Function
import Control.Lens hiding (Context', Context) import Control.Lens hiding (Context', Context)
@@ -26,22 +29,32 @@ import Rlp.AltSyntax
newtype Context = Context newtype Context = Context
{ _contextVars :: HashMap PsName (Type PsName) { _contextVars :: HashMap PsName (Type PsName)
} }
deriving (Generic) deriving (Show, Generic)
deriving (Semigroup, Monoid) deriving (Semigroup, Monoid)
via Generically Context via Generically Context
data Constraint = Equality (Type PsName) (Type PsName) data Constraint = Equality (Type PsName) (Type PsName)
deriving (Eq, Generic, Show) deriving (Eq, Generic, Show)
data PartialJudgement = PartialJudgement [Constraint] data PartialJudgement = PartialJudgement
(HashMap PsName [Type PsName]) { _constraints :: [Constraint]
, _assumptions :: HashMap PsName [Type PsName]
}
deriving (Generic, Show) deriving (Generic, Show)
deriving (Semigroup, Monoid) deriving (Monoid)
via Generically PartialJudgement via Generically PartialJudgement
instance Semigroup PartialJudgement where
a <> b = PartialJudgement
{ _constraints = ((<>) `on` _constraints) a b
, _assumptions = (H.unionWith (<>) `on` _assumptions) a b
}
instance Hashable Constraint instance Hashable Constraint
type HM = ErrorfulT TypeError (StateT Int (Writer [Constraint])) type Memo = HashMap (RlpExpr PsName) (Type PsName, PartialJudgement)
type HM = ErrorfulT TypeError (StateT Int (Accum Memo))
-- | Type error enum. -- | Type error enum.
data TypeError data TypeError
@@ -116,16 +129,16 @@ freshTv = do
modify succ modify succ
pure . VarT $ "$a" <> T.pack (show n) pure . VarT $ "$a" <> T.pack (show n)
runHM' :: HM a -> Either [TypeError] (a, [Constraint]) runHM' :: HM a -> Either [TypeError] a
runHM' e = maybe (Left es) (Right . (,cs)) ma runHM' e = maybe (Left es) Right ma
where where
((ma,es),cs) = runWriter . (`evalStateT` 0) . runErrorfulT $ e ((ma,es),m) = (`runAccum` mempty) . (`evalStateT` 0) . runErrorfulT $ e
addConstraint :: Constraint -> HM () -- addConstraint :: Constraint -> HM ()
addConstraint = tell . pure -- addConstraint = tell . pure
-- makePrisms ''PartialJudgement
makePrisms ''PartialJudgement
makeLenses ''PartialJudgement
makeLenses ''Context makeLenses ''Context
makePrisms ''Constraint makePrisms ''Constraint
makePrisms ''TypeError makePrisms ''TypeError