core2core

This commit is contained in:
crumbtoo
2023-12-11 14:18:48 -07:00
parent e477891bc3
commit 238729cf1e
5 changed files with 193 additions and 63 deletions

View File

@@ -23,6 +23,7 @@ library
, GM
, Compiler.RLPC
, Core.Syntax
, Core.Utils
other-modules: Data.Heap
, Data.Pretty
@@ -31,6 +32,7 @@ library
, Core.Examples
, Core.Lex
, Control.Monad.Errorful
, Core2Core
build-tool-depends: happy:happy, alex:alex

View File

@@ -3,11 +3,7 @@ Module : Core.Syntax
Description : Core ASTs and the like
-}
{-# LANGUAGE PatternSynonyms, OverloadedStrings #-}
-- for recursion schemes
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
-- for recursion schemes
{-# LANGUAGE TemplateHaskell, TypeFamilies #-}
{-# LANGUAGE FunctionalDependencies #-}
module Core.Syntax
( Expr(..)
, Literal(..)
@@ -22,27 +18,24 @@ module Core.Syntax
, ScDef(..)
, Module(..)
, Program(..)
, CoreProgram
, CoreExpr
, CoreScDef
, CoreAlter
, CoreBinding
, bindersOf
, rhssOf
, isAtomic
, insertModule
, extractProgram
, Program'
, Expr'
, ScDef'
, Alter'
, Binding'
, HasRHS(_rhs)
)
where
----------------------------------------------------------------------------------
import Data.Coerce
import Data.Pretty
import GHC.Generics
import Data.List (intersperse)
import Data.Function ((&))
import Data.String
-- Lift instances for the Core quasiquoters
import Lens.Micro
import Language.Haskell.TH.Syntax (Lift)
import Data.Functor.Foldable.TH (makeBaseFunctor)
----------------------------------------------------------------------------------
data Expr b = Var Name
@@ -100,17 +93,15 @@ data Module b = Module (Maybe (Name, [Name])) (Program b)
newtype Program b = Program [ScDef b]
deriving (Show, Lift)
type CoreProgram = Program Name
type CoreExpr = Expr Name
type CoreScDef = ScDef Name
type CoreAlter = Alter Name
type CoreBinding = Binding Name
type Program' = Program Name
type Expr' = Expr Name
type ScDef' = ScDef Name
type Alter' = Alter Name
type Binding' = Binding Name
instance IsString (Expr b) where
fromString = Var
----------------------------------------------------------------------------------
instance Semigroup (Program b) where
(<>) = coerce $ (<>) @[ScDef b]
@@ -119,27 +110,21 @@ instance Monoid (Program b) where
----------------------------------------------------------------------------------
bindersOf :: [(Name, b)] -> [Name]
bindersOf = fmap fst
class HasRHS s z | s -> z where
_rhs :: Lens' s (Expr z)
rhssOf :: [(Name, b)] -> [b]
rhssOf = fmap snd
instance HasRHS (Alter b) b where
_rhs = lens
(\ (Alter _ _ e) -> e)
(\ (Alter t as _) e' -> Alter t as e')
isAtomic :: Expr b -> Bool
isAtomic (Var _) = True
isAtomic (LitE _) = True
isAtomic _ = False
instance HasRHS (ScDef b) b where
_rhs = lens
(\ (ScDef _ _ e) -> e)
(\ (ScDef n as _) e' -> ScDef n as e')
----------------------------------------------------------------------------------
-- TODO: export list awareness
insertModule :: (Module b) -> (Program b) -> (Program b)
insertModule (Module _ m) p = p <> m
extractProgram :: (Module b) -> (Program b)
extractProgram (Module _ p) = p
----------------------------------------------------------------------------------
makeBaseFunctor ''Expr
instance HasRHS (Binding b) b where
_rhs = lens
(\ (_ := e) -> e)
(\ (k := _) e' -> k := e')

72
src/Core/Utils.hs Normal file
View File

@@ -0,0 +1,72 @@
-- for recursion schemes
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
-- for recursion schemes
{-# LANGUAGE TemplateHaskell, TypeFamilies #-}
module Core.Utils
( bindersOf
, rhssOf
, isAtomic
, insertModule
, extractProgram
, freeVariables
, ExprF(..)
)
where
----------------------------------------------------------------------------------
import Data.Functor.Foldable.TH (makeBaseFunctor)
import Data.Functor.Foldable
import Data.Set (Set)
import Data.Set qualified as S
import Core.Syntax
import GHC.Exts (IsList(..))
----------------------------------------------------------------------------------
bindersOf :: (IsList l, Item l ~ b) => [Binding b] -> l
bindersOf bs = fromList $ fmap f bs
where f (k := _) = k
rhssOf :: (IsList l, Item l ~ Expr b) => [Binding b] -> l
rhssOf = fromList . fmap f
where f (_ := v) = v
isAtomic :: Expr b -> Bool
isAtomic (Var _) = True
isAtomic (LitE _) = True
isAtomic _ = False
----------------------------------------------------------------------------------
-- TODO: export list awareness
insertModule :: Module b -> Program b -> Program b
insertModule (Module _ m) p = p <> m
extractProgram :: Module b -> Program b
extractProgram (Module _ p) = p
----------------------------------------------------------------------------------
makeBaseFunctor ''Expr
freeVariables :: Expr' -> Set Name
freeVariables = cata go
where
go :: ExprF Name (Set Name) -> Set Name
go (VarF k) = S.singleton k
-- TODO: collect free vars in rhss of bs
go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns
where
es = rhssOf bs :: [Expr']
ns = bindersOf bs
-- TODO: this feels a little wrong. maybe a different scheme is
-- appropriate
esFree = foldMap id $ freeVariables <$> es
go (CaseF e as) = e `S.union` asFree
where
asFree = foldMap id $ freeVariables <$> (fmap altToLam as)
-- we map alts to lambdas to avoid writing a 'freeVariablesAlt'
altToLam (Alter _ ns e) = Lam ns e
go (LamF bs e) = e `S.difference` (S.fromList bs)
go e = foldMap id e

View File

@@ -1,14 +1,85 @@
{-# LANGUAGE LambdaCase #-}
module Core2Core
(
( core2core
-- internal utilities for convenience
, floatCase
)
where
----------------------------------------------------------------------------------
import Data.Functor.Foldable
import Data.Maybe (fromJust)
import Data.Set qualified as S
import Data.List
import Control.Monad.Writer
import Control.Monad.State
import Lens.Micro
import Core.Syntax
import Core.Utils
----------------------------------------------------------------------------------
core2core :: Program -> Program
core2core = undefined
core2core :: Program' -> Program'
core2core p = undefined
floatNonStrictCase :: Expr -> Expr
floatNonStrictCase (Case e as) = Case e ()
-- assumes the provided expression is in a strict context
-- replaceNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef'])
-- replaceNonStrictCases names = runWriter . cata goE
-- where
-- goE :: ExprF Name (Writer [ScDef'] Expr')
-- -> Writer [ScDef'] Expr'
-- -- strict context
-- goE (VarF k) = pure (Var k)
-- goE (CaseF e as) = e *> ae'
-- where
-- ae = (\ (Alter _ _ b) -> b) <$> as
-- ae' = mconcat <$> traverse replaceNonStrictCases ae
type Replacer = StateT [Name] (Writer [ScDef'])
-- TODO: formally define a "strict context" and reference that here
replaceNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef'])
replaceNonStrictCases names = runWriter . flip evalStateT names . goE
where
goE :: Expr' -> Replacer Expr'
goE (Var k) = pure (Var k)
goE (LitE l) = pure (LitE l)
goE (Let Rec bs e) = Let Rec <$> bs' <*> goE e
where bs' = travBs goE bs
goE e = goC e
goC :: Expr' -> Replacer Expr'
-- the only truly non-trivial case: when a case expr is found in a
-- non-strict context, we float it into a supercombinator, give it a
-- name consumed from the state, record the newly created sc within the
-- Writer, and finally return an expression appropriately calling the sc
goC p@(Case e as) = do
n <- name
let (e',sc) = floatCase n p
altBodies = (\(Alter _ _ b) -> b) <$> as
tell [sc]
goE e
traverse goE altBodies
pure e'
goC (f :$ x) = (:$) <$> goC f <*> goC x
goC (Let r bs e) = Let r <$> bs' <*> goE e
where bs' = travBs goC bs
name = state (fromJust . uncons)
-- extract the right-hand sides of a list of bindings, traverse each
-- one, and return the original list of bindings
travBs :: (Expr' -> Replacer Expr') -> [Binding'] -> Replacer [Binding']
travBs c bs = bs ^.. each . _rhs
& traverse goC
& const (pure bs)
-- when provided with a case expr, floatCase will float the case into a
-- supercombinator of its free variables. the sc is returned along with an
-- expression that calls the sc with the necessary arguments
floatCase :: Name -> Expr' -> (Expr', ScDef')
floatCase n c@(Case e as) = (e', sc)
where
sc = ScDef n caseFrees c
caseFrees = S.toList $ freeVariables c
e' = foldl App (Var n) (Var <$> caseFrees)

View File

@@ -118,7 +118,7 @@ pure []
----------------------------------------------------------------------------------
evalProg :: CoreProgram -> Maybe (Node, Stats)
evalProg :: Program' -> Maybe (Node, Stats)
evalProg p = res <&> (,sts)
where
final = eval (compile p) & last
@@ -127,7 +127,7 @@ evalProg p = res <&> (,sts)
resAddr = final ^. gmStack ^? _head
res = resAddr >>= flip hLookup h
hdbgProg :: CoreProgram -> Handle -> IO (Node, Stats)
hdbgProg :: Program' -> Handle -> IO (Node, Stats)
hdbgProg p hio = do
(renderOut . showState) `traverse_` states
-- TODO: i'd like the statistics to be at the top of the file, but `sts`
@@ -548,7 +548,7 @@ pop [] = []
----------------------------------------------------------------------------------
compile :: CoreProgram -> GmState
compile :: Program' -> GmState
compile p = GmState c [] [] h g sts
where
-- find the entry point and evaluate it
@@ -575,7 +575,7 @@ compiledPrims =
binop k i = (k, 2, [Push 1, Eval, Push 1, Eval, i, Update 2, Pop 2, Unwind])
buildInitialHeap :: CoreProgram -> (GmHeap, Env)
buildInitialHeap :: Program' -> (GmHeap, Env)
buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
where
compiledScs = fmap compileSc ss <> compiledPrims
@@ -588,20 +588,20 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
-- >> [ref/compileSc]
-- type CompiledSC = (Name, Int, Code)
compileSc :: CoreScDef -> CompiledSC
compileSc :: ScDef' -> CompiledSC
compileSc (ScDef n as b) = (n, d, compileR env b)
where
env = (NameKey <$> as) `zip` [0..]
d = length as
-- << [ref/compileSc]
compileR :: Env -> CoreExpr -> Code
compileR :: Env -> Expr' -> Code
compileR g e = compileE g e <> [Update d, Pop d, Unwind]
where
d = length g
-- compile an expression in a lazy context
compileC :: Env -> CoreExpr -> Code
-- compile an expression in a non-strict context
compileC :: Env -> Expr' -> Code
compileC g (Var k)
| k `elem` domain = [Push n]
| otherwise = [PushGlobal k]
@@ -627,7 +627,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
-- kinda gross. revisit this
addressed = bs `zip` reverse [0 .. d-1]
compileBinder :: Env -> (CoreBinding, Int) -> (Env, Code)
compileBinder :: Env -> (Binding', Int) -> (Env, Code)
compileBinder m (k := v, a) = (m',c)
where
m' = (NameKey k, a) : m
@@ -645,7 +645,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
initialisers = mconcat $ compileBinder <$> addressed
body = compileC g' e
compileBinder :: (CoreBinding, Int) -> Code
compileBinder :: (Binding', Int) -> Code
compileBinder (_ := v, a) = compileC g' v <> [Update a]
compileC _ (Con t n) = [PushConstr t n]
@@ -663,7 +663,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
-- compile an expression in a strict context such that a pointer to the
-- expression is left on top of the stack in WHNF
compileE :: Env -> CoreExpr -> Code
compileE :: Env -> Expr' -> Code
compileE _ (LitE l) = compileEL l
compileE g (Let NonRec bs e) =
-- we use compileE instead of compileC
@@ -674,7 +674,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
-- kinda gross. revisit this
addressed = bs `zip` reverse [0 .. d-1]
compileBinder :: Env -> (CoreBinding, Int) -> (Env, Code)
compileBinder :: Env -> (Binding', Int) -> (Env, Code)
compileBinder m (k := v, a) = (m',c)
where
m' = (NameKey k, a) : m
@@ -695,7 +695,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
body = compileE g' e
-- we use compileE instead of compileC
compileBinder :: (CoreBinding, Int) -> Code
compileBinder :: (Binding', Int) -> Code
compileBinder (_ := v, a) = compileC g' v <> [Update a]
-- special cases for prim functions; essentially inlining
@@ -710,10 +710,10 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs
compileE g e = compileC g e ++ [Eval]
compileD :: Env -> [CoreAlter] -> [(Tag, Code)]
compileD :: Env -> [Alter'] -> [(Tag, Code)]
compileD g as = fmap (compileA g) as
compileA :: Env -> CoreAlter -> (Tag, Code)
compileA :: Env -> Alter' -> (Tag, Code)
compileA g (Alter (AltData t) as e) = (t, [Split n] <> c <> [Slide n])
where
n = length as