diff --git a/src/Core/Syntax.hs b/src/Core/Syntax.hs index 3d0a1ca..174e63d 100644 --- a/src/Core/Syntax.hs +++ b/src/Core/Syntax.hs @@ -19,6 +19,7 @@ module Core.Syntax , Module(..) , Program(..) , Program' + , programScDefs , Expr' , ScDef' , Alter' @@ -93,6 +94,9 @@ data Module b = Module (Maybe (Name, [Name])) (Program b) newtype Program b = Program [ScDef b] deriving (Show, Lift) +programScDefs :: Lens' (Program b) [ScDef b] +programScDefs = lens coerce (const coerce) + type Program' = Program Name type Expr' = Expr Name type ScDef' = ScDef Name diff --git a/src/Core2Core.hs b/src/Core2Core.hs index 9f895b4..e4b6f89 100644 --- a/src/Core2Core.hs +++ b/src/Core2Core.hs @@ -15,6 +15,8 @@ import Data.Set qualified as S import Data.List import Control.Monad.Writer import Control.Monad.State +import Control.Arrow ((>>>)) +import Numeric (showHex) import Lens.Micro import Core.Syntax import Core.Utils @@ -24,13 +26,29 @@ core2core :: Program' -> Program' core2core p = undefined gmPrep :: Program' -> Program' -gmPrep = undefined +gmPrep p = p' <> Program caseScs + where + rhss :: Applicative f => (Expr z -> f (Expr z)) -> Program z -> f (Program z) + rhss = programScDefs . each . _rhs + -- i kinda don't like that we're calling floatNonStrictCases twice tbh + p' = p & rhss %~ fst . runFloater . floatNonStrictCases + caseScs = (p ^.. rhss) + <&> snd . runFloater . floatNonStrictCases + & mconcat +-- | Auxilary type used in @floatNonSrictCases@ type Floater = StateT [Name] (Writer [ScDef']) +runFloater :: Floater a -> (a, [ScDef']) +runFloater = flip evalStateT ns >>> runWriter + where + -- TODO: safer, uncapturable names + ns = [ "nonstrict_case_" ++ showHex n "" | n <- [0..] ] + -- TODO: formally define a "strict context" and reference that here -floatNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef']) -floatNonStrictCases names = runWriter . flip evalStateT names . goE +-- the returned ScDefs are guaranteed to be free of non-strict cases. +floatNonStrictCases :: Expr' -> Floater Expr' +floatNonStrictCases = goE where goE :: Expr' -> Floater Expr' goE (Var k) = pure (Var k) @@ -55,6 +73,8 @@ floatNonStrictCases names = runWriter . flip evalStateT names . goE goC (f :$ x) = (:$) <$> goC f <*> goC x goC (Let r bs e) = Let r <$> bs' <*> goE e where bs' = travBs goC bs + goC (LitE l) = pure (LitE l) + goC (Var k) = pure (Var k) name = state (fromJust . uncons)