From 4f39dd36f1f84b0e2f8fc11453962469f1ab8ad0 Mon Sep 17 00:00:00 2001 From: crumbtoo Date: Thu, 25 Jan 2024 12:39:57 -0700 Subject: [PATCH] resolve named data in case exprs --- src/Core/Parse.y | 5 ++--- src/Core/Syntax.hs | 15 +++++++++++++++ src/Core/Utils.hs | 26 ++++++++------------------ src/Core2Core.hs | 32 +++++++++++++++++++++++++------- 4 files changed, 50 insertions(+), 28 deletions(-) diff --git a/src/Core/Parse.y b/src/Core/Parse.y index 969d3e5..7dbb6b5 100644 --- a/src/Core/Parse.y +++ b/src/Core/Parse.y @@ -85,8 +85,8 @@ Program : ScTypeSig ';' Program { insTypeSig $1 $3 } | ScTypeSig OptSemi { singletonTypeSig $1 } | ScDef ';' Program { insScDef $1 $3 } | ScDef OptSemi { singletonScDef $1 } - | TLPragma ';' Program {% doTLPragma $1 $3 } - | TLPragma OptSemi {% doTLPragma $1 mempty } + | TLPragma Program {% doTLPragma $1 $2 } + | TLPragma {% doTLPragma $1 mempty } TLPragma :: { Pragma } : '{-#' Words '#-}' { Pragma $2 } @@ -106,7 +106,6 @@ ScDefs :: { [ScDef Name] } ScDefs : ScDef ';' ScDefs { $1 : $3 } | ScDef ';' { [$1] } | ScDef { [$1] } - | {- epsilon -} { [] } ScDef :: { ScDef Name } ScDef : Var ParList '=' Expr { ScDef $1 $2 $4 } diff --git a/src/Core/Syntax.hs b/src/Core/Syntax.hs index 9717b61..83b4934 100644 --- a/src/Core/Syntax.hs +++ b/src/Core/Syntax.hs @@ -6,8 +6,13 @@ Description : Core ASTs and the like {-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE DerivingStrategies, DerivingVia #-} +-- for recursion-schemes +{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable + , TemplateHaskell, TypeFamilies #-} module Core.Syntax ( Expr(..) + , ExprF(..) + , ExprF'(..) , Type(..) , pattern TyInt , Lit(..) @@ -43,6 +48,8 @@ import Data.Coerce import Data.Pretty import Data.List (intersperse) import Data.Function ((&)) +import Data.Functor.Foldable +import Data.Functor.Foldable.TH (makeBaseFunctor) import Data.String import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as H @@ -142,8 +149,11 @@ data Program b = Program via Generically (Program b) makeLenses ''Program +makeBaseFunctor ''Expr pure [] +type ExprF' = ExprF Name + type Program' = Program Name type Expr' = Expr Name type ScDef' = ScDef Name @@ -193,3 +203,8 @@ instance HasLHS (ScDef b) (ScDef b) (b, [b]) (b, [b]) where (\ (ScDef n as _) -> (n,as)) (\ (ScDef _ _ e) (n',as') -> (ScDef n' as' e)) +instance HasLHS (Binding b) (Binding b) b b where + _lhs = lens + (\ (k := _) -> k) + (\ (_ := e) k' -> k' := e) + diff --git a/src/Core/Utils.hs b/src/Core/Utils.hs index 1a47785..956a067 100644 --- a/src/Core/Utils.hs +++ b/src/Core/Utils.hs @@ -1,16 +1,10 @@ --- for recursion schemes -{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-} --- for recursion schemes -{-# LANGUAGE TemplateHaskell, TypeFamilies #-} - module Core.Utils - ( bindersOf - , rhssOf + ( programRhss + , programGlobals , isAtomic -- , insertModule , extractProgram , freeVariables - , ExprF(..) ) where ---------------------------------------------------------------------------------- @@ -23,13 +17,11 @@ import Lens.Micro import GHC.Exts (IsList(..)) ---------------------------------------------------------------------------------- -bindersOf :: (IsList l, Item l ~ b) => [Binding b] -> l -bindersOf bs = fromList $ fmap f bs - where f (k := _) = k +programGlobals :: Traversal' (Program b) b +programGlobals = programScDefs . each . _lhs . _1 -rhssOf :: (IsList l, Item l ~ Expr b) => [Binding b] -> l -rhssOf = fromList . fmap f - where f (_ := v) = v +programRhss :: Traversal' (Program b) (Expr b) +programRhss = programScDefs . each . _rhs isAtomic :: Expr b -> Bool isAtomic (Var _) = True @@ -47,8 +39,6 @@ extractProgram (Module _ p) = p ---------------------------------------------------------------------------------- -makeBaseFunctor ''Expr - freeVariables :: Expr' -> Set Name freeVariables = cata go where @@ -57,8 +47,8 @@ freeVariables = cata go -- TODO: collect free vars in rhss of bs go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns where - es = rhssOf bs :: [Expr'] - ns = bindersOf bs + es = bs ^.. each . _rhs :: [Expr'] + ns = S.fromList $ bs ^.. each . _lhs -- TODO: this feels a little wrong. maybe a different scheme is -- appropriate esFree = foldMap id $ freeVariables <$> es diff --git a/src/Core2Core.hs b/src/Core2Core.hs index c21bd92..2036915 100644 --- a/src/Core2Core.hs +++ b/src/Core2Core.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE ImplicitParams #-} {-# LANGUAGE LambdaCase #-} module Core2Core ( core2core @@ -18,8 +19,9 @@ import Control.Monad.Writer import Control.Monad.State.Lazy import Control.Arrow ((>>>)) import Data.Text qualified as T +import Data.HashMap.Strict (HashMap) import Numeric (showHex) -import Lens.Micro +import Lens.Micro.Platform import Core.Syntax import Core.Utils ---------------------------------------------------------------------------------- @@ -28,19 +30,35 @@ core2core :: Program' -> Program' core2core p = undefined gmPrep :: Program' -> Program' -gmPrep p = p & traverseOf rhss (floatNonStrictCases globals) - & runFloater - & \ (me,caseScs) -> me & programScDefs %~ (<>caseScs) +gmPrep p = p & appFloater (floatNonStrictCases globals) + & tagData where - rhss :: Traversal' (Program z) (Expr z) - rhss = programScDefs . each . _rhs - globals = p ^.. programScDefs . each . _lhs . _1 & S.fromList +tagData :: Program' -> Program' +tagData p = let ?dt = p ^. programDataTags + in p & programRhss %~ cata go where + go :: (?dt :: HashMap Name (Tag, Int)) => ExprF' Expr' -> Expr' + go (CaseF e as) = Case e (tagAlts <$> as) + go x = embed x + + tagAlts :: (?dt :: HashMap Name (Tag, Int)) => Alter' -> Alter' + tagAlts (Alter (AltData c) bs e) = Alter (AltTag tag) bs e + where tag = case ?dt ^. at c of + Just (t,_) -> t + -- TODO: errorful + Nothing -> error $ "unknown constructor " <> show c + tagAlts x = x + -- | Auxilary type used in @floatNonSrictCases@ type Floater = StateT [Name] (Writer [ScDef']) +appFloater :: (Expr' -> Floater Expr') -> Program' -> Program' +appFloater fl p = p & traverseOf programRhss fl + & runFloater + & \ (me,floats) -> me & programScDefs %~ (<>floats) + runFloater :: Floater a -> (a, [ScDef']) runFloater = flip evalStateT ns >>> runWriter where