case expression inference

This commit is contained in:
crumbtoo
2024-04-05 15:21:49 -06:00
parent 5511d70e26
commit bcf6dc1951
3 changed files with 56 additions and 13 deletions

View File

@@ -15,6 +15,7 @@ import Control.Monad.Accum
import Control.Monad.Reader import Control.Monad.Reader
import Control.Monad import Control.Monad
import Control.Monad.Extra import Control.Monad.Extra
import Control.Monad.Free
import Control.Arrow ((>>>)) import Control.Arrow ((>>>))
import Control.Monad.Writer.Strict import Control.Monad.Writer.Strict
@@ -40,7 +41,7 @@ import Debug.Trace
import Data.Functor hiding (unzip) import Data.Functor hiding (unzip)
import Data.Functor.Extend import Data.Functor.Extend
import Data.Functor.Foldable hiding (fold) import Data.Functor.Foldable hiding (fold)
import Data.Fix hiding (cata, para, cataM) import Data.Fix hiding (cata, para, cataM, ana)
import Control.Comonad.Cofree import Control.Comonad.Cofree
import Control.Comonad import Control.Comonad
@@ -136,11 +137,41 @@ gather (InR (LetEF Rec (withoutPatterns -> bs) (te,je))) = do
elimRecBind (x,(tx,_)) j = elim x tx j elimRecBind (x,(tx,_)) j = elim x tx j
elimBind (x,(tx,_)) j = elimGenerally x tx j elimBind (x,(tx,_)) j = elimGenerally x tx j
gather (InR (CaseEF (te,je) [Alter (ConP' n []) (ta,ja)])) = do gather (InR (CaseEF (te,je) as)) = do
tc <- freshTv as' <- gatherAlter te `traverse` as
let j = equal te tc <> je <> assume n tc <> ja t <- freshTv
let eqs = allEqual (t : (as' ^.. each . _1))
j = je <> foldOf (each . _2) as' <> eqs
pure (t,j)
-- gather (InR (CaseEF (te,je) [Alter (ConP' n bs) (ta,ja)])) = do
-- -- let tc' be the type of the saturated type constructor
-- tc' <- freshTv
-- bs <- for bs (\b -> (b ^. singular _VarP,) <$> freshTv)
-- let tbs = bs ^.. each . _2
-- tc = foldr (:->) tc' tbs
-- j = equal te tc' <> je <> assume n tc <> forBinds elim bs ja
-- pure (ta,j)
gatherAlter :: (Unique :> es)
=> Type'
-> Alter PsName (Type', Judgement)
-> Eff es (Type', Judgement)
gatherAlter te (Alter (ConP' n bs) (ta,ja)) = do
-- let tc' be the type of the saturated type constructor
tc' <- freshTv
bs' <- for bs (\b -> (b ^. singular _VarP,) <$> freshTv)
let tbs = bs' ^.. each . _2
tc = foldr (:->) tc' tbs
j = equal te tc' <> assume n tc <> forBinds elim bs' ja
pure (ta,j) pure (ta,j)
allEqual :: [Type'] -> Judgement
allEqual = fold . ana @[_] \case
[] -> Nil
[a] -> Nil
(a:b:xs) -> Cons (equal a b) (b:xs)
forBinds :: (PsName -> Type' -> Judgement -> Judgement) forBinds :: (PsName -> Type' -> Judgement -> Judgement)
-> [(PsName, Type')] -> Judgement -> Judgement -> [(PsName, Type')] -> Judgement -> Judgement
forBinds f bs j = foldr (uncurry f) j bs forBinds f bs j = foldr (uncurry f) j bs

View File

@@ -103,6 +103,15 @@
(defn LitExpr [_ l] (defn LitExpr [_ l]
[:code (str l)]) [:code (str l)])
(defn Alter [colours a]
(pprint a)
[:code "<alter>"])
(defn CaseExpr [colours e as]
[:<> "case " [Expr colours 0 e] " of { "
"<alters>"
" }"])
(defn Expr [[c & colours] p {e :e t :type}] (defn Expr [[c & colours] p {e :e t :type}]
(match e (match e
{:InL {:tag "LamF" :contents [bs body & _]}} {:InL {:tag "LamF" :contents [bs body & _]}}
@@ -118,6 +127,9 @@
[Typed c t [LetExpr colours r bs body]]) [Typed c t [LetExpr colours r bs body]])
{:InL {:tag "LitF" :contents l}} {:InL {:tag "LitF" :contents l}}
[Typed c t [LitExpr colours l]] [Typed c t [LitExpr colours l]]
{:InR {:tag "CaseEF" :contents [scrut as]}}
(maybe-parens (< ppr/app-prec1 p)
[Typed c t [CaseExpr colours scrut as]])
:else [:code "<expr>"])) :else [:code "<expr>"]))
(def rainbow-cycle (cycle ["red" (def rainbow-cycle (cycle ["red"