it's so over (whole-program inference again)

This commit is contained in:
crumbtoo
2024-03-28 10:59:51 -06:00
parent d360edc476
commit ff006abac0
7 changed files with 107 additions and 67 deletions

View File

@@ -61,6 +61,7 @@ Available debug flags include:
** TODO rlp to core desugaring :feature:
** TODO [#A] HM memoisation prevents shadowing :bug:
Example:
#+begin_src haskell
-- >>> runHM' $ infer1 [rlpExpr|let f = \x -> x in f (let f = 2 in f)|]
-- Left [TyErrCouldNotUnify
@@ -69,6 +70,7 @@ Available debug flags include:
-- >>> :t let f = \x -> x in f (let f = 2 in f)
-- let f = \x -> x in f (let f = 2 in f) :: Int
#+end_src
For the time being, I just disabled the memoisation. This is very, very bad.
** DONE README.md -> README.org :docs:
CLOSED: [2024-03-28 Thu 10:44]

View File

@@ -114,8 +114,10 @@ liftMaybe m = RLPCT . lift . ErrorfulT . pure $ (m, [])
liftEither :: (Monad m, IsRlpcError e)
=> Either [e] a -> RLPCT m a
liftEither = RLPCT . lift . ErrorfulT . pure
. either (const (Nothing,[])) ((,[]) . Just)
liftEither = RLPCT . lift . ErrorfulT . pure . f where
f (Left es) = (Nothing, errorMsg s . liftRlpcError <$> es)
where s = SrcSpan 0 0 0 0
f (Right a) = (Just a, [])
hoistRlpcT :: (forall a. m a -> n a)
-> RLPCT m a -> RLPCT n a

View File

@@ -444,21 +444,6 @@ instance (Out b) => Out (ScDef b) where
instance (Out b, Out a) => Out (ExprF b a) where
outPrec = outPrec1
-- outPrec _ (VarF n) = ttext n
-- outPrec _ (ConF t a) = "Pack{" <> (ttext t <+> ttext a) <> "}"
-- outPrec p (LamF bs e) = maybeParens (p>0) $
-- hsep ["λ", hsep (outPrec appPrec1 <$> bs), "->", out e]
-- outPrec p (LetF r bs e) = maybeParens (p>0)
-- $ hsep [out r, explicitLayout bs]
-- $+$ hsep ["in", out e]
-- outPrec p (AppF f x) = maybeParens (p>appPrec) $
-- outPrec appPrec f <+> outPrec appPrec1 x
-- outPrec p (LitF l) = outPrec p l
-- outPrec p (CaseF e as) = maybeParens (p>0) $
-- "case" <+> out e <+> "of"
-- $+$ nest 2 (explicitLayout as)
-- outPrec p (TypeF t) = "@" <> outPrec appPrec1 t
instance (Out b) => Out1 (ExprF b) where
liftOutPrec pr _ (VarF n) = ttext n
liftOutPrec pr _ (ConF t a) = "Pack{" <> (ttext t <+> ttext a) <> "}"

View File

@@ -158,6 +158,7 @@ Expr1 :: { RlpExpr PsName }
. singular _TokenLitInt
. to (Finl . Core.LitF . Core.IntL) }
| '(' Expr ')' { $2 }
| ConE { $1 }
AppE :: { RlpExpr PsName }
: AppE Expr1 { Finl $ Core.AppF $1 $2 }
@@ -166,6 +167,9 @@ AppE :: { RlpExpr PsName }
VarE :: { RlpExpr PsName }
: Var { Finl $ Core.VarF $1 }
ConE :: { RlpExpr PsName }
: Con { Finl $ Core.VarF $1 }
Pat1s :: { [Pat PsName] }
: list0(Pat1) { $1 }

View File

@@ -41,6 +41,7 @@ import Data.Hashable.Lifted
import GHC.Exts (IsString)
import Control.Lens hiding ((.=))
import Data.Functor.Extend
import Data.Functor.Foldable.TH
import Text.Show.Deriving
import Data.Eq.Deriving
@@ -62,6 +63,11 @@ type PsName = T.Text
newtype Program b a = Program [Decl b a]
deriving (Show, Functor, Foldable, Traversable)
instance Extend (Decl b) where
extended c w@(FunD n as a) = FunD n as (c w)
extended _ (DataD n as cs) = DataD n as cs
extended _ (TySigD n t) = TySigD n t
programDecls :: Iso (Program b a) (Program b' a') [Decl b a] [Decl b' a']
programDecls = iso sa bt where
sa (Program ds) = ds
@@ -158,6 +164,12 @@ instance Out b => Out1 (ExprF b) where
liftOutPrec pr p (CaseEF e as) = maybeParens (p>0) $
vsep [ hsep [ "case", pr 0 e, "of" ]
, nest 2 (vcat $ liftOutPrec pr 0 <$> as) ]
liftOutPrec pr p (LetEF r bs e) = maybeParens (p>0) $
vsep [ hsep [ letword r, "<bs>" ]
, nest 2 (hsep [ "in", pr 0 e ]) ]
where
letword Core.Rec = "letrec"
letword Core.NonRec = "let"
instance (Out b, Out a) => Out (Decl b a) where
outPrec = outPrec1

View File

@@ -43,6 +43,7 @@ import GHC.Generics (Generic, Generically(..))
import Debug.Trace
import Data.Functor hiding (unzip)
import Data.Functor.Extend
import Data.Functor.Foldable hiding (fold)
import Data.Fix hiding (cata, para)
import Control.Comonad.Cofree
@@ -68,11 +69,11 @@ lookupVar n g = case g ^. contextVars . at n of
Nothing -> addFatal $ TyErrUntypedVariable n
gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather e = look >>= (H.lookup e >>> maybe memoise pure)
gather e = use hmMemo >>= (H.lookup e >>> maybe memoise pure)
where
memoise = do
r <- gather' e
add (H.singleton e r)
hmMemo <>= H.singleton e r
pure r
gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
@@ -145,8 +146,8 @@ deleteKeys :: (Eq k, Hashable k) => [k] -> HashMap k v -> HashMap k v
deleteKeys ks h = foldr H.delete h ks
gatherBinds :: [Binding PsName (RlpExpr PsName)]
-> HM [( Type PsName
, Type PsName
-> HM [( Type PsName -- inferred type
, Type PsName -- generalised type
, PartialJudgement )]
gatherBinds bs = for bs $ \ (VarB (VarP k) x) -> do
((tx,jx),frees) <- listenFreshTvNames $ gather x
@@ -201,10 +202,10 @@ unify [] = pure mempty
unify (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify $ Equality sx tx : Equality sy ty : cs
unify (Equality a@(ConT ca `AppT` as) b@(ConT cb `AppT` bs) : cs)
| ca == cb = do
cs' <- liftA2 (zipWith Equality) (saturated a) (saturated b)
unify $ cs' ++ cs
-- unify (Equality a@(ConT ca `AppT` as) b@(ConT cb `AppT` bs) : cs)
-- | ca == cb = do
-- cs' <- liftA2 (zipWith Equality) (saturated a) (saturated b)
-- unify $ cs' ++ cs
-- elim
unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs
@@ -274,7 +275,7 @@ infer e = do
e' <- annotate e
let (cs,as) = finalJudgement e' ^. lensProduct constraints assumptions
cs' <- (<>cs) <$> elimAssumptionsG g0 as
checkUndefinedVariables e'
-- checkUndefinedVariables e'
sub <- solve cs'
pure $ e' & fmap (sub . view _1)
& _extract %~ generaliseG g0
@@ -349,38 +350,57 @@ typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName)
-> RLPCT m (Program PsName
(TypedRlpExpr PsName))
typeCheckRlpProgR p = liftHM g (inferProg . etaExpandAll $ p)
typeCheckRlpProgR p = liftHM g (inferProg p)
where
g = buildInitialContext p
etaExpandAll = programDecls . each %~ etaExpand
inferProg :: Program PsName (RlpExpr PsName)
-> HM (Program PsName (TypedRlpExpr PsName))
inferProg p = do
g0 <- ask
p' <- annotateProg p
let (cs,as) = foldOf (folded . folded . _2) p'
^. lensProduct constraints assumptions
cs' <- (<>cs) <$> elimAssumptionsG g0 as
-- we only wipe the memo here as a temporary solution to the memo shadowing
-- problem
p' <- (\e -> (hmMemo .= mempty) *> annotate e) `traverse` etaExpandAll p
traceM $ "p' : " <> show p'
let (cs,as) = foldMap finalJudgement p' ^. lensProduct constraints assumptions
cs' <- (cs <>) <$> elimAssumptionsG g0 as
traceM $ "cs' : " <> show cs'
sub <- solve cs'
pure $ p' & programDecls . traversed . _FunD . _3
%~ ((_extract %~ generaliseG g0) . fmap (sub . view _1))
where
etaExpandAll = programDecls . each %~ etaExpand
annotateProg :: Program PsName (RlpExpr PsName)
-> HM ( Program PsName
(Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)))
annotateProg = traverse annotate
gatherProg :: Program PsName (RlpExpr PsName)
-> HM [Constraint]
gatherProg p = do
-- this should be nearly identical to the rule for `letrec` in gather'
let (ks,xs) = unzip $ funsToSimpleBinds (p ^. programDecls)
(txs,txs',jxs) <- unzip3 <$> gatherBinds (zipWith simpleBind ks xs)
let jxsa = foldOf (each . assumptions) jxs
jxcs <- elimWithBinds (ks `zip` txs) jxsa
(Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
, [Constraint] )
annotateProg p = do
let bs = funsToSimpleBinds (p ^. programDecls)
(ks,xs) = unzip bs
xs' <- annotate `traverse` xs
let jxs = foldOf (each . _extract . _2) xs'
txs = xs' ^.. each . _extract . _1
cs <- elimWithBinds (ks `zip` txs) (jxs ^. assumptions)
-- let p' = annotateDecls (ks `zip` xs') p
p' <- annotate `traverse` p
-- TODO: any remaining assumptions should be errors at this point
pure jxcs
pure (p',cs)
-- this sucks! FunDs should probably be stored as a hashmap in Program...
annotateDecls :: [( PsName
, Cofree (RlpExprF PsName) (Type PsName, PartialJudgement) )]
-> Program PsName a
-> Program PsName
(Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
annotateDecls bs = programDecls . traversed . _FunD %~ \case
(n,_,_)
| Just e <- lookup n bs
-> (n,[],e)
gatherBinds' :: [(PsName, RlpExpr PsName)]
-> HM [(Type PsName, Type PsName, PartialJudgement)]
gatherBinds' = gatherBinds . fmap (uncurry simpleBind)
elimWithBinds :: [(PsName, Type PsName)]
-> Assumptions

View File

@@ -60,7 +60,21 @@ instance Hashable Constraint
type Memo = HashMap (RlpExpr PsName) (Type PsName, PartialJudgement)
type HM = ErrorfulT TypeError (ReaderT Context (StateT Int (Accum Memo)))
data HMState = HMState
{ _hmMemo :: Memo
, _hmUniq :: Int
}
deriving Show
newtype HM a = HM {
unHM :: ErrorfulT TypeError
(ReaderT Context (State HMState)) a
}
deriving (Functor, Applicative, Monad)
deriving ( MonadReader Context
, MonadState HMState
, MonadErrorful TypeError
)
-- | Type error enum.
data TypeError
@@ -91,30 +105,11 @@ instance IsRlpcError TypeError where
(rout @String t) (rout @String x)
]
tvNameOfInt :: Int -> PsName
tvNameOfInt n = "$a" <> T.pack (show n)
freshTv :: HM (Type PsName)
freshTv = do
n <- get
modify succ
pure (VarT $ tvNameOfInt n)
listenFreshTvs :: HM a -> HM (a, [Type PsName])
listenFreshTvs hm = listenFreshTvNames hm & mapped . _2 . each %~ VarT
listenFreshTvNames :: HM a -> HM (a, [PsName])
listenFreshTvNames hm = do
n <- get
a <- hm
n' <- get
pure (a, [ tvNameOfInt k | k <- [n .. pred n'] ])
runHM :: Context -> HM a -> Either [TypeError] a
runHM g e = maybe (Left es) Right ma
where
((ma,es),m) = (`runAccum` mempty) . (`evalStateT` 0)
. (`runReaderT` g) . runErrorfulT $ e
(ma,es) = (`evalState` (HMState mempty 0))
. (`runReaderT` g) . runErrorfulT $ unHM e
runHM' :: HM a -> Either [TypeError] a
runHM' = runHM mempty
@@ -124,6 +119,7 @@ makeLenses ''PartialJudgement
makeLenses ''Context
makePrisms ''Constraint
makePrisms ''TypeError
makeLenses ''HMState
supplement :: [(PsName, Type PsName)] -> Context -> Context
supplement bs = contextVars %~ (H.fromList bs <>)
@@ -147,3 +143,22 @@ instance Out Constraint where
out (Equality s t) =
hsep [outPrec appPrec1 s, "~", outPrec appPrec1 t]
tvNameOfInt :: Int -> PsName
tvNameOfInt n = "$a" <> T.pack (show n)
freshTv :: HM (Type PsName)
freshTv = do
n <- use hmUniq
hmUniq %= succ
pure (VarT $ tvNameOfInt n)
listenFreshTvs :: HM a -> HM (a, [Type PsName])
listenFreshTvs hm = listenFreshTvNames hm & mapped . _2 . each %~ VarT
listenFreshTvNames :: HM a -> HM (a, [PsName])
listenFreshTvNames hm = do
n <- use hmUniq
a <- hm
n' <- use hmUniq
pure (a, [ tvNameOfInt k | k <- [n .. pred n'] ])