This commit is contained in:
crumbtoo
2023-12-11 17:25:41 -07:00
parent 6c43f86397
commit 7391148d62
2 changed files with 27 additions and 3 deletions

View File

@@ -19,6 +19,7 @@ module Core.Syntax
, Module(..)
, Program(..)
, Program'
, programScDefs
, Expr'
, ScDef'
, Alter'
@@ -93,6 +94,9 @@ data Module b = Module (Maybe (Name, [Name])) (Program b)
newtype Program b = Program [ScDef b]
deriving (Show, Lift)
programScDefs :: Lens' (Program b) [ScDef b]
programScDefs = lens coerce (const coerce)
type Program' = Program Name
type Expr' = Expr Name
type ScDef' = ScDef Name

View File

@@ -15,6 +15,8 @@ import Data.Set qualified as S
import Data.List
import Control.Monad.Writer
import Control.Monad.State
import Control.Arrow ((>>>))
import Numeric (showHex)
import Lens.Micro
import Core.Syntax
import Core.Utils
@@ -24,13 +26,29 @@ core2core :: Program' -> Program'
core2core p = undefined
gmPrep :: Program' -> Program'
gmPrep = undefined
gmPrep p = p' <> Program caseScs
where
rhss :: Applicative f => (Expr z -> f (Expr z)) -> Program z -> f (Program z)
rhss = programScDefs . each . _rhs
-- i kinda don't like that we're calling floatNonStrictCases twice tbh
p' = p & rhss %~ fst . runFloater . floatNonStrictCases
caseScs = (p ^.. rhss)
<&> snd . runFloater . floatNonStrictCases
& mconcat
-- | Auxilary type used in @floatNonSrictCases@
type Floater = StateT [Name] (Writer [ScDef'])
runFloater :: Floater a -> (a, [ScDef'])
runFloater = flip evalStateT ns >>> runWriter
where
-- TODO: safer, uncapturable names
ns = [ "nonstrict_case_" ++ showHex n "" | n <- [0..] ]
-- TODO: formally define a "strict context" and reference that here
floatNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef'])
floatNonStrictCases names = runWriter . flip evalStateT names . goE
-- the returned ScDefs are guaranteed to be free of non-strict cases.
floatNonStrictCases :: Expr' -> Floater Expr'
floatNonStrictCases = goE
where
goE :: Expr' -> Floater Expr'
goE (Var k) = pure (Var k)
@@ -55,6 +73,8 @@ floatNonStrictCases names = runWriter . flip evalStateT names . goE
goC (f :$ x) = (:$) <$> goC f <*> goC x
goC (Let r bs e) = Let r <$> bs' <*> goE e
where bs' = travBs goC bs
goC (LitE l) = pure (LitE l)
goC (Var k) = pure (Var k)
name = state (fromJust . uncons)