resolve named data in case exprs

This commit is contained in:
crumbtoo
2024-01-25 12:39:57 -07:00
parent 4c99e44c04
commit 4f39dd36f1
4 changed files with 50 additions and 28 deletions

View File

@@ -85,8 +85,8 @@ Program : ScTypeSig ';' Program { insTypeSig $1 $3 }
| ScTypeSig OptSemi { singletonTypeSig $1 } | ScTypeSig OptSemi { singletonTypeSig $1 }
| ScDef ';' Program { insScDef $1 $3 } | ScDef ';' Program { insScDef $1 $3 }
| ScDef OptSemi { singletonScDef $1 } | ScDef OptSemi { singletonScDef $1 }
| TLPragma ';' Program {% doTLPragma $1 $3 } | TLPragma Program {% doTLPragma $1 $2 }
| TLPragma OptSemi {% doTLPragma $1 mempty } | TLPragma {% doTLPragma $1 mempty }
TLPragma :: { Pragma } TLPragma :: { Pragma }
: '{-#' Words '#-}' { Pragma $2 } : '{-#' Words '#-}' { Pragma $2 }
@@ -106,7 +106,6 @@ ScDefs :: { [ScDef Name] }
ScDefs : ScDef ';' ScDefs { $1 : $3 } ScDefs : ScDef ';' ScDefs { $1 : $3 }
| ScDef ';' { [$1] } | ScDef ';' { [$1] }
| ScDef { [$1] } | ScDef { [$1] }
| {- epsilon -} { [] }
ScDef :: { ScDef Name } ScDef :: { ScDef Name }
ScDef : Var ParList '=' Expr { ScDef $1 $2 $4 } ScDef : Var ParList '=' Expr { ScDef $1 $2 $4 }

View File

@@ -6,8 +6,13 @@ Description : Core ASTs and the like
{-# LANGUAGE FunctionalDependencies #-} {-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE DerivingStrategies, DerivingVia #-} {-# LANGUAGE DerivingStrategies, DerivingVia #-}
-- for recursion-schemes
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable
, TemplateHaskell, TypeFamilies #-}
module Core.Syntax module Core.Syntax
( Expr(..) ( Expr(..)
, ExprF(..)
, ExprF'(..)
, Type(..) , Type(..)
, pattern TyInt , pattern TyInt
, Lit(..) , Lit(..)
@@ -43,6 +48,8 @@ import Data.Coerce
import Data.Pretty import Data.Pretty
import Data.List (intersperse) import Data.List (intersperse)
import Data.Function ((&)) import Data.Function ((&))
import Data.Functor.Foldable
import Data.Functor.Foldable.TH (makeBaseFunctor)
import Data.String import Data.String
import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as H import Data.HashMap.Strict qualified as H
@@ -142,8 +149,11 @@ data Program b = Program
via Generically (Program b) via Generically (Program b)
makeLenses ''Program makeLenses ''Program
makeBaseFunctor ''Expr
pure [] pure []
type ExprF' = ExprF Name
type Program' = Program Name type Program' = Program Name
type Expr' = Expr Name type Expr' = Expr Name
type ScDef' = ScDef Name type ScDef' = ScDef Name
@@ -193,3 +203,8 @@ instance HasLHS (ScDef b) (ScDef b) (b, [b]) (b, [b]) where
(\ (ScDef n as _) -> (n,as)) (\ (ScDef n as _) -> (n,as))
(\ (ScDef _ _ e) (n',as') -> (ScDef n' as' e)) (\ (ScDef _ _ e) (n',as') -> (ScDef n' as' e))
instance HasLHS (Binding b) (Binding b) b b where
_lhs = lens
(\ (k := _) -> k)
(\ (_ := e) k' -> k' := e)

View File

@@ -1,16 +1,10 @@
-- for recursion schemes
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
-- for recursion schemes
{-# LANGUAGE TemplateHaskell, TypeFamilies #-}
module Core.Utils module Core.Utils
( bindersOf ( programRhss
, rhssOf , programGlobals
, isAtomic , isAtomic
-- , insertModule -- , insertModule
, extractProgram , extractProgram
, freeVariables , freeVariables
, ExprF(..)
) )
where where
---------------------------------------------------------------------------------- ----------------------------------------------------------------------------------
@@ -23,13 +17,11 @@ import Lens.Micro
import GHC.Exts (IsList(..)) import GHC.Exts (IsList(..))
---------------------------------------------------------------------------------- ----------------------------------------------------------------------------------
bindersOf :: (IsList l, Item l ~ b) => [Binding b] -> l programGlobals :: Traversal' (Program b) b
bindersOf bs = fromList $ fmap f bs programGlobals = programScDefs . each . _lhs . _1
where f (k := _) = k
rhssOf :: (IsList l, Item l ~ Expr b) => [Binding b] -> l programRhss :: Traversal' (Program b) (Expr b)
rhssOf = fromList . fmap f programRhss = programScDefs . each . _rhs
where f (_ := v) = v
isAtomic :: Expr b -> Bool isAtomic :: Expr b -> Bool
isAtomic (Var _) = True isAtomic (Var _) = True
@@ -47,8 +39,6 @@ extractProgram (Module _ p) = p
---------------------------------------------------------------------------------- ----------------------------------------------------------------------------------
makeBaseFunctor ''Expr
freeVariables :: Expr' -> Set Name freeVariables :: Expr' -> Set Name
freeVariables = cata go freeVariables = cata go
where where
@@ -57,8 +47,8 @@ freeVariables = cata go
-- TODO: collect free vars in rhss of bs -- TODO: collect free vars in rhss of bs
go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns
where where
es = rhssOf bs :: [Expr'] es = bs ^.. each . _rhs :: [Expr']
ns = bindersOf bs ns = S.fromList $ bs ^.. each . _lhs
-- TODO: this feels a little wrong. maybe a different scheme is -- TODO: this feels a little wrong. maybe a different scheme is
-- appropriate -- appropriate
esFree = foldMap id $ freeVariables <$> es esFree = foldMap id $ freeVariables <$> es

View File

@@ -1,3 +1,4 @@
{-# LANGUAGE ImplicitParams #-}
{-# LANGUAGE LambdaCase #-} {-# LANGUAGE LambdaCase #-}
module Core2Core module Core2Core
( core2core ( core2core
@@ -18,8 +19,9 @@ import Control.Monad.Writer
import Control.Monad.State.Lazy import Control.Monad.State.Lazy
import Control.Arrow ((>>>)) import Control.Arrow ((>>>))
import Data.Text qualified as T import Data.Text qualified as T
import Data.HashMap.Strict (HashMap)
import Numeric (showHex) import Numeric (showHex)
import Lens.Micro import Lens.Micro.Platform
import Core.Syntax import Core.Syntax
import Core.Utils import Core.Utils
---------------------------------------------------------------------------------- ----------------------------------------------------------------------------------
@@ -28,19 +30,35 @@ core2core :: Program' -> Program'
core2core p = undefined core2core p = undefined
gmPrep :: Program' -> Program' gmPrep :: Program' -> Program'
gmPrep p = p & traverseOf rhss (floatNonStrictCases globals) gmPrep p = p & appFloater (floatNonStrictCases globals)
& runFloater & tagData
& \ (me,caseScs) -> me & programScDefs %~ (<>caseScs)
where where
rhss :: Traversal' (Program z) (Expr z)
rhss = programScDefs . each . _rhs
globals = p ^.. programScDefs . each . _lhs . _1 globals = p ^.. programScDefs . each . _lhs . _1
& S.fromList & S.fromList
tagData :: Program' -> Program'
tagData p = let ?dt = p ^. programDataTags
in p & programRhss %~ cata go where
go :: (?dt :: HashMap Name (Tag, Int)) => ExprF' Expr' -> Expr'
go (CaseF e as) = Case e (tagAlts <$> as)
go x = embed x
tagAlts :: (?dt :: HashMap Name (Tag, Int)) => Alter' -> Alter'
tagAlts (Alter (AltData c) bs e) = Alter (AltTag tag) bs e
where tag = case ?dt ^. at c of
Just (t,_) -> t
-- TODO: errorful
Nothing -> error $ "unknown constructor " <> show c
tagAlts x = x
-- | Auxilary type used in @floatNonSrictCases@ -- | Auxilary type used in @floatNonSrictCases@
type Floater = StateT [Name] (Writer [ScDef']) type Floater = StateT [Name] (Writer [ScDef'])
appFloater :: (Expr' -> Floater Expr') -> Program' -> Program'
appFloater fl p = p & traverseOf programRhss fl
& runFloater
& \ (me,floats) -> me & programScDefs %~ (<>floats)
runFloater :: Floater a -> (a, [ScDef']) runFloater :: Floater a -> (a, [ScDef'])
runFloater = flip evalStateT ns >>> runWriter runFloater = flip evalStateT ns >>> runWriter
where where