Files
rlp/src/Core2Core.hs
2024-04-15 10:07:22 -06:00

152 lines
5.2 KiB
Haskell

{-# LANGUAGE ImplicitParams #-}
module Core2Core
( core2core
, gmPrep
-- internal utilities for convenience
, floatNonStrictCases
, floatCase
)
where
----------------------------------------------------------------------------------
import Data.Functor.Foldable
import Data.Maybe (fromJust)
import Data.HashSet (HashSet)
import Data.HashSet qualified as S
import Data.List
import Data.Foldable
import Control.Monad.Writer
import Control.Monad.State.Lazy
import Control.Arrow ((>>>))
import Data.Text qualified as T
import Data.HashMap.Strict (HashMap)
import Numeric (showHex)
import Misc.MonadicRecursionSchemes
import Data.Pretty
import Compiler.RLPC
import Control.Lens
import Core.Syntax
import Core.Utils
----------------------------------------------------------------------------------
-- | General optimisations
core2core :: Program' -> Program'
core2core p = undefined
gmPrepR :: (Monad m) => Program' -> RLPCT m Program'
gmPrepR p = do
let p' = gmPrep p
addDebugMsg "dump-gm-preprocessed" $ show . out $ p'
pure p'
-- | G-machine-specific preprocessing.
gmPrep :: Program' -> Program'
gmPrep p = p & appFloater (floatNonStrictCases globals)
& tagData
& defineData
& etaReduce
where
globals = p ^.. programScDefs . each . _lhs . _1
& S.fromList
programGlobals :: Program b -> HashSet b
programGlobals = undefined
-- | Define concrete supercombinators for all datatags defined via pragmas (or
-- desugaring)
defineData :: Program' -> Program'
defineData p = p & programScDefs <>~ defs
where
defs = p ^. programDataTags
. to (ifoldMap (\k (t,a) -> [ScDef k [] (Con t a)]))
-- | Substitute all pattern matches on named constructors for matches on tags
tagData :: Program' -> Program'
tagData p = let ?dt = p ^. programDataTags
in p & programRhss %~ cata go where
go :: (?dt :: HashMap Name (Tag, Int)) => ExprF' Expr' -> Expr'
go (CaseF e as) = Case e (tagAlts <$> as)
go x = embed x
tagAlts :: (?dt :: HashMap Name (Tag, Int)) => Alter' -> Alter'
tagAlts (Alter (AltData c) bs e) = Alter (AltTag tag) bs (cata go e)
where tag = case ?dt ^. at c of
Just (t,_) -> t
-- TODO: errorful
Nothing -> error $ "unknown constructor " <> show c
tagAlts x = x
-- | Auxilary type used in @floatNonSrictCases@
type Floater = StateT [Name] (Writer [ScDef'])
appFloater :: (Expr' -> Floater Expr') -> Program' -> Program'
appFloater fl p = p & traverseOf programRhss fl
& runFloater
& \ (me,floats) -> me & programScDefs %~ (<>floats)
-- TODO: move NameSupply from Rlp2Core into a common module to share here
runFloater :: Floater a -> (a, [ScDef'])
runFloater = flip evalStateT ns >>> runWriter
where
ns = [ T.pack $ "$nonstrict_case_" ++ showHex n "" | n <- [0..] ]
-- TODO: formally define a "strict context" and reference that here
-- the returned ScDefs are guaranteed to be free of non-strict cases.
floatNonStrictCases :: HashSet Name -> Expr' -> Floater Expr'
floatNonStrictCases g = goE
where
goE :: Expr' -> Floater Expr'
goE (Var k) = pure (Var k)
goE (Lit l) = pure (Lit l)
goE (Case e as) = pure (Case e as)
goE (Let Rec bs e) = Let Rec <$> bs' <*> goE e
where bs' = travBs goE bs
goE e = goC e
goC :: Expr' -> Floater Expr'
goC = cataM \case
-- the only truly non-trivial case: when a case expr is found in a
-- non-strict context, we float it into a supercombinator, give it a
-- name consumed from the state, record the newly created sc within the
-- Writer, and finally return an expression appropriately calling the sc
CaseF e as -> do
n <- name
let (e',sc) = floatCase g n (Case e as)
altBodies = (\(Alter _ _ b) -> b) <$> as
tell [sc]
goE e
traverse_ goE altBodies
pure e'
t -> pure $ embed t
name = state (fromJust . Data.List.uncons)
-- extract the right-hand sides of a list of bindings, traverse each
-- one, and return the original list of bindings
travBs :: (Expr' -> Floater Expr') -> [Binding'] -> Floater [Binding']
travBs c bs = undefined
-- ^ ??? what the fuck?
-- ^ 24/02/22: what is this shit lol?
etaReduce :: Program' -> Program'
etaReduce = programScDefs . each %~ \case
ScDef n as (Lam bs e) -> ScDef n (as ++ bs) e
ScDef n as e -> ScDef n as e
-- when provided with a case expr, floatCase will float the case into a
-- supercombinator of its free variables. the sc is returned along with an
-- expression that calls the sc with the necessary arguments
floatCase :: HashSet Name -> Name -> Expr' -> (Expr', ScDef')
floatCase g n c@(Case e as) = (e', sc)
where
sc = ScDef n caseFrees c
caseFrees = S.toList $ freeVariables c `S.difference` g
e' = foldl App (Var n) (Var <$> caseFrees)