diff --git a/rlp.cabal b/rlp.cabal index 187962d..c4249e6 100644 --- a/rlp.cabal +++ b/rlp.cabal @@ -23,6 +23,7 @@ library , GM , Compiler.RLPC , Core.Syntax + , Core.Utils other-modules: Data.Heap , Data.Pretty @@ -31,6 +32,7 @@ library , Core.Examples , Core.Lex , Control.Monad.Errorful + , Core2Core build-tool-depends: happy:happy, alex:alex diff --git a/src/Core/Syntax.hs b/src/Core/Syntax.hs index decc325..3d0a1ca 100644 --- a/src/Core/Syntax.hs +++ b/src/Core/Syntax.hs @@ -3,11 +3,7 @@ Module : Core.Syntax Description : Core ASTs and the like -} {-# LANGUAGE PatternSynonyms, OverloadedStrings #-} --- for recursion schemes -{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-} --- for recursion schemes -{-# LANGUAGE TemplateHaskell, TypeFamilies #-} - +{-# LANGUAGE FunctionalDependencies #-} module Core.Syntax ( Expr(..) , Literal(..) @@ -22,27 +18,24 @@ module Core.Syntax , ScDef(..) , Module(..) , Program(..) - , CoreProgram - , CoreExpr - , CoreScDef - , CoreAlter - , CoreBinding - , bindersOf - , rhssOf - , isAtomic - , insertModule - , extractProgram + , Program' + , Expr' + , ScDef' + , Alter' + , Binding' + , HasRHS(_rhs) ) where ---------------------------------------------------------------------------------- import Data.Coerce import Data.Pretty +import GHC.Generics import Data.List (intersperse) import Data.Function ((&)) import Data.String -- Lift instances for the Core quasiquoters +import Lens.Micro import Language.Haskell.TH.Syntax (Lift) -import Data.Functor.Foldable.TH (makeBaseFunctor) ---------------------------------------------------------------------------------- data Expr b = Var Name @@ -100,17 +93,15 @@ data Module b = Module (Maybe (Name, [Name])) (Program b) newtype Program b = Program [ScDef b] deriving (Show, Lift) -type CoreProgram = Program Name -type CoreExpr = Expr Name -type CoreScDef = ScDef Name -type CoreAlter = Alter Name -type CoreBinding = Binding Name +type Program' = Program Name +type Expr' = Expr Name +type ScDef' = ScDef Name +type Alter' = Alter Name +type Binding' = Binding Name instance IsString (Expr b) where fromString = Var ----------------------------------------------------------------------------------- - instance Semigroup (Program b) where (<>) = coerce $ (<>) @[ScDef b] @@ -119,27 +110,21 @@ instance Monoid (Program b) where ---------------------------------------------------------------------------------- -bindersOf :: [(Name, b)] -> [Name] -bindersOf = fmap fst +class HasRHS s z | s -> z where + _rhs :: Lens' s (Expr z) -rhssOf :: [(Name, b)] -> [b] -rhssOf = fmap snd +instance HasRHS (Alter b) b where + _rhs = lens + (\ (Alter _ _ e) -> e) + (\ (Alter t as _) e' -> Alter t as e') -isAtomic :: Expr b -> Bool -isAtomic (Var _) = True -isAtomic (LitE _) = True -isAtomic _ = False +instance HasRHS (ScDef b) b where + _rhs = lens + (\ (ScDef _ _ e) -> e) + (\ (ScDef n as _) e' -> ScDef n as e') ----------------------------------------------------------------------------------- - --- TODO: export list awareness -insertModule :: (Module b) -> (Program b) -> (Program b) -insertModule (Module _ m) p = p <> m - -extractProgram :: (Module b) -> (Program b) -extractProgram (Module _ p) = p - ----------------------------------------------------------------------------------- - -makeBaseFunctor ''Expr +instance HasRHS (Binding b) b where + _rhs = lens + (\ (_ := e) -> e) + (\ (k := _) e' -> k := e') diff --git a/src/Core/Utils.hs b/src/Core/Utils.hs new file mode 100644 index 0000000..dd9c6ed --- /dev/null +++ b/src/Core/Utils.hs @@ -0,0 +1,72 @@ +-- for recursion schemes +{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-} +-- for recursion schemes +{-# LANGUAGE TemplateHaskell, TypeFamilies #-} + +module Core.Utils + ( bindersOf + , rhssOf + , isAtomic + , insertModule + , extractProgram + , freeVariables + , ExprF(..) + ) + where +---------------------------------------------------------------------------------- +import Data.Functor.Foldable.TH (makeBaseFunctor) +import Data.Functor.Foldable +import Data.Set (Set) +import Data.Set qualified as S +import Core.Syntax +import GHC.Exts (IsList(..)) +---------------------------------------------------------------------------------- + +bindersOf :: (IsList l, Item l ~ b) => [Binding b] -> l +bindersOf bs = fromList $ fmap f bs + where f (k := _) = k + +rhssOf :: (IsList l, Item l ~ Expr b) => [Binding b] -> l +rhssOf = fromList . fmap f + where f (_ := v) = v + +isAtomic :: Expr b -> Bool +isAtomic (Var _) = True +isAtomic (LitE _) = True +isAtomic _ = False + +---------------------------------------------------------------------------------- + +-- TODO: export list awareness +insertModule :: Module b -> Program b -> Program b +insertModule (Module _ m) p = p <> m + +extractProgram :: Module b -> Program b +extractProgram (Module _ p) = p + +---------------------------------------------------------------------------------- + +makeBaseFunctor ''Expr + +freeVariables :: Expr' -> Set Name +freeVariables = cata go + where + go :: ExprF Name (Set Name) -> Set Name + go (VarF k) = S.singleton k + -- TODO: collect free vars in rhss of bs + go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns + where + es = rhssOf bs :: [Expr'] + ns = bindersOf bs + -- TODO: this feels a little wrong. maybe a different scheme is + -- appropriate + esFree = foldMap id $ freeVariables <$> es + + go (CaseF e as) = e `S.union` asFree + where + asFree = foldMap id $ freeVariables <$> (fmap altToLam as) + -- we map alts to lambdas to avoid writing a 'freeVariablesAlt' + altToLam (Alter _ ns e) = Lam ns e + go (LamF bs e) = e `S.difference` (S.fromList bs) + go e = foldMap id e + diff --git a/src/Core2Core.hs b/src/Core2Core.hs index 0a37265..1211a22 100644 --- a/src/Core2Core.hs +++ b/src/Core2Core.hs @@ -1,14 +1,85 @@ +{-# LANGUAGE LambdaCase #-} module Core2Core - ( + ( core2core + + -- internal utilities for convenience + , floatCase ) where ---------------------------------------------------------------------------------- +import Data.Functor.Foldable +import Data.Maybe (fromJust) +import Data.Set qualified as S +import Data.List +import Control.Monad.Writer +import Control.Monad.State +import Lens.Micro import Core.Syntax +import Core.Utils ---------------------------------------------------------------------------------- -core2core :: Program -> Program -core2core = undefined +core2core :: Program' -> Program' +core2core p = undefined -floatNonStrictCase :: Expr -> Expr -floatNonStrictCase (Case e as) = Case e () +-- assumes the provided expression is in a strict context +-- replaceNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef']) +-- replaceNonStrictCases names = runWriter . cata goE +-- where +-- goE :: ExprF Name (Writer [ScDef'] Expr') +-- -> Writer [ScDef'] Expr' +-- -- strict context +-- goE (VarF k) = pure (Var k) +-- goE (CaseF e as) = e *> ae' +-- where +-- ae = (\ (Alter _ _ b) -> b) <$> as +-- ae' = mconcat <$> traverse replaceNonStrictCases ae + +type Replacer = StateT [Name] (Writer [ScDef']) + +-- TODO: formally define a "strict context" and reference that here +replaceNonStrictCases :: [Name] -> Expr' -> (Expr', [ScDef']) +replaceNonStrictCases names = runWriter . flip evalStateT names . goE + where + goE :: Expr' -> Replacer Expr' + goE (Var k) = pure (Var k) + goE (LitE l) = pure (LitE l) + goE (Let Rec bs e) = Let Rec <$> bs' <*> goE e + where bs' = travBs goE bs + goE e = goC e + + goC :: Expr' -> Replacer Expr' + -- the only truly non-trivial case: when a case expr is found in a + -- non-strict context, we float it into a supercombinator, give it a + -- name consumed from the state, record the newly created sc within the + -- Writer, and finally return an expression appropriately calling the sc + goC p@(Case e as) = do + n <- name + let (e',sc) = floatCase n p + altBodies = (\(Alter _ _ b) -> b) <$> as + tell [sc] + goE e + traverse goE altBodies + pure e' + goC (f :$ x) = (:$) <$> goC f <*> goC x + goC (Let r bs e) = Let r <$> bs' <*> goE e + where bs' = travBs goC bs + + name = state (fromJust . uncons) + + -- extract the right-hand sides of a list of bindings, traverse each + -- one, and return the original list of bindings + travBs :: (Expr' -> Replacer Expr') -> [Binding'] -> Replacer [Binding'] + travBs c bs = bs ^.. each . _rhs + & traverse goC + & const (pure bs) + +-- when provided with a case expr, floatCase will float the case into a +-- supercombinator of its free variables. the sc is returned along with an +-- expression that calls the sc with the necessary arguments +floatCase :: Name -> Expr' -> (Expr', ScDef') +floatCase n c@(Case e as) = (e', sc) + where + sc = ScDef n caseFrees c + caseFrees = S.toList $ freeVariables c + e' = foldl App (Var n) (Var <$> caseFrees) diff --git a/src/GM.hs b/src/GM.hs index fa072d5..38d6e75 100644 --- a/src/GM.hs +++ b/src/GM.hs @@ -118,7 +118,7 @@ pure [] ---------------------------------------------------------------------------------- -evalProg :: CoreProgram -> Maybe (Node, Stats) +evalProg :: Program' -> Maybe (Node, Stats) evalProg p = res <&> (,sts) where final = eval (compile p) & last @@ -127,7 +127,7 @@ evalProg p = res <&> (,sts) resAddr = final ^. gmStack ^? _head res = resAddr >>= flip hLookup h -hdbgProg :: CoreProgram -> Handle -> IO (Node, Stats) +hdbgProg :: Program' -> Handle -> IO (Node, Stats) hdbgProg p hio = do (renderOut . showState) `traverse_` states -- TODO: i'd like the statistics to be at the top of the file, but `sts` @@ -548,7 +548,7 @@ pop [] = [] ---------------------------------------------------------------------------------- -compile :: CoreProgram -> GmState +compile :: Program' -> GmState compile p = GmState c [] [] h g sts where -- find the entry point and evaluate it @@ -575,7 +575,7 @@ compiledPrims = binop k i = (k, 2, [Push 1, Eval, Push 1, Eval, i, Update 2, Pop 2, Unwind]) -buildInitialHeap :: CoreProgram -> (GmHeap, Env) +buildInitialHeap :: Program' -> (GmHeap, Env) buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs where compiledScs = fmap compileSc ss <> compiledPrims @@ -588,20 +588,20 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs -- >> [ref/compileSc] -- type CompiledSC = (Name, Int, Code) - compileSc :: CoreScDef -> CompiledSC + compileSc :: ScDef' -> CompiledSC compileSc (ScDef n as b) = (n, d, compileR env b) where env = (NameKey <$> as) `zip` [0..] d = length as -- << [ref/compileSc] - compileR :: Env -> CoreExpr -> Code + compileR :: Env -> Expr' -> Code compileR g e = compileE g e <> [Update d, Pop d, Unwind] where d = length g - -- compile an expression in a lazy context - compileC :: Env -> CoreExpr -> Code + -- compile an expression in a non-strict context + compileC :: Env -> Expr' -> Code compileC g (Var k) | k `elem` domain = [Push n] | otherwise = [PushGlobal k] @@ -627,7 +627,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs -- kinda gross. revisit this addressed = bs `zip` reverse [0 .. d-1] - compileBinder :: Env -> (CoreBinding, Int) -> (Env, Code) + compileBinder :: Env -> (Binding', Int) -> (Env, Code) compileBinder m (k := v, a) = (m',c) where m' = (NameKey k, a) : m @@ -645,7 +645,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs initialisers = mconcat $ compileBinder <$> addressed body = compileC g' e - compileBinder :: (CoreBinding, Int) -> Code + compileBinder :: (Binding', Int) -> Code compileBinder (_ := v, a) = compileC g' v <> [Update a] compileC _ (Con t n) = [PushConstr t n] @@ -663,7 +663,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs -- compile an expression in a strict context such that a pointer to the -- expression is left on top of the stack in WHNF - compileE :: Env -> CoreExpr -> Code + compileE :: Env -> Expr' -> Code compileE _ (LitE l) = compileEL l compileE g (Let NonRec bs e) = -- we use compileE instead of compileC @@ -674,7 +674,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs -- kinda gross. revisit this addressed = bs `zip` reverse [0 .. d-1] - compileBinder :: Env -> (CoreBinding, Int) -> (Env, Code) + compileBinder :: Env -> (Binding', Int) -> (Env, Code) compileBinder m (k := v, a) = (m',c) where m' = (NameKey k, a) : m @@ -695,7 +695,7 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs body = compileE g' e -- we use compileE instead of compileC - compileBinder :: (CoreBinding, Int) -> Code + compileBinder :: (Binding', Int) -> Code compileBinder (_ := v, a) = compileC g' v <> [Update a] -- special cases for prim functions; essentially inlining @@ -710,10 +710,10 @@ buildInitialHeap (Program ss) = mapAccumL allocateSc mempty compiledScs compileE g e = compileC g e ++ [Eval] - compileD :: Env -> [CoreAlter] -> [(Tag, Code)] + compileD :: Env -> [Alter'] -> [(Tag, Code)] compileD g as = fmap (compileA g) as - compileA :: Env -> CoreAlter -> (Tag, Code) + compileA :: Env -> Alter' -> (Tag, Code) compileA g (Alter (AltData t) as e) = (t, [Split n] <> c <> [Slide n]) where n = length as