diff --git a/src/Transfer/SyntaxToCore.hs b/src/Transfer/SyntaxToCore.hs index 0ba92a250..1b17e4a3f 100644 --- a/src/Transfer/SyntaxToCore.hs +++ b/src/Transfer/SyntaxToCore.hs @@ -36,8 +36,7 @@ declsToCore_ = numberMetas >>> optimize optimize :: [Decl] -> C [Decl] -optimize = removeUnusedVariables - >>> removeUselessMatch +optimize = removeUselessMatch >>> betaReduce newState :: CState @@ -263,7 +262,7 @@ betaReduce = return . map f _ -> composOp f t -- --- * Remove useless pattern matching. +-- * Remove useless pattern matching and variable binding. -- removeUselessMatch :: [Decl] -> C [Decl] @@ -271,53 +270,83 @@ removeUselessMatch = return . map f where f :: Tree a -> Tree a f x = case x of - -- replace \x -> case x of { y -> e } with \y -> e, - -- if x is not free in e - -- FIXME: this checks the result of the recursive call, - -- can we do something about this? - EAbs (VVar x) b -> + EAbs (VVar x) b -> case f b of + -- replace \x -> case x of { y -> e } with \y -> e, + -- if x is not free in e ECase (EVar x') [Case (PVar y) e] | x' == x && not (x `isFreeIn` e) -> f (EAbs (VVar y) e) + -- replace unused variable in lambda with wild card + e | not (x `isFreeIn` e) -> f (EAbs VWild e) e -> EAbs (VVar x) e + -- replace unused variable in pi with wild card + EPi (VVar x) t e -> + let e' = f e + v = if not (x `isFreeIn` e') then VWild else VVar x + in EPi v (f t) e' + -- replace unused variables in case patterns with wild cards + Case p e -> + let e' = f e + p' = f (removeUnusedVarPatts (freeVars e') p) + in Case p' e' -- for value declarations without patterns, compilePattDecls -- generates pattern matching on the empty record, remove these ECase (ERec []) [Case (PRec []) e] -> f e -- if the pattern matching is on a single field of a record expression -- with only one field, there is no need to wrap it in a record - ECase (ERec [FieldValue x e]) cs | all (isSingleFieldPattern x) [ p | Case p _ <- cs] + ECase (ERec [FieldValue x e]) cs | all (isSingleFieldPattern x) (casePatterns cs) -> f (ECase e [ Case p r | Case (PRec [FieldPattern _ p]) r <- cs ]) - -- In cases: remove record field patterns which only bind unused variables - Case (PRec fps) e -> Case (f (PRec (fps \\ unused))) (f e) - where unused = [fp | fp@(FieldPattern l (PVar id)) <- fps, - not (id `isFreeIn` e)] + -- for all fields in record matching where all patterns just + -- bind variables, substitute in the field value (if it is a variable) + -- in the right hand sides. + ECase (ERec fs) cs | all isPRec (casePatterns cs) -> + let g (FieldValue f v@(EVar _):fs) xs + | all (onlyBindsFieldToVariable f) (casePatterns xs) + = g fs (map (inlineField f v) xs) + g (f:fs) xs = let (fs',xs') = g fs xs in (f:fs',xs') + g [] xs = ([],xs) + inlineField f v (Case (PRec fps) e) = + let p' = PRec [fp | fp@(FieldPattern f' _) <- fps, f' /= f] + ss = zip (fieldPatternVars f fps) (repeat v) + in Case p' (substs ss e) + (fs',cs') = g fs cs + x' = ECase (ERec fs') cs' + in if length fs' < length fs then f x' else composOp f x' -- Remove wild card patterns in record patterns PRec fps -> PRec (map f (fps \\ wildcards)) where wildcards = [fp | fp@(FieldPattern _ PWild) <- fps] _ -> composOp f x - isSingleFieldPattern :: Ident -> Pattern -> Bool - isSingleFieldPattern x p = case p of + +removeUnusedVarPatts :: Set Ident -> Tree a -> Tree a +removeUnusedVarPatts keep x = case x of + PVar id | not (id `Set.member` keep) -> PWild + _ -> composOp (removeUnusedVarPatts keep) x + +isSingleFieldPattern :: Ident -> Pattern -> Bool +isSingleFieldPattern x p = case p of PRec [FieldPattern y _] -> x == y _ -> False --- --- * Change varibles which are not used to wildcards. --- -removeUnusedVariables :: [Decl] -> C [Decl] -removeUnusedVariables = return . map f - where - f :: Tree a -> Tree a - f x = case x of - EAbs (VVar id) e | not (id `isFreeIn` e) -> EAbs VWild (f e) - EPi (VVar id) t e | not (id `isFreeIn` e) -> EPi VWild (f t) (f e) - Case p e -> Case (g (freeVars e) p) (f e) - _ -> composOp f x - -- replace pattern variables not in the given set with wildcards - g :: Set Ident -> Tree a -> Tree a - g keep x = case x of - PVar id | not (id `Set.member` keep) -> PWild - _ -> composOp (g keep) x +casePatterns :: [Case] -> [Pattern] +casePatterns cs = [p | Case p _ <- cs] + +isPRec :: Pattern -> Bool +isPRec (PRec _) = True +isPRec _ = False + +-- | Checks if given pattern is a record pattern, and matches the field +-- with just a variable, with a wild card, or not at all. +onlyBindsFieldToVariable :: Ident -> Pattern -> Bool +onlyBindsFieldToVariable f (PRec fps) = + all isVar [p | FieldPattern f' p <- fps, f == f'] + where isVar (PVar _) = True + isVar PWild = True + isVar _ = False +onlyBindsFieldToVariable _ _ = False + +fieldPatternVars :: Ident -> [FieldPattern] -> [Ident] +fieldPatternVars f fps = [p | FieldPattern f' (PVar p) <- fps, f == f'] -- -- * Remove simple syntactic sugar. @@ -376,17 +405,24 @@ ifBool c t e = ECase c [Case (PCons (Ident "True") []) t, -- subst :: Ident -> Exp -> Exp -> Exp -subst x e = f - where - f :: Tree a -> Tree a - f t = case t of - ELet defs exp3 | x `Set.member` letDefBinds defs -> - ELet [ LetDef id (f exp1) exp2 | LetDef id exp1 exp2 <- defs] exp3 - Case p e | x `Set.member` binds p -> t - EAbs (VVar id) _ | x == id -> t - EPi (VVar id) exp1 exp2 | x == id -> EPi (VVar id) (f exp1) exp2 - EVar i | i == x -> e - _ -> composOp f t +subst x e = substs [(x,e)] + +-- | Simultaneuous substitution +substs :: [(Ident, Exp)] -> Exp -> Exp +substs ss = f (Map.fromList ss) + where + f :: Map Ident Exp -> Tree a -> Tree a + f ss t | Map.null ss = t + f ss t = case t of + ELet ds e3 -> + ELet [LetDef id (f ss e1) (f ss' e2) | LetDef id e1 e2 <- ds] (f ss' e3) + where ss' = ss `mapMinusSet` letDefBinds ds + Case p e -> Case p (f ss' e) where ss' = ss `mapMinusSet` binds p + EAbs (VVar id) e -> EAbs (VVar id) (f ss' e) where ss' = Map.delete id ss + EPi (VVar id) e1 e2 -> + EPi (VVar id) (f ss e1) (f ss' e2) where ss' = Map.delete id ss + EVar i -> Map.findWithDefault t i ss + _ -> composOp (f ss) t -- -- * Abstract syntax utilities @@ -512,3 +548,6 @@ infixl 1 >>> (>>>) :: Monad m => (a -> m b) -> (b -> m c) -> a -> m c f >>> g = (g =<<) . f + +mapMinusSet :: Ord k => Map k a -> Set k -> Map k a +mapMinusSet m s = m Map.\\ (Map.fromList [(x,()) | x <- Set.toList s]) diff --git a/transfer/examples/nat.tr b/transfer/examples/nat.tr index cd9101574..c529e5238 100644 --- a/transfer/examples/nat.tr +++ b/transfer/examples/nat.tr @@ -14,9 +14,5 @@ natToInt : Nat -> Int natToInt Zero = 0 natToInt (Succ n) = 1 + natToInt n -plus : Nat -> Nat -> Nat -plus Zero y = y -plus (Succ x) y = Succ (plus x y) - intToNat : Int -> Nat intToNat n = if n == 0 then Zero else Succ (intToNat (n-1))