it's so over (whole-program inference again)

This commit is contained in:
crumbtoo
2024-03-28 10:59:51 -06:00
parent ddd1e7b931
commit fa2b2d6ed5
7 changed files with 107 additions and 67 deletions

View File

@@ -61,6 +61,7 @@ Available debug flags include:
** TODO rlp to core desugaring :feature: ** TODO rlp to core desugaring :feature:
** TODO [#A] HM memoisation prevents shadowing :bug: ** TODO [#A] HM memoisation prevents shadowing :bug:
Example:
#+begin_src haskell #+begin_src haskell
-- >>> runHM' $ infer1 [rlpExpr|let f = \x -> x in f (let f = 2 in f)|] -- >>> runHM' $ infer1 [rlpExpr|let f = \x -> x in f (let f = 2 in f)|]
-- Left [TyErrCouldNotUnify -- Left [TyErrCouldNotUnify
@@ -69,6 +70,7 @@ Available debug flags include:
-- >>> :t let f = \x -> x in f (let f = 2 in f) -- >>> :t let f = \x -> x in f (let f = 2 in f)
-- let f = \x -> x in f (let f = 2 in f) :: Int -- let f = \x -> x in f (let f = 2 in f) :: Int
#+end_src #+end_src
For the time being, I just disabled the memoisation. This is very, very bad.
** DONE README.md -> README.org :docs: ** DONE README.md -> README.org :docs:
CLOSED: [2024-03-28 Thu 10:44] CLOSED: [2024-03-28 Thu 10:44]

View File

@@ -114,8 +114,10 @@ liftMaybe m = RLPCT . lift . ErrorfulT . pure $ (m, [])
liftEither :: (Monad m, IsRlpcError e) liftEither :: (Monad m, IsRlpcError e)
=> Either [e] a -> RLPCT m a => Either [e] a -> RLPCT m a
liftEither = RLPCT . lift . ErrorfulT . pure liftEither = RLPCT . lift . ErrorfulT . pure . f where
. either (const (Nothing,[])) ((,[]) . Just) f (Left es) = (Nothing, errorMsg s . liftRlpcError <$> es)
where s = SrcSpan 0 0 0 0
f (Right a) = (Just a, [])
hoistRlpcT :: (forall a. m a -> n a) hoistRlpcT :: (forall a. m a -> n a)
-> RLPCT m a -> RLPCT n a -> RLPCT m a -> RLPCT n a

View File

@@ -444,21 +444,6 @@ instance (Out b) => Out (ScDef b) where
instance (Out b, Out a) => Out (ExprF b a) where instance (Out b, Out a) => Out (ExprF b a) where
outPrec = outPrec1 outPrec = outPrec1
-- outPrec _ (VarF n) = ttext n
-- outPrec _ (ConF t a) = "Pack{" <> (ttext t <+> ttext a) <> "}"
-- outPrec p (LamF bs e) = maybeParens (p>0) $
-- hsep ["λ", hsep (outPrec appPrec1 <$> bs), "->", out e]
-- outPrec p (LetF r bs e) = maybeParens (p>0)
-- $ hsep [out r, explicitLayout bs]
-- $+$ hsep ["in", out e]
-- outPrec p (AppF f x) = maybeParens (p>appPrec) $
-- outPrec appPrec f <+> outPrec appPrec1 x
-- outPrec p (LitF l) = outPrec p l
-- outPrec p (CaseF e as) = maybeParens (p>0) $
-- "case" <+> out e <+> "of"
-- $+$ nest 2 (explicitLayout as)
-- outPrec p (TypeF t) = "@" <> outPrec appPrec1 t
instance (Out b) => Out1 (ExprF b) where instance (Out b) => Out1 (ExprF b) where
liftOutPrec pr _ (VarF n) = ttext n liftOutPrec pr _ (VarF n) = ttext n
liftOutPrec pr _ (ConF t a) = "Pack{" <> (ttext t <+> ttext a) <> "}" liftOutPrec pr _ (ConF t a) = "Pack{" <> (ttext t <+> ttext a) <> "}"

View File

@@ -158,6 +158,7 @@ Expr1 :: { RlpExpr PsName }
. singular _TokenLitInt . singular _TokenLitInt
. to (Finl . Core.LitF . Core.IntL) } . to (Finl . Core.LitF . Core.IntL) }
| '(' Expr ')' { $2 } | '(' Expr ')' { $2 }
| ConE { $1 }
AppE :: { RlpExpr PsName } AppE :: { RlpExpr PsName }
: AppE Expr1 { Finl $ Core.AppF $1 $2 } : AppE Expr1 { Finl $ Core.AppF $1 $2 }
@@ -166,6 +167,9 @@ AppE :: { RlpExpr PsName }
VarE :: { RlpExpr PsName } VarE :: { RlpExpr PsName }
: Var { Finl $ Core.VarF $1 } : Var { Finl $ Core.VarF $1 }
ConE :: { RlpExpr PsName }
: Con { Finl $ Core.VarF $1 }
Pat1s :: { [Pat PsName] } Pat1s :: { [Pat PsName] }
: list0(Pat1) { $1 } : list0(Pat1) { $1 }

View File

@@ -41,6 +41,7 @@ import Data.Hashable.Lifted
import GHC.Exts (IsString) import GHC.Exts (IsString)
import Control.Lens hiding ((.=)) import Control.Lens hiding ((.=))
import Data.Functor.Extend
import Data.Functor.Foldable.TH import Data.Functor.Foldable.TH
import Text.Show.Deriving import Text.Show.Deriving
import Data.Eq.Deriving import Data.Eq.Deriving
@@ -62,6 +63,11 @@ type PsName = T.Text
newtype Program b a = Program [Decl b a] newtype Program b a = Program [Decl b a]
deriving (Show, Functor, Foldable, Traversable) deriving (Show, Functor, Foldable, Traversable)
instance Extend (Decl b) where
extended c w@(FunD n as a) = FunD n as (c w)
extended _ (DataD n as cs) = DataD n as cs
extended _ (TySigD n t) = TySigD n t
programDecls :: Iso (Program b a) (Program b' a') [Decl b a] [Decl b' a'] programDecls :: Iso (Program b a) (Program b' a') [Decl b a] [Decl b' a']
programDecls = iso sa bt where programDecls = iso sa bt where
sa (Program ds) = ds sa (Program ds) = ds
@@ -158,6 +164,12 @@ instance Out b => Out1 (ExprF b) where
liftOutPrec pr p (CaseEF e as) = maybeParens (p>0) $ liftOutPrec pr p (CaseEF e as) = maybeParens (p>0) $
vsep [ hsep [ "case", pr 0 e, "of" ] vsep [ hsep [ "case", pr 0 e, "of" ]
, nest 2 (vcat $ liftOutPrec pr 0 <$> as) ] , nest 2 (vcat $ liftOutPrec pr 0 <$> as) ]
liftOutPrec pr p (LetEF r bs e) = maybeParens (p>0) $
vsep [ hsep [ letword r, "<bs>" ]
, nest 2 (hsep [ "in", pr 0 e ]) ]
where
letword Core.Rec = "letrec"
letword Core.NonRec = "let"
instance (Out b, Out a) => Out (Decl b a) where instance (Out b, Out a) => Out (Decl b a) where
outPrec = outPrec1 outPrec = outPrec1

View File

@@ -43,6 +43,7 @@ import GHC.Generics (Generic, Generically(..))
import Debug.Trace import Debug.Trace
import Data.Functor hiding (unzip) import Data.Functor hiding (unzip)
import Data.Functor.Extend
import Data.Functor.Foldable hiding (fold) import Data.Functor.Foldable hiding (fold)
import Data.Fix hiding (cata, para) import Data.Fix hiding (cata, para)
import Control.Comonad.Cofree import Control.Comonad.Cofree
@@ -68,11 +69,11 @@ lookupVar n g = case g ^. contextVars . at n of
Nothing -> addFatal $ TyErrUntypedVariable n Nothing -> addFatal $ TyErrUntypedVariable n
gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement) gather :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
gather e = look >>= (H.lookup e >>> maybe memoise pure) gather e = use hmMemo >>= (H.lookup e >>> maybe memoise pure)
where where
memoise = do memoise = do
r <- gather' e r <- gather' e
add (H.singleton e r) hmMemo <>= H.singleton e r
pure r pure r
gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement) gather' :: RlpExpr PsName -> HM (Type PsName, PartialJudgement)
@@ -145,8 +146,8 @@ deleteKeys :: (Eq k, Hashable k) => [k] -> HashMap k v -> HashMap k v
deleteKeys ks h = foldr H.delete h ks deleteKeys ks h = foldr H.delete h ks
gatherBinds :: [Binding PsName (RlpExpr PsName)] gatherBinds :: [Binding PsName (RlpExpr PsName)]
-> HM [( Type PsName -> HM [( Type PsName -- inferred type
, Type PsName , Type PsName -- generalised type
, PartialJudgement )] , PartialJudgement )]
gatherBinds bs = for bs $ \ (VarB (VarP k) x) -> do gatherBinds bs = for bs $ \ (VarB (VarP k) x) -> do
((tx,jx),frees) <- listenFreshTvNames $ gather x ((tx,jx),frees) <- listenFreshTvNames $ gather x
@@ -201,10 +202,10 @@ unify [] = pure mempty
unify (Equality (sx :-> sy) (tx :-> ty) : cs) = unify (Equality (sx :-> sy) (tx :-> ty) : cs) =
unify $ Equality sx tx : Equality sy ty : cs unify $ Equality sx tx : Equality sy ty : cs
unify (Equality a@(ConT ca `AppT` as) b@(ConT cb `AppT` bs) : cs) -- unify (Equality a@(ConT ca `AppT` as) b@(ConT cb `AppT` bs) : cs)
| ca == cb = do -- | ca == cb = do
cs' <- liftA2 (zipWith Equality) (saturated a) (saturated b) -- cs' <- liftA2 (zipWith Equality) (saturated a) (saturated b)
unify $ cs' ++ cs -- unify $ cs' ++ cs
-- elim -- elim
unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs unify (Equality (ConT s) (ConT t) : cs) | s == t = unify cs
@@ -274,7 +275,7 @@ infer e = do
e' <- annotate e e' <- annotate e
let (cs,as) = finalJudgement e' ^. lensProduct constraints assumptions let (cs,as) = finalJudgement e' ^. lensProduct constraints assumptions
cs' <- (<>cs) <$> elimAssumptionsG g0 as cs' <- (<>cs) <$> elimAssumptionsG g0 as
checkUndefinedVariables e' -- checkUndefinedVariables e'
sub <- solve cs' sub <- solve cs'
pure $ e' & fmap (sub . view _1) pure $ e' & fmap (sub . view _1)
& _extract %~ generaliseG g0 & _extract %~ generaliseG g0
@@ -349,38 +350,57 @@ typeCheckRlpProgR :: (Monad m)
=> Program PsName (RlpExpr PsName) => Program PsName (RlpExpr PsName)
-> RLPCT m (Program PsName -> RLPCT m (Program PsName
(TypedRlpExpr PsName)) (TypedRlpExpr PsName))
typeCheckRlpProgR p = liftHM g (inferProg . etaExpandAll $ p) typeCheckRlpProgR p = liftHM g (inferProg p)
where where
g = buildInitialContext p g = buildInitialContext p
etaExpandAll = programDecls . each %~ etaExpand
inferProg :: Program PsName (RlpExpr PsName) inferProg :: Program PsName (RlpExpr PsName)
-> HM (Program PsName (TypedRlpExpr PsName)) -> HM (Program PsName (TypedRlpExpr PsName))
inferProg p = do inferProg p = do
g0 <- ask g0 <- ask
p' <- annotateProg p -- we only wipe the memo here as a temporary solution to the memo shadowing
let (cs,as) = foldOf (folded . folded . _2) p' -- problem
^. lensProduct constraints assumptions p' <- (\e -> (hmMemo .= mempty) *> annotate e) `traverse` etaExpandAll p
cs' <- (<>cs) <$> elimAssumptionsG g0 as traceM $ "p' : " <> show p'
let (cs,as) = foldMap finalJudgement p' ^. lensProduct constraints assumptions
cs' <- (cs <>) <$> elimAssumptionsG g0 as
traceM $ "cs' : " <> show cs'
sub <- solve cs' sub <- solve cs'
pure $ p' & programDecls . traversed . _FunD . _3 pure $ p' & programDecls . traversed . _FunD . _3
%~ ((_extract %~ generaliseG g0) . fmap (sub . view _1)) %~ ((_extract %~ generaliseG g0) . fmap (sub . view _1))
where
etaExpandAll = programDecls . each %~ etaExpand
annotateProg :: Program PsName (RlpExpr PsName) annotateProg :: Program PsName (RlpExpr PsName)
-> HM (Program PsName -> HM ( Program PsName
(Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))) (Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
annotateProg = traverse annotate , [Constraint] )
annotateProg p = do
gatherProg :: Program PsName (RlpExpr PsName) let bs = funsToSimpleBinds (p ^. programDecls)
-> HM [Constraint] (ks,xs) = unzip bs
gatherProg p = do xs' <- annotate `traverse` xs
-- this should be nearly identical to the rule for `letrec` in gather' let jxs = foldOf (each . _extract . _2) xs'
let (ks,xs) = unzip $ funsToSimpleBinds (p ^. programDecls) txs = xs' ^.. each . _extract . _1
(txs,txs',jxs) <- unzip3 <$> gatherBinds (zipWith simpleBind ks xs) cs <- elimWithBinds (ks `zip` txs) (jxs ^. assumptions)
let jxsa = foldOf (each . assumptions) jxs -- let p' = annotateDecls (ks `zip` xs') p
jxcs <- elimWithBinds (ks `zip` txs) jxsa p' <- annotate `traverse` p
-- TODO: any remaining assumptions should be errors at this point -- TODO: any remaining assumptions should be errors at this point
pure jxcs pure (p',cs)
-- this sucks! FunDs should probably be stored as a hashmap in Program...
annotateDecls :: [( PsName
, Cofree (RlpExprF PsName) (Type PsName, PartialJudgement) )]
-> Program PsName a
-> Program PsName
(Cofree (RlpExprF PsName) (Type PsName, PartialJudgement))
annotateDecls bs = programDecls . traversed . _FunD %~ \case
(n,_,_)
| Just e <- lookup n bs
-> (n,[],e)
gatherBinds' :: [(PsName, RlpExpr PsName)]
-> HM [(Type PsName, Type PsName, PartialJudgement)]
gatherBinds' = gatherBinds . fmap (uncurry simpleBind)
elimWithBinds :: [(PsName, Type PsName)] elimWithBinds :: [(PsName, Type PsName)]
-> Assumptions -> Assumptions

View File

@@ -60,7 +60,21 @@ instance Hashable Constraint
type Memo = HashMap (RlpExpr PsName) (Type PsName, PartialJudgement) type Memo = HashMap (RlpExpr PsName) (Type PsName, PartialJudgement)
type HM = ErrorfulT TypeError (ReaderT Context (StateT Int (Accum Memo))) data HMState = HMState
{ _hmMemo :: Memo
, _hmUniq :: Int
}
deriving Show
newtype HM a = HM {
unHM :: ErrorfulT TypeError
(ReaderT Context (State HMState)) a
}
deriving (Functor, Applicative, Monad)
deriving ( MonadReader Context
, MonadState HMState
, MonadErrorful TypeError
)
-- | Type error enum. -- | Type error enum.
data TypeError data TypeError
@@ -91,30 +105,11 @@ instance IsRlpcError TypeError where
(rout @String t) (rout @String x) (rout @String t) (rout @String x)
] ]
tvNameOfInt :: Int -> PsName
tvNameOfInt n = "$a" <> T.pack (show n)
freshTv :: HM (Type PsName)
freshTv = do
n <- get
modify succ
pure (VarT $ tvNameOfInt n)
listenFreshTvs :: HM a -> HM (a, [Type PsName])
listenFreshTvs hm = listenFreshTvNames hm & mapped . _2 . each %~ VarT
listenFreshTvNames :: HM a -> HM (a, [PsName])
listenFreshTvNames hm = do
n <- get
a <- hm
n' <- get
pure (a, [ tvNameOfInt k | k <- [n .. pred n'] ])
runHM :: Context -> HM a -> Either [TypeError] a runHM :: Context -> HM a -> Either [TypeError] a
runHM g e = maybe (Left es) Right ma runHM g e = maybe (Left es) Right ma
where where
((ma,es),m) = (`runAccum` mempty) . (`evalStateT` 0) (ma,es) = (`evalState` (HMState mempty 0))
. (`runReaderT` g) . runErrorfulT $ e . (`runReaderT` g) . runErrorfulT $ unHM e
runHM' :: HM a -> Either [TypeError] a runHM' :: HM a -> Either [TypeError] a
runHM' = runHM mempty runHM' = runHM mempty
@@ -124,6 +119,7 @@ makeLenses ''PartialJudgement
makeLenses ''Context makeLenses ''Context
makePrisms ''Constraint makePrisms ''Constraint
makePrisms ''TypeError makePrisms ''TypeError
makeLenses ''HMState
supplement :: [(PsName, Type PsName)] -> Context -> Context supplement :: [(PsName, Type PsName)] -> Context -> Context
supplement bs = contextVars %~ (H.fromList bs <>) supplement bs = contextVars %~ (H.fromList bs <>)
@@ -147,3 +143,22 @@ instance Out Constraint where
out (Equality s t) = out (Equality s t) =
hsep [outPrec appPrec1 s, "~", outPrec appPrec1 t] hsep [outPrec appPrec1 s, "~", outPrec appPrec1 t]
tvNameOfInt :: Int -> PsName
tvNameOfInt n = "$a" <> T.pack (show n)
freshTv :: HM (Type PsName)
freshTv = do
n <- use hmUniq
hmUniq %= succ
pure (VarT $ tvNameOfInt n)
listenFreshTvs :: HM a -> HM (a, [Type PsName])
listenFreshTvs hm = listenFreshTvNames hm & mapped . _2 . each %~ VarT
listenFreshTvNames :: HM a -> HM (a, [PsName])
listenFreshTvNames hm = do
n <- use hmUniq
a <- hm
n' <- use hmUniq
pure (a, [ tvNameOfInt k | k <- [n .. pred n'] ])