This commit is contained in:
crumbtoo
2023-12-11 14:37:53 -07:00
parent b371b75c5e
commit 6c43f86397

View File

@@ -1,8 +1,10 @@
{-# LANGUAGE LambdaCase #-}
module Core2Core
( core2core
, gmPrep
-- internal utilities for convenience
, floatNonStrictCases
, floatCase
)
where
@@ -21,33 +23,23 @@ import Core.Utils
core2core :: Program' -> Program'
core2core p = undefined
-- assumes the provided expression is in a strict context
-- replaceNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef'])
-- replaceNonStrictCases names = runWriter . cata goE
-- where
-- goE :: ExprF Name (Writer [ScDef'] Expr')
-- -> Writer [ScDef'] Expr'
-- -- strict context
-- goE (VarF k) = pure (Var k)
-- goE (CaseF e as) = e *> ae'
-- where
-- ae = (\ (Alter _ _ b) -> b) <$> as
-- ae' = mconcat <$> traverse replaceNonStrictCases ae
gmPrep :: Program' -> Program'
gmPrep = undefined
type Replacer = StateT [Name] (Writer [ScDef'])
type Floater = StateT [Name] (Writer [ScDef'])
-- TODO: formally define a "strict context" and reference that here
replaceNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef'])
replaceNonStrictCases names = runWriter . flip evalStateT names . goE
floatNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef'])
floatNonStrictCases names = runWriter . flip evalStateT names . goE
where
goE :: Expr' -> Replacer Expr'
goE :: Expr' -> Floater Expr'
goE (Var k) = pure (Var k)
goE (LitE l) = pure (LitE l)
goE (Let Rec bs e) = Let Rec <$> bs' <*> goE e
where bs' = travBs goE bs
goE e = goC e
goC :: Expr' -> Replacer Expr'
goC :: Expr' -> Floater Expr'
-- the only truly non-trivial case: when a case expr is found in a
-- non-strict context, we float it into a supercombinator, give it a
-- name consumed from the state, record the newly created sc within the
@@ -68,7 +60,7 @@ replaceNonStrictCases names = runWriter . flip evalStateT names . goE
-- extract the right-hand sides of a list of bindings, traverse each
-- one, and return the original list of bindings
travBs :: (Expr' -> Replacer Expr') -> [Binding'] -> Replacer [Binding']
travBs :: (Expr' -> Floater Expr') -> [Binding'] -> Floater [Binding']
travBs c bs = bs ^.. each . _rhs
& traverse goC
& const (pure bs)