nixify
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
*.cm[iox]
|
||||
*.o
|
||||
mgen
|
||||
@@ -0,0 +1 @@
|
||||
match_clause=4
|
||||
@@ -0,0 +1,16 @@
|
||||
BIN = mgen
|
||||
SRC = \
|
||||
match.ml \
|
||||
fuzz.ml \
|
||||
cgen.ml \
|
||||
sexp.ml \
|
||||
test.ml \
|
||||
main.ml
|
||||
|
||||
$(BIN): $(SRC)
|
||||
ocamlopt -o $(BIN) -g str.cmxa $(SRC)
|
||||
|
||||
clean:
|
||||
rm -f *.cm? *.o $(BIN)
|
||||
|
||||
.PHONY: clean
|
||||
@@ -0,0 +1,420 @@
|
||||
open Match
|
||||
|
||||
type options =
|
||||
{ pfx: string
|
||||
; static: bool
|
||||
; oc: out_channel }
|
||||
|
||||
type side = L | R
|
||||
|
||||
type id_pred =
|
||||
| InBitSet of Int64.t
|
||||
| Ge of int
|
||||
| Eq of int
|
||||
|
||||
and id_test =
|
||||
| Pred of (side * id_pred)
|
||||
| And of id_test * id_test
|
||||
|
||||
type case_code =
|
||||
| Table of ((int * int) * int) list
|
||||
| IfThen of
|
||||
{ test: id_test
|
||||
; cif: case_code
|
||||
; cthen: case_code option }
|
||||
| Return of int
|
||||
|
||||
type case =
|
||||
{ swap: bool
|
||||
; code: case_code }
|
||||
|
||||
let cgen_case tmp nstates map =
|
||||
let cgen_test ids =
|
||||
match ids with
|
||||
| [id] -> Eq id
|
||||
| _ ->
|
||||
let min_id =
|
||||
List.fold_left min max_int ids in
|
||||
if List.length ids = nstates - min_id
|
||||
then Ge min_id
|
||||
else begin
|
||||
assert (nstates <= 64);
|
||||
InBitSet
|
||||
(List.fold_left (fun bs id ->
|
||||
Int64.logor bs
|
||||
(Int64.shift_left 1L id))
|
||||
0L ids)
|
||||
end
|
||||
in
|
||||
let symmetric =
|
||||
let inverse ((l, r), x) = ((r, l), x) in
|
||||
setify map = setify (List.map inverse map) in
|
||||
let map =
|
||||
let ordered ((l, r), _) = r <= l in
|
||||
if symmetric then
|
||||
List.filter ordered map
|
||||
else map
|
||||
in
|
||||
let exception BailToTable in
|
||||
try
|
||||
let st =
|
||||
match setify (List.map snd map) with
|
||||
| [st] -> st
|
||||
| _ -> raise BailToTable
|
||||
in
|
||||
(* the operation considered can only
|
||||
* generate a single state *)
|
||||
let pairs = List.map fst map in
|
||||
let ls, rs = List.split pairs in
|
||||
let ls = setify ls and rs = setify rs in
|
||||
if List.length ls > 1 && List.length rs > 1 then
|
||||
raise BailToTable;
|
||||
{ swap = symmetric
|
||||
; code =
|
||||
let pl = Pred (L, cgen_test ls)
|
||||
and pr = Pred (R, cgen_test rs) in
|
||||
IfThen
|
||||
{ test = And (pl, pr)
|
||||
; cif = Return st
|
||||
; cthen = Some (Return tmp) } }
|
||||
with BailToTable ->
|
||||
{ swap = symmetric
|
||||
; code = Table map }
|
||||
|
||||
let show_op (_cls, op) =
|
||||
"O" ^ show_op_base op
|
||||
|
||||
let indent oc i =
|
||||
Printf.fprintf oc "%s" (String.sub "\t\t\t\t\t" 0 i)
|
||||
|
||||
let emit_swap oc i =
|
||||
let pf m = Printf.fprintf oc m in
|
||||
let pfi n m = indent oc n; pf m in
|
||||
pfi i "if (l < r)\n";
|
||||
pfi (i+1) "t = l, l = r, r = t;\n"
|
||||
|
||||
let gen_tables oc tmp pfx nstates (op, c) =
|
||||
let i = 1 in
|
||||
let pf m = Printf.fprintf oc m in
|
||||
let pfi n m = indent oc n; pf m in
|
||||
let ntables = ref 0 in
|
||||
(* we must follow the order in which
|
||||
* we visit code in emit_case, or
|
||||
* else ntables goes out of sync *)
|
||||
let base = pfx ^ show_op op in
|
||||
let swap = c.swap in
|
||||
let rec gen c =
|
||||
match c with
|
||||
| Table map ->
|
||||
let name =
|
||||
if !ntables = 0 then base else
|
||||
base ^ string_of_int !ntables
|
||||
in
|
||||
assert (nstates <= 256);
|
||||
if swap then
|
||||
let n = nstates * (nstates + 1) / 2 in
|
||||
pfi i "static uchar %stbl[%d] = {\n" name n
|
||||
else
|
||||
pfi i "static uchar %stbl[%d][%d] = {\n"
|
||||
name nstates nstates;
|
||||
for l = 0 to nstates - 1 do
|
||||
pfi (i+1) "";
|
||||
for r = 0 to nstates - 1 do
|
||||
if not swap || r <= l then
|
||||
begin
|
||||
pf "%d"
|
||||
(try List.assoc (l,r) map
|
||||
with Not_found -> tmp);
|
||||
pf ",";
|
||||
end
|
||||
done;
|
||||
pf "\n";
|
||||
done;
|
||||
pfi i "};\n"
|
||||
| IfThen {cif; cthen} ->
|
||||
gen cif;
|
||||
Option.iter gen cthen
|
||||
| Return _ -> ()
|
||||
in
|
||||
gen c.code
|
||||
|
||||
let emit_case oc pfx no_swap (op, c) =
|
||||
let fpf = Printf.fprintf in
|
||||
let pf m = fpf oc m in
|
||||
let pfi n m = indent oc n; pf m in
|
||||
let rec side oc = function
|
||||
| L -> fpf oc "l"
|
||||
| R -> fpf oc "r"
|
||||
in
|
||||
let pred oc (s, pred) =
|
||||
match pred with
|
||||
| InBitSet bs -> fpf oc "BIT(%a) & %#Lx" side s bs
|
||||
| Eq id -> fpf oc "%a == %d" side s id
|
||||
| Ge id -> fpf oc "%d <= %a" id side s
|
||||
in
|
||||
let base = pfx ^ show_op op in
|
||||
let swap = c.swap in
|
||||
let ntables = ref 0 in
|
||||
let rec code i c =
|
||||
match c with
|
||||
| Return id -> pfi i "return %d;\n" id
|
||||
| Table map ->
|
||||
let name =
|
||||
if !ntables = 0 then base else
|
||||
base ^ string_of_int !ntables
|
||||
in
|
||||
incr ntables;
|
||||
if swap then
|
||||
pfi i "return %stbl[(l + l*l)/2 + r];\n" name
|
||||
else pfi i "return %stbl[l][r];\n" name
|
||||
| IfThen ({test = And (And (t1, t2), t3)} as r) ->
|
||||
code i @@ IfThen
|
||||
{r with test = And (t1, And (t2, t3))}
|
||||
| IfThen {test = And (Pred p, t); cif; cthen} ->
|
||||
pfi i "if (%a)\n" pred p;
|
||||
code i (IfThen {test = t; cif; cthen})
|
||||
| IfThen {test = Pred p; cif; cthen} ->
|
||||
pfi i "if (%a) {\n" pred p;
|
||||
code (i+1) cif;
|
||||
pfi i "}\n";
|
||||
Option.iter (code i) cthen
|
||||
in
|
||||
pfi 1 "case %s:\n" (show_op op);
|
||||
if not no_swap && c.swap then
|
||||
emit_swap oc 2;
|
||||
code 2 c.code
|
||||
|
||||
let emit_list
|
||||
?(limit=60) ?(cut_before_sep=false)
|
||||
~col ~indent:i ~sep ~f oc l =
|
||||
let sl = String.length sep in
|
||||
let rstripped_sep, rssl =
|
||||
if sep.[sl - 1] = ' ' then
|
||||
String.sub sep 0 (sl - 1), sl - 1
|
||||
else sep, sl
|
||||
in
|
||||
let lstripped_sep, lssl =
|
||||
if sep.[0] = ' ' then
|
||||
String.sub sep 1 (sl - 1), sl - 1
|
||||
else sep, sl
|
||||
in
|
||||
let rec line col acc = function
|
||||
| [] -> (List.rev acc, [])
|
||||
| s :: l ->
|
||||
let col = col + sl + String.length s in
|
||||
let no_space =
|
||||
if cut_before_sep || l = [] then
|
||||
col > limit
|
||||
else
|
||||
col + rssl > limit
|
||||
in
|
||||
if no_space then
|
||||
(List.rev acc, s :: l)
|
||||
else
|
||||
line col (s :: acc) l
|
||||
in
|
||||
let rec go col l =
|
||||
if l = [] then () else
|
||||
let ll, l = line col [] l in
|
||||
Printf.fprintf oc "%s" (String.concat sep ll);
|
||||
if l <> [] && cut_before_sep then begin
|
||||
Printf.fprintf oc "\n";
|
||||
indent oc i;
|
||||
Printf.fprintf oc "%s" lstripped_sep;
|
||||
go (8*i + lssl) l
|
||||
end else if l <> [] then begin
|
||||
Printf.fprintf oc "%s\n" rstripped_sep;
|
||||
indent oc i;
|
||||
go (8*i) l
|
||||
end else ()
|
||||
in
|
||||
go col (List.map f l)
|
||||
|
||||
let emit_numberer opts n =
|
||||
let pf m = Printf.fprintf opts.oc m in
|
||||
let tmp = (atom_state n Tmp).id in
|
||||
let con = (atom_state n AnyCon).id in
|
||||
let nst = Array.length n.states in
|
||||
let cases =
|
||||
StateMap.by_ops n.statemap |>
|
||||
List.map (fun (op, map) ->
|
||||
(op, cgen_case tmp nst map))
|
||||
in
|
||||
let all_swap =
|
||||
List.for_all (fun (_, c) -> c.swap) cases in
|
||||
(* opn() *)
|
||||
if opts.static then pf "static ";
|
||||
pf "int\n";
|
||||
pf "%sopn(int op, int l, int r)\n" opts.pfx;
|
||||
pf "{\n";
|
||||
cases |> List.iter
|
||||
(gen_tables opts.oc tmp opts.pfx nst);
|
||||
if List.exists (fun (_, c) -> c.swap) cases then
|
||||
pf "\tint t;\n\n";
|
||||
if all_swap then emit_swap opts.oc 1;
|
||||
pf "\tswitch (op) {\n";
|
||||
cases |> List.iter
|
||||
(emit_case opts.oc opts.pfx all_swap);
|
||||
pf "\tdefault:\n";
|
||||
pf "\t\treturn %d;\n" tmp;
|
||||
pf "\t}\n";
|
||||
pf "}\n\n";
|
||||
(* refn() *)
|
||||
if opts.static then pf "static ";
|
||||
pf "int\n";
|
||||
pf "%srefn(Ref r, Num *tn, Con *con)\n" opts.pfx;
|
||||
pf "{\n";
|
||||
let cons =
|
||||
List.filter_map (function
|
||||
| (Con c, s) -> Some (c, s.id)
|
||||
| _ -> None)
|
||||
n.atoms
|
||||
in
|
||||
if cons <> [] then
|
||||
pf "\tint64_t n;\n\n";
|
||||
pf "\tswitch (rtype(r)) {\n";
|
||||
pf "\tcase RTmp:\n";
|
||||
if tmp <> 0 then begin
|
||||
assert
|
||||
(List.exists (fun (_, s) ->
|
||||
s.id = 0
|
||||
) n.atoms &&
|
||||
(* no temp should ever get state 0 *)
|
||||
List.for_all (fun (a, s) ->
|
||||
s.id <> 0 ||
|
||||
match a with
|
||||
| AnyCon | Con _ -> true
|
||||
| _ -> false
|
||||
) n.atoms);
|
||||
pf "\t\tif (!tn[r.val].n)\n";
|
||||
pf "\t\t\ttn[r.val].n = %d;\n" tmp;
|
||||
end;
|
||||
pf "\t\treturn tn[r.val].n;\n";
|
||||
pf "\tcase RCon:\n";
|
||||
if cons <> [] then begin
|
||||
pf "\t\tif (con[r.val].type != CBits)\n";
|
||||
pf "\t\t\treturn %d;\n" con;
|
||||
pf "\t\tn = con[r.val].bits.i;\n";
|
||||
cons |> inverse |> group_by_fst
|
||||
|> List.iter (fun (id, cs) ->
|
||||
pf "\t\tif (";
|
||||
emit_list ~cut_before_sep:true
|
||||
~col:20 ~indent:2 ~sep:" || "
|
||||
~f:(fun c -> "n == " ^ Int64.to_string c)
|
||||
opts.oc cs;
|
||||
pf ")\n";
|
||||
pf "\t\t\treturn %d;\n" id
|
||||
);
|
||||
end;
|
||||
pf "\t\treturn %d;\n" con;
|
||||
pf "\tdefault:\n";
|
||||
pf "\t\treturn INT_MIN;\n";
|
||||
pf "\t}\n";
|
||||
pf "}\n\n";
|
||||
(* match[]: patterns per state *)
|
||||
if opts.static then pf "static ";
|
||||
pf "bits %smatch[%d] = {\n" opts.pfx nst;
|
||||
n.states |> Array.iteri (fun sn s ->
|
||||
let tops =
|
||||
List.filter_map (function
|
||||
| Top ("$" | "%") -> None
|
||||
| Top r -> Some ("BIT(P" ^ r ^ ")")
|
||||
| _ -> None) s.point |> setify
|
||||
in
|
||||
if tops <> [] then
|
||||
pf "\t[%d] = %s,\n"
|
||||
sn (String.concat " | " tops);
|
||||
);
|
||||
pf "};\n\n"
|
||||
|
||||
let var_id vars f =
|
||||
List.mapi (fun i x -> (x, i)) vars |>
|
||||
List.assoc f
|
||||
|
||||
let compile_action vars act =
|
||||
let pcs = Hashtbl.create 100 in
|
||||
let rec gen pc (act: Action.t) =
|
||||
try
|
||||
[10 + Hashtbl.find pcs act.id]
|
||||
with Not_found ->
|
||||
let code =
|
||||
match act.node with
|
||||
| Action.Stop ->
|
||||
[0]
|
||||
| Action.Push (sym, k) ->
|
||||
let c = if sym then 1 else 2 in
|
||||
[c] @ gen (pc + 1) k
|
||||
| Action.Set (v, {node = Action.Pop k; _})
|
||||
| Action.Set (v, ({node = Action.Stop; _} as k)) ->
|
||||
let v = var_id vars v in
|
||||
[3; v] @ gen (pc + 2) k
|
||||
| Action.Set _ ->
|
||||
(* for now, only atomic patterns can be
|
||||
* tied to a variable, so Set must be
|
||||
* followed by either Pop or Stop *)
|
||||
assert false
|
||||
| Action.Pop k ->
|
||||
[4] @ gen (pc + 1) k
|
||||
| Action.Switch cases ->
|
||||
let cases =
|
||||
inverse cases |> group_by_fst |>
|
||||
List.sort (fun (_, cs1) (_, cs2) ->
|
||||
let n1 = List.length cs1
|
||||
and n2 = List.length cs2 in
|
||||
compare n2 n1)
|
||||
in
|
||||
(* the last case is the one with
|
||||
* the max number of entries *)
|
||||
let cases = List.rev (List.tl cases)
|
||||
and last = fst (List.hd cases) in
|
||||
let ncases =
|
||||
List.fold_left (fun n (_, cs) ->
|
||||
List.length cs + n)
|
||||
0 cases
|
||||
in
|
||||
let body_off = 2 + 2 * ncases + 1 in
|
||||
let pc, tbl, body =
|
||||
List.fold_left
|
||||
(fun (pc, tbl, body) (a, cs) ->
|
||||
let ofs = body_off + List.length body in
|
||||
let case = gen pc a in
|
||||
let pc = pc + List.length case in
|
||||
let body = body @ case in
|
||||
let tbl =
|
||||
List.fold_left (fun tbl c ->
|
||||
tbl @ [c; ofs]
|
||||
) tbl cs
|
||||
in
|
||||
(pc, tbl, body))
|
||||
(pc + body_off, [], [])
|
||||
cases
|
||||
in
|
||||
let ofs = body_off + List.length body in
|
||||
let tbl = tbl @ [ofs] in
|
||||
assert (2 + List.length tbl = body_off);
|
||||
[5; ncases] @ tbl @ body @ gen pc last
|
||||
in
|
||||
if act.node <> Action.Stop then
|
||||
Hashtbl.replace pcs act.id pc;
|
||||
code
|
||||
in
|
||||
gen 0 act
|
||||
|
||||
let emit_matchers opts ms =
|
||||
let pf m = Printf.fprintf opts.oc m in
|
||||
if opts.static then pf "static ";
|
||||
pf "uchar *%smatcher[] = {\n" opts.pfx;
|
||||
List.iter (fun (vars, pname, m) ->
|
||||
pf "\t[P%s] = (uchar[]){\n" pname;
|
||||
pf "\t\t";
|
||||
let bytes = compile_action vars m in
|
||||
emit_list
|
||||
~col:16 ~indent:2 ~sep:","
|
||||
~f:string_of_int opts.oc bytes;
|
||||
pf "\n";
|
||||
pf "\t},\n")
|
||||
ms;
|
||||
pf "};\n\n"
|
||||
|
||||
let emit_c opts n =
|
||||
emit_numberer opts n
|
||||
@@ -0,0 +1,413 @@
|
||||
(* fuzz the tables and matchers generated *)
|
||||
open Match
|
||||
|
||||
module Buffer: sig
|
||||
type 'a t
|
||||
val create: ?capacity:int -> unit -> 'a t
|
||||
val reset: 'a t -> unit
|
||||
val size: 'a t -> int
|
||||
val get: 'a t -> int -> 'a
|
||||
val set: 'a t -> int -> 'a -> unit
|
||||
val push: 'a t -> 'a -> unit
|
||||
end = struct
|
||||
type 'a t =
|
||||
{ mutable size: int
|
||||
; mutable data: 'a array }
|
||||
let mk_array n = Array.make n (Obj.magic 0)
|
||||
let create ?(capacity = 10) () =
|
||||
if capacity < 0 then invalid_arg "Buffer.make";
|
||||
{size = 0; data = mk_array capacity}
|
||||
let reset b = b.size <- 0
|
||||
let size b = b.size
|
||||
let get b n =
|
||||
if n >= size b then invalid_arg "Buffer.get";
|
||||
b.data.(n)
|
||||
let set b n x =
|
||||
if n >= size b then invalid_arg "Buffer.set";
|
||||
b.data.(n) <- x
|
||||
let push b x =
|
||||
let cap = Array.length b.data in
|
||||
if size b = cap then begin
|
||||
let data = mk_array (2 * cap + 1) in
|
||||
Array.blit b.data 0 data 0 cap;
|
||||
b.data <- data
|
||||
end;
|
||||
let sz = size b in
|
||||
b.size <- sz + 1;
|
||||
set b sz x
|
||||
end
|
||||
|
||||
let binop_state n op s1 s2 =
|
||||
let key = K (op, s1, s2) in
|
||||
try StateMap.find key n.statemap
|
||||
with Not_found -> atom_state n Tmp
|
||||
|
||||
type id = int
|
||||
type term_data =
|
||||
| Binop of op * id * id
|
||||
| Leaf of atomic_pattern
|
||||
type term =
|
||||
{ id: id
|
||||
; data: term_data
|
||||
; state: p state }
|
||||
|
||||
let pp_term fmt (ta, id) =
|
||||
let fpf x = Format.fprintf fmt x in
|
||||
let rec pp _fmt id =
|
||||
match ta.(id).data with
|
||||
| Leaf (Con c) -> fpf "%Ld" c
|
||||
| Leaf AnyCon -> fpf "$%d" id
|
||||
| Leaf Tmp -> fpf "%%%d" id
|
||||
| Binop (op, id1, id2) ->
|
||||
fpf "@[(%s@%d:%d @[<hov>%a@ %a@])@]"
|
||||
(show_op op) id ta.(id).state.id
|
||||
pp id1 pp id2
|
||||
in pp fmt id
|
||||
|
||||
(* A term pool is a deduplicated set of term
|
||||
* that maintains nodes numbering using the
|
||||
* statemap passed at creation time *)
|
||||
module TermPool = struct
|
||||
type t =
|
||||
{ terms: term Buffer.t
|
||||
; hcons: (term_data, id) Hashtbl.t
|
||||
; numbr: numberer }
|
||||
|
||||
let create numbr =
|
||||
{ terms = Buffer.create ()
|
||||
; hcons = Hashtbl.create 100
|
||||
; numbr }
|
||||
let reset tp =
|
||||
Buffer.reset tp.terms;
|
||||
Hashtbl.clear tp.hcons
|
||||
|
||||
let size tp = Buffer.size tp.terms
|
||||
let term tp id = Buffer.get tp.terms id
|
||||
|
||||
let mk_leaf tp atm =
|
||||
let data = Leaf atm in
|
||||
match Hashtbl.find tp.hcons data with
|
||||
| id -> term tp id
|
||||
| exception Not_found ->
|
||||
let id = Buffer.size tp.terms in
|
||||
let state = atom_state tp.numbr atm in
|
||||
Buffer.push tp.terms {id; data; state};
|
||||
Hashtbl.add tp.hcons data id;
|
||||
term tp id
|
||||
let mk_binop tp op t1 t2 =
|
||||
let data = Binop (op, t1.id, t2.id) in
|
||||
match Hashtbl.find tp.hcons data with
|
||||
| id -> term tp id
|
||||
| exception Not_found ->
|
||||
let id = Buffer.size tp.terms in
|
||||
let state =
|
||||
binop_state tp.numbr op t1.state t2.state
|
||||
in
|
||||
Buffer.push tp.terms {id; data; state};
|
||||
Hashtbl.add tp.hcons data id;
|
||||
term tp id
|
||||
|
||||
let rec add_pattern tp = function
|
||||
| Bnr (op, p1, p2) ->
|
||||
let t1 = add_pattern tp p1 in
|
||||
let t2 = add_pattern tp p2 in
|
||||
mk_binop tp op t1 t2
|
||||
| Atm atm -> mk_leaf tp atm
|
||||
| Var (_, atm) -> add_pattern tp (Atm atm)
|
||||
|
||||
let explode_term tp id =
|
||||
let rec aux tms n id =
|
||||
let t = term tp id in
|
||||
match t.data with
|
||||
| Leaf _ -> (n, {t with id = n} :: tms)
|
||||
| Binop (op, id1, id2) ->
|
||||
let n1, tms = aux tms n id1 in
|
||||
let n = n1 + 1 in
|
||||
let n2, tms = aux tms n id2 in
|
||||
let n = n2 + 1 in
|
||||
(n, { t with data = Binop (op, n1, n2)
|
||||
; id = n } :: tms)
|
||||
in
|
||||
let n, tms = aux [] 0 id in
|
||||
Array.of_list (List.rev tms), n
|
||||
end
|
||||
|
||||
module R = Random
|
||||
|
||||
(* uniform pick in a list *)
|
||||
let list_pick l =
|
||||
let rec aux n l x =
|
||||
match l with
|
||||
| [] -> x
|
||||
| y :: l ->
|
||||
if R.int (n + 1) = 0 then
|
||||
aux (n + 1) l y
|
||||
else
|
||||
aux (n + 1) l x
|
||||
in
|
||||
match l with
|
||||
| [] -> invalid_arg "list_pick"
|
||||
| x :: l -> aux 1 l x
|
||||
|
||||
let term_pick ~numbr =
|
||||
let ops =
|
||||
if numbr.ops = [] then
|
||||
numbr.ops <-
|
||||
(StateMap.fold (fun k _ ops ->
|
||||
match k with
|
||||
| K (op, _, _) -> op :: ops)
|
||||
numbr.statemap [] |> setify);
|
||||
numbr.ops
|
||||
in
|
||||
let rec gen depth =
|
||||
(* exponential probability for leaves to
|
||||
* avoid skewing towards shallow terms *)
|
||||
let atm_prob = 0.75 ** float_of_int depth in
|
||||
if R.float 1.0 <= atm_prob || ops = [] then
|
||||
let atom, st = list_pick numbr.atoms in
|
||||
(st, Atm atom)
|
||||
else
|
||||
let op = list_pick ops in
|
||||
let s1, t1 = gen (depth - 1) in
|
||||
let s2, t2 = gen (depth - 1) in
|
||||
( binop_state numbr op s1 s2
|
||||
, Bnr (op, t1, t2) )
|
||||
in fun ~depth -> gen depth
|
||||
|
||||
exception FuzzError
|
||||
|
||||
let rec pattern_depth = function
|
||||
| Bnr (_, p1, p2) ->
|
||||
1 + max (pattern_depth p1) (pattern_depth p2)
|
||||
| Atm _ -> 0
|
||||
| Var (_, atm) -> pattern_depth (Atm atm)
|
||||
|
||||
let ( %% ) a b =
|
||||
1e2 *. float_of_int a /. float_of_int b
|
||||
|
||||
let progress ?(width = 50) msg pct =
|
||||
Format.eprintf "\x1b[2K\r%!";
|
||||
let progress_bar fmt =
|
||||
let n =
|
||||
let fwidth = float_of_int width in
|
||||
1 + int_of_float (pct *. fwidth /. 1e2)
|
||||
in
|
||||
Format.fprintf fmt " %s%s %.0f%%@?"
|
||||
(String.concat "" (List.init n (fun _ -> "▒")))
|
||||
(String.make (max 0 (width - n)) '-')
|
||||
pct
|
||||
in
|
||||
Format.kfprintf progress_bar
|
||||
Format.err_formatter msg
|
||||
|
||||
let fuzz_numberer rules numbr =
|
||||
(* pick twice the max pattern depth so we
|
||||
* have a chance to find non-trivial numbers
|
||||
* for the atomic patterns in the rules *)
|
||||
let depth =
|
||||
List.fold_left (fun depth r ->
|
||||
max depth (pattern_depth r.pattern))
|
||||
0 rules * 2
|
||||
in
|
||||
(* fuzz until the term pool we are constructing
|
||||
* is no longer growing fast enough; or we just
|
||||
* went through sufficiently many iterations *)
|
||||
let max_iter = 1_000_000 in
|
||||
let low_insert_rate = 1e-2 in
|
||||
let tp = TermPool.create numbr in
|
||||
let rec loop new_stats i =
|
||||
let (_, _, insert_rate) = new_stats in
|
||||
if insert_rate <= low_insert_rate then () else
|
||||
if i >= max_iter then () else
|
||||
(* periodically update stats *)
|
||||
let new_stats =
|
||||
let (num, cnt, rate) = new_stats in
|
||||
if num land 1023 = 0 then
|
||||
let rate =
|
||||
0.5 *. (rate +. float_of_int cnt /. 1023.)
|
||||
in
|
||||
progress " insert_rate=%.1f%%"
|
||||
(i %% max_iter) (rate *. 1e2);
|
||||
(num + 1, 0, rate)
|
||||
else new_stats
|
||||
in
|
||||
(* create a term and check that its number is
|
||||
* accurate wrt the rules *)
|
||||
let st, term = term_pick ~numbr ~depth in
|
||||
let state_matched =
|
||||
List.filter_map (fun cu ->
|
||||
match cu with
|
||||
| Top ("$" | "%") -> None
|
||||
| Top name -> Some name
|
||||
| _ -> None)
|
||||
st.point |> setify
|
||||
in
|
||||
let rule_matched =
|
||||
List.filter_map (fun r ->
|
||||
if pattern_match r.pattern term then
|
||||
Some r.name
|
||||
else None)
|
||||
rules |> setify
|
||||
in
|
||||
if state_matched <> rule_matched then begin
|
||||
let open Format in
|
||||
let pp_str_list =
|
||||
let pp_sep fmt () = fprintf fmt ",@ " in
|
||||
pp_print_list ~pp_sep pp_print_string
|
||||
in
|
||||
eprintf "@.@[<v2>fuzz error for %s"
|
||||
(show_pattern term);
|
||||
eprintf "@ @[state matched: %a@]"
|
||||
pp_str_list state_matched;
|
||||
eprintf "@ @[rule matched: %a@]"
|
||||
pp_str_list rule_matched;
|
||||
eprintf "@]@.";
|
||||
raise FuzzError;
|
||||
end;
|
||||
if state_matched = [] then
|
||||
loop new_stats (i + 1)
|
||||
else
|
||||
(* add to the term pool *)
|
||||
let old_size = TermPool.size tp in
|
||||
let _ = TermPool.add_pattern tp term in
|
||||
let new_stats =
|
||||
let (num, cnt, rate) = new_stats in
|
||||
if TermPool.size tp <> old_size then
|
||||
(num + 1, cnt + 1, rate)
|
||||
else
|
||||
(num + 1, cnt, rate)
|
||||
in
|
||||
loop new_stats (i + 1)
|
||||
in
|
||||
loop (1, 0, 1.0) 0;
|
||||
Format.eprintf
|
||||
"@.@[ generated %.3fMiB of test terms@]@."
|
||||
(float_of_int (Obj.reachable_words (Obj.repr tp))
|
||||
/. 128. /. 1024.);
|
||||
tp
|
||||
|
||||
let rec run_matcher stk m (ta, id as t) =
|
||||
let state id = ta.(id).state.id in
|
||||
match m.Action.node with
|
||||
| Action.Switch cases ->
|
||||
let m =
|
||||
try List.assoc (state id) cases
|
||||
with Not_found -> failwith "no switch case"
|
||||
in
|
||||
run_matcher stk m t
|
||||
| Action.Push (sym, m) ->
|
||||
let l, r =
|
||||
match ta.(id).data with
|
||||
| Leaf _ -> failwith "push on leaf"
|
||||
| Binop (_, l, r) -> (l, r)
|
||||
in
|
||||
if sym && state l > state r
|
||||
then run_matcher (l :: stk) m (ta, r)
|
||||
else run_matcher (r :: stk) m (ta, l)
|
||||
| Action.Pop m -> begin
|
||||
match stk with
|
||||
| id :: stk -> run_matcher stk m (ta, id)
|
||||
| [] -> failwith "pop on empty stack"
|
||||
end
|
||||
| Action.Set (v, m) ->
|
||||
(v, id) :: run_matcher stk m t
|
||||
| Action.Stop -> []
|
||||
|
||||
let rec term_match p (ta, id) =
|
||||
let (|>>) x f =
|
||||
match x with None -> None | Some x -> f x
|
||||
in
|
||||
let atom_match a =
|
||||
match ta.(id).data with
|
||||
| Leaf a' -> pattern_match (Atm a) (Atm a')
|
||||
| Binop _ -> pattern_match (Atm a) (Atm Tmp)
|
||||
in
|
||||
match p with
|
||||
| Var (v, a) when atom_match a ->
|
||||
Some [(v, id)]
|
||||
| Atm a when atom_match a -> Some []
|
||||
| (Atm _ | Var _) -> None
|
||||
| Bnr (op, pl, pr) -> begin
|
||||
match ta.(id).data with
|
||||
| Binop (op', idl, idr) when op' = op ->
|
||||
term_match pl (ta, idl) |>> fun l1 ->
|
||||
term_match pr (ta, idr) |>> fun l2 ->
|
||||
Some (l1 @ l2)
|
||||
| _ -> None
|
||||
end
|
||||
|
||||
let test_matchers tp numbr rules =
|
||||
let {statemap = sm; states = sa; _} = numbr in
|
||||
let total = ref 0 in
|
||||
let matchers =
|
||||
let htbl = Hashtbl.create (Array.length sa) in
|
||||
List.map (fun r -> (r.name, r.pattern)) rules |>
|
||||
group_by_fst |>
|
||||
List.iter (fun (r, ps) ->
|
||||
total := !total + List.length ps;
|
||||
let pm = (ps, lr_matcher sm sa rules r) in
|
||||
sa |> Array.iter (fun s ->
|
||||
if List.mem (Top r) s.point then
|
||||
Hashtbl.add htbl s.id pm));
|
||||
htbl
|
||||
in
|
||||
let seen = Hashtbl.create !total in
|
||||
for id = 0 to TermPool.size tp - 1 do
|
||||
if id land 1023 = 0 ||
|
||||
id = TermPool.size tp - 1 then begin
|
||||
progress
|
||||
" coverage=%.1f%%"
|
||||
(id %% TermPool.size tp)
|
||||
(Hashtbl.length seen %% !total)
|
||||
end;
|
||||
let t = TermPool.explode_term tp id in
|
||||
Hashtbl.find_all matchers
|
||||
(TermPool.term tp id).state.id |>
|
||||
List.iter (fun (ps, m) ->
|
||||
let norm = List.fast_sort compare in
|
||||
let ok =
|
||||
match norm (run_matcher [] m t) with
|
||||
| asn -> `Match (List.exists (fun p ->
|
||||
match term_match p t with
|
||||
| None -> false
|
||||
| Some asn' ->
|
||||
if asn = norm asn' then begin
|
||||
Hashtbl.replace seen p ();
|
||||
true
|
||||
end else false) ps)
|
||||
| exception e -> `RunFailure e
|
||||
in
|
||||
if ok <> `Match true then begin
|
||||
let open Format in
|
||||
let pp_asn fmt asn =
|
||||
fprintf fmt "@[<h>";
|
||||
pp_print_list
|
||||
~pp_sep:(fun fmt () -> fprintf fmt ";@ ")
|
||||
(fun fmt (v, d) ->
|
||||
fprintf fmt "@[%s←%d@]" v d)
|
||||
fmt asn;
|
||||
fprintf fmt "@]"
|
||||
in
|
||||
eprintf "@.@[<v2>matcher error for";
|
||||
eprintf "@ @[%a@]" pp_term t;
|
||||
begin match ok with
|
||||
| `RunFailure e ->
|
||||
eprintf "@ @[exception: %s@]"
|
||||
(Printexc.to_string e)
|
||||
| `Match (* false *) _ ->
|
||||
let asn = run_matcher [] m t in
|
||||
eprintf "@ @[assignment: %a@]"
|
||||
pp_asn asn;
|
||||
eprintf "@ @[<v2>could not match";
|
||||
List.iter (fun p ->
|
||||
eprintf "@ + @[%s@]"
|
||||
(show_pattern p)) ps;
|
||||
eprintf "@]"
|
||||
end;
|
||||
eprintf "@]@.";
|
||||
raise FuzzError
|
||||
end)
|
||||
done;
|
||||
Format.eprintf "@."
|
||||
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
open Cgen
|
||||
open Match
|
||||
|
||||
let mgen ~verbose ~fuzz path lofs input oc =
|
||||
let info ?(level = 1) fmt =
|
||||
if level <= verbose then
|
||||
Printf.eprintf fmt
|
||||
else
|
||||
Printf.ifprintf stdout fmt
|
||||
in
|
||||
|
||||
let rules =
|
||||
match Sexp.(run_parser ppats) input with
|
||||
| `Error (ps, err, loc) ->
|
||||
Printf.eprintf "%s:%d:%d %s\n"
|
||||
path (lofs + ps.Sexp.line) ps.Sexp.coln err;
|
||||
Printf.eprintf "%s" loc;
|
||||
exit 1
|
||||
| `Ok rules -> rules
|
||||
in
|
||||
|
||||
info "adding ac variants...%!";
|
||||
let nparsed =
|
||||
List.fold_left
|
||||
(fun npats (_, _, ps) ->
|
||||
npats + List.length ps)
|
||||
0 rules
|
||||
in
|
||||
let varsmap = Hashtbl.create 10 in
|
||||
let rules =
|
||||
List.concat_map (fun (name, vars, patterns) ->
|
||||
(try assert (Hashtbl.find varsmap name = vars)
|
||||
with Not_found -> ());
|
||||
Hashtbl.replace varsmap name vars;
|
||||
List.map
|
||||
(fun pattern -> {name; vars; pattern})
|
||||
(List.concat_map ac_equiv patterns)
|
||||
) rules
|
||||
in
|
||||
info " %d -> %d patterns\n"
|
||||
nparsed (List.length rules);
|
||||
|
||||
let rnames =
|
||||
setify (List.map (fun r -> r.name) rules) in
|
||||
|
||||
info "generating match tables...%!";
|
||||
let sa, am, sm = generate_table rules in
|
||||
let numbr = make_numberer sa am sm in
|
||||
info " %d states, %d rules\n"
|
||||
(Array.length sa) (StateMap.cardinal sm);
|
||||
if verbose >= 2 then begin
|
||||
info "-------------\nstates:\n";
|
||||
Array.iteri (fun i s ->
|
||||
info " state %d: %s\n"
|
||||
i (show_pattern s.seen)) sa;
|
||||
info "-------------\nstatemap:\n";
|
||||
Test.print_sm stderr sm;
|
||||
info "-------------\n";
|
||||
end;
|
||||
|
||||
info "generating matchers...\n";
|
||||
let matchers =
|
||||
List.map (fun rname ->
|
||||
info "+ %s...%!" rname;
|
||||
let m = lr_matcher sm sa rules rname in
|
||||
let vars = Hashtbl.find varsmap rname in
|
||||
info " %d nodes\n" (Action.size m);
|
||||
info ~level:2 " -------------\n";
|
||||
info ~level:2 " automaton:\n";
|
||||
info ~level:2 "%s\n"
|
||||
(Format.asprintf " @[%a@]" Action.pp m);
|
||||
info ~level:2 " ----------\n";
|
||||
(vars, rname, m)
|
||||
) rnames
|
||||
in
|
||||
|
||||
if fuzz then begin
|
||||
info ~level:0 "fuzzing statemap...\n";
|
||||
let tp = Fuzz.fuzz_numberer rules numbr in
|
||||
info ~level:0 "testing %d patterns...\n"
|
||||
(List.length rules);
|
||||
Fuzz.test_matchers tp numbr rules
|
||||
end;
|
||||
|
||||
info "emitting C...\n";
|
||||
flush stderr;
|
||||
|
||||
let cgopts =
|
||||
{ pfx = ""; static = true; oc = oc } in
|
||||
emit_c cgopts numbr;
|
||||
emit_matchers cgopts matchers;
|
||||
|
||||
()
|
||||
|
||||
let read_all ic =
|
||||
let bufsz = 4096 in
|
||||
let buf = Bytes.create bufsz in
|
||||
let data = Buffer.create bufsz in
|
||||
let read = ref 0 in
|
||||
while
|
||||
read := input ic buf 0 bufsz;
|
||||
!read <> 0
|
||||
do
|
||||
Buffer.add_subbytes data buf 0 !read
|
||||
done;
|
||||
Buffer.contents data
|
||||
|
||||
let split_c src =
|
||||
let begin_re, eoc_re, end_re =
|
||||
let re = Str.regexp in
|
||||
( re "mgen generated code"
|
||||
, re "\\*/"
|
||||
, re "end of generated code" )
|
||||
in
|
||||
let str_match regexp str =
|
||||
try
|
||||
let _: int =
|
||||
Str.search_forward regexp str 0
|
||||
in true
|
||||
with Not_found -> false
|
||||
in
|
||||
|
||||
let rec go st lofs pfx rules lines =
|
||||
let line, lines =
|
||||
match lines with
|
||||
| [] ->
|
||||
failwith (
|
||||
match st with
|
||||
| `Prefix -> "could not find mgen section"
|
||||
| `Rules -> "mgen rules not terminated"
|
||||
| `Skip -> "mgen section not terminated"
|
||||
)
|
||||
| l :: ls -> (l, ls)
|
||||
in
|
||||
match st with
|
||||
| `Prefix ->
|
||||
let pfx = line :: pfx in
|
||||
if str_match begin_re line
|
||||
then
|
||||
let lofs = List.length pfx in
|
||||
go `Rules lofs pfx rules lines
|
||||
else go `Prefix 0 pfx rules lines
|
||||
| `Rules ->
|
||||
let pfx = line :: pfx in
|
||||
if str_match eoc_re line
|
||||
then go `Skip lofs pfx rules lines
|
||||
else go `Rules lofs pfx (line :: rules) lines
|
||||
| `Skip ->
|
||||
if str_match end_re line then
|
||||
let join = String.concat "\n" in
|
||||
let pfx = join (List.rev pfx) ^ "\n\n"
|
||||
and rules = join (List.rev rules)
|
||||
and sfx = join (line :: lines)
|
||||
in (lofs, pfx, rules, sfx)
|
||||
else go `Skip lofs pfx rules lines
|
||||
in
|
||||
|
||||
let lines = String.split_on_char '\n' src in
|
||||
go `Prefix 0 [] [] lines
|
||||
|
||||
let () =
|
||||
let usage_msg =
|
||||
"mgen [--fuzz] [--verbose <N>] <file>" in
|
||||
|
||||
let fuzz_arg = ref false in
|
||||
let verbose_arg = ref 0 in
|
||||
let input_paths = ref [] in
|
||||
|
||||
let anon_fun filename =
|
||||
input_paths := filename :: !input_paths in
|
||||
|
||||
let speclist =
|
||||
[ ( "--fuzz", Arg.Set fuzz_arg
|
||||
, " Fuzz tables and matchers" )
|
||||
; ( "--verbose", Arg.Set_int verbose_arg
|
||||
, "<N> Set verbosity level" )
|
||||
; ( "--", Arg.Rest_all (List.iter anon_fun)
|
||||
, " Stop argument parsing" ) ]
|
||||
in
|
||||
Arg.parse speclist anon_fun usage_msg;
|
||||
|
||||
let input_paths = !input_paths in
|
||||
let verbose = !verbose_arg in
|
||||
let fuzz = !fuzz_arg in
|
||||
let input_path, input =
|
||||
match input_paths with
|
||||
| ["-"] -> ("-", read_all stdin)
|
||||
| [path] -> (path, read_all (open_in path))
|
||||
| _ ->
|
||||
Printf.eprintf
|
||||
"%s: single input file expected\n"
|
||||
Sys.argv.(0);
|
||||
Arg.usage speclist usage_msg; exit 1
|
||||
in
|
||||
let mgen = mgen ~verbose ~fuzz in
|
||||
|
||||
if Str.last_chars input_path 2 <> ".c"
|
||||
then mgen input_path 0 input stdout
|
||||
else
|
||||
let tmp_path = input_path ^ ".tmp" in
|
||||
Fun.protect
|
||||
~finally:(fun () ->
|
||||
try Sys.remove tmp_path with _ -> ())
|
||||
(fun () ->
|
||||
let lofs, pfx, rules, sfx = split_c input in
|
||||
let oc = open_out tmp_path in
|
||||
output_string oc pfx;
|
||||
mgen input_path lofs rules oc;
|
||||
output_string oc sfx;
|
||||
close_out oc;
|
||||
Sys.rename tmp_path input_path;
|
||||
());
|
||||
|
||||
()
|
||||
@@ -0,0 +1,651 @@
|
||||
type cls = Kw | Kl | Ks | Kd
|
||||
type op_base =
|
||||
| Oadd
|
||||
| Osub
|
||||
| Omul
|
||||
| Oor
|
||||
| Oshl
|
||||
| Oshr
|
||||
type op = cls * op_base
|
||||
|
||||
let op_bases =
|
||||
[Oadd; Osub; Omul; Oor; Oshl; Oshr]
|
||||
|
||||
let commutative = function
|
||||
| (_, (Oadd | Omul | Oor)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
let associative = function
|
||||
| (_, (Oadd | Omul | Oor)) -> true
|
||||
| (_, _) -> false
|
||||
|
||||
type atomic_pattern =
|
||||
| Tmp
|
||||
| AnyCon
|
||||
| Con of int64
|
||||
(* Tmp < AnyCon < Con k *)
|
||||
|
||||
type pattern =
|
||||
| Bnr of op * pattern * pattern
|
||||
| Atm of atomic_pattern
|
||||
| Var of string * atomic_pattern
|
||||
|
||||
let is_atomic = function
|
||||
| (Atm _ | Var _) -> true
|
||||
| _ -> false
|
||||
|
||||
let show_op_base o =
|
||||
match o with
|
||||
| Oadd -> "add"
|
||||
| Osub -> "sub"
|
||||
| Omul -> "mul"
|
||||
| Oor -> "or"
|
||||
| Oshl -> "shl"
|
||||
| Oshr -> "shr"
|
||||
|
||||
let show_op (k, o) =
|
||||
show_op_base o ^
|
||||
(match k with
|
||||
| Kw -> "w"
|
||||
| Kl -> "l"
|
||||
| Ks -> "s"
|
||||
| Kd -> "d")
|
||||
|
||||
let rec show_pattern p =
|
||||
match p with
|
||||
| Atm Tmp -> "%"
|
||||
| Atm AnyCon -> "$"
|
||||
| Atm (Con n) -> Int64.to_string n
|
||||
| Var (v, p) ->
|
||||
show_pattern (Atm p) ^ "'" ^ v
|
||||
| Bnr (o, pl, pr) ->
|
||||
"(" ^ show_op o ^
|
||||
" " ^ show_pattern pl ^
|
||||
" " ^ show_pattern pr ^ ")"
|
||||
|
||||
let get_atomic p =
|
||||
match p with
|
||||
| (Atm a | Var (_, a)) -> Some a
|
||||
| _ -> None
|
||||
|
||||
let rec pattern_match p w =
|
||||
match p with
|
||||
| Var (_, p) ->
|
||||
pattern_match (Atm p) w
|
||||
| Atm Tmp ->
|
||||
begin match get_atomic w with
|
||||
| Some (Con _ | AnyCon) -> false
|
||||
| _ -> true
|
||||
end
|
||||
| Atm (Con _) -> w = p
|
||||
| Atm (AnyCon) ->
|
||||
not (pattern_match (Atm Tmp) w)
|
||||
| Bnr (o, pl, pr) ->
|
||||
begin match w with
|
||||
| Bnr (o', wl, wr) ->
|
||||
o' = o &&
|
||||
pattern_match pl wl &&
|
||||
pattern_match pr wr
|
||||
| _ -> false
|
||||
end
|
||||
|
||||
type +'a cursor = (* a position inside a pattern *)
|
||||
| Bnrl of op * 'a cursor * pattern
|
||||
| Bnrr of op * pattern * 'a cursor
|
||||
| Top of 'a
|
||||
|
||||
let rec fold_cursor c p =
|
||||
match c with
|
||||
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
|
||||
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
|
||||
| Top _ -> p
|
||||
|
||||
let peel p x =
|
||||
let once out (p, c) =
|
||||
match p with
|
||||
| Var (_, p) -> (Atm p, c) :: out
|
||||
| Atm _ -> (p, c) :: out
|
||||
| Bnr (o, pl, pr) ->
|
||||
(pl, Bnrl (o, c, pr)) ::
|
||||
(pr, Bnrr (o, pl, c)) :: out
|
||||
in
|
||||
let rec go l =
|
||||
let l' = List.fold_left once [] l in
|
||||
if List.length l' = List.length l
|
||||
then l'
|
||||
else go l'
|
||||
in go [(p, Top x)]
|
||||
|
||||
let fold_pairs l1 l2 ini f =
|
||||
let rec go acc = function
|
||||
| [] -> acc
|
||||
| a :: l1' ->
|
||||
go (List.fold_left
|
||||
(fun acc b -> f (a, b) acc)
|
||||
acc l2) l1'
|
||||
in go ini l1
|
||||
|
||||
let iter_pairs l f =
|
||||
fold_pairs l l () (fun x () -> f x)
|
||||
|
||||
let inverse l =
|
||||
List.map (fun (a, b) -> (b, a)) l
|
||||
|
||||
type 'a state =
|
||||
{ id: int
|
||||
; seen: pattern
|
||||
; point: ('a cursor) list }
|
||||
|
||||
let rec binops side {point; _} =
|
||||
List.filter_map (fun c ->
|
||||
match c, side with
|
||||
| Bnrl (o, c, r), `L -> Some ((o, c), r)
|
||||
| Bnrr (o, l, c), `R -> Some ((o, c), l)
|
||||
| _ -> None)
|
||||
point
|
||||
|
||||
let group_by_fst l =
|
||||
List.fast_sort (fun (a, _) (b, _) ->
|
||||
compare a b) l |>
|
||||
List.fold_left (fun (oo, l, res) (o', c) ->
|
||||
match oo with
|
||||
| None -> (Some o', [c], [])
|
||||
| Some o when o = o' -> (oo, c :: l, res)
|
||||
| Some o -> (Some o', [c], (o, l) :: res))
|
||||
(None, [], []) |>
|
||||
(function
|
||||
| (None, _, _) -> []
|
||||
| (Some o, l, res) -> (o, l) :: res)
|
||||
|
||||
let sort_uniq cmp l =
|
||||
List.fast_sort cmp l |>
|
||||
List.fold_left (fun (eo, l) e' ->
|
||||
match eo with
|
||||
| None -> (Some e', l)
|
||||
| Some e when cmp e e' = 0 -> (eo, l)
|
||||
| Some e -> (Some e', e :: l))
|
||||
(None, []) |>
|
||||
(function
|
||||
| (None, _) -> []
|
||||
| (Some e, l) -> List.rev (e :: l))
|
||||
|
||||
let setify l =
|
||||
sort_uniq compare l
|
||||
|
||||
let normalize (point: ('a cursor) list) =
|
||||
setify point
|
||||
|
||||
let next_binary tmp s1 s2 =
|
||||
let pm w (_, p) = pattern_match p w in
|
||||
let o1 = binops `L s1 |>
|
||||
List.filter (pm s2.seen) |>
|
||||
List.map fst in
|
||||
let o2 = binops `R s2 |>
|
||||
List.filter (pm s1.seen) |>
|
||||
List.map fst in
|
||||
List.map (fun (o, l) ->
|
||||
o,
|
||||
{ id = -1
|
||||
; seen = Bnr (o, s1.seen, s2.seen)
|
||||
; point = normalize (l @ tmp) })
|
||||
(group_by_fst (o1 @ o2))
|
||||
|
||||
type p = string
|
||||
|
||||
module StateSet : sig
|
||||
type t
|
||||
val create: unit -> t
|
||||
val add: t -> p state ->
|
||||
[> `Added | `Found ] * p state
|
||||
val iter: t -> (p state -> unit) -> unit
|
||||
val elems: t -> (p state) list
|
||||
end = struct
|
||||
open Hashtbl.Make(struct
|
||||
type t = p state
|
||||
let equal s1 s2 = s1.point = s2.point
|
||||
let hash s = Hashtbl.hash s.point
|
||||
end)
|
||||
type nonrec t =
|
||||
{ h: int t
|
||||
; mutable next_id: int }
|
||||
let create () =
|
||||
{ h = create 500; next_id = 0 }
|
||||
let add set s =
|
||||
assert (s.point = normalize s.point);
|
||||
try
|
||||
let id = find set.h s in
|
||||
`Found, {s with id}
|
||||
with Not_found -> begin
|
||||
let id = set.next_id in
|
||||
set.next_id <- id + 1;
|
||||
add set.h s id;
|
||||
`Added, {s with id}
|
||||
end
|
||||
let iter set f =
|
||||
let f s id = f {s with id} in
|
||||
iter f set.h
|
||||
let elems set =
|
||||
let res = ref [] in
|
||||
iter set (fun s -> res := s :: !res);
|
||||
!res
|
||||
end
|
||||
|
||||
type table_key =
|
||||
| K of op * p state * p state
|
||||
|
||||
module StateMap = struct
|
||||
include Map.Make(struct
|
||||
type t = table_key
|
||||
let compare ka kb =
|
||||
match ka, kb with
|
||||
| K (o, sl, sr), K (o', sl', sr') ->
|
||||
compare (o, sl.id, sr.id)
|
||||
(o', sl'.id, sr'.id)
|
||||
end)
|
||||
let invert n sm =
|
||||
let rmap = Array.make n [] in
|
||||
iter (fun k {id; _} ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
rmap.(id) <-
|
||||
(o, (sl.id, sr.id)) :: rmap.(id)
|
||||
) sm;
|
||||
Array.map group_by_fst rmap
|
||||
let by_ops sm =
|
||||
fold (fun tk s ops ->
|
||||
match tk with
|
||||
| K (op, l, r) ->
|
||||
(op, ((l.id, r.id), s.id)) :: ops)
|
||||
sm [] |> group_by_fst
|
||||
end
|
||||
|
||||
type rule =
|
||||
{ name: string
|
||||
; vars: string list
|
||||
; pattern: pattern }
|
||||
|
||||
let generate_table rl =
|
||||
let states = StateSet.create () in
|
||||
let rl =
|
||||
(* these atomic patterns must occur in
|
||||
* rules so that we are able to number
|
||||
* all possible refs *)
|
||||
[ { name = "$"; vars = []
|
||||
; pattern = Atm AnyCon }
|
||||
; { name = "%"; vars = []
|
||||
; pattern = Atm Tmp } ] @ rl
|
||||
in
|
||||
(* initialize states *)
|
||||
let ground =
|
||||
List.concat_map
|
||||
(fun r -> peel r.pattern r.name) rl |>
|
||||
group_by_fst
|
||||
in
|
||||
let tmp = List.assoc (Atm Tmp) ground in
|
||||
let con = List.assoc (Atm AnyCon) ground in
|
||||
let atoms = ref [] in
|
||||
let () =
|
||||
List.iter (fun (seen, l) ->
|
||||
let point =
|
||||
if pattern_match (Atm Tmp) seen
|
||||
then normalize (tmp @ l)
|
||||
else normalize (con @ l)
|
||||
in
|
||||
let s = {id = -1; seen; point} in
|
||||
let _, s = StateSet.add states s in
|
||||
match get_atomic seen with
|
||||
| Some atm -> atoms := (atm, s) :: !atoms
|
||||
| None -> ()
|
||||
) ground
|
||||
in
|
||||
(* setup loop state *)
|
||||
let map = ref StateMap.empty in
|
||||
let map_add k s' =
|
||||
map := StateMap.add k s' !map
|
||||
in
|
||||
let flag = ref `Added in
|
||||
let flagmerge = function
|
||||
| `Added -> flag := `Added
|
||||
| _ -> ()
|
||||
in
|
||||
(* iterate until fixpoint *)
|
||||
while !flag = `Added do
|
||||
flag := `Stop;
|
||||
let statel = StateSet.elems states in
|
||||
iter_pairs statel (fun (sl, sr) ->
|
||||
next_binary tmp sl sr |>
|
||||
List.iter (fun (o, s') ->
|
||||
let flag', s' =
|
||||
StateSet.add states s' in
|
||||
flagmerge flag';
|
||||
map_add (K (o, sl, sr)) s';
|
||||
));
|
||||
done;
|
||||
let states =
|
||||
StateSet.elems states |>
|
||||
List.sort (fun s s' -> compare s.id s'.id) |>
|
||||
Array.of_list
|
||||
in
|
||||
(states, !atoms, !map)
|
||||
|
||||
let intersperse x l =
|
||||
let rec go left right out =
|
||||
let out =
|
||||
(List.rev left @ [x] @ right) ::
|
||||
out in
|
||||
match right with
|
||||
| x :: right' ->
|
||||
go (x :: left) right' out
|
||||
| [] -> out
|
||||
in go [] l []
|
||||
|
||||
let rec permute = function
|
||||
| [] -> [[]]
|
||||
| x :: l ->
|
||||
List.concat (List.map
|
||||
(intersperse x) (permute l))
|
||||
|
||||
(* build all binary trees with ordered
|
||||
* leaves l *)
|
||||
let rec bins build l =
|
||||
let rec go l r out =
|
||||
match r with
|
||||
| [] -> out
|
||||
| x :: r' ->
|
||||
go (l @ [x]) r'
|
||||
(fold_pairs
|
||||
(bins build l)
|
||||
(bins build r)
|
||||
out (fun (l, r) out ->
|
||||
build l r :: out))
|
||||
in
|
||||
match l with
|
||||
| [] -> []
|
||||
| [x] -> [x]
|
||||
| x :: l -> go [x] l []
|
||||
|
||||
let products l ini f =
|
||||
let rec go acc la = function
|
||||
| [] -> f (List.rev la) acc
|
||||
| xs :: l ->
|
||||
List.fold_left (fun acc x ->
|
||||
go acc (x :: la) l)
|
||||
acc xs
|
||||
in go ini [] l
|
||||
|
||||
(* combinatorial nuke... *)
|
||||
let rec ac_equiv =
|
||||
let rec alevel o = function
|
||||
| Bnr (o', l, r) when o' = o ->
|
||||
alevel o l @ alevel o r
|
||||
| x -> [x]
|
||||
in function
|
||||
| Bnr (o, _, _) as p
|
||||
when associative o ->
|
||||
products
|
||||
(List.map ac_equiv (alevel o p)) []
|
||||
(fun choice out ->
|
||||
List.concat_map
|
||||
(bins (fun l r -> Bnr (o, l, r)))
|
||||
(if commutative o
|
||||
then permute choice
|
||||
else [choice]) @ out)
|
||||
| Bnr (o, l, r)
|
||||
when commutative o ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) ::
|
||||
Bnr (o, r, l) :: out)
|
||||
| Bnr (o, l, r) ->
|
||||
fold_pairs
|
||||
(ac_equiv l) (ac_equiv r) []
|
||||
(fun (l, r) out ->
|
||||
Bnr (o, l, r) :: out)
|
||||
| x -> [x]
|
||||
|
||||
module Action: sig
|
||||
type node =
|
||||
| Switch of (int * t) list
|
||||
| Push of bool * t
|
||||
| Pop of t
|
||||
| Set of string * t
|
||||
| Stop
|
||||
and t = private
|
||||
{ id: int; node: node }
|
||||
val equal: t -> t -> bool
|
||||
val size: t -> int
|
||||
val stop: t
|
||||
val mk_push: sym:bool -> t -> t
|
||||
val mk_pop: t -> t
|
||||
val mk_set: string -> t -> t
|
||||
val mk_switch: int list -> (int -> t) -> t
|
||||
val pp: Format.formatter -> t -> unit
|
||||
end = struct
|
||||
type node =
|
||||
| Switch of (int * t) list
|
||||
| Push of bool * t
|
||||
| Pop of t
|
||||
| Set of string * t
|
||||
| Stop
|
||||
and t =
|
||||
{ id: int; node: node }
|
||||
|
||||
let equal a a' = a.id = a'.id
|
||||
let size a =
|
||||
let seen = Hashtbl.create 10 in
|
||||
let rec node_size = function
|
||||
| Switch l ->
|
||||
List.fold_left
|
||||
(fun n (_, a) -> n + size a) 0 l
|
||||
| (Push (_, a) | Pop a | Set (_, a)) ->
|
||||
size a
|
||||
| Stop -> 0
|
||||
and size {id; node} =
|
||||
if Hashtbl.mem seen id
|
||||
then 0
|
||||
else begin
|
||||
Hashtbl.add seen id ();
|
||||
1 + node_size node
|
||||
end
|
||||
in
|
||||
size a
|
||||
|
||||
let mk =
|
||||
let hcons = Hashtbl.create 100 in
|
||||
let fresh = ref 0 in
|
||||
fun node ->
|
||||
let id =
|
||||
try Hashtbl.find hcons node
|
||||
with Not_found ->
|
||||
let id = !fresh in
|
||||
Hashtbl.add hcons node id;
|
||||
fresh := id + 1;
|
||||
id
|
||||
in
|
||||
{id; node}
|
||||
let stop = mk Stop
|
||||
let mk_push ~sym a = mk (Push (sym, a))
|
||||
let mk_pop a =
|
||||
match a.node with
|
||||
| Stop -> a
|
||||
| _ -> mk (Pop a)
|
||||
let mk_set v a = mk (Set (v, a))
|
||||
let mk_switch ids f =
|
||||
match List.map f ids with
|
||||
| [] -> failwith "empty switch";
|
||||
| c :: cs as cases ->
|
||||
if List.for_all (equal c) cs then c
|
||||
else
|
||||
let cases = List.combine ids cases in
|
||||
mk (Switch cases)
|
||||
|
||||
open Format
|
||||
let rec pp_node fmt = function
|
||||
| Switch l ->
|
||||
fprintf fmt "@[<v>@[<v2>switch{";
|
||||
let pp_case (c, a) =
|
||||
let pp_sep fmt () = fprintf fmt "," in
|
||||
fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
|
||||
(pp_print_list ~pp_sep pp_print_int)
|
||||
c pp a
|
||||
in
|
||||
inverse l |> group_by_fst |> inverse |>
|
||||
List.iter pp_case;
|
||||
fprintf fmt "@]@,}@]"
|
||||
| Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
|
||||
| Push (false, a) -> fprintf fmt "push@ %a" pp a
|
||||
| Pop a -> fprintf fmt "pop@ %a" pp a
|
||||
| Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
|
||||
| Stop -> fprintf fmt "•"
|
||||
and pp fmt a = pp_node fmt a.node
|
||||
end
|
||||
|
||||
(* a state is commutative if (a op b) enters
|
||||
* it iff (b op a) enters it as well *)
|
||||
let symmetric rmap id =
|
||||
List.for_all (fun (_, l) ->
|
||||
let l1, l2 =
|
||||
List.filter (fun (a, b) -> a <> b) l |>
|
||||
List.partition (fun (a, b) -> a < b)
|
||||
in
|
||||
setify l1 = setify (inverse l2))
|
||||
rmap.(id)
|
||||
|
||||
(* left-to-right matching of a set of patterns;
|
||||
* may raise if there is no lr matcher for the
|
||||
* input rule *)
|
||||
let lr_matcher statemap states rules name =
|
||||
let rmap =
|
||||
let nstates = Array.length states in
|
||||
StateMap.invert nstates statemap
|
||||
in
|
||||
let exception Stuck in
|
||||
(* the list of ids represents a class of terms
|
||||
* whose root ends up being labelled with one
|
||||
* such id; the gen function generates a matcher
|
||||
* that will, given any such term, assign values
|
||||
* for the Var nodes of one pattern in pats *)
|
||||
let rec gen
|
||||
: 'a. int list -> (pattern * 'a) list
|
||||
-> (int -> (pattern * 'a) list -> Action.t)
|
||||
-> Action.t
|
||||
= fun ids pats k ->
|
||||
Action.mk_switch (setify ids) @@ fun id_top ->
|
||||
let sym = symmetric rmap id_top in
|
||||
let id_ops =
|
||||
if sym then
|
||||
let ordered (a, b) = a <= b in
|
||||
List.map (fun (o, l) ->
|
||||
(o, List.filter ordered l))
|
||||
rmap.(id_top)
|
||||
else rmap.(id_top)
|
||||
in
|
||||
(* consider only the patterns that are
|
||||
* compatible with the current id *)
|
||||
let atm_pats, bin_pats =
|
||||
List.filter (function
|
||||
| Bnr (o, _, _), _ ->
|
||||
List.exists
|
||||
(fun (o', _) -> o' = o)
|
||||
id_ops
|
||||
| _ -> true) pats |>
|
||||
List.partition
|
||||
(fun (pat, _) -> is_atomic pat)
|
||||
in
|
||||
try
|
||||
if bin_pats = [] then raise Stuck;
|
||||
let pats_l =
|
||||
List.map (function
|
||||
| (Bnr (o, l, r), x) ->
|
||||
(l, (o, x, r))
|
||||
| _ -> assert false)
|
||||
bin_pats
|
||||
and pats_r =
|
||||
List.map (fun (l, (o, x, r)) ->
|
||||
(r, (o, l, x)))
|
||||
and patstop =
|
||||
List.map (fun (r, (o, l, x)) ->
|
||||
(Bnr (o, l, r), x))
|
||||
in
|
||||
let id_pairs = List.concat_map snd id_ops in
|
||||
let ids_l = List.map fst id_pairs
|
||||
and ids_r id_left =
|
||||
List.filter_map (fun (l, r) ->
|
||||
if l = id_left then Some r else None)
|
||||
id_pairs
|
||||
in
|
||||
(* match the left arm *)
|
||||
Action.mk_push ~sym
|
||||
(gen ids_l pats_l
|
||||
@@ fun lid pats ->
|
||||
(* then the right arm, considering
|
||||
* only the remaining possible
|
||||
* patterns and knowing that the
|
||||
* left arm was numbered 'lid' *)
|
||||
Action.mk_pop
|
||||
(gen (ids_r lid) (pats_r pats)
|
||||
@@ fun _rid pats ->
|
||||
(* continue with the parent *)
|
||||
k id_top (patstop pats)))
|
||||
with Stuck ->
|
||||
let atm_pats =
|
||||
let seen = states.(id_top).seen in
|
||||
List.filter (fun (pat, _) ->
|
||||
pattern_match pat seen) atm_pats
|
||||
in
|
||||
if atm_pats = [] then raise Stuck else
|
||||
let vars =
|
||||
List.filter_map (function
|
||||
| (Var (v, _), _) -> Some v
|
||||
| _ -> None) atm_pats |> setify
|
||||
in
|
||||
match vars with
|
||||
| [] -> k id_top atm_pats
|
||||
| [v] -> Action.mk_set v (k id_top atm_pats)
|
||||
| _ -> failwith "ambiguous var match"
|
||||
in
|
||||
(* generate a matcher for the rule *)
|
||||
let ids_top =
|
||||
Array.to_list states |>
|
||||
List.filter_map (fun {id; point = p; _} ->
|
||||
if List.exists ((=) (Top name)) p then
|
||||
Some id
|
||||
else None)
|
||||
in
|
||||
let rec filter_dups pats =
|
||||
match pats with
|
||||
| p :: pats ->
|
||||
if List.exists (pattern_match p) pats
|
||||
then filter_dups pats
|
||||
else p :: filter_dups pats
|
||||
| [] -> []
|
||||
in
|
||||
let pats_top =
|
||||
List.filter_map (fun r ->
|
||||
if r.name = name then
|
||||
Some r.pattern
|
||||
else None) rules |>
|
||||
filter_dups |>
|
||||
List.map (fun p -> (p, ()))
|
||||
in
|
||||
gen ids_top pats_top (fun _ pats ->
|
||||
assert (pats <> []);
|
||||
Action.stop)
|
||||
|
||||
type numberer =
|
||||
{ atoms: (atomic_pattern * p state) list
|
||||
; statemap: p state StateMap.t
|
||||
; states: p state array
|
||||
; mutable ops: op list
|
||||
(* memoizes the list of possible operations
|
||||
* according to the statemap *) }
|
||||
|
||||
let make_numberer sa am sm =
|
||||
{ atoms = am
|
||||
; states = sa
|
||||
; statemap = sm
|
||||
; ops = [] }
|
||||
|
||||
let atom_state n atm =
|
||||
List.assoc atm n.atoms
|
||||
@@ -0,0 +1,292 @@
|
||||
type pstate =
|
||||
{ data: string
|
||||
; line: int
|
||||
; coln: int
|
||||
; indx: int }
|
||||
|
||||
type perror =
|
||||
{ error: string
|
||||
; ps: pstate }
|
||||
|
||||
exception ParseError of perror
|
||||
|
||||
type 'a parser =
|
||||
{ fn: 'r. pstate -> ('a -> pstate -> 'r) -> 'r }
|
||||
|
||||
let update_pos ps beg fin =
|
||||
let l, c = (ref ps.line, ref ps.coln) in
|
||||
for i = beg to fin - 1 do
|
||||
if ps.data.[i] = '\n' then
|
||||
(incr l; c := 0)
|
||||
else
|
||||
incr c
|
||||
done;
|
||||
{ ps with line = !l; coln = !c }
|
||||
|
||||
let pret (type a) (x: a): a parser =
|
||||
let fn ps k = k x ps in { fn }
|
||||
|
||||
let pfail error: 'a parser =
|
||||
let fn ps _ = raise (ParseError {error; ps})
|
||||
in { fn }
|
||||
|
||||
let por: 'a parser -> 'a parser -> 'a parser =
|
||||
fun p1 p2 ->
|
||||
let fn ps k =
|
||||
try p1.fn ps k with ParseError e1 ->
|
||||
try p2.fn ps k with ParseError e2 ->
|
||||
if e1.ps.indx > e2.ps.indx then
|
||||
raise (ParseError e1)
|
||||
else
|
||||
raise (ParseError e2)
|
||||
in { fn }
|
||||
|
||||
let pbind: 'a parser -> ('a -> 'b parser) -> 'b parser =
|
||||
fun p1 p2 ->
|
||||
let fn ps k =
|
||||
p1.fn ps (fun x ps -> (p2 x).fn ps k)
|
||||
in { fn }
|
||||
|
||||
(* handy for recursive rules *)
|
||||
let papp p x = pbind (pret x) p
|
||||
|
||||
let psnd: 'a parser -> 'b parser -> 'b parser =
|
||||
fun p1 p2 -> pbind p1 (fun _x -> p2)
|
||||
|
||||
let pfst: 'a parser -> 'b parser -> 'a parser =
|
||||
fun p1 p2 -> pbind p1 (fun x -> psnd p2 (pret x))
|
||||
|
||||
module Infix = struct
|
||||
let ( let* ) = pbind
|
||||
let ( ||| ) = por
|
||||
let ( |<< ) = pfst
|
||||
let ( |>> ) = psnd
|
||||
end
|
||||
|
||||
open Infix
|
||||
|
||||
let pre: ?what:string -> string -> string parser =
|
||||
fun ?what re ->
|
||||
let what =
|
||||
match what with
|
||||
| None -> Printf.sprintf "%S" re
|
||||
| Some what -> what
|
||||
and re = Str.regexp re in
|
||||
let fn ps k =
|
||||
if not (Str.string_match re ps.data ps.indx) then
|
||||
(let error =
|
||||
Printf.sprintf "expected to match %s" what in
|
||||
raise (ParseError {error; ps}));
|
||||
let ps =
|
||||
let indx = Str.match_end () in
|
||||
{ (update_pos ps ps.indx indx) with indx }
|
||||
in
|
||||
k (Str.matched_string ps.data) ps
|
||||
in { fn }
|
||||
|
||||
let peoi: unit parser =
|
||||
let fn ps k =
|
||||
if ps.indx <> String.length ps.data then
|
||||
raise (ParseError
|
||||
{ error = "expected end of input"; ps });
|
||||
k () ps
|
||||
in { fn }
|
||||
|
||||
let pws = pre "[ \r\n\t*]*"
|
||||
let pws1 = pre "[ \r\n\t*]+"
|
||||
|
||||
let pthen p1 p2 =
|
||||
let* x1 = p1 in
|
||||
let* x2 = p2 in
|
||||
pret (x1, x2)
|
||||
|
||||
let rec plist_tail: 'a parser -> ('a list) parser =
|
||||
fun pitem ->
|
||||
(pws |>> pre ")" |>> pret []) |||
|
||||
(let* itm = pitem in
|
||||
let* itms = plist_tail pitem in
|
||||
pret (itm :: itms))
|
||||
|
||||
let plist pitem =
|
||||
pws |>> pre ~what:"a list" "("
|
||||
|>> plist_tail pitem
|
||||
|
||||
let plist1p p1 pitem =
|
||||
pws |>> pre ~what:"a list" "("
|
||||
|>> pthen p1 (plist_tail pitem)
|
||||
|
||||
let ppair p1 p2 =
|
||||
pws |>> pre ~what:"a pair" "("
|
||||
|>> pthen p1 p2 |<< pws |<< pre ")"
|
||||
|
||||
let run_parser p s =
|
||||
let ps =
|
||||
{data = s; line = 1; coln = 0; indx = 0} in
|
||||
try `Ok (p.fn ps (fun res _ps -> res))
|
||||
with ParseError e ->
|
||||
let rec bol i =
|
||||
if i = 0 then i else
|
||||
if i < String.length s && s.[i] = '\n'
|
||||
then i+1 (* XXX BUG *)
|
||||
else bol (i-1)
|
||||
in
|
||||
let rec eol i =
|
||||
if i = String.length s then i else
|
||||
if s.[i] = '\n' then i else
|
||||
eol (i+1)
|
||||
in
|
||||
let bol = bol e.ps.indx in
|
||||
let eol = eol e.ps.indx in
|
||||
(*
|
||||
Printf.eprintf "bol:%d eol:%d indx:%d len:%d\n"
|
||||
bol eol e.ps.indx (String.length s); (* XXX debug *)
|
||||
*)
|
||||
let lines =
|
||||
String.split_on_char '\n'
|
||||
(String.sub s bol (eol - bol))
|
||||
in
|
||||
let nl = List.length lines in
|
||||
let caret = ref (e.ps.indx - bol) in
|
||||
let msg = ref [] in
|
||||
let pfx = " > " in
|
||||
lines |> List.iteri (fun ln l ->
|
||||
if ln <> nl - 1 || l <> "" then begin
|
||||
let ll = String.length l + 1 in
|
||||
msg := (pfx ^ l ^ "\n") :: !msg;
|
||||
if !caret <= ll then begin
|
||||
let pad = String.make !caret ' ' in
|
||||
msg := (pfx ^ pad ^ "^\n") :: !msg;
|
||||
end;
|
||||
caret := !caret - ll;
|
||||
end;
|
||||
);
|
||||
`Error
|
||||
( e.ps, e.error
|
||||
, String.concat "" (List.rev !msg) )
|
||||
|
||||
(* ---------------------------------------- *)
|
||||
(* pattern parsing *)
|
||||
(* ---------------------------------------- *)
|
||||
(* Example syntax:
|
||||
|
||||
(with-vars (a b c d)
|
||||
(patterns
|
||||
(ob (add (tmp a) (con d)))
|
||||
(bsm (add (tmp b) (mul (tmp m) (con 2 4 8)))) ))
|
||||
*)
|
||||
open Match
|
||||
|
||||
let pint64 =
|
||||
let* s = pre "[-]?[0-9_]+" in
|
||||
pret (Int64.of_string s)
|
||||
|
||||
let pid =
|
||||
pre ~what:"an identifer"
|
||||
"[a-zA-Z][a-zA-Z0-9_]*"
|
||||
|
||||
let pop_base =
|
||||
let sob, obs = show_op_base, op_bases in
|
||||
let* s = pre ~what:"an operator"
|
||||
(String.concat "\\|" (List.map sob obs))
|
||||
in pret (List.find (fun o -> s = sob o) obs)
|
||||
|
||||
let pop = let* ob = pop_base in pret (Kl, ob)
|
||||
|
||||
let rec ppat vs =
|
||||
let pcons_tail =
|
||||
let* cs = plist_tail (pws1 |>> pint64) in
|
||||
match cs with
|
||||
| [] -> pret [AnyCon]
|
||||
| _ -> pret (List.map (fun c -> Con c) cs)
|
||||
in
|
||||
let pvar =
|
||||
let* id = pid in
|
||||
if not (List.mem id vs) then
|
||||
pfail ("unbound variable: " ^ id)
|
||||
else
|
||||
pret id
|
||||
in
|
||||
pws |>> (
|
||||
( let* c = pint64 in pret [Atm (Con c)] )
|
||||
|||
|
||||
( pre "(con)" |>> pret [Atm AnyCon] ) |||
|
||||
( let* cs = pre "(con" |>> pcons_tail in
|
||||
pret (List.map (fun c -> Atm c) cs) ) |||
|
||||
( let* v = pre "(con" |>> pws1 |>> pvar in
|
||||
let* cs = pcons_tail in
|
||||
pret (List.map (fun c -> Var (v, c)) cs) )
|
||||
|||
|
||||
( pre "(tmp)" |>> pret [Atm Tmp] ) |||
|
||||
( let* v = pre "(tmp" |>> pws1 |>> pvar in
|
||||
pws |>> pre ")" |>> pret [Var (v, Tmp)] )
|
||||
|||
|
||||
( let* (op, rands) =
|
||||
plist1p (pws |>> pop) (papp ppat vs) in
|
||||
let nrands = List.length rands in
|
||||
if nrands < 2 then
|
||||
pfail ( "binary op requires at least"
|
||||
^ " two arguments" )
|
||||
else
|
||||
let mk x y = Bnr (op, x, y) in
|
||||
pret
|
||||
(products rands []
|
||||
(fun rands pats ->
|
||||
(* construct a left-heavy tree *)
|
||||
let r0 = List.hd rands in
|
||||
let rs = List.tl rands in
|
||||
List.fold_left mk r0 rs :: pats)) )
|
||||
)
|
||||
|
||||
let pwith_vars ?(vs = []) p =
|
||||
( let* vs =
|
||||
pws |>> pre "(with-vars" |>> pws |>>
|
||||
plist (pws |>> pid)
|
||||
in pws |>> p vs |<< pws |<< pre ")" )
|
||||
||| p vs
|
||||
|
||||
let ppats =
|
||||
pwith_vars @@ fun vs ->
|
||||
pre "(patterns" |>> plist_tail
|
||||
(pwith_vars ~vs @@ fun vs ->
|
||||
let* n, ps = ppair pid (ppat vs) in
|
||||
pret (n, vs, ps))
|
||||
|
||||
(* ---------------------------------------- *)
|
||||
(* tests *)
|
||||
(* ---------------------------------------- *)
|
||||
|
||||
let () =
|
||||
if false then
|
||||
let show_patterns ps =
|
||||
"[" ^ String.concat "; "
|
||||
(List.map show_pattern ps) ^ "]"
|
||||
in
|
||||
let pat s =
|
||||
Printf.printf "parse %s = " s;
|
||||
let vars =
|
||||
[ "foobar"; "a"; "b"; "d"
|
||||
; "m"; "s"; "x" ]
|
||||
in
|
||||
match run_parser (ppat vars) s with
|
||||
| `Ok p ->
|
||||
Printf.printf "%s\n" (show_patterns p)
|
||||
| `Error (_, e, _) ->
|
||||
Printf.printf "ERROR: %s\n" e
|
||||
in
|
||||
pat "42";
|
||||
pat "(tmp)";
|
||||
pat "(tmp foobar)";
|
||||
pat "(con)";
|
||||
pat "(con 1 2 3)";
|
||||
pat "(con x 1 2 3)";
|
||||
pat "(add 1 2)";
|
||||
pat "(add 1 2 3 4)";
|
||||
pat "(sub 1 2)";
|
||||
pat "(sub 1 2 3)";
|
||||
pat "(tmp unbound_var)";
|
||||
pat "(add 0)";
|
||||
pat "(add 1 (add 2 3))";
|
||||
pat "(add (tmp a) (con d))";
|
||||
pat "(add (tmp b) (mul (tmp m) (con s 2 4 8)))";
|
||||
pat "(add (con 1 2) (con 3 4))";
|
||||
()
|
||||
@@ -0,0 +1,134 @@
|
||||
open Match
|
||||
open Fuzz
|
||||
open Cgen
|
||||
|
||||
(* unit tests *)
|
||||
|
||||
let test_pattern_match =
|
||||
let pm = pattern_match
|
||||
and nm = fun x y -> not (pattern_match x y) in
|
||||
begin
|
||||
assert (nm (Atm Tmp) (Atm (Con 42L)));
|
||||
assert (pm (Atm AnyCon) (Atm (Con 42L)));
|
||||
assert (nm (Atm (Con 42L)) (Atm AnyCon));
|
||||
assert (nm (Atm (Con 42L)) (Atm Tmp));
|
||||
end
|
||||
|
||||
let test_peel =
|
||||
let o = Kw, Oadd in
|
||||
let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
|
||||
Atm (Con 42L)) in
|
||||
let l = peel p () in
|
||||
let () = assert (List.length l = 3) in
|
||||
let atomic_p (p, _) =
|
||||
match p with Atm _ -> true | _ -> false in
|
||||
let () = assert (List.for_all atomic_p l) in
|
||||
let l = List.map (fun (p, c) -> fold_cursor c p) l in
|
||||
let () = assert (List.for_all ((=) p) l) in
|
||||
()
|
||||
|
||||
let test_fold_pairs =
|
||||
let l = [1; 2; 3; 4; 5] in
|
||||
let p = fold_pairs l l [] (fun a b -> a :: b) in
|
||||
let () = assert (List.length p = 25) in
|
||||
let p = sort_uniq compare p in
|
||||
let () = assert (List.length p = 25) in
|
||||
()
|
||||
|
||||
(* test pattern & state *)
|
||||
|
||||
let print_sm oc =
|
||||
StateMap.iter (fun k s' ->
|
||||
match k with
|
||||
| K (o, sl, sr) ->
|
||||
let top =
|
||||
List.fold_left (fun top c ->
|
||||
match c with
|
||||
| Top r -> top ^ " " ^ r
|
||||
| _ -> top) "" s'.point
|
||||
in
|
||||
Printf.fprintf oc
|
||||
" (%s %d %d) -> %d%s\n"
|
||||
(show_op o)
|
||||
sl.id sr.id s'.id top)
|
||||
|
||||
let rules =
|
||||
let oa = Kl, Oadd in
|
||||
let om = Kl, Omul in
|
||||
let va = Var ("a", Tmp)
|
||||
and vb = Var ("b", Tmp)
|
||||
and vc = Var ("c", Tmp)
|
||||
and vs = Var ("s", Tmp) in
|
||||
let vars = ["a"; "b"; "c"; "s"] in
|
||||
let rule name pattern =
|
||||
List.map
|
||||
(fun pattern -> {name; vars; pattern})
|
||||
(ac_equiv pattern)
|
||||
in
|
||||
match `X64Addr with
|
||||
(* ------------------------------- *)
|
||||
| `X64Addr ->
|
||||
(* o + b *)
|
||||
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
|
||||
@ (* b + s * m *)
|
||||
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 2L), vs)))
|
||||
@
|
||||
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 4L), vs)))
|
||||
@
|
||||
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 8L), vs)))
|
||||
@ (* b + s *)
|
||||
rule "bs1" (Bnr (oa, vb, vs))
|
||||
@ (* o + s * m *)
|
||||
(* rule "osm" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp))) *) []
|
||||
@ (* o + b + s *)
|
||||
rule "obs1" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb), vs))
|
||||
@ (* o + b + s * m *)
|
||||
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
|
||||
Bnr (om, Var ("m", Con 2L), vs)))
|
||||
@
|
||||
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
|
||||
Bnr (om, Var ("m", Con 4L), vs)))
|
||||
@
|
||||
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
|
||||
Bnr (om, Var ("m", Con 8L), vs)))
|
||||
(* ------------------------------- *)
|
||||
| `Add3 ->
|
||||
[ { name = "add"
|
||||
; vars = []
|
||||
; pattern = Bnr (oa, va, Bnr (oa, vb, vc)) } ] @
|
||||
[ { name = "add"
|
||||
; vars = []
|
||||
; pattern = Bnr (oa, Bnr (oa, va, vb), vc) } ]
|
||||
|
||||
(*
|
||||
|
||||
let sa, am, sm = generate_table rules
|
||||
let () =
|
||||
Array.iteri (fun i s ->
|
||||
Format.printf "@[state %d: %s@]@."
|
||||
i (show_pattern s.seen))
|
||||
sa
|
||||
let () = print_sm stdout sm; flush stdout
|
||||
|
||||
let matcher = lr_matcher sm sa rules "obsm" (* XXX *)
|
||||
let () = Format.printf "@[<v>%a@]@." Action.pp matcher
|
||||
let () = Format.printf "@[matcher size: %d@]@." (Action.size matcher)
|
||||
|
||||
let numbr = make_numberer sa am sm
|
||||
|
||||
let () =
|
||||
let opts = { pfx = ""
|
||||
; static = true
|
||||
; oc = stdout } in
|
||||
emit_c opts numbr;
|
||||
emit_matchers opts
|
||||
[ ( ["b"; "o"; "s"; "m"]
|
||||
, "obsm"
|
||||
, matcher ) ]
|
||||
|
||||
(*
|
||||
let tp = fuzz_numberer rules numbr
|
||||
let () = test_matchers tp numbr rules
|
||||
*)
|
||||
|
||||
*)
|
||||
Reference in New Issue
Block a user