diff --git a/src/Core/Syntax.hs b/src/Core/Syntax.hs index 174e63d..2f66e00 100644 --- a/src/Core/Syntax.hs +++ b/src/Core/Syntax.hs @@ -25,6 +25,7 @@ module Core.Syntax , Alter' , Binding' , HasRHS(_rhs) + , HasLHS(_lhs) ) where ---------------------------------------------------------------------------------- @@ -132,3 +133,16 @@ instance HasRHS (Binding b) b where (\ (_ := e) -> e) (\ (k := _) e' -> k := e') +class HasLHS s a | s -> a where + _lhs :: Lens' s a + +instance HasLHS (Alter b) (AltCon, [b]) where + _lhs = lens + (\ (Alter a bs _) -> (a,bs)) + (\ (Alter _ _ e) (a',bs') -> Alter a' bs' e) + +instance HasLHS (ScDef b) (b, [b]) where + _lhs = lens + (\ (ScDef n as _) -> (n,as)) + (\ (ScDef _ _ e) (n',as') -> (ScDef n' as' e)) + diff --git a/src/Core2Core.hs b/src/Core2Core.hs index e4b6f89..c2f5a03 100644 --- a/src/Core2Core.hs +++ b/src/Core2Core.hs @@ -11,6 +11,7 @@ module Core2Core ---------------------------------------------------------------------------------- import Data.Functor.Foldable import Data.Maybe (fromJust) +import Data.Set (Set) import Data.Set qualified as S import Data.List import Control.Monad.Writer @@ -30,10 +31,13 @@ gmPrep p = p' <> Program caseScs where rhss :: Applicative f => (Expr z -> f (Expr z)) -> Program z -> f (Program z) rhss = programScDefs . each . _rhs + globals = p ^.. programScDefs . each . _lhs . _1 + & S.fromList + -- i kinda don't like that we're calling floatNonStrictCases twice tbh - p' = p & rhss %~ fst . runFloater . floatNonStrictCases + p' = p & rhss %~ fst . runFloater . floatNonStrictCases globals caseScs = (p ^.. rhss) - <&> snd . runFloater . floatNonStrictCases + <&> snd . runFloater . floatNonStrictCases globals & mconcat -- | Auxilary type used in @floatNonSrictCases@ @@ -47,8 +51,8 @@ runFloater = flip evalStateT ns >>> runWriter -- TODO: formally define a "strict context" and reference that here -- the returned ScDefs are guaranteed to be free of non-strict cases. -floatNonStrictCases :: Expr' -> Floater Expr' -floatNonStrictCases = goE +floatNonStrictCases :: Set Name -> Expr' -> Floater Expr' +floatNonStrictCases g = goE where goE :: Expr' -> Floater Expr' goE (Var k) = pure (Var k) @@ -64,7 +68,7 @@ floatNonStrictCases = goE -- Writer, and finally return an expression appropriately calling the sc goC p@(Case e as) = do n <- name - let (e',sc) = floatCase n p + let (e',sc) = floatCase g n p altBodies = (\(Alter _ _ b) -> b) <$> as tell [sc] goE e @@ -75,6 +79,7 @@ floatNonStrictCases = goE where bs' = travBs goC bs goC (LitE l) = pure (LitE l) goC (Var k) = pure (Var k) + goC (Con t as) = pure (Con t as) name = state (fromJust . uncons) @@ -88,10 +93,10 @@ floatNonStrictCases = goE -- when provided with a case expr, floatCase will float the case into a -- supercombinator of its free variables. the sc is returned along with an -- expression that calls the sc with the necessary arguments -floatCase :: Name -> Expr' -> (Expr', ScDef') -floatCase n c@(Case e as) = (e', sc) +floatCase :: Set Name -> Name -> Expr' -> (Expr', ScDef') +floatCase g n c@(Case e as) = (e', sc) where sc = ScDef n caseFrees c - caseFrees = S.toList $ freeVariables c + caseFrees = S.toList $ freeVariables c `S.difference` g e' = foldl App (Var n) (Var <$> caseFrees) diff --git a/src/GM.hs b/src/GM.hs index 38d6e75..9f3a27b 100644 --- a/src/GM.hs +++ b/src/GM.hs @@ -27,6 +27,7 @@ import System.IO (Handle, hPutStrLn) import Data.String (IsString) import Data.Heap import Debug.Trace +import Core2Core import Core ---------------------------------------------------------------------------------- @@ -551,9 +552,10 @@ pop [] = [] compile :: Program' -> GmState compile p = GmState c [] [] h g sts where + p' = gmPrep p -- find the entry point and evaluate it c = [PushGlobal "main", Eval] - (h,g) = buildInitialHeap p + (h,g) = buildInitialHeap p' sts = def type CompiledSC = (Name, Int, Code)