resolve named data in case exprs
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
{-# LANGUAGE ImplicitParams #-}
|
||||
{-# LANGUAGE LambdaCase #-}
|
||||
module Core2Core
|
||||
( core2core
|
||||
@@ -18,8 +19,9 @@ import Control.Monad.Writer
|
||||
import Control.Monad.State.Lazy
|
||||
import Control.Arrow ((>>>))
|
||||
import Data.Text qualified as T
|
||||
import Data.HashMap.Strict (HashMap)
|
||||
import Numeric (showHex)
|
||||
import Lens.Micro
|
||||
import Lens.Micro.Platform
|
||||
import Core.Syntax
|
||||
import Core.Utils
|
||||
----------------------------------------------------------------------------------
|
||||
@@ -28,19 +30,35 @@ core2core :: Program' -> Program'
|
||||
core2core p = undefined
|
||||
|
||||
gmPrep :: Program' -> Program'
|
||||
gmPrep p = p & traverseOf rhss (floatNonStrictCases globals)
|
||||
& runFloater
|
||||
& \ (me,caseScs) -> me & programScDefs %~ (<>caseScs)
|
||||
gmPrep p = p & appFloater (floatNonStrictCases globals)
|
||||
& tagData
|
||||
where
|
||||
rhss :: Traversal' (Program z) (Expr z)
|
||||
rhss = programScDefs . each . _rhs
|
||||
|
||||
globals = p ^.. programScDefs . each . _lhs . _1
|
||||
& S.fromList
|
||||
|
||||
tagData :: Program' -> Program'
|
||||
tagData p = let ?dt = p ^. programDataTags
|
||||
in p & programRhss %~ cata go where
|
||||
go :: (?dt :: HashMap Name (Tag, Int)) => ExprF' Expr' -> Expr'
|
||||
go (CaseF e as) = Case e (tagAlts <$> as)
|
||||
go x = embed x
|
||||
|
||||
tagAlts :: (?dt :: HashMap Name (Tag, Int)) => Alter' -> Alter'
|
||||
tagAlts (Alter (AltData c) bs e) = Alter (AltTag tag) bs e
|
||||
where tag = case ?dt ^. at c of
|
||||
Just (t,_) -> t
|
||||
-- TODO: errorful
|
||||
Nothing -> error $ "unknown constructor " <> show c
|
||||
tagAlts x = x
|
||||
|
||||
-- | Auxilary type used in @floatNonSrictCases@
|
||||
type Floater = StateT [Name] (Writer [ScDef'])
|
||||
|
||||
appFloater :: (Expr' -> Floater Expr') -> Program' -> Program'
|
||||
appFloater fl p = p & traverseOf programRhss fl
|
||||
& runFloater
|
||||
& \ (me,floats) -> me & programScDefs %~ (<>floats)
|
||||
|
||||
runFloater :: Floater a -> (a, [ScDef'])
|
||||
runFloater = flip evalStateT ns >>> runWriter
|
||||
where
|
||||
|
||||
Reference in New Issue
Block a user