resolve named data in case exprs

This commit is contained in:
crumbtoo
2024-01-25 12:39:57 -07:00
parent 4c99e44c04
commit 4f39dd36f1
4 changed files with 50 additions and 28 deletions

View File

@@ -1,16 +1,10 @@
-- for recursion schemes
{-# LANGUAGE DeriveFunctor, DeriveFoldable, DeriveTraversable #-}
-- for recursion schemes
{-# LANGUAGE TemplateHaskell, TypeFamilies #-}
module Core.Utils
( bindersOf
, rhssOf
( programRhss
, programGlobals
, isAtomic
-- , insertModule
, extractProgram
, freeVariables
, ExprF(..)
)
where
----------------------------------------------------------------------------------
@@ -23,13 +17,11 @@ import Lens.Micro
import GHC.Exts (IsList(..))
----------------------------------------------------------------------------------
bindersOf :: (IsList l, Item l ~ b) => [Binding b] -> l
bindersOf bs = fromList $ fmap f bs
where f (k := _) = k
programGlobals :: Traversal' (Program b) b
programGlobals = programScDefs . each . _lhs . _1
rhssOf :: (IsList l, Item l ~ Expr b) => [Binding b] -> l
rhssOf = fromList . fmap f
where f (_ := v) = v
programRhss :: Traversal' (Program b) (Expr b)
programRhss = programScDefs . each . _rhs
isAtomic :: Expr b -> Bool
isAtomic (Var _) = True
@@ -47,8 +39,6 @@ extractProgram (Module _ p) = p
----------------------------------------------------------------------------------
makeBaseFunctor ''Expr
freeVariables :: Expr' -> Set Name
freeVariables = cata go
where
@@ -57,8 +47,8 @@ freeVariables = cata go
-- TODO: collect free vars in rhss of bs
go (LetF _ bs e) = (e `S.union` esFree) `S.difference` ns
where
es = rhssOf bs :: [Expr']
ns = bindersOf bs
es = bs ^.. each . _rhs :: [Expr']
ns = S.fromList $ bs ^.. each . _lhs
-- TODO: this feels a little wrong. maybe a different scheme is
-- appropriate
esFree = foldMap id $ freeVariables <$> es