resolve named data in case exprs
This commit is contained in:
@@ -85,8 +85,8 @@ Program : ScTypeSig ';' Program { insTypeSig $1 $3 }
|
|||||||
| ScTypeSig OptSemi { singletonTypeSig $1 }
|
| ScTypeSig OptSemi { singletonTypeSig $1 }
|
||||||
| ScDef ';' Program { insScDef $1 $3 }
|
| ScDef ';' Program { insScDef $1 $3 }
|
||||||
| ScDef OptSemi { singletonScDef $1 }
|
| ScDef OptSemi { singletonScDef $1 }
|
||||||
| TLPragma ';' Program {% doTLPragma $1 $3 }
|
| TLPragma Program {% doTLPragma $1 $2 }
|
||||||
| TLPragma OptSemi {% doTLPragma $1 mempty }
|
| TLPragma {% doTLPragma $1 mempty }
|
||||||
|
|
||||||
TLPragma :: { Pragma }
|
TLPragma :: { Pragma }
|
||||||
: '{-#' Words '#-}' { Pragma $2 }
|
: '{-#' Words '#-}' { Pragma $2 }
|
||||||
@@ -106,7 +106,6 @@ ScDefs :: { [ScDef Name] }
|
|||||||
ScDefs : ScDef ';' ScDefs { $1 : $3 }
|
ScDefs : ScDef ';' ScDefs { $1 : $3 }
|
||||||
| ScDef ';' { [$1] }
|
| ScDef ';' { [$1] }
|
||||||
| ScDef { [$1] }
|
| ScDef { [$1] }
|
||||||
| {- epsilon -} { [] }
|
|
||||||
|
|
||||||
ScDef :: { ScDef Name }
|
ScDef :: { ScDef Name }
|
||||||
ScDef : Var ParList '=' Expr { ScDef $1 $2 $4 }
|
ScDef : Var ParList '=' Expr { ScDef $1 $2 $4 }
|
||||||
|
|||||||
@@ -6,8 +6,13 @@ Description : Core ASTs and the like
|
|||||||
{-# LANGUAGE FunctionalDependencies #-}
|
{-# LANGUAGE FunctionalDependencies #-}
|
||||||
{-# LANGUAGE TemplateHaskell #-}
|
{-# LANGUAGE TemplateHaskell #-}
|
||||||
{-# LANGUAGE DerivingStrategies, DerivingVia #-}
|
{-# LANGUAGE DerivingStrategies, DerivingVia #-}
|
||||||
|
-- for recursion-schemes
|
||||||
|
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable
|
||||||
|
, TemplateHaskell, TypeFamilies #-}
|
||||||
module Core.Syntax
|
module Core.Syntax
|
||||||
( Expr(..)
|
( Expr(..)
|
||||||
|
, ExprF(..)
|
||||||
|
, ExprF'(..)
|
||||||
, Type(..)
|
, Type(..)
|
||||||
, pattern TyInt
|
, pattern TyInt
|
||||||
, Lit(..)
|
, Lit(..)
|
||||||
@@ -43,6 +48,8 @@ import Data.Coerce
|
|||||||
import Data.Pretty
|
import Data.Pretty
|
||||||
import Data.List (intersperse)
|
import Data.List (intersperse)
|
||||||
import Data.Function ((&))
|
import Data.Function ((&))
|
||||||
|
import Data.Functor.Foldable
|
||||||
|
import Data.Functor.Foldable.TH (makeBaseFunctor)
|
||||||
import Data.String
|
import Data.String
|
||||||
import Data.HashMap.Strict (HashMap)
|
import Data.HashMap.Strict (HashMap)
|
||||||
import Data.HashMap.Strict qualified as H
|
import Data.HashMap.Strict qualified as H
|
||||||
@@ -142,8 +149,11 @@ data Program b = Program
|
|||||||
via Generically (Program b)
|
via Generically (Program b)
|
||||||
|
|
||||||
makeLenses ''Program
|
makeLenses ''Program
|
||||||
|
makeBaseFunctor ''Expr
|
||||||
pure []
|
pure []
|
||||||
|
|
||||||
|
type ExprF' = ExprF Name
|
||||||
|
|
||||||
type Program' = Program Name
|
type Program' = Program Name
|
||||||
type Expr' = Expr Name
|
type Expr' = Expr Name
|
||||||
type ScDef' = ScDef Name
|
type ScDef' = ScDef Name
|
||||||
@@ -193,3 +203,8 @@ instance HasLHS (ScDef b) (ScDef b) (b, [b]) (b, [b]) where
|
|||||||
(\ (ScDef n as _) -> (n,as))
|
(\ (ScDef n as _) -> (n,as))
|
||||||
(\ (ScDef _ _ e) (n',as') -> (ScDef n' as' e))
|
(\ (ScDef _ _ e) (n',as') -> (ScDef n' as' e))
|
||||||
|
|
||||||
|
instance HasLHS (Binding b) (Binding b) b b where
|
||||||
|
_lhs = lens
|
||||||
|
(\ (k := _) -> k)
|
||||||
|
(\ (_ := e) k' -> k' := e)
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,10 @@
|
|||||||
-- for recursion schemes
|
|
||||||
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
|
|
||||||
-- for recursion schemes
|
|
||||||
{-# LANGUAGE TemplateHaskell, TypeFamilies #-}
|
|
||||||
|
|
||||||
module Core.Utils
|
module Core.Utils
|
||||||
( bindersOf
|
( programRhss
|
||||||
, rhssOf
|
, programGlobals
|
||||||
, isAtomic
|
, isAtomic
|
||||||
-- , insertModule
|
-- , insertModule
|
||||||
, extractProgram
|
, extractProgram
|
||||||
, freeVariables
|
, freeVariables
|
||||||
, ExprF(..)
|
|
||||||
)
|
)
|
||||||
where
|
where
|
||||||
----------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------
|
||||||
@@ -23,13 +17,11 @@ import Lens.Micro
|
|||||||
import GHC.Exts (IsList(..))
|
import GHC.Exts (IsList(..))
|
||||||
----------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------
|
||||||
|
|
||||||
bindersOf :: (IsList l, Item l ~ b) => [Binding b] -> l
|
programGlobals :: Traversal' (Program b) b
|
||||||
bindersOf bs = fromList $ fmap f bs
|
programGlobals = programScDefs . each . _lhs . _1
|
||||||
where f (k := _) = k
|
|
||||||
|
|
||||||
rhssOf :: (IsList l, Item l ~ Expr b) => [Binding b] -> l
|
programRhss :: Traversal' (Program b) (Expr b)
|
||||||
rhssOf = fromList . fmap f
|
programRhss = programScDefs . each . _rhs
|
||||||
where f (_ := v) = v
|
|
||||||
|
|
||||||
isAtomic :: Expr b -> Bool
|
isAtomic :: Expr b -> Bool
|
||||||
isAtomic (Var _) = True
|
isAtomic (Var _) = True
|
||||||
@@ -47,8 +39,6 @@ extractProgram (Module _ p) = p
|
|||||||
|
|
||||||
----------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------
|
||||||
|
|
||||||
makeBaseFunctor ''Expr
|
|
||||||
|
|
||||||
freeVariables :: Expr' -> Set Name
|
freeVariables :: Expr' -> Set Name
|
||||||
freeVariables = cata go
|
freeVariables = cata go
|
||||||
where
|
where
|
||||||
@@ -57,8 +47,8 @@ freeVariables = cata go
|
|||||||
-- TODO: collect free vars in rhss of bs
|
-- TODO: collect free vars in rhss of bs
|
||||||
go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns
|
go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns
|
||||||
where
|
where
|
||||||
es = rhssOf bs :: [Expr']
|
es = bs ^.. each . _rhs :: [Expr']
|
||||||
ns = bindersOf bs
|
ns = S.fromList $ bs ^.. each . _lhs
|
||||||
-- TODO: this feels a little wrong. maybe a different scheme is
|
-- TODO: this feels a little wrong. maybe a different scheme is
|
||||||
-- appropriate
|
-- appropriate
|
||||||
esFree = foldMap id $ freeVariables <$> es
|
esFree = foldMap id $ freeVariables <$> es
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
{-# LANGUAGE ImplicitParams #-}
|
||||||
{-# LANGUAGE LambdaCase #-}
|
{-# LANGUAGE LambdaCase #-}
|
||||||
module Core2Core
|
module Core2Core
|
||||||
( core2core
|
( core2core
|
||||||
@@ -18,8 +19,9 @@ import Control.Monad.Writer
|
|||||||
import Control.Monad.State.Lazy
|
import Control.Monad.State.Lazy
|
||||||
import Control.Arrow ((>>>))
|
import Control.Arrow ((>>>))
|
||||||
import Data.Text qualified as T
|
import Data.Text qualified as T
|
||||||
|
import Data.HashMap.Strict (HashMap)
|
||||||
import Numeric (showHex)
|
import Numeric (showHex)
|
||||||
import Lens.Micro
|
import Lens.Micro.Platform
|
||||||
import Core.Syntax
|
import Core.Syntax
|
||||||
import Core.Utils
|
import Core.Utils
|
||||||
----------------------------------------------------------------------------------
|
----------------------------------------------------------------------------------
|
||||||
@@ -28,19 +30,35 @@ core2core :: Program' -> Program'
|
|||||||
core2core p = undefined
|
core2core p = undefined
|
||||||
|
|
||||||
gmPrep :: Program' -> Program'
|
gmPrep :: Program' -> Program'
|
||||||
gmPrep p = p & traverseOf rhss (floatNonStrictCases globals)
|
gmPrep p = p & appFloater (floatNonStrictCases globals)
|
||||||
& runFloater
|
& tagData
|
||||||
& \ (me,caseScs) -> me & programScDefs %~ (<>caseScs)
|
|
||||||
where
|
where
|
||||||
rhss :: Traversal' (Program z) (Expr z)
|
|
||||||
rhss = programScDefs . each . _rhs
|
|
||||||
|
|
||||||
globals = p ^.. programScDefs . each . _lhs . _1
|
globals = p ^.. programScDefs . each . _lhs . _1
|
||||||
& S.fromList
|
& S.fromList
|
||||||
|
|
||||||
|
tagData :: Program' -> Program'
|
||||||
|
tagData p = let ?dt = p ^. programDataTags
|
||||||
|
in p & programRhss %~ cata go where
|
||||||
|
go :: (?dt :: HashMap Name (Tag, Int)) => ExprF' Expr' -> Expr'
|
||||||
|
go (CaseF e as) = Case e (tagAlts <$> as)
|
||||||
|
go x = embed x
|
||||||
|
|
||||||
|
tagAlts :: (?dt :: HashMap Name (Tag, Int)) => Alter' -> Alter'
|
||||||
|
tagAlts (Alter (AltData c) bs e) = Alter (AltTag tag) bs e
|
||||||
|
where tag = case ?dt ^. at c of
|
||||||
|
Just (t,_) -> t
|
||||||
|
-- TODO: errorful
|
||||||
|
Nothing -> error $ "unknown constructor " <> show c
|
||||||
|
tagAlts x = x
|
||||||
|
|
||||||
-- | Auxilary type used in @floatNonSrictCases@
|
-- | Auxilary type used in @floatNonSrictCases@
|
||||||
type Floater = StateT [Name] (Writer [ScDef'])
|
type Floater = StateT [Name] (Writer [ScDef'])
|
||||||
|
|
||||||
|
appFloater :: (Expr' -> Floater Expr') -> Program' -> Program'
|
||||||
|
appFloater fl p = p & traverseOf programRhss fl
|
||||||
|
& runFloater
|
||||||
|
& \ (me,floats) -> me & programScDefs %~ (<>floats)
|
||||||
|
|
||||||
runFloater :: Floater a -> (a, [ScDef'])
|
runFloater :: Floater a -> (a, [ScDef'])
|
||||||
runFloater = flip evalStateT ns >>> runWriter
|
runFloater = flip evalStateT ns >>> runWriter
|
||||||
where
|
where
|
||||||
|
|||||||
Reference in New Issue
Block a user