define datatags

This commit is contained in:
crumbtoo
2024-02-07 23:45:38 -07:00
parent c6f9c615b4
commit bb2a07d2e9
2 changed files with 22 additions and 6 deletions

View File

@@ -207,8 +207,6 @@ namedConsCase :: Program'
namedConsCase = [coreProg|
{-# PackData Nil 0 0 #-}
{-# PackData Cons 1 2 #-}
Nil = Pack{0 0};
Cons = Pack{1 2};
foldr f z l = case l of
{ Nil -> z
; Cons x xs -> f x (foldr f z xs)

View File

@@ -1,5 +1,4 @@
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-}
module Core2Core
( core2core
, gmPrep
@@ -15,13 +14,15 @@ import Data.Maybe (fromJust)
import Data.Set (Set)
import Data.Set qualified as S
import Data.List
import Data.Foldable
import Control.Monad.Writer
import Control.Monad.State.Lazy
import Control.Arrow ((>>>))
import Data.Text qualified as T
import Data.HashMap.Strict (HashMap)
import Numeric (showHex)
import Lens.Micro.Platform
-- import Lens.Micro.Platform
import Control.Lens
import Core.Syntax
import Core.Utils
----------------------------------------------------------------------------------
@@ -29,13 +30,28 @@ import Core.Utils
core2core :: Program' -> Program'
core2core p = undefined
-- | G-machine preprocessing.
gmPrep :: Program' -> Program'
gmPrep p = p & appFloater (floatNonStrictCases globals)
& tagData
& defineData
where
globals = p ^.. programScDefs . each . _lhs . _1
& S.fromList
-- | Define concrete supercombinators for all datatags defined via pragmas (or
-- desugaring)
defineData :: Program' -> Program'
defineData p = p & programScDefs <>~ defs
where
-- defs = ifoldMap' _ (p ^. programDataTags)
defs = p ^. programDataTags
. to (ifoldMap (\k (t,a) -> [ScDef k [] (Con t a)]))
-- | Substitute all pattern matches on named constructors for matches on tags
tagData :: Program' -> Program'
tagData p = let ?dt = p ^. programDataTags
in p & programRhss %~ cata go where
@@ -59,6 +75,7 @@ appFloater fl p = p & traverseOf programRhss fl
& runFloater
& \ (me,floats) -> me & programScDefs %~ (<>floats)
-- TODO: move NameSupply from Rlp2Core into a common module to share here
runFloater :: Floater a -> (a, [ScDef'])
runFloater = flip evalStateT ns >>> runWriter
where
@@ -88,7 +105,7 @@ floatNonStrictCases g = goE
altBodies = (\(Alter _ _ b) -> b) <$> as
tell [sc]
goE e
traverse goE altBodies
traverse_ goE altBodies
pure e'
goC (f :$ x) = (:$) <$> goC f <*> goC x
goC (Let r bs e) = Let r <$> bs' <*> goE e
@@ -97,7 +114,7 @@ floatNonStrictCases g = goE
goC (Var k) = pure (Var k)
goC (Con t as) = pure (Con t as)
name = state (fromJust . uncons)
name = state (fromJust . Data.List.uncons)
-- extract the right-hand sides of a list of bindings, traverse each
-- one, and return the original list of bindings
@@ -105,6 +122,7 @@ floatNonStrictCases g = goE
travBs c bs = bs ^.. each . _rhs
& traverse goC
& const (pure bs)
-- ^ ??? what the fuck?
-- when provided with a case expr, floatCase will float the case into a
-- supercombinator of its free variables. the sc is returned along with an