letrec inference

This commit is contained in:
crumbtoo
2024-03-27 11:26:36 -06:00
parent e578adeb1f
commit 7795547de8
2 changed files with 50 additions and 19 deletions

View File

@@ -146,6 +146,8 @@ CaseAlt :: { Alter PsName (RlpExpr PsName) }
LetE :: { RlpExpr PsName }
: let layout1(Binding) in Expr
{ Finr $ LetEF Core.NonRec $2 $4 }
| letrec layout1(Binding) in Expr
{ Finr $ LetEF Core.Rec $2 $4 }
Binding :: { Binding PsName (RlpExpr PsName) }
: Pat '=' Expr { VarB $1 $3 }

View File

@@ -26,6 +26,7 @@ import Data.Monoid
import Data.Text qualified as T
import Data.Foldable (fold)
import Data.Function
import Data.Foldable
import Data.Pretty hiding (annotate)
import Data.Maybe
import Data.Hashable
@@ -108,30 +109,48 @@ gather' = \case
let jxcs = jxs ^. each . constraints
& each . constraintTypes %~ substMap m
as = foldr H.delete (je ^. assumptions) ks
j = mempty & constraints .~ (je ^. constraints <> jxcs <> cs)
& assumptions .~ as
pure (te, fold jxs <> j)
j = mempty & constraints .~ je ^. constraints <> jxcs <> cs
& assumptions .~ foldOf (each . assumptions) jxs <> as
pure (te, j)
-- Finr (LetEF NonRec [VarB (VarP k) x] e) -> do
-- ((tx,jx),freeTvs) <- listenFreshTvNames $ gather x
-- let tx' = generalise freeTvs tx
Finr (LetEF Rec bs e) -> do
let ks = bs ^.. each . singular _VarB . _1 . singular _VarP
(txs,txs',jxs) <- unzip3 <$> gatherBinds bs
let jxsa = foldOf (each . assumptions) jxs
jxcs <- fmap concat . for (ks `zip` txs) $ \ (k,t) ->
elimAssumptions' jxsa k t
(te,je) <- gather e
-- ... why don't we need the map?
(cs,_) <- fmap fold . for (ks `zip` txs') $ \ (k,t) ->
elimAssumptionsMap (je ^. assumptions) k t
let as = deleteKeys ks (je ^. assumptions <> jxsa)
j = mempty & constraints .~ je ^. constraints <> jxcs <> cs
& assumptions .~ as
pure (te,j)
-- Finr (LetEF Rec [VarB (VarP k) x] e) -> do
-- ((tx,jx),frees) <- listenFreshTvNames $ gather x
-- jxcs <- elimAssumptions' (jx ^. assumptions) k tx
-- let tx' = generalise frees tx
-- (te,je) <- gather e
-- (cs,m) <- elimAssumptionsMap (je ^. assumptions) k tx'
-- let jxcs = jx ^. constraints
-- & each . constraintTypes %~ substMap m
-- as = H.delete k (je ^. assumptions)
-- j = mempty & constraints .~ (je ^. constraints <> jxcs <> cs)
-- let as = H.delete k (je ^. assumptions)
-- <> H.delete k (jx ^. assumptions)
-- j = mempty & constraints .~ je ^. constraints <> jxcs <> cs
-- & assumptions .~ as
-- pure (te, jx <> j)
-- pure (te,j)
-- Finl (LamF [b] e) -> do
-- tb <- freshTv
-- (te,je) <- gather e
-- let cs = maybe [] (fmap $ Equality tb) (je ^. assumptions . at b)
-- as = je ^. assumptions & at b .~ Nothing
-- j = mempty & constraints .~ cs & assumptions .~ as
-- t = tb :-> te
-- pure (t,j)
deleteKeys :: (Eq k, Hashable k) => [k] -> HashMap k v -> HashMap k v
deleteKeys ks h = foldr H.delete h ks
gatherBinds :: [Binding PsName (RlpExpr PsName)]
-> HM [( Type PsName
, Type PsName
, PartialJudgement )]
gatherBinds bs = for bs $ \ (VarB (VarP k) x) -> do
((tx,jx),frees) <- listenFreshTvNames $ gather x
let tx' = generalise frees tx
pure (tx,tx',jx)
generaliseGatherBinds :: [Binding PsName (RlpExpr PsName)]
-> HM [(Type PsName, PartialJudgement)]
@@ -240,6 +259,7 @@ infer g0 e = do
pure (csa <> cs)
g <- unify cs'
let sub t = ifoldrOf (reversed . assocs) subst t g
checkUndefinedVariables e'
pure $ e' & fmap (sub . view _1)
& _extract %~ generaliseG g0
where
@@ -247,6 +267,15 @@ infer g0 e = do
-- the user-specified type: prioritise her.
unifyTypes _ s t = unify [Equality s t] $> s
checkUndefinedVariables
:: Cofree (RlpExprF PsName) (Type PsName, PartialJudgement)
-> HM ()
checkUndefinedVariables ((_,j) :< es)
= case j ^. assumptions of
[] -> checkUndefinedVariables `traverse_` es
as -> doErrs *> checkUndefinedVariables `traverse_` es
where doErrs = ifor as \n _ -> addWound $ TyErrUntypedVariable n
infer1 :: Context -> RlpExpr PsName -> HM (Type PsName)
infer1 g = fmap extract . infer g