case expression inference

This commit is contained in:
crumbtoo
2024-04-05 15:21:49 -06:00
parent 5511d70e26
commit bcf6dc1951
3 changed files with 56 additions and 13 deletions

View File

@@ -15,6 +15,7 @@ import Control.Monad.Accum
import Control.Monad.Reader
import Control.Monad
import Control.Monad.Extra
import Control.Monad.Free
import Control.Arrow ((>>>))
import Control.Monad.Writer.Strict
@@ -40,7 +41,7 @@ import Debug.Trace
import Data.Functor hiding (unzip)
import Data.Functor.Extend
import Data.Functor.Foldable hiding (fold)
import Data.Fix hiding (cata, para, cataM)
import Data.Fix hiding (cata, para, cataM, ana)
import Control.Comonad.Cofree
import Control.Comonad
@@ -136,11 +137,41 @@ gather (InR (LetEF Rec (withoutPatterns -> bs) (te,je))) = do
elimRecBind (x,(tx,_)) j = elim x tx j
elimBind (x,(tx,_)) j = elimGenerally x tx j
gather (InR (CaseEF (te,je) [Alter (ConP' n []) (ta,ja)])) = do
tc <- freshTv
let j = equal te tc <> je <> assume n tc <> ja
gather (InR (CaseEF (te,je) as)) = do
as' <- gatherAlter te `traverse` as
t <- freshTv
let eqs = allEqual (t : (as' ^.. each . _1))
j = je <> foldOf (each . _2) as' <> eqs
pure (t,j)
-- gather (InR (CaseEF (te,je) [Alter (ConP' n bs) (ta,ja)])) = do
-- -- let tc' be the type of the saturated type constructor
-- tc' <- freshTv
-- bs <- for bs (\b -> (b ^. singular _VarP,) <$> freshTv)
-- let tbs = bs ^.. each . _2
-- tc = foldr (:->) tc' tbs
-- j = equal te tc' <> je <> assume n tc <> forBinds elim bs ja
-- pure (ta,j)
gatherAlter :: (Unique :> es)
=> Type'
-> Alter PsName (Type', Judgement)
-> Eff es (Type', Judgement)
gatherAlter te (Alter (ConP' n bs) (ta,ja)) = do
-- let tc' be the type of the saturated type constructor
tc' <- freshTv
bs' <- for bs (\b -> (b ^. singular _VarP,) <$> freshTv)
let tbs = bs' ^.. each . _2
tc = foldr (:->) tc' tbs
j = equal te tc' <> assume n tc <> forBinds elim bs' ja
pure (ta,j)
allEqual :: [Type'] -> Judgement
allEqual = fold . ana @[_] \case
[] -> Nil
[a] -> Nil
(a:b:xs) -> Cons (equal a b) (b:xs)
forBinds :: (PsName -> Type' -> Judgement -> Judgement)
-> [(PsName, Type')] -> Judgement -> Judgement
forBinds f bs j = foldr (uncurry f) j bs

View File

@@ -103,6 +103,15 @@
(defn LitExpr [_ l]
[:code (str l)])
(defn Alter [colours a]
(pprint a)
[:code "<alter>"])
(defn CaseExpr [colours e as]
[:<> "case " [Expr colours 0 e] " of { "
"<alters>"
" }"])
(defn Expr [[c & colours] p {e :e t :type}]
(match e
{:InL {:tag "LamF" :contents [bs body & _]}}
@@ -118,6 +127,9 @@
[Typed c t [LetExpr colours r bs body]])
{:InL {:tag "LitF" :contents l}}
[Typed c t [LitExpr colours l]]
{:InR {:tag "CaseEF" :contents [scrut as]}}
(maybe-parens (< ppr/app-prec1 p)
[Typed c t [CaseExpr colours scrut as]])
:else [:code "<expr>"]))
(def rainbow-cycle (cycle ["red"