mirror of
https://github.com/GrammaticalFramework/gf-core.git
synced 2026-04-22 11:19:32 -06:00
Transfer added guards and Eq derivation.
This commit is contained in:
@@ -28,11 +28,11 @@ declsToCore :: [Decl] -> [Decl]
|
||||
declsToCore m = evalState (declsToCore_ m) newState
|
||||
|
||||
declsToCore_ :: [Decl] -> C [Decl]
|
||||
declsToCore_ = desugar
|
||||
>>> numberMetas
|
||||
>>> deriveDecls
|
||||
>>> replaceCons
|
||||
declsToCore_ = deriveDecls
|
||||
>>> desugar
|
||||
>>> compilePattDecls
|
||||
>>> numberMetas
|
||||
>>> replaceCons
|
||||
>>> expandOrPatts
|
||||
>>> optimize
|
||||
|
||||
@@ -61,13 +61,14 @@ numberMetas = mapM f
|
||||
return $ EVar $ Ident $ "?" ++ show (nextMeta st) -- FIXME: hack
|
||||
_ -> composOpM f t
|
||||
|
||||
|
||||
--
|
||||
-- * Pattern equations
|
||||
--
|
||||
|
||||
compilePattDecls :: [Decl] -> C [Decl]
|
||||
compilePattDecls [] = return []
|
||||
compilePattDecls (d@(ValueDecl x _ _):ds) =
|
||||
compilePattDecls (d@(ValueDecl x _ _ _):ds) =
|
||||
do
|
||||
let (xs,rest) = span (isValueDecl x) ds
|
||||
d <- mergeDecls (d:xs)
|
||||
@@ -75,20 +76,26 @@ compilePattDecls (d@(ValueDecl x _ _):ds) =
|
||||
return (d:rs)
|
||||
compilePattDecls (d:ds) = liftM (d:) (compilePattDecls ds)
|
||||
|
||||
-- | Take a non-empty list of pattern equations for the same
|
||||
-- function, and produce a single declaration.
|
||||
-- | Checks if a declaration is a value declaration
|
||||
-- of the given identifier.
|
||||
isValueDecl :: Ident -> Decl -> Bool
|
||||
isValueDecl x (ValueDecl y _ _ _) = x == y
|
||||
isValueDecl _ _ = False
|
||||
|
||||
-- | Take a non-empty list of pattern equations with guards
|
||||
-- for the same function, and produce a single declaration.
|
||||
mergeDecls :: [Decl] -> C Decl
|
||||
mergeDecls ds@(ValueDecl x p _:_)
|
||||
= do let cs = [ (ps,rhs) | ValueDecl _ ps rhs <- ds ]
|
||||
(pss,rhss) = unzip cs
|
||||
mergeDecls ds@(ValueDecl x p _ _:_)
|
||||
= do let cs = [ (ps,g,rhs) | ValueDecl _ ps g rhs <- ds ]
|
||||
(pss,_,_) = unzip3 cs
|
||||
n = length p
|
||||
when (not (all ((== n) . length) pss))
|
||||
$ fail $ "Pattern count mismatch for " ++ printTree x
|
||||
vs <- freshIdents n
|
||||
let cases = map (\ (ps,rhs) -> Case (mkPRec ps) rhs) cs
|
||||
let cases = map (\ (ps,g,rhs) -> Case (mkPRec ps) g rhs) cs
|
||||
c = ECase (mkERec (map EVar vs)) cases
|
||||
f = foldr (EAbs . VVar) c vs
|
||||
return $ ValueDecl x [] f
|
||||
return $ ValueDecl x [] GuardNo f
|
||||
where mkRec r f = r . zipWith (\i e -> f (Ident ("p"++show i)) e) [0..]
|
||||
mkPRec = mkRec PRec FieldPattern
|
||||
mkERec = mkRec ERec FieldValue
|
||||
@@ -118,6 +125,10 @@ derivators = [
|
||||
("Ord", deriveOrd)
|
||||
]
|
||||
|
||||
--
|
||||
-- * Deriving instances of Compos
|
||||
--
|
||||
|
||||
deriveCompos :: Derivator
|
||||
deriveCompos t@(Ident ts) k cs =
|
||||
do
|
||||
@@ -128,7 +139,7 @@ deriveCompos t@(Ident ts) k cs =
|
||||
dt = apply (EVar (Ident "Compos")) [c, EVar t]
|
||||
r = ERec [FieldValue (Ident "composOp") co,
|
||||
FieldValue (Ident "composFold") cf]
|
||||
return [TypeDecl d dt, ValueDecl d [] r]
|
||||
return [TypeDecl d dt, ValueDecl d [] GuardNo r]
|
||||
|
||||
deriveComposOp :: Ident -> Exp -> [(Ident,Exp)] -> C Exp
|
||||
deriveComposOp t k cs =
|
||||
@@ -149,9 +160,9 @@ deriveComposOp t k cs =
|
||||
EApp (EVar t') c | t' == t -> apply (e f) [c, e v]
|
||||
_ -> e v
|
||||
calls = zipWith rec vars (argumentTypes ct)
|
||||
return $ Case (PCons ci (map PVar vars)) (apply (e ci) calls)
|
||||
return $ Case (PCons ci (map PVar vars)) gtrue (apply (e ci) calls)
|
||||
cases <- mapM (uncurry mkCase) cs
|
||||
let cases' = cases ++ [Case PWild (e x)]
|
||||
let cases' = cases ++ [Case PWild gtrue (e x)]
|
||||
fb <- abstract (arity k) $ const $ pv f \-> pv x \-> ECase (e x) cases'
|
||||
return fb
|
||||
|
||||
@@ -180,17 +191,61 @@ deriveComposFold t k cs =
|
||||
p = EProj (e r) (Ident "mplus")
|
||||
joinCalls [] = z
|
||||
joinCalls cs = foldr1 (\x y -> apply p [x,y]) cs
|
||||
return $ Case (PCons ci (map PVar vars)) (joinCalls calls)
|
||||
return $ Case (PCons ci (map PVar vars)) gtrue (joinCalls calls)
|
||||
cases <- mapM (uncurry mkCase) cs
|
||||
let cases' = cases ++ [Case PWild (e x)]
|
||||
let cases' = cases ++ [Case PWild gtrue (e x)]
|
||||
fb <- abstract (arity k) $ const $ pv f \-> pv x \-> ECase (e x) cases'
|
||||
return $ VWild \-> pv r \-> fb
|
||||
|
||||
--
|
||||
-- * Deriving instances of Show
|
||||
--
|
||||
|
||||
deriveShow :: Derivator
|
||||
deriveShow t k cs = fail $ "derive Show not implemented"
|
||||
|
||||
--
|
||||
-- * Deriving instances of Eq
|
||||
--
|
||||
|
||||
-- FIXME: how do we require Eq instances for all
|
||||
-- constructor arguments?
|
||||
|
||||
deriveEq :: Derivator
|
||||
deriveEq t k cs = fail $ "derive Eq not implemented"
|
||||
deriveEq t@(Ident tn) k cs =
|
||||
do
|
||||
let ats = argumentTypes k
|
||||
d = Ident ("eq_"++tn)
|
||||
dt <- abstractType ats (EApp (EVar (Ident "Eq")) . apply (EVar t))
|
||||
eq <- mkEq
|
||||
r <- abstract (arity k) (\_ -> ERec [FieldValue (Ident "eq") eq])
|
||||
return [TypeDecl d dt, ValueDecl d [] GuardNo r]
|
||||
where
|
||||
mkEq = do
|
||||
x <- freshIdent
|
||||
cases <- mapM (uncurry mkEqCase) cs
|
||||
return $ EAbs (VVar x) (ECase (EVar x) cases)
|
||||
mkEqCase c ct =
|
||||
do
|
||||
let n = arity ct
|
||||
vs1 <- freshIdents n
|
||||
vs2 <- freshIdents n
|
||||
y <- freshIdent
|
||||
let p1 = PCons c (map PVar vs1)
|
||||
p2 = PCons c (map PVar vs2)
|
||||
es1 = map EVar vs1
|
||||
es2 = map EVar vs2
|
||||
tc | n == 0 = true
|
||||
-- FIXME: using EEq doesn't work right now
|
||||
| otherwise = foldr1 EAnd (zipWith EEq es1 es2)
|
||||
c1 = Case p2 gtrue tc
|
||||
c2 = Case PWild gtrue false
|
||||
return $ Case p1 gtrue (EAbs (VVar y) (ECase (EVar y) [c1,c2]))
|
||||
|
||||
|
||||
--
|
||||
-- * Deriving instances of Ord
|
||||
--
|
||||
|
||||
deriveOrd :: Derivator
|
||||
deriveOrd t k cs = fail $ "derive Ord not implemented"
|
||||
@@ -268,10 +323,10 @@ removeUselessMatch = return . map f
|
||||
f x = case x of
|
||||
EAbs (VVar x) b ->
|
||||
case f b of
|
||||
-- replace \x -> case x of { y -> e } with \y -> e,
|
||||
-- replace \x -> case x of { y | True -> e } with \y -> e,
|
||||
-- if x is not free in e
|
||||
ECase (EVar x') [Case (PVar y) e]
|
||||
| x' == x && not (x `isFreeIn` e)
|
||||
ECase (EVar x') [Case (PVar y) g e]
|
||||
| x' == x && isTrueGuard g && not (x `isFreeIn` e)
|
||||
-> f (EAbs (VVar y) e)
|
||||
-- replace unused variable in lambda with wild card
|
||||
e | not (x `isFreeIn` e) -> f (EAbs VWild e)
|
||||
@@ -282,31 +337,33 @@ removeUselessMatch = return . map f
|
||||
v = if not (x `isFreeIn` e') then VWild else VVar x
|
||||
in EPi v (f t) e'
|
||||
-- replace unused variables in case patterns with wild cards
|
||||
Case p e ->
|
||||
let e' = f e
|
||||
p' = f (removeUnusedVarPatts (freeVars e') p)
|
||||
in Case p' e'
|
||||
Case p (GuardExp g) e ->
|
||||
let g' = f g
|
||||
e' = f e
|
||||
used = freeVars g' `Set.union` freeVars e'
|
||||
p' = f (removeUnusedVarPatts used p)
|
||||
in Case p' (GuardExp g') e'
|
||||
-- for value declarations without patterns, compilePattDecls
|
||||
-- generates pattern matching on the empty record, remove these
|
||||
ECase (ERec []) [Case (PRec []) e] -> f e
|
||||
ECase (ERec []) [Case (PRec []) g e] | isTrueGuard g -> f e
|
||||
-- if the pattern matching is on a single field of a record expression
|
||||
-- with only one field, there is no need to wrap it in a record
|
||||
ECase (ERec [FieldValue x e]) cs | all (isSingleFieldPattern x) (casePatterns cs)
|
||||
-> f (ECase e [ Case p r | Case (PRec [FieldPattern _ p]) r <- cs ])
|
||||
-- for all fields in record matching where all patterns just
|
||||
-> f (ECase e [ Case p g r | Case (PRec [FieldPattern _ p]) g r <- cs ])
|
||||
-- for all fields in record matching where all patterns for the field just
|
||||
-- bind variables, substitute in the field value (if it is a variable)
|
||||
-- in the right hand sides.
|
||||
-- in the guards and right hand sides.
|
||||
ECase (ERec fs) cs | all isPRec (casePatterns cs) ->
|
||||
let g (FieldValue f v@(EVar _):fs) xs
|
||||
let h (FieldValue f v@(EVar _):fs) xs
|
||||
| all (onlyBindsFieldToVariable f) (casePatterns xs)
|
||||
= g fs (map (inlineField f v) xs)
|
||||
g (f:fs) xs = let (fs',xs') = g fs xs in (f:fs',xs')
|
||||
g [] xs = ([],xs)
|
||||
inlineField f v (Case (PRec fps) e) =
|
||||
= h fs (map (inlineField f v) xs)
|
||||
h (f:fs) xs = let (fs',xs') = h fs xs in (f:fs',xs')
|
||||
h [] xs = ([],xs)
|
||||
inlineField f v (Case (PRec fps) (GuardExp g) e) =
|
||||
let p' = PRec [fp | fp@(FieldPattern f' _) <- fps, f' /= f]
|
||||
ss = zip (fieldPatternVars f fps) (repeat v)
|
||||
in Case p' (substs ss e)
|
||||
(fs',cs') = g fs cs
|
||||
in Case p' (GuardExp (substs ss g)) (substs ss e)
|
||||
(fs',cs') = h fs cs
|
||||
x' = ECase (ERec fs') cs'
|
||||
in if length fs' < length fs then f x' else composOp f x'
|
||||
-- Remove wild card patterns in record patterns
|
||||
@@ -314,6 +371,11 @@ removeUselessMatch = return . map f
|
||||
where wildcards = [fp | fp@(FieldPattern _ PWild) <- fps]
|
||||
_ -> composOp f x
|
||||
|
||||
isTrueGuard :: Guard -> Bool
|
||||
isTrueGuard (GuardExp (EVar (Ident "True"))) = True
|
||||
isTrueGuard GuardNo = True
|
||||
isTrueGuard _ = False
|
||||
|
||||
removeUnusedVarPatts :: Set Ident -> Tree a -> Tree a
|
||||
removeUnusedVarPatts keep x = case x of
|
||||
PVar id | not (id `Set.member` keep) -> PWild
|
||||
@@ -325,7 +387,7 @@ isSingleFieldPattern x p = case p of
|
||||
_ -> False
|
||||
|
||||
casePatterns :: [Case] -> [Pattern]
|
||||
casePatterns cs = [p | Case p _ <- cs]
|
||||
casePatterns cs = [p | Case p _ _ <- cs]
|
||||
|
||||
isPRec :: Pattern -> Bool
|
||||
isPRec (PRec _) = True
|
||||
@@ -357,7 +419,7 @@ expandOrPatts = return . map f
|
||||
_ -> composOp f x
|
||||
|
||||
expandCase :: Case -> [Case]
|
||||
expandCase (Case p e) = [ Case p' e | p' <- expandPatt p ]
|
||||
expandCase (Case p g e) = [ Case p' g e | p' <- expandPatt p ]
|
||||
|
||||
expandPatt :: Pattern -> [Pattern]
|
||||
expandPatt p = case p of
|
||||
@@ -383,14 +445,15 @@ desugar = return . map f
|
||||
f x = case x of
|
||||
PListCons p1 p2 -> pListCons <| p1 <| p2
|
||||
PList xs -> pList (map f [p | PListElem p <- xs])
|
||||
GuardNo -> gtrue
|
||||
EIf exp0 exp1 exp2 -> ifBool <| exp0 <| exp1 <| exp2
|
||||
EDo bs e -> mkDo (map f bs) (f e)
|
||||
BindNoVar exp0 -> BindVar VWild <| exp0
|
||||
EPiNoVar exp0 exp1 -> EPi VWild <| exp0 <| exp1
|
||||
EBind exp0 exp1 -> appBind <| exp0 <| exp1
|
||||
EBindC exp0 exp1 -> appBindC <| exp0 <| exp1
|
||||
EOr exp0 exp1 -> andBool <| exp0 <| exp1
|
||||
EAnd exp0 exp1 -> orBool <| exp0 <| exp1
|
||||
EOr exp0 exp1 -> orBool <| exp0 <| exp1
|
||||
EAnd exp0 exp1 -> andBool <| exp0 <| exp1
|
||||
EEq exp0 exp1 -> overlBin "eq" <| exp0 <| exp1
|
||||
ENe exp0 exp1 -> overlBin "ne" <| exp0 <| exp1
|
||||
ELt exp0 exp1 -> overlBin "lt" <| exp0 <| exp1
|
||||
@@ -457,14 +520,14 @@ appCons e1 e2 = apply (EVar (Ident "Cons")) [EMeta,e1,e2]
|
||||
--
|
||||
|
||||
andBool :: Exp -> Exp -> Exp
|
||||
andBool e1 e2 = ifBool e1 e2 (var "False")
|
||||
andBool e1 e2 = ifBool e1 e2 false
|
||||
|
||||
orBool :: Exp -> Exp -> Exp
|
||||
orBool e1 e2 = ifBool e1 (var "True") e2
|
||||
orBool e1 e2 = ifBool e1 true e2
|
||||
|
||||
ifBool :: Exp -> Exp -> Exp -> Exp
|
||||
ifBool c t e = ECase c [Case (PCons (Ident "True") []) t,
|
||||
Case (PCons (Ident "False") []) e]
|
||||
ifBool c t e = ECase c [Case (PCons (Ident "True") []) gtrue t,
|
||||
Case (PCons (Ident "False") []) gtrue e]
|
||||
|
||||
--
|
||||
-- * Substitution
|
||||
@@ -483,7 +546,7 @@ substs ss = f (Map.fromList ss)
|
||||
ELet ds e3 ->
|
||||
ELet [LetDef id (f ss e1) (f ss' e2) | LetDef id e1 e2 <- ds] (f ss' e3)
|
||||
where ss' = ss `mapMinusSet` letDefBinds ds
|
||||
Case p e -> Case p (f ss' e) where ss' = ss `mapMinusSet` binds p
|
||||
Case p g e -> Case p (f ss' g) (f ss' e) where ss' = ss `mapMinusSet` binds p
|
||||
EAbs (VVar id) e -> EAbs (VVar id) (f ss' e) where ss' = Map.delete id ss
|
||||
EPi (VVar id) e1 e2 ->
|
||||
EPi (VVar id) (f ss e1) (f ss' e2) where ss' = Map.delete id ss
|
||||
@@ -497,6 +560,15 @@ substs ss = f (Map.fromList ss)
|
||||
var :: String -> Exp
|
||||
var s = EVar (Ident s)
|
||||
|
||||
true :: Exp
|
||||
true = var "True"
|
||||
|
||||
false :: Exp
|
||||
false = var "False"
|
||||
|
||||
gtrue :: Guard
|
||||
gtrue = GuardExp true
|
||||
|
||||
-- | Apply an expression to a list of arguments.
|
||||
apply :: Exp -> [Exp] -> Exp
|
||||
apply = foldl EApp
|
||||
@@ -511,7 +583,8 @@ abstract n f =
|
||||
|
||||
-- | Abstract a type over some arguments.
|
||||
abstractType :: [Exp] -- ^ argument types
|
||||
-> ([Exp] -> Exp)
|
||||
-> ([Exp] -> Exp) -- ^ function from variable expressions
|
||||
-- to the expression to return
|
||||
-> C Exp
|
||||
abstractType ts f =
|
||||
do
|
||||
@@ -551,7 +624,8 @@ freeVars = f
|
||||
(Set.unions (f exp3:map f (letDefRhss defs)) Set.\\ letDefBinds defs)
|
||||
:map f (letDefTypes defs)
|
||||
ECase exp cases -> f exp `Set.union`
|
||||
Set.unions [ f e Set.\\ binds p | Case p e <- cases]
|
||||
Set.unions [(f g `Set.union` f e) Set.\\ binds p
|
||||
| Case p g e <- cases]
|
||||
EAbs (VVar id) exp -> Set.delete id (f exp)
|
||||
EPi (VVar id) exp1 exp2 -> f exp1 `Set.union` Set.delete id (f exp2)
|
||||
EVar i -> Set.singleton i
|
||||
@@ -568,7 +642,7 @@ countFreeOccur x = f
|
||||
f t = case t of
|
||||
ELet defs _ | x `Set.member` letDefBinds defs ->
|
||||
sum (map f (letDefTypes defs))
|
||||
Case p e | x `Set.member` binds p -> 0
|
||||
Case p _ _ | x `Set.member` binds p -> 0
|
||||
EAbs (VVar id) _ | id == x -> 0
|
||||
EPi (VVar id) exp1 _ | id == x -> f exp1
|
||||
EVar id | id == x -> 1
|
||||
@@ -584,11 +658,6 @@ binds = f
|
||||
PVar id -> Set.singleton id
|
||||
_ -> composOpMonoid f p
|
||||
|
||||
-- | Checks if a declaration is a value declaration
|
||||
-- of the given identifier.
|
||||
isValueDecl :: Ident -> Decl -> Bool
|
||||
isValueDecl x (ValueDecl y _ _) = x == y
|
||||
isValueDecl _ _ = False
|
||||
|
||||
fromPRec :: [FieldPattern] -> [(Ident,Pattern)]
|
||||
fromPRec fps = [ (l,p) | FieldPattern l p <- fps ]
|
||||
|
||||
Reference in New Issue
Block a user