nixify
This commit is contained in:
@@ -0,0 +1,651 @@
|
||||
type cls = Kw | Kl | Ks | Kd
|
||||
type op_base =
|
||||
| Oadd
|
||||
| Osub
|
||||
| Omul
|
||||
| Oor
|
||||
| Oshl
|
||||
| Oshr
|
||||
type op = cls * op_base
|
||||
|
||||
let op_bases =
|
||||
[Oadd; Osub; Omul; Oor; Oshl; Oshr]
|
||||
|
||||
let commutative = function
|
||||
| (_, (Oadd | Omul | Oor)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
let associative = function
|
||||
| (_, (Oadd | Omul | Oor)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
type atomic_pattern =
|
||||
| Tmp
|
||||
| AnyCon
|
||||
| Con of int64
|
||||
(* Tmp < AnyCon < Con k *)
|
||||
|
||||
type pattern =
|
||||
| Bnr of op * pattern * pattern
|
||||
| Atm of atomic_pattern
|
||||
| Var of string * atomic_pattern
|
||||
|
||||
let is_atomic = function
|
||||
| (Atm _ | Var _) -> true
|
||||
| _ -> false
|
||||
|
||||
let show_op_base o =
|
||||
match o with
|
||||
| Oadd -> "add"
|
||||
| Osub -> "sub"
|
||||
| Omul -> "mul"
|
||||
| Oor -> "or"
|
||||
| Oshl -> "shl"
|
||||
| Oshr -> "shr"
|
||||
|
||||
let show_op (k, o) =
|
||||
show_op_base o ^
|
||||
(match k with
|
||||
| Kw -> "w"
|
||||
| Kl -> "l"
|
||||
| Ks -> "s"
|
||||
| Kd -> "d")
|
||||
|
||||
let rec show_pattern p =
|
||||
match p with
|
||||
| Atm Tmp -> "%"
|
||||
| Atm AnyCon -> "$"
|
||||
| Atm (Con n) -> Int64.to_string n
|
||||
| Var (v, p) ->
|
||||
show_pattern (Atm p) ^ "'" ^ v
|
||||
| Bnr (o, pl, pr) ->
|
||||
"(" ^ show_op o ^
|
||||
" " ^ show_pattern pl ^
|
||||
" " ^ show_pattern pr ^ ")"
|
||||
|
||||
let get_atomic p =
|
||||
match p with
|
||||
| (Atm a | Var (_, a)) -> Some a
|
||||
| _ -> None
|
||||
|
||||
let rec pattern_match p w =
|
||||
match p with
|
||||
| Var (_, p) ->
|
||||
pattern_match (Atm p) w
|
||||
| Atm Tmp ->
|
||||
begin match get_atomic w with
|
||||
| Some (Con _ | AnyCon) -> false
|
||||
| _ -> true
|
||||
end
|
||||
| Atm (Con _) -> w = p
|
||||
| Atm (AnyCon) ->
|
||||
not (pattern_match (Atm Tmp) w)
|
||||
| Bnr (o, pl, pr) ->
|
||||
begin match w with
|
||||
| Bnr (o', wl, wr) ->
|
||||
o' = o &&
|
||||
pattern_match pl wl &&
|
||||
pattern_match pr wr
|
||||
| _ -> false
|
||||
end
|
||||
|
||||
type +'a cursor = (* a position inside a pattern *)
|
||||
| Bnrl of op * 'a cursor * pattern
|
||||
| Bnrr of op * pattern * 'a cursor
|
||||
| Top of 'a
|
||||
|
||||
let rec fold_cursor c p =
|
||||
match c with
|
||||
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
|
||||
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
|
||||
| Top _ -> p
|
||||
|
||||
let peel p x =
|
||||
let once out (p, c) =
|
||||
match p with
|
||||
| Var (_, p) -> (Atm p, c) :: out
|
||||
| Atm _ -> (p, c) :: out
|
||||
| Bnr (o, pl, pr) ->
|
||||
(pl, Bnrl (o, c, pr)) ::
|
||||
(pr, Bnrr (o, pl, c)) :: out
|
||||
in
|
||||
let rec go l =
|
||||
let l' = List.fold_left once [] l in
|
||||
if List.length l' = List.length l
|
||||
then l'
|
||||
else go l'
|
||||
in go [(p, Top x)]
|
||||
|
||||
let fold_pairs l1 l2 ini f =
|
||||
let rec go acc = function
|
||||
| [] -> acc
|
||||
| a :: l1' ->
|
||||
go (List.fold_left
|
||||
(fun acc b -> f (a, b) acc)
|
||||
acc l2) l1'
|
||||
in go ini l1
|
||||
|
||||
let iter_pairs l f =
|
||||
fold_pairs l l () (fun x () -> f x)
|
||||
|
||||
let inverse l =
|
||||
List.map (fun (a, b) -> (b, a)) l
|
||||
|
||||
type 'a state =
|
||||
{ id: int
|
||||
; seen: pattern
|
||||
; point: ('a cursor) list }
|
||||
|
||||
let rec binops side {point; _} =
|
||||
List.filter_map (fun c ->
|
||||
match c, side with
|
||||
| Bnrl (o, c, r), `L -> Some ((o, c), r)
|
||||
| Bnrr (o, l, c), `R -> Some ((o, c), l)
|
||||
| _ -> None)
|
||||
point
|
||||
|
||||
let group_by_fst l =
|
||||
List.fast_sort (fun (a, _) (b, _) ->
|
||||
compare a b) l |>
|
||||
List.fold_left (fun (oo, l, res) (o', c) ->
|
||||
match oo with
|
||||
| None -> (Some o', [c], [])
|
||||
| Some o when o = o' -> (oo, c :: l, res)
|
||||
| Some o -> (Some o', [c], (o, l) :: res))
|
||||
(None, [], []) |>
|
||||
(function
|
||||
| (None, _, _) -> []
|
||||
| (Some o, l, res) -> (o, l) :: res)
|
||||
|
||||
let sort_uniq cmp l =
|
||||
List.fast_sort cmp l |>
|
||||
List.fold_left (fun (eo, l) e' ->
|
||||
match eo with
|
||||
| None -> (Some e', l)
|
||||
| Some e when cmp e e' = 0 -> (eo, l)
|
||||
| Some e -> (Some e', e :: l))
|
||||
(None, []) |>
|
||||
(function
|
||||
| (None, _) -> []
|
||||
| (Some e, l) -> List.rev (e :: l))
|
||||
|
||||
let setify l =
|
||||
sort_uniq compare l
|
||||
|
||||
let normalize (point: ('a cursor) list) =
|
||||
setify point
|
||||
|
||||
let next_binary tmp s1 s2 =
|
||||
let pm w (_, p) = pattern_match p w in
|
||||
let o1 = binops `L s1 |>
|
||||
List.filter (pm s2.seen) |>
|
||||
List.map fst in
|
||||
let o2 = binops `R s2 |>
|
||||
List.filter (pm s1.seen) |>
|
||||
List.map fst in
|
||||
List.map (fun (o, l) ->
|
||||
o,
|
||||
{ id = -1
|
||||
; seen = Bnr (o, s1.seen, s2.seen)
|
||||
; point = normalize (l @ tmp) })
|
||||
(group_by_fst (o1 @ o2))
|
||||
|
||||
type p = string
|
||||
|
||||
module StateSet : sig
|
||||
type t
|
||||
val create: unit -> t
|
||||
val add: t -> p state ->
|
||||
[> `Added | `Found ] * p state
|
||||
val iter: t -> (p state -> unit) -> unit
|
||||
val elems: t -> (p state) list
|
||||
end = struct
|
||||
open Hashtbl.Make(struct
|
||||
type t = p state
|
||||
let equal s1 s2 = s1.point = s2.point
|
||||
let hash s = Hashtbl.hash s.point
|
||||
end)
|
||||
type nonrec t =
|
||||
{ h: int t
|
||||
; mutable next_id: int }
|
||||
let create () =
|
||||
{ h = create 500; next_id = 0 }
|
||||
let add set s =
|
||||
assert (s.point = normalize s.point);
|
||||
try
|
||||
let id = find set.h s in
|
||||
`Found, {s with id}
|
||||
with Not_found -> begin
|
||||
let id = set.next_id in
|
||||
set.next_id <- id + 1;
|
||||
add set.h s id;
|
||||
`Added, {s with id}
|
||||
end
|
||||
let iter set f =
|
||||
let f s id = f {s with id} in
|
||||
iter f set.h
|
||||
let elems set =
|
||||
let res = ref [] in
|
||||
iter set (fun s -> res := s :: !res);
|
||||
!res
|
||||
end
|
||||
|
||||
type table_key =
|
||||
| K of op * p state * p state
|
||||
|
||||
module StateMap = struct
|
||||
include Map.Make(struct
|
||||
type t = table_key
|
||||
let compare ka kb =
|
||||
match ka, kb with
|
||||
| K (o, sl, sr), K (o', sl', sr') ->
|
||||
compare (o, sl.id, sr.id)
|
||||
(o', sl'.id, sr'.id)
|
||||
end)
|
||||
let invert n sm =
|
||||
let rmap = Array.make n [] in
|
||||
iter (fun k {id; _} ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
rmap.(id) <-
|
||||
(o, (sl.id, sr.id)) :: rmap.(id)
|
||||
) sm;
|
||||
Array.map group_by_fst rmap
|
||||
let by_ops sm =
|
||||
fold (fun tk s ops ->
|
||||
match tk with
|
||||
| K (op, l, r) ->
|
||||
(op, ((l.id, r.id), s.id)) :: ops)
|
||||
sm [] |> group_by_fst
|
||||
end
|
||||
|
||||
type rule =
|
||||
{ name: string
|
||||
; vars: string list
|
||||
; pattern: pattern }
|
||||
|
||||
let generate_table rl =
|
||||
let states = StateSet.create () in
|
||||
let rl =
|
||||
(* these atomic patterns must occur in
|
||||
* rules so that we are able to number
|
||||
* all possible refs *)
|
||||
[ { name = "$"; vars = []
|
||||
; pattern = Atm AnyCon }
|
||||
; { name = "%"; vars = []
|
||||
; pattern = Atm Tmp } ] @ rl
|
||||
in
|
||||
(* initialize states *)
|
||||
let ground =
|
||||
List.concat_map
|
||||
(fun r -> peel r.pattern r.name) rl |>
|
||||
group_by_fst
|
||||
in
|
||||
let tmp = List.assoc (Atm Tmp) ground in
|
||||
let con = List.assoc (Atm AnyCon) ground in
|
||||
let atoms = ref [] in
|
||||
let () =
|
||||
List.iter (fun (seen, l) ->
|
||||
let point =
|
||||
if pattern_match (Atm Tmp) seen
|
||||
then normalize (tmp @ l)
|
||||
else normalize (con @ l)
|
||||
in
|
||||
let s = {id = -1; seen; point} in
|
||||
let _, s = StateSet.add states s in
|
||||
match get_atomic seen with
|
||||
| Some atm -> atoms := (atm, s) :: !atoms
|
||||
| None -> ()
|
||||
) ground
|
||||
in
|
||||
(* setup loop state *)
|
||||
let map = ref StateMap.empty in
|
||||
let map_add k s' =
|
||||
map := StateMap.add k s' !map
|
||||
in
|
||||
let flag = ref `Added in
|
||||
let flagmerge = function
|
||||
| `Added -> flag := `Added
|
||||
| _ -> ()
|
||||
in
|
||||
(* iterate until fixpoint *)
|
||||
while !flag = `Added do
|
||||
flag := `Stop;
|
||||
let statel = StateSet.elems states in
|
||||
iter_pairs statel (fun (sl, sr) ->
|
||||
next_binary tmp sl sr |>
|
||||
List.iter (fun (o, s') ->
|
||||
let flag', s' =
|
||||
StateSet.add states s' in
|
||||
flagmerge flag';
|
||||
map_add (K (o, sl, sr)) s';
|
||||
));
|
||||
done;
|
||||
let states =
|
||||
StateSet.elems states |>
|
||||
List.sort (fun s s' -> compare s.id s'.id) |>
|
||||
Array.of_list
|
||||
in
|
||||
(states, !atoms, !map)
|
||||
|
||||
let intersperse x l =
|
||||
let rec go left right out =
|
||||
let out =
|
||||
(List.rev left @ [x] @ right) ::
|
||||
out in
|
||||
match right with
|
||||
| x :: right' ->
|
||||
go (x :: left) right' out
|
||||
| [] -> out
|
||||
in go [] l []
|
||||
|
||||
let rec permute = function
|
||||
| [] -> [[]]
|
||||
| x :: l ->
|
||||
List.concat (List.map
|
||||
(intersperse x) (permute l))
|
||||
|
||||
(* build all binary trees with ordered
|
||||
* leaves l *)
|
||||
let rec bins build l =
|
||||
let rec go l r out =
|
||||
match r with
|
||||
| [] -> out
|
||||
| x :: r' ->
|
||||
go (l @ [x]) r'
|
||||
(fold_pairs
|
||||
(bins build l)
|
||||
(bins build r)
|
||||
out (fun (l, r) out ->
|
||||
build l r :: out))
|
||||
in
|
||||
match l with
|
||||
| [] -> []
|
||||
| [x] -> [x]
|
||||
| x :: l -> go [x] l []
|
||||
|
||||
let products l ini f =
|
||||
let rec go acc la = function
|
||||
| [] -> f (List.rev la) acc
|
||||
| xs :: l ->
|
||||
List.fold_left (fun acc x ->
|
||||
go acc (x :: la) l)
|
||||
acc xs
|
||||
in go ini [] l
|
||||
|
||||
(* combinatorial nuke... *)
|
||||
let rec ac_equiv =
|
||||
let rec alevel o = function
|
||||
| Bnr (o', l, r) when o' = o ->
|
||||
alevel o l @ alevel o r
|
||||
| x -> [x]
|
||||
in function
|
||||
| Bnr (o, _, _) as p
|
||||
when associative o ->
|
||||
products
|
||||
(List.map ac_equiv (alevel o p)) []
|
||||
(fun choice out ->
|
||||
List.concat_map
|
||||
(bins (fun l r -> Bnr (o, l, r)))
|
||||
(if commutative o
|
||||
then permute choice
|
||||
else [choice]) @ out)
|
||||
| Bnr (o, l, r)
|
||||
when commutative o ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) ::
|
||||
Bnr (o, r, l) :: out)
|
||||
| Bnr (o, l, r) ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) :: out)
|
||||
| x -> [x]
|
||||
|
||||
module Action: sig
|
||||
type node =
|
||||
| Switch of (int * t) list
|
||||
| Push of bool * t
|
||||
| Pop of t
|
||||
| Set of string * t
|
||||
| Stop
|
||||
and t = private
|
||||
{ id: int; node: node }
|
||||
val equal: t -> t -> bool
|
||||
val size: t -> int
|
||||
val stop: t
|
||||
val mk_push: sym:bool -> t -> t
|
||||
val mk_pop: t -> t
|
||||
val mk_set: string -> t -> t
|
||||
val mk_switch: int list -> (int -> t) -> t
|
||||
val pp: Format.formatter -> t -> unit
|
||||
end = struct
|
||||
type node =
|
||||
| Switch of (int * t) list
|
||||
| Push of bool * t
|
||||
| Pop of t
|
||||
| Set of string * t
|
||||
| Stop
|
||||
and t =
|
||||
{ id: int; node: node }
|
||||
|
||||
let equal a a' = a.id = a'.id
|
||||
let size a =
|
||||
let seen = Hashtbl.create 10 in
|
||||
let rec node_size = function
|
||||
| Switch l ->
|
||||
List.fold_left
|
||||
(fun n (_, a) -> n + size a) 0 l
|
||||
| (Push (_, a) | Pop a | Set (_, a)) ->
|
||||
size a
|
||||
| Stop -> 0
|
||||
and size {id; node} =
|
||||
if Hashtbl.mem seen id
|
||||
then 0
|
||||
else begin
|
||||
Hashtbl.add seen id ();
|
||||
1 + node_size node
|
||||
end
|
||||
in
|
||||
size a
|
||||
|
||||
let mk =
|
||||
let hcons = Hashtbl.create 100 in
|
||||
let fresh = ref 0 in
|
||||
fun node ->
|
||||
let id =
|
||||
try Hashtbl.find hcons node
|
||||
with Not_found ->
|
||||
let id = !fresh in
|
||||
Hashtbl.add hcons node id;
|
||||
fresh := id + 1;
|
||||
id
|
||||
in
|
||||
{id; node}
|
||||
let stop = mk Stop
|
||||
let mk_push ~sym a = mk (Push (sym, a))
|
||||
let mk_pop a =
|
||||
match a.node with
|
||||
| Stop -> a
|
||||
| _ -> mk (Pop a)
|
||||
let mk_set v a = mk (Set (v, a))
|
||||
let mk_switch ids f =
|
||||
match List.map f ids with
|
||||
| [] -> failwith "empty switch";
|
||||
| c :: cs as cases ->
|
||||
if List.for_all (equal c) cs then c
|
||||
else
|
||||
let cases = List.combine ids cases in
|
||||
mk (Switch cases)
|
||||
|
||||
open Format
|
||||
let rec pp_node fmt = function
|
||||
| Switch l ->
|
||||
fprintf fmt "@[<v>@[<v2>switch{";
|
||||
let pp_case (c, a) =
|
||||
let pp_sep fmt () = fprintf fmt "," in
|
||||
fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
|
||||
(pp_print_list ~pp_sep pp_print_int)
|
||||
c pp a
|
||||
in
|
||||
inverse l |> group_by_fst |> inverse |>
|
||||
List.iter pp_case;
|
||||
fprintf fmt "@]@,}@]"
|
||||
| Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
|
||||
| Push (false, a) -> fprintf fmt "push@ %a" pp a
|
||||
| Pop a -> fprintf fmt "pop@ %a" pp a
|
||||
| Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
|
||||
| Stop -> fprintf fmt "•"
|
||||
and pp fmt a = pp_node fmt a.node
|
||||
end
|
||||
|
||||
(* a state is commutative if (a op b) enters
|
||||
* it iff (b op a) enters it as well *)
|
||||
let symmetric rmap id =
|
||||
List.for_all (fun (_, l) ->
|
||||
let l1, l2 =
|
||||
List.filter (fun (a, b) -> a <> b) l |>
|
||||
List.partition (fun (a, b) -> a < b)
|
||||
in
|
||||
setify l1 = setify (inverse l2))
|
||||
rmap.(id)
|
||||
|
||||
(* left-to-right matching of a set of patterns;
|
||||
* may raise if there is no lr matcher for the
|
||||
* input rule *)
|
||||
let lr_matcher statemap states rules name =
|
||||
let rmap =
|
||||
let nstates = Array.length states in
|
||||
StateMap.invert nstates statemap
|
||||
in
|
||||
let exception Stuck in
|
||||
(* the list of ids represents a class of terms
|
||||
* whose root ends up being labelled with one
|
||||
* such id; the gen function generates a matcher
|
||||
* that will, given any such term, assign values
|
||||
* for the Var nodes of one pattern in pats *)
|
||||
let rec gen
|
||||
: 'a. int list -> (pattern * 'a) list
|
||||
-> (int -> (pattern * 'a) list -> Action.t)
|
||||
-> Action.t
|
||||
= fun ids pats k ->
|
||||
Action.mk_switch (setify ids) @@ fun id_top ->
|
||||
let sym = symmetric rmap id_top in
|
||||
let id_ops =
|
||||
if sym then
|
||||
let ordered (a, b) = a <= b in
|
||||
List.map (fun (o, l) ->
|
||||
(o, List.filter ordered l))
|
||||
rmap.(id_top)
|
||||
else rmap.(id_top)
|
||||
in
|
||||
(* consider only the patterns that are
|
||||
* compatible with the current id *)
|
||||
let atm_pats, bin_pats =
|
||||
List.filter (function
|
||||
| Bnr (o, _, _), _ ->
|
||||
List.exists
|
||||
(fun (o', _) -> o' = o)
|
||||
id_ops
|
||||
| _ -> true) pats |>
|
||||
List.partition
|
||||
(fun (pat, _) -> is_atomic pat)
|
||||
in
|
||||
try
|
||||
if bin_pats = [] then raise Stuck;
|
||||
let pats_l =
|
||||
List.map (function
|
||||
| (Bnr (o, l, r), x) ->
|
||||
(l, (o, x, r))
|
||||
| _ -> assert false)
|
||||
bin_pats
|
||||
and pats_r =
|
||||
List.map (fun (l, (o, x, r)) ->
|
||||
(r, (o, l, x)))
|
||||
and patstop =
|
||||
List.map (fun (r, (o, l, x)) ->
|
||||
(Bnr (o, l, r), x))
|
||||
in
|
||||
let id_pairs = List.concat_map snd id_ops in
|
||||
let ids_l = List.map fst id_pairs
|
||||
and ids_r id_left =
|
||||
List.filter_map (fun (l, r) ->
|
||||
if l = id_left then Some r else None)
|
||||
id_pairs
|
||||
in
|
||||
(* match the left arm *)
|
||||
Action.mk_push ~sym
|
||||
(gen ids_l pats_l
|
||||
@@ fun lid pats ->
|
||||
(* then the right arm, considering
|
||||
* only the remaining possible
|
||||
* patterns and knowing that the
|
||||
* left arm was numbered 'lid' *)
|
||||
Action.mk_pop
|
||||
(gen (ids_r lid) (pats_r pats)
|
||||
@@ fun _rid pats ->
|
||||
(* continue with the parent *)
|
||||
k id_top (patstop pats)))
|
||||
with Stuck ->
|
||||
let atm_pats =
|
||||
let seen = states.(id_top).seen in
|
||||
List.filter (fun (pat, _) ->
|
||||
pattern_match pat seen) atm_pats
|
||||
in
|
||||
if atm_pats = [] then raise Stuck else
|
||||
let vars =
|
||||
List.filter_map (function
|
||||
| (Var (v, _), _) -> Some v
|
||||
| _ -> None) atm_pats |> setify
|
||||
in
|
||||
match vars with
|
||||
| [] -> k id_top atm_pats
|
||||
| [v] -> Action.mk_set v (k id_top atm_pats)
|
||||
| _ -> failwith "ambiguous var match"
|
||||
in
|
||||
(* generate a matcher for the rule *)
|
||||
let ids_top =
|
||||
Array.to_list states |>
|
||||
List.filter_map (fun {id; point = p; _} ->
|
||||
if List.exists ((=) (Top name)) p then
|
||||
Some id
|
||||
else None)
|
||||
in
|
||||
let rec filter_dups pats =
|
||||
match pats with
|
||||
| p :: pats ->
|
||||
if List.exists (pattern_match p) pats
|
||||
then filter_dups pats
|
||||
else p :: filter_dups pats
|
||||
| [] -> []
|
||||
in
|
||||
let pats_top =
|
||||
List.filter_map (fun r ->
|
||||
if r.name = name then
|
||||
Some r.pattern
|
||||
else None) rules |>
|
||||
filter_dups |>
|
||||
List.map (fun p -> (p, ()))
|
||||
in
|
||||
gen ids_top pats_top (fun _ pats ->
|
||||
assert (pats <> []);
|
||||
Action.stop)
|
||||
|
||||
type numberer =
|
||||
{ atoms: (atomic_pattern * p state) list
|
||||
; statemap: p state StateMap.t
|
||||
; states: p state array
|
||||
; mutable ops: op list
|
||||
(* memoizes the list of possible operations
|
||||
* according to the statemap *) }
|
||||
|
||||
let make_numberer sa am sm =
|
||||
{ atoms = am
|
||||
; states = sa
|
||||
; statemap = sm
|
||||
; ops = [] }
|
||||
|
||||
let atom_state n atm =
|
||||
List.assoc atm n.atoms
|
||||
Reference in New Issue
Block a user