This commit is contained in:
2026-06-11 10:59:54 -06:00
commit 8650a71f67
159 changed files with 78653 additions and 0 deletions
+3
View File
@@ -0,0 +1,3 @@
*.cm[iox]
*.o
mgen
+1
View File
@@ -0,0 +1 @@
match_clause=4
+16
View File
@@ -0,0 +1,16 @@
BIN = mgen
SRC = \
match.ml \
fuzz.ml \
cgen.ml \
sexp.ml \
test.ml \
main.ml
$(BIN): $(SRC)
ocamlopt -o $(BIN) -g str.cmxa $(SRC)
clean:
rm -f *.cm? *.o $(BIN)
.PHONY: clean
+420
View File
@@ -0,0 +1,420 @@
open Match
type options =
{ pfx: string
; static: bool
; oc: out_channel }
type side = L | R
type id_pred =
| InBitSet of Int64.t
| Ge of int
| Eq of int
and id_test =
| Pred of (side * id_pred)
| And of id_test * id_test
type case_code =
| Table of ((int * int) * int) list
| IfThen of
{ test: id_test
; cif: case_code
; cthen: case_code option }
| Return of int
type case =
{ swap: bool
; code: case_code }
let cgen_case tmp nstates map =
let cgen_test ids =
match ids with
| [id] -> Eq id
| _ ->
let min_id =
List.fold_left min max_int ids in
if List.length ids = nstates - min_id
then Ge min_id
else begin
assert (nstates <= 64);
InBitSet
(List.fold_left (fun bs id ->
Int64.logor bs
(Int64.shift_left 1L id))
0L ids)
end
in
let symmetric =
let inverse ((l, r), x) = ((r, l), x) in
setify map = setify (List.map inverse map) in
let map =
let ordered ((l, r), _) = r <= l in
if symmetric then
List.filter ordered map
else map
in
let exception BailToTable in
try
let st =
match setify (List.map snd map) with
| [st] -> st
| _ -> raise BailToTable
in
(* the operation considered can only
* generate a single state *)
let pairs = List.map fst map in
let ls, rs = List.split pairs in
let ls = setify ls and rs = setify rs in
if List.length ls > 1 && List.length rs > 1 then
raise BailToTable;
{ swap = symmetric
; code =
let pl = Pred (L, cgen_test ls)
and pr = Pred (R, cgen_test rs) in
IfThen
{ test = And (pl, pr)
; cif = Return st
; cthen = Some (Return tmp) } }
with BailToTable ->
{ swap = symmetric
; code = Table map }
let show_op (_cls, op) =
"O" ^ show_op_base op
let indent oc i =
Printf.fprintf oc "%s" (String.sub "\t\t\t\t\t" 0 i)
let emit_swap oc i =
let pf m = Printf.fprintf oc m in
let pfi n m = indent oc n; pf m in
pfi i "if (l < r)\n";
pfi (i+1) "t = l, l = r, r = t;\n"
let gen_tables oc tmp pfx nstates (op, c) =
let i = 1 in
let pf m = Printf.fprintf oc m in
let pfi n m = indent oc n; pf m in
let ntables = ref 0 in
(* we must follow the order in which
* we visit code in emit_case, or
* else ntables goes out of sync *)
let base = pfx ^ show_op op in
let swap = c.swap in
let rec gen c =
match c with
| Table map ->
let name =
if !ntables = 0 then base else
base ^ string_of_int !ntables
in
assert (nstates <= 256);
if swap then
let n = nstates * (nstates + 1) / 2 in
pfi i "static uchar %stbl[%d] = {\n" name n
else
pfi i "static uchar %stbl[%d][%d] = {\n"
name nstates nstates;
for l = 0 to nstates - 1 do
pfi (i+1) "";
for r = 0 to nstates - 1 do
if not swap || r <= l then
begin
pf "%d"
(try List.assoc (l,r) map
with Not_found -> tmp);
pf ",";
end
done;
pf "\n";
done;
pfi i "};\n"
| IfThen {cif; cthen} ->
gen cif;
Option.iter gen cthen
| Return _ -> ()
in
gen c.code
let emit_case oc pfx no_swap (op, c) =
let fpf = Printf.fprintf in
let pf m = fpf oc m in
let pfi n m = indent oc n; pf m in
let rec side oc = function
| L -> fpf oc "l"
| R -> fpf oc "r"
in
let pred oc (s, pred) =
match pred with
| InBitSet bs -> fpf oc "BIT(%a) & %#Lx" side s bs
| Eq id -> fpf oc "%a == %d" side s id
| Ge id -> fpf oc "%d <= %a" id side s
in
let base = pfx ^ show_op op in
let swap = c.swap in
let ntables = ref 0 in
let rec code i c =
match c with
| Return id -> pfi i "return %d;\n" id
| Table map ->
let name =
if !ntables = 0 then base else
base ^ string_of_int !ntables
in
incr ntables;
if swap then
pfi i "return %stbl[(l + l*l)/2 + r];\n" name
else pfi i "return %stbl[l][r];\n" name
| IfThen ({test = And (And (t1, t2), t3)} as r) ->
code i @@ IfThen
{r with test = And (t1, And (t2, t3))}
| IfThen {test = And (Pred p, t); cif; cthen} ->
pfi i "if (%a)\n" pred p;
code i (IfThen {test = t; cif; cthen})
| IfThen {test = Pred p; cif; cthen} ->
pfi i "if (%a) {\n" pred p;
code (i+1) cif;
pfi i "}\n";
Option.iter (code i) cthen
in
pfi 1 "case %s:\n" (show_op op);
if not no_swap && c.swap then
emit_swap oc 2;
code 2 c.code
let emit_list
?(limit=60) ?(cut_before_sep=false)
~col ~indent:i ~sep ~f oc l =
let sl = String.length sep in
let rstripped_sep, rssl =
if sep.[sl - 1] = ' ' then
String.sub sep 0 (sl - 1), sl - 1
else sep, sl
in
let lstripped_sep, lssl =
if sep.[0] = ' ' then
String.sub sep 1 (sl - 1), sl - 1
else sep, sl
in
let rec line col acc = function
| [] -> (List.rev acc, [])
| s :: l ->
let col = col + sl + String.length s in
let no_space =
if cut_before_sep || l = [] then
col > limit
else
col + rssl > limit
in
if no_space then
(List.rev acc, s :: l)
else
line col (s :: acc) l
in
let rec go col l =
if l = [] then () else
let ll, l = line col [] l in
Printf.fprintf oc "%s" (String.concat sep ll);
if l <> [] && cut_before_sep then begin
Printf.fprintf oc "\n";
indent oc i;
Printf.fprintf oc "%s" lstripped_sep;
go (8*i + lssl) l
end else if l <> [] then begin
Printf.fprintf oc "%s\n" rstripped_sep;
indent oc i;
go (8*i) l
end else ()
in
go col (List.map f l)
let emit_numberer opts n =
let pf m = Printf.fprintf opts.oc m in
let tmp = (atom_state n Tmp).id in
let con = (atom_state n AnyCon).id in
let nst = Array.length n.states in
let cases =
StateMap.by_ops n.statemap |>
List.map (fun (op, map) ->
(op, cgen_case tmp nst map))
in
let all_swap =
List.for_all (fun (_, c) -> c.swap) cases in
(* opn() *)
if opts.static then pf "static ";
pf "int\n";
pf "%sopn(int op, int l, int r)\n" opts.pfx;
pf "{\n";
cases |> List.iter
(gen_tables opts.oc tmp opts.pfx nst);
if List.exists (fun (_, c) -> c.swap) cases then
pf "\tint t;\n\n";
if all_swap then emit_swap opts.oc 1;
pf "\tswitch (op) {\n";
cases |> List.iter
(emit_case opts.oc opts.pfx all_swap);
pf "\tdefault:\n";
pf "\t\treturn %d;\n" tmp;
pf "\t}\n";
pf "}\n\n";
(* refn() *)
if opts.static then pf "static ";
pf "int\n";
pf "%srefn(Ref r, Num *tn, Con *con)\n" opts.pfx;
pf "{\n";
let cons =
List.filter_map (function
| (Con c, s) -> Some (c, s.id)
| _ -> None)
n.atoms
in
if cons <> [] then
pf "\tint64_t n;\n\n";
pf "\tswitch (rtype(r)) {\n";
pf "\tcase RTmp:\n";
if tmp <> 0 then begin
assert
(List.exists (fun (_, s) ->
s.id = 0
) n.atoms &&
(* no temp should ever get state 0 *)
List.for_all (fun (a, s) ->
s.id <> 0 ||
match a with
| AnyCon | Con _ -> true
| _ -> false
) n.atoms);
pf "\t\tif (!tn[r.val].n)\n";
pf "\t\t\ttn[r.val].n = %d;\n" tmp;
end;
pf "\t\treturn tn[r.val].n;\n";
pf "\tcase RCon:\n";
if cons <> [] then begin
pf "\t\tif (con[r.val].type != CBits)\n";
pf "\t\t\treturn %d;\n" con;
pf "\t\tn = con[r.val].bits.i;\n";
cons |> inverse |> group_by_fst
|> List.iter (fun (id, cs) ->
pf "\t\tif (";
emit_list ~cut_before_sep:true
~col:20 ~indent:2 ~sep:" || "
~f:(fun c -> "n == " ^ Int64.to_string c)
opts.oc cs;
pf ")\n";
pf "\t\t\treturn %d;\n" id
);
end;
pf "\t\treturn %d;\n" con;
pf "\tdefault:\n";
pf "\t\treturn INT_MIN;\n";
pf "\t}\n";
pf "}\n\n";
(* match[]: patterns per state *)
if opts.static then pf "static ";
pf "bits %smatch[%d] = {\n" opts.pfx nst;
n.states |> Array.iteri (fun sn s ->
let tops =
List.filter_map (function
| Top ("$" | "%") -> None
| Top r -> Some ("BIT(P" ^ r ^ ")")
| _ -> None) s.point |> setify
in
if tops <> [] then
pf "\t[%d] = %s,\n"
sn (String.concat " | " tops);
);
pf "};\n\n"
let var_id vars f =
List.mapi (fun i x -> (x, i)) vars |>
List.assoc f
let compile_action vars act =
let pcs = Hashtbl.create 100 in
let rec gen pc (act: Action.t) =
try
[10 + Hashtbl.find pcs act.id]
with Not_found ->
let code =
match act.node with
| Action.Stop ->
[0]
| Action.Push (sym, k) ->
let c = if sym then 1 else 2 in
[c] @ gen (pc + 1) k
| Action.Set (v, {node = Action.Pop k; _})
| Action.Set (v, ({node = Action.Stop; _} as k)) ->
let v = var_id vars v in
[3; v] @ gen (pc + 2) k
| Action.Set _ ->
(* for now, only atomic patterns can be
* tied to a variable, so Set must be
* followed by either Pop or Stop *)
assert false
| Action.Pop k ->
[4] @ gen (pc + 1) k
| Action.Switch cases ->
let cases =
inverse cases |> group_by_fst |>
List.sort (fun (_, cs1) (_, cs2) ->
let n1 = List.length cs1
and n2 = List.length cs2 in
compare n2 n1)
in
(* the last case is the one with
* the max number of entries *)
let cases = List.rev (List.tl cases)
and last = fst (List.hd cases) in
let ncases =
List.fold_left (fun n (_, cs) ->
List.length cs + n)
0 cases
in
let body_off = 2 + 2 * ncases + 1 in
let pc, tbl, body =
List.fold_left
(fun (pc, tbl, body) (a, cs) ->
let ofs = body_off + List.length body in
let case = gen pc a in
let pc = pc + List.length case in
let body = body @ case in
let tbl =
List.fold_left (fun tbl c ->
tbl @ [c; ofs]
) tbl cs
in
(pc, tbl, body))
(pc + body_off, [], [])
cases
in
let ofs = body_off + List.length body in
let tbl = tbl @ [ofs] in
assert (2 + List.length tbl = body_off);
[5; ncases] @ tbl @ body @ gen pc last
in
if act.node <> Action.Stop then
Hashtbl.replace pcs act.id pc;
code
in
gen 0 act
let emit_matchers opts ms =
let pf m = Printf.fprintf opts.oc m in
if opts.static then pf "static ";
pf "uchar *%smatcher[] = {\n" opts.pfx;
List.iter (fun (vars, pname, m) ->
pf "\t[P%s] = (uchar[]){\n" pname;
pf "\t\t";
let bytes = compile_action vars m in
emit_list
~col:16 ~indent:2 ~sep:","
~f:string_of_int opts.oc bytes;
pf "\n";
pf "\t},\n")
ms;
pf "};\n\n"
let emit_c opts n =
emit_numberer opts n
+413
View File
@@ -0,0 +1,413 @@
(* fuzz the tables and matchers generated *)
open Match
module Buffer: sig
type 'a t
val create: ?capacity:int -> unit -> 'a t
val reset: 'a t -> unit
val size: 'a t -> int
val get: 'a t -> int -> 'a
val set: 'a t -> int -> 'a -> unit
val push: 'a t -> 'a -> unit
end = struct
type 'a t =
{ mutable size: int
; mutable data: 'a array }
let mk_array n = Array.make n (Obj.magic 0)
let create ?(capacity = 10) () =
if capacity < 0 then invalid_arg "Buffer.make";
{size = 0; data = mk_array capacity}
let reset b = b.size <- 0
let size b = b.size
let get b n =
if n >= size b then invalid_arg "Buffer.get";
b.data.(n)
let set b n x =
if n >= size b then invalid_arg "Buffer.set";
b.data.(n) <- x
let push b x =
let cap = Array.length b.data in
if size b = cap then begin
let data = mk_array (2 * cap + 1) in
Array.blit b.data 0 data 0 cap;
b.data <- data
end;
let sz = size b in
b.size <- sz + 1;
set b sz x
end
let binop_state n op s1 s2 =
let key = K (op, s1, s2) in
try StateMap.find key n.statemap
with Not_found -> atom_state n Tmp
type id = int
type term_data =
| Binop of op * id * id
| Leaf of atomic_pattern
type term =
{ id: id
; data: term_data
; state: p state }
let pp_term fmt (ta, id) =
let fpf x = Format.fprintf fmt x in
let rec pp _fmt id =
match ta.(id).data with
| Leaf (Con c) -> fpf "%Ld" c
| Leaf AnyCon -> fpf "$%d" id
| Leaf Tmp -> fpf "%%%d" id
| Binop (op, id1, id2) ->
fpf "@[(%s@%d:%d @[<hov>%a@ %a@])@]"
(show_op op) id ta.(id).state.id
pp id1 pp id2
in pp fmt id
(* A term pool is a deduplicated set of term
* that maintains nodes numbering using the
* statemap passed at creation time *)
module TermPool = struct
type t =
{ terms: term Buffer.t
; hcons: (term_data, id) Hashtbl.t
; numbr: numberer }
let create numbr =
{ terms = Buffer.create ()
; hcons = Hashtbl.create 100
; numbr }
let reset tp =
Buffer.reset tp.terms;
Hashtbl.clear tp.hcons
let size tp = Buffer.size tp.terms
let term tp id = Buffer.get tp.terms id
let mk_leaf tp atm =
let data = Leaf atm in
match Hashtbl.find tp.hcons data with
| id -> term tp id
| exception Not_found ->
let id = Buffer.size tp.terms in
let state = atom_state tp.numbr atm in
Buffer.push tp.terms {id; data; state};
Hashtbl.add tp.hcons data id;
term tp id
let mk_binop tp op t1 t2 =
let data = Binop (op, t1.id, t2.id) in
match Hashtbl.find tp.hcons data with
| id -> term tp id
| exception Not_found ->
let id = Buffer.size tp.terms in
let state =
binop_state tp.numbr op t1.state t2.state
in
Buffer.push tp.terms {id; data; state};
Hashtbl.add tp.hcons data id;
term tp id
let rec add_pattern tp = function
| Bnr (op, p1, p2) ->
let t1 = add_pattern tp p1 in
let t2 = add_pattern tp p2 in
mk_binop tp op t1 t2
| Atm atm -> mk_leaf tp atm
| Var (_, atm) -> add_pattern tp (Atm atm)
let explode_term tp id =
let rec aux tms n id =
let t = term tp id in
match t.data with
| Leaf _ -> (n, {t with id = n} :: tms)
| Binop (op, id1, id2) ->
let n1, tms = aux tms n id1 in
let n = n1 + 1 in
let n2, tms = aux tms n id2 in
let n = n2 + 1 in
(n, { t with data = Binop (op, n1, n2)
; id = n } :: tms)
in
let n, tms = aux [] 0 id in
Array.of_list (List.rev tms), n
end
module R = Random
(* uniform pick in a list *)
let list_pick l =
let rec aux n l x =
match l with
| [] -> x
| y :: l ->
if R.int (n + 1) = 0 then
aux (n + 1) l y
else
aux (n + 1) l x
in
match l with
| [] -> invalid_arg "list_pick"
| x :: l -> aux 1 l x
let term_pick ~numbr =
let ops =
if numbr.ops = [] then
numbr.ops <-
(StateMap.fold (fun k _ ops ->
match k with
| K (op, _, _) -> op :: ops)
numbr.statemap [] |> setify);
numbr.ops
in
let rec gen depth =
(* exponential probability for leaves to
* avoid skewing towards shallow terms *)
let atm_prob = 0.75 ** float_of_int depth in
if R.float 1.0 <= atm_prob || ops = [] then
let atom, st = list_pick numbr.atoms in
(st, Atm atom)
else
let op = list_pick ops in
let s1, t1 = gen (depth - 1) in
let s2, t2 = gen (depth - 1) in
( binop_state numbr op s1 s2
, Bnr (op, t1, t2) )
in fun ~depth -> gen depth
exception FuzzError
let rec pattern_depth = function
| Bnr (_, p1, p2) ->
1 + max (pattern_depth p1) (pattern_depth p2)
| Atm _ -> 0
| Var (_, atm) -> pattern_depth (Atm atm)
let ( %% ) a b =
1e2 *. float_of_int a /. float_of_int b
let progress ?(width = 50) msg pct =
Format.eprintf "\x1b[2K\r%!";
let progress_bar fmt =
let n =
let fwidth = float_of_int width in
1 + int_of_float (pct *. fwidth /. 1e2)
in
Format.fprintf fmt " %s%s %.0f%%@?"
(String.concat "" (List.init n (fun _ -> "")))
(String.make (max 0 (width - n)) '-')
pct
in
Format.kfprintf progress_bar
Format.err_formatter msg
let fuzz_numberer rules numbr =
(* pick twice the max pattern depth so we
* have a chance to find non-trivial numbers
* for the atomic patterns in the rules *)
let depth =
List.fold_left (fun depth r ->
max depth (pattern_depth r.pattern))
0 rules * 2
in
(* fuzz until the term pool we are constructing
* is no longer growing fast enough; or we just
* went through sufficiently many iterations *)
let max_iter = 1_000_000 in
let low_insert_rate = 1e-2 in
let tp = TermPool.create numbr in
let rec loop new_stats i =
let (_, _, insert_rate) = new_stats in
if insert_rate <= low_insert_rate then () else
if i >= max_iter then () else
(* periodically update stats *)
let new_stats =
let (num, cnt, rate) = new_stats in
if num land 1023 = 0 then
let rate =
0.5 *. (rate +. float_of_int cnt /. 1023.)
in
progress " insert_rate=%.1f%%"
(i %% max_iter) (rate *. 1e2);
(num + 1, 0, rate)
else new_stats
in
(* create a term and check that its number is
* accurate wrt the rules *)
let st, term = term_pick ~numbr ~depth in
let state_matched =
List.filter_map (fun cu ->
match cu with
| Top ("$" | "%") -> None
| Top name -> Some name
| _ -> None)
st.point |> setify
in
let rule_matched =
List.filter_map (fun r ->
if pattern_match r.pattern term then
Some r.name
else None)
rules |> setify
in
if state_matched <> rule_matched then begin
let open Format in
let pp_str_list =
let pp_sep fmt () = fprintf fmt ",@ " in
pp_print_list ~pp_sep pp_print_string
in
eprintf "@.@[<v2>fuzz error for %s"
(show_pattern term);
eprintf "@ @[state matched: %a@]"
pp_str_list state_matched;
eprintf "@ @[rule matched: %a@]"
pp_str_list rule_matched;
eprintf "@]@.";
raise FuzzError;
end;
if state_matched = [] then
loop new_stats (i + 1)
else
(* add to the term pool *)
let old_size = TermPool.size tp in
let _ = TermPool.add_pattern tp term in
let new_stats =
let (num, cnt, rate) = new_stats in
if TermPool.size tp <> old_size then
(num + 1, cnt + 1, rate)
else
(num + 1, cnt, rate)
in
loop new_stats (i + 1)
in
loop (1, 0, 1.0) 0;
Format.eprintf
"@.@[ generated %.3fMiB of test terms@]@."
(float_of_int (Obj.reachable_words (Obj.repr tp))
/. 128. /. 1024.);
tp
let rec run_matcher stk m (ta, id as t) =
let state id = ta.(id).state.id in
match m.Action.node with
| Action.Switch cases ->
let m =
try List.assoc (state id) cases
with Not_found -> failwith "no switch case"
in
run_matcher stk m t
| Action.Push (sym, m) ->
let l, r =
match ta.(id).data with
| Leaf _ -> failwith "push on leaf"
| Binop (_, l, r) -> (l, r)
in
if sym && state l > state r
then run_matcher (l :: stk) m (ta, r)
else run_matcher (r :: stk) m (ta, l)
| Action.Pop m -> begin
match stk with
| id :: stk -> run_matcher stk m (ta, id)
| [] -> failwith "pop on empty stack"
end
| Action.Set (v, m) ->
(v, id) :: run_matcher stk m t
| Action.Stop -> []
let rec term_match p (ta, id) =
let (|>>) x f =
match x with None -> None | Some x -> f x
in
let atom_match a =
match ta.(id).data with
| Leaf a' -> pattern_match (Atm a) (Atm a')
| Binop _ -> pattern_match (Atm a) (Atm Tmp)
in
match p with
| Var (v, a) when atom_match a ->
Some [(v, id)]
| Atm a when atom_match a -> Some []
| (Atm _ | Var _) -> None
| Bnr (op, pl, pr) -> begin
match ta.(id).data with
| Binop (op', idl, idr) when op' = op ->
term_match pl (ta, idl) |>> fun l1 ->
term_match pr (ta, idr) |>> fun l2 ->
Some (l1 @ l2)
| _ -> None
end
let test_matchers tp numbr rules =
let {statemap = sm; states = sa; _} = numbr in
let total = ref 0 in
let matchers =
let htbl = Hashtbl.create (Array.length sa) in
List.map (fun r -> (r.name, r.pattern)) rules |>
group_by_fst |>
List.iter (fun (r, ps) ->
total := !total + List.length ps;
let pm = (ps, lr_matcher sm sa rules r) in
sa |> Array.iter (fun s ->
if List.mem (Top r) s.point then
Hashtbl.add htbl s.id pm));
htbl
in
let seen = Hashtbl.create !total in
for id = 0 to TermPool.size tp - 1 do
if id land 1023 = 0 ||
id = TermPool.size tp - 1 then begin
progress
" coverage=%.1f%%"
(id %% TermPool.size tp)
(Hashtbl.length seen %% !total)
end;
let t = TermPool.explode_term tp id in
Hashtbl.find_all matchers
(TermPool.term tp id).state.id |>
List.iter (fun (ps, m) ->
let norm = List.fast_sort compare in
let ok =
match norm (run_matcher [] m t) with
| asn -> `Match (List.exists (fun p ->
match term_match p t with
| None -> false
| Some asn' ->
if asn = norm asn' then begin
Hashtbl.replace seen p ();
true
end else false) ps)
| exception e -> `RunFailure e
in
if ok <> `Match true then begin
let open Format in
let pp_asn fmt asn =
fprintf fmt "@[<h>";
pp_print_list
~pp_sep:(fun fmt () -> fprintf fmt ";@ ")
(fun fmt (v, d) ->
fprintf fmt "@[%s←%d@]" v d)
fmt asn;
fprintf fmt "@]"
in
eprintf "@.@[<v2>matcher error for";
eprintf "@ @[%a@]" pp_term t;
begin match ok with
| `RunFailure e ->
eprintf "@ @[exception: %s@]"
(Printexc.to_string e)
| `Match (* false *) _ ->
let asn = run_matcher [] m t in
eprintf "@ @[assignment: %a@]"
pp_asn asn;
eprintf "@ @[<v2>could not match";
List.iter (fun p ->
eprintf "@ + @[%s@]"
(show_pattern p)) ps;
eprintf "@]"
end;
eprintf "@]@.";
raise FuzzError
end)
done;
Format.eprintf "@."
+214
View File
@@ -0,0 +1,214 @@
open Cgen
open Match
let mgen ~verbose ~fuzz path lofs input oc =
let info ?(level = 1) fmt =
if level <= verbose then
Printf.eprintf fmt
else
Printf.ifprintf stdout fmt
in
let rules =
match Sexp.(run_parser ppats) input with
| `Error (ps, err, loc) ->
Printf.eprintf "%s:%d:%d %s\n"
path (lofs + ps.Sexp.line) ps.Sexp.coln err;
Printf.eprintf "%s" loc;
exit 1
| `Ok rules -> rules
in
info "adding ac variants...%!";
let nparsed =
List.fold_left
(fun npats (_, _, ps) ->
npats + List.length ps)
0 rules
in
let varsmap = Hashtbl.create 10 in
let rules =
List.concat_map (fun (name, vars, patterns) ->
(try assert (Hashtbl.find varsmap name = vars)
with Not_found -> ());
Hashtbl.replace varsmap name vars;
List.map
(fun pattern -> {name; vars; pattern})
(List.concat_map ac_equiv patterns)
) rules
in
info " %d -> %d patterns\n"
nparsed (List.length rules);
let rnames =
setify (List.map (fun r -> r.name) rules) in
info "generating match tables...%!";
let sa, am, sm = generate_table rules in
let numbr = make_numberer sa am sm in
info " %d states, %d rules\n"
(Array.length sa) (StateMap.cardinal sm);
if verbose >= 2 then begin
info "-------------\nstates:\n";
Array.iteri (fun i s ->
info " state %d: %s\n"
i (show_pattern s.seen)) sa;
info "-------------\nstatemap:\n";
Test.print_sm stderr sm;
info "-------------\n";
end;
info "generating matchers...\n";
let matchers =
List.map (fun rname ->
info "+ %s...%!" rname;
let m = lr_matcher sm sa rules rname in
let vars = Hashtbl.find varsmap rname in
info " %d nodes\n" (Action.size m);
info ~level:2 " -------------\n";
info ~level:2 " automaton:\n";
info ~level:2 "%s\n"
(Format.asprintf " @[%a@]" Action.pp m);
info ~level:2 " ----------\n";
(vars, rname, m)
) rnames
in
if fuzz then begin
info ~level:0 "fuzzing statemap...\n";
let tp = Fuzz.fuzz_numberer rules numbr in
info ~level:0 "testing %d patterns...\n"
(List.length rules);
Fuzz.test_matchers tp numbr rules
end;
info "emitting C...\n";
flush stderr;
let cgopts =
{ pfx = ""; static = true; oc = oc } in
emit_c cgopts numbr;
emit_matchers cgopts matchers;
()
let read_all ic =
let bufsz = 4096 in
let buf = Bytes.create bufsz in
let data = Buffer.create bufsz in
let read = ref 0 in
while
read := input ic buf 0 bufsz;
!read <> 0
do
Buffer.add_subbytes data buf 0 !read
done;
Buffer.contents data
let split_c src =
let begin_re, eoc_re, end_re =
let re = Str.regexp in
( re "mgen generated code"
, re "\\*/"
, re "end of generated code" )
in
let str_match regexp str =
try
let _: int =
Str.search_forward regexp str 0
in true
with Not_found -> false
in
let rec go st lofs pfx rules lines =
let line, lines =
match lines with
| [] ->
failwith (
match st with
| `Prefix -> "could not find mgen section"
| `Rules -> "mgen rules not terminated"
| `Skip -> "mgen section not terminated"
)
| l :: ls -> (l, ls)
in
match st with
| `Prefix ->
let pfx = line :: pfx in
if str_match begin_re line
then
let lofs = List.length pfx in
go `Rules lofs pfx rules lines
else go `Prefix 0 pfx rules lines
| `Rules ->
let pfx = line :: pfx in
if str_match eoc_re line
then go `Skip lofs pfx rules lines
else go `Rules lofs pfx (line :: rules) lines
| `Skip ->
if str_match end_re line then
let join = String.concat "\n" in
let pfx = join (List.rev pfx) ^ "\n\n"
and rules = join (List.rev rules)
and sfx = join (line :: lines)
in (lofs, pfx, rules, sfx)
else go `Skip lofs pfx rules lines
in
let lines = String.split_on_char '\n' src in
go `Prefix 0 [] [] lines
let () =
let usage_msg =
"mgen [--fuzz] [--verbose <N>] <file>" in
let fuzz_arg = ref false in
let verbose_arg = ref 0 in
let input_paths = ref [] in
let anon_fun filename =
input_paths := filename :: !input_paths in
let speclist =
[ ( "--fuzz", Arg.Set fuzz_arg
, " Fuzz tables and matchers" )
; ( "--verbose", Arg.Set_int verbose_arg
, "<N> Set verbosity level" )
; ( "--", Arg.Rest_all (List.iter anon_fun)
, " Stop argument parsing" ) ]
in
Arg.parse speclist anon_fun usage_msg;
let input_paths = !input_paths in
let verbose = !verbose_arg in
let fuzz = !fuzz_arg in
let input_path, input =
match input_paths with
| ["-"] -> ("-", read_all stdin)
| [path] -> (path, read_all (open_in path))
| _ ->
Printf.eprintf
"%s: single input file expected\n"
Sys.argv.(0);
Arg.usage speclist usage_msg; exit 1
in
let mgen = mgen ~verbose ~fuzz in
if Str.last_chars input_path 2 <> ".c"
then mgen input_path 0 input stdout
else
let tmp_path = input_path ^ ".tmp" in
Fun.protect
~finally:(fun () ->
try Sys.remove tmp_path with _ -> ())
(fun () ->
let lofs, pfx, rules, sfx = split_c input in
let oc = open_out tmp_path in
output_string oc pfx;
mgen input_path lofs rules oc;
output_string oc sfx;
close_out oc;
Sys.rename tmp_path input_path;
());
()
+651
View File
@@ -0,0 +1,651 @@
type cls = Kw | Kl | Ks | Kd
type op_base =
| Oadd
| Osub
| Omul
| Oor
| Oshl
| Oshr
type op = cls * op_base
let op_bases =
[Oadd; Osub; Omul; Oor; Oshl; Oshr]
let commutative = function
| (_, (Oadd | Omul | Oor)) -> true
| (_, _) -> false
let associative = function
| (_, (Oadd | Omul | Oor)) -> true
| (_, _) -> false
type atomic_pattern =
| Tmp
| AnyCon
| Con of int64
(* Tmp < AnyCon < Con k *)
type pattern =
| Bnr of op * pattern * pattern
| Atm of atomic_pattern
| Var of string * atomic_pattern
let is_atomic = function
| (Atm _ | Var _) -> true
| _ -> false
let show_op_base o =
match o with
| Oadd -> "add"
| Osub -> "sub"
| Omul -> "mul"
| Oor -> "or"
| Oshl -> "shl"
| Oshr -> "shr"
let show_op (k, o) =
show_op_base o ^
(match k with
| Kw -> "w"
| Kl -> "l"
| Ks -> "s"
| Kd -> "d")
let rec show_pattern p =
match p with
| Atm Tmp -> "%"
| Atm AnyCon -> "$"
| Atm (Con n) -> Int64.to_string n
| Var (v, p) ->
show_pattern (Atm p) ^ "'" ^ v
| Bnr (o, pl, pr) ->
"(" ^ show_op o ^
" " ^ show_pattern pl ^
" " ^ show_pattern pr ^ ")"
let get_atomic p =
match p with
| (Atm a | Var (_, a)) -> Some a
| _ -> None
let rec pattern_match p w =
match p with
| Var (_, p) ->
pattern_match (Atm p) w
| Atm Tmp ->
begin match get_atomic w with
| Some (Con _ | AnyCon) -> false
| _ -> true
end
| Atm (Con _) -> w = p
| Atm (AnyCon) ->
not (pattern_match (Atm Tmp) w)
| Bnr (o, pl, pr) ->
begin match w with
| Bnr (o', wl, wr) ->
o' = o &&
pattern_match pl wl &&
pattern_match pr wr
| _ -> false
end
type +'a cursor = (* a position inside a pattern *)
| Bnrl of op * 'a cursor * pattern
| Bnrr of op * pattern * 'a cursor
| Top of 'a
let rec fold_cursor c p =
match c with
| Bnrl (o, c', p') -> fold_cursor c' (Bnr (o, p, p'))
| Bnrr (o, p', c') -> fold_cursor c' (Bnr (o, p', p))
| Top _ -> p
let peel p x =
let once out (p, c) =
match p with
| Var (_, p) -> (Atm p, c) :: out
| Atm _ -> (p, c) :: out
| Bnr (o, pl, pr) ->
(pl, Bnrl (o, c, pr)) ::
(pr, Bnrr (o, pl, c)) :: out
in
let rec go l =
let l' = List.fold_left once [] l in
if List.length l' = List.length l
then l'
else go l'
in go [(p, Top x)]
let fold_pairs l1 l2 ini f =
let rec go acc = function
| [] -> acc
| a :: l1' ->
go (List.fold_left
(fun acc b -> f (a, b) acc)
acc l2) l1'
in go ini l1
let iter_pairs l f =
fold_pairs l l () (fun x () -> f x)
let inverse l =
List.map (fun (a, b) -> (b, a)) l
type 'a state =
{ id: int
; seen: pattern
; point: ('a cursor) list }
let rec binops side {point; _} =
List.filter_map (fun c ->
match c, side with
| Bnrl (o, c, r), `L -> Some ((o, c), r)
| Bnrr (o, l, c), `R -> Some ((o, c), l)
| _ -> None)
point
let group_by_fst l =
List.fast_sort (fun (a, _) (b, _) ->
compare a b) l |>
List.fold_left (fun (oo, l, res) (o', c) ->
match oo with
| None -> (Some o', [c], [])
| Some o when o = o' -> (oo, c :: l, res)
| Some o -> (Some o', [c], (o, l) :: res))
(None, [], []) |>
(function
| (None, _, _) -> []
| (Some o, l, res) -> (o, l) :: res)
let sort_uniq cmp l =
List.fast_sort cmp l |>
List.fold_left (fun (eo, l) e' ->
match eo with
| None -> (Some e', l)
| Some e when cmp e e' = 0 -> (eo, l)
| Some e -> (Some e', e :: l))
(None, []) |>
(function
| (None, _) -> []
| (Some e, l) -> List.rev (e :: l))
let setify l =
sort_uniq compare l
let normalize (point: ('a cursor) list) =
setify point
let next_binary tmp s1 s2 =
let pm w (_, p) = pattern_match p w in
let o1 = binops `L s1 |>
List.filter (pm s2.seen) |>
List.map fst in
let o2 = binops `R s2 |>
List.filter (pm s1.seen) |>
List.map fst in
List.map (fun (o, l) ->
o,
{ id = -1
; seen = Bnr (o, s1.seen, s2.seen)
; point = normalize (l @ tmp) })
(group_by_fst (o1 @ o2))
type p = string
module StateSet : sig
type t
val create: unit -> t
val add: t -> p state ->
[> `Added | `Found ] * p state
val iter: t -> (p state -> unit) -> unit
val elems: t -> (p state) list
end = struct
open Hashtbl.Make(struct
type t = p state
let equal s1 s2 = s1.point = s2.point
let hash s = Hashtbl.hash s.point
end)
type nonrec t =
{ h: int t
; mutable next_id: int }
let create () =
{ h = create 500; next_id = 0 }
let add set s =
assert (s.point = normalize s.point);
try
let id = find set.h s in
`Found, {s with id}
with Not_found -> begin
let id = set.next_id in
set.next_id <- id + 1;
add set.h s id;
`Added, {s with id}
end
let iter set f =
let f s id = f {s with id} in
iter f set.h
let elems set =
let res = ref [] in
iter set (fun s -> res := s :: !res);
!res
end
type table_key =
| K of op * p state * p state
module StateMap = struct
include Map.Make(struct
type t = table_key
let compare ka kb =
match ka, kb with
| K (o, sl, sr), K (o', sl', sr') ->
compare (o, sl.id, sr.id)
(o', sl'.id, sr'.id)
end)
let invert n sm =
let rmap = Array.make n [] in
iter (fun k {id; _} ->
match k with
| K (o, sl, sr) ->
rmap.(id) <-
(o, (sl.id, sr.id)) :: rmap.(id)
) sm;
Array.map group_by_fst rmap
let by_ops sm =
fold (fun tk s ops ->
match tk with
| K (op, l, r) ->
(op, ((l.id, r.id), s.id)) :: ops)
sm [] |> group_by_fst
end
type rule =
{ name: string
; vars: string list
; pattern: pattern }
let generate_table rl =
let states = StateSet.create () in
let rl =
(* these atomic patterns must occur in
* rules so that we are able to number
* all possible refs *)
[ { name = "$"; vars = []
; pattern = Atm AnyCon }
; { name = "%"; vars = []
; pattern = Atm Tmp } ] @ rl
in
(* initialize states *)
let ground =
List.concat_map
(fun r -> peel r.pattern r.name) rl |>
group_by_fst
in
let tmp = List.assoc (Atm Tmp) ground in
let con = List.assoc (Atm AnyCon) ground in
let atoms = ref [] in
let () =
List.iter (fun (seen, l) ->
let point =
if pattern_match (Atm Tmp) seen
then normalize (tmp @ l)
else normalize (con @ l)
in
let s = {id = -1; seen; point} in
let _, s = StateSet.add states s in
match get_atomic seen with
| Some atm -> atoms := (atm, s) :: !atoms
| None -> ()
) ground
in
(* setup loop state *)
let map = ref StateMap.empty in
let map_add k s' =
map := StateMap.add k s' !map
in
let flag = ref `Added in
let flagmerge = function
| `Added -> flag := `Added
| _ -> ()
in
(* iterate until fixpoint *)
while !flag = `Added do
flag := `Stop;
let statel = StateSet.elems states in
iter_pairs statel (fun (sl, sr) ->
next_binary tmp sl sr |>
List.iter (fun (o, s') ->
let flag', s' =
StateSet.add states s' in
flagmerge flag';
map_add (K (o, sl, sr)) s';
));
done;
let states =
StateSet.elems states |>
List.sort (fun s s' -> compare s.id s'.id) |>
Array.of_list
in
(states, !atoms, !map)
let intersperse x l =
let rec go left right out =
let out =
(List.rev left @ [x] @ right) ::
out in
match right with
| x :: right' ->
go (x :: left) right' out
| [] -> out
in go [] l []
let rec permute = function
| [] -> [[]]
| x :: l ->
List.concat (List.map
(intersperse x) (permute l))
(* build all binary trees with ordered
* leaves l *)
let rec bins build l =
let rec go l r out =
match r with
| [] -> out
| x :: r' ->
go (l @ [x]) r'
(fold_pairs
(bins build l)
(bins build r)
out (fun (l, r) out ->
build l r :: out))
in
match l with
| [] -> []
| [x] -> [x]
| x :: l -> go [x] l []
let products l ini f =
let rec go acc la = function
| [] -> f (List.rev la) acc
| xs :: l ->
List.fold_left (fun acc x ->
go acc (x :: la) l)
acc xs
in go ini [] l
(* combinatorial nuke... *)
let rec ac_equiv =
let rec alevel o = function
| Bnr (o', l, r) when o' = o ->
alevel o l @ alevel o r
| x -> [x]
in function
| Bnr (o, _, _) as p
when associative o ->
products
(List.map ac_equiv (alevel o p)) []
(fun choice out ->
List.concat_map
(bins (fun l r -> Bnr (o, l, r)))
(if commutative o
then permute choice
else [choice]) @ out)
| Bnr (o, l, r)
when commutative o ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) ::
Bnr (o, r, l) :: out)
| Bnr (o, l, r) ->
fold_pairs
(ac_equiv l) (ac_equiv r) []
(fun (l, r) out ->
Bnr (o, l, r) :: out)
| x -> [x]
module Action: sig
type node =
| Switch of (int * t) list
| Push of bool * t
| Pop of t
| Set of string * t
| Stop
and t = private
{ id: int; node: node }
val equal: t -> t -> bool
val size: t -> int
val stop: t
val mk_push: sym:bool -> t -> t
val mk_pop: t -> t
val mk_set: string -> t -> t
val mk_switch: int list -> (int -> t) -> t
val pp: Format.formatter -> t -> unit
end = struct
type node =
| Switch of (int * t) list
| Push of bool * t
| Pop of t
| Set of string * t
| Stop
and t =
{ id: int; node: node }
let equal a a' = a.id = a'.id
let size a =
let seen = Hashtbl.create 10 in
let rec node_size = function
| Switch l ->
List.fold_left
(fun n (_, a) -> n + size a) 0 l
| (Push (_, a) | Pop a | Set (_, a)) ->
size a
| Stop -> 0
and size {id; node} =
if Hashtbl.mem seen id
then 0
else begin
Hashtbl.add seen id ();
1 + node_size node
end
in
size a
let mk =
let hcons = Hashtbl.create 100 in
let fresh = ref 0 in
fun node ->
let id =
try Hashtbl.find hcons node
with Not_found ->
let id = !fresh in
Hashtbl.add hcons node id;
fresh := id + 1;
id
in
{id; node}
let stop = mk Stop
let mk_push ~sym a = mk (Push (sym, a))
let mk_pop a =
match a.node with
| Stop -> a
| _ -> mk (Pop a)
let mk_set v a = mk (Set (v, a))
let mk_switch ids f =
match List.map f ids with
| [] -> failwith "empty switch";
| c :: cs as cases ->
if List.for_all (equal c) cs then c
else
let cases = List.combine ids cases in
mk (Switch cases)
open Format
let rec pp_node fmt = function
| Switch l ->
fprintf fmt "@[<v>@[<v2>switch{";
let pp_case (c, a) =
let pp_sep fmt () = fprintf fmt "," in
fprintf fmt "@,@[<2>→%a:@ @[%a@]@]"
(pp_print_list ~pp_sep pp_print_int)
c pp a
in
inverse l |> group_by_fst |> inverse |>
List.iter pp_case;
fprintf fmt "@]@,}@]"
| Push (true, a) -> fprintf fmt "pushsym@ %a" pp a
| Push (false, a) -> fprintf fmt "push@ %a" pp a
| Pop a -> fprintf fmt "pop@ %a" pp a
| Set (v, a) -> fprintf fmt "set(%s)@ %a" v pp a
| Stop -> fprintf fmt ""
and pp fmt a = pp_node fmt a.node
end
(* a state is commutative if (a op b) enters
* it iff (b op a) enters it as well *)
let symmetric rmap id =
List.for_all (fun (_, l) ->
let l1, l2 =
List.filter (fun (a, b) -> a <> b) l |>
List.partition (fun (a, b) -> a < b)
in
setify l1 = setify (inverse l2))
rmap.(id)
(* left-to-right matching of a set of patterns;
* may raise if there is no lr matcher for the
* input rule *)
let lr_matcher statemap states rules name =
let rmap =
let nstates = Array.length states in
StateMap.invert nstates statemap
in
let exception Stuck in
(* the list of ids represents a class of terms
* whose root ends up being labelled with one
* such id; the gen function generates a matcher
* that will, given any such term, assign values
* for the Var nodes of one pattern in pats *)
let rec gen
: 'a. int list -> (pattern * 'a) list
-> (int -> (pattern * 'a) list -> Action.t)
-> Action.t
= fun ids pats k ->
Action.mk_switch (setify ids) @@ fun id_top ->
let sym = symmetric rmap id_top in
let id_ops =
if sym then
let ordered (a, b) = a <= b in
List.map (fun (o, l) ->
(o, List.filter ordered l))
rmap.(id_top)
else rmap.(id_top)
in
(* consider only the patterns that are
* compatible with the current id *)
let atm_pats, bin_pats =
List.filter (function
| Bnr (o, _, _), _ ->
List.exists
(fun (o', _) -> o' = o)
id_ops
| _ -> true) pats |>
List.partition
(fun (pat, _) -> is_atomic pat)
in
try
if bin_pats = [] then raise Stuck;
let pats_l =
List.map (function
| (Bnr (o, l, r), x) ->
(l, (o, x, r))
| _ -> assert false)
bin_pats
and pats_r =
List.map (fun (l, (o, x, r)) ->
(r, (o, l, x)))
and patstop =
List.map (fun (r, (o, l, x)) ->
(Bnr (o, l, r), x))
in
let id_pairs = List.concat_map snd id_ops in
let ids_l = List.map fst id_pairs
and ids_r id_left =
List.filter_map (fun (l, r) ->
if l = id_left then Some r else None)
id_pairs
in
(* match the left arm *)
Action.mk_push ~sym
(gen ids_l pats_l
@@ fun lid pats ->
(* then the right arm, considering
* only the remaining possible
* patterns and knowing that the
* left arm was numbered 'lid' *)
Action.mk_pop
(gen (ids_r lid) (pats_r pats)
@@ fun _rid pats ->
(* continue with the parent *)
k id_top (patstop pats)))
with Stuck ->
let atm_pats =
let seen = states.(id_top).seen in
List.filter (fun (pat, _) ->
pattern_match pat seen) atm_pats
in
if atm_pats = [] then raise Stuck else
let vars =
List.filter_map (function
| (Var (v, _), _) -> Some v
| _ -> None) atm_pats |> setify
in
match vars with
| [] -> k id_top atm_pats
| [v] -> Action.mk_set v (k id_top atm_pats)
| _ -> failwith "ambiguous var match"
in
(* generate a matcher for the rule *)
let ids_top =
Array.to_list states |>
List.filter_map (fun {id; point = p; _} ->
if List.exists ((=) (Top name)) p then
Some id
else None)
in
let rec filter_dups pats =
match pats with
| p :: pats ->
if List.exists (pattern_match p) pats
then filter_dups pats
else p :: filter_dups pats
| [] -> []
in
let pats_top =
List.filter_map (fun r ->
if r.name = name then
Some r.pattern
else None) rules |>
filter_dups |>
List.map (fun p -> (p, ()))
in
gen ids_top pats_top (fun _ pats ->
assert (pats <> []);
Action.stop)
type numberer =
{ atoms: (atomic_pattern * p state) list
; statemap: p state StateMap.t
; states: p state array
; mutable ops: op list
(* memoizes the list of possible operations
* according to the statemap *) }
let make_numberer sa am sm =
{ atoms = am
; states = sa
; statemap = sm
; ops = [] }
let atom_state n atm =
List.assoc atm n.atoms
+292
View File
@@ -0,0 +1,292 @@
type pstate =
{ data: string
; line: int
; coln: int
; indx: int }
type perror =
{ error: string
; ps: pstate }
exception ParseError of perror
type 'a parser =
{ fn: 'r. pstate -> ('a -> pstate -> 'r) -> 'r }
let update_pos ps beg fin =
let l, c = (ref ps.line, ref ps.coln) in
for i = beg to fin - 1 do
if ps.data.[i] = '\n' then
(incr l; c := 0)
else
incr c
done;
{ ps with line = !l; coln = !c }
let pret (type a) (x: a): a parser =
let fn ps k = k x ps in { fn }
let pfail error: 'a parser =
let fn ps _ = raise (ParseError {error; ps})
in { fn }
let por: 'a parser -> 'a parser -> 'a parser =
fun p1 p2 ->
let fn ps k =
try p1.fn ps k with ParseError e1 ->
try p2.fn ps k with ParseError e2 ->
if e1.ps.indx > e2.ps.indx then
raise (ParseError e1)
else
raise (ParseError e2)
in { fn }
let pbind: 'a parser -> ('a -> 'b parser) -> 'b parser =
fun p1 p2 ->
let fn ps k =
p1.fn ps (fun x ps -> (p2 x).fn ps k)
in { fn }
(* handy for recursive rules *)
let papp p x = pbind (pret x) p
let psnd: 'a parser -> 'b parser -> 'b parser =
fun p1 p2 -> pbind p1 (fun _x -> p2)
let pfst: 'a parser -> 'b parser -> 'a parser =
fun p1 p2 -> pbind p1 (fun x -> psnd p2 (pret x))
module Infix = struct
let ( let* ) = pbind
let ( ||| ) = por
let ( |<< ) = pfst
let ( |>> ) = psnd
end
open Infix
let pre: ?what:string -> string -> string parser =
fun ?what re ->
let what =
match what with
| None -> Printf.sprintf "%S" re
| Some what -> what
and re = Str.regexp re in
let fn ps k =
if not (Str.string_match re ps.data ps.indx) then
(let error =
Printf.sprintf "expected to match %s" what in
raise (ParseError {error; ps}));
let ps =
let indx = Str.match_end () in
{ (update_pos ps ps.indx indx) with indx }
in
k (Str.matched_string ps.data) ps
in { fn }
let peoi: unit parser =
let fn ps k =
if ps.indx <> String.length ps.data then
raise (ParseError
{ error = "expected end of input"; ps });
k () ps
in { fn }
let pws = pre "[ \r\n\t*]*"
let pws1 = pre "[ \r\n\t*]+"
let pthen p1 p2 =
let* x1 = p1 in
let* x2 = p2 in
pret (x1, x2)
let rec plist_tail: 'a parser -> ('a list) parser =
fun pitem ->
(pws |>> pre ")" |>> pret []) |||
(let* itm = pitem in
let* itms = plist_tail pitem in
pret (itm :: itms))
let plist pitem =
pws |>> pre ~what:"a list" "("
|>> plist_tail pitem
let plist1p p1 pitem =
pws |>> pre ~what:"a list" "("
|>> pthen p1 (plist_tail pitem)
let ppair p1 p2 =
pws |>> pre ~what:"a pair" "("
|>> pthen p1 p2 |<< pws |<< pre ")"
let run_parser p s =
let ps =
{data = s; line = 1; coln = 0; indx = 0} in
try `Ok (p.fn ps (fun res _ps -> res))
with ParseError e ->
let rec bol i =
if i = 0 then i else
if i < String.length s && s.[i] = '\n'
then i+1 (* XXX BUG *)
else bol (i-1)
in
let rec eol i =
if i = String.length s then i else
if s.[i] = '\n' then i else
eol (i+1)
in
let bol = bol e.ps.indx in
let eol = eol e.ps.indx in
(*
Printf.eprintf "bol:%d eol:%d indx:%d len:%d\n"
bol eol e.ps.indx (String.length s); (* XXX debug *)
*)
let lines =
String.split_on_char '\n'
(String.sub s bol (eol - bol))
in
let nl = List.length lines in
let caret = ref (e.ps.indx - bol) in
let msg = ref [] in
let pfx = " > " in
lines |> List.iteri (fun ln l ->
if ln <> nl - 1 || l <> "" then begin
let ll = String.length l + 1 in
msg := (pfx ^ l ^ "\n") :: !msg;
if !caret <= ll then begin
let pad = String.make !caret ' ' in
msg := (pfx ^ pad ^ "^\n") :: !msg;
end;
caret := !caret - ll;
end;
);
`Error
( e.ps, e.error
, String.concat "" (List.rev !msg) )
(* ---------------------------------------- *)
(* pattern parsing *)
(* ---------------------------------------- *)
(* Example syntax:
(with-vars (a b c d)
(patterns
(ob (add (tmp a) (con d)))
(bsm (add (tmp b) (mul (tmp m) (con 2 4 8)))) ))
*)
open Match
let pint64 =
let* s = pre "[-]?[0-9_]+" in
pret (Int64.of_string s)
let pid =
pre ~what:"an identifer"
"[a-zA-Z][a-zA-Z0-9_]*"
let pop_base =
let sob, obs = show_op_base, op_bases in
let* s = pre ~what:"an operator"
(String.concat "\\|" (List.map sob obs))
in pret (List.find (fun o -> s = sob o) obs)
let pop = let* ob = pop_base in pret (Kl, ob)
let rec ppat vs =
let pcons_tail =
let* cs = plist_tail (pws1 |>> pint64) in
match cs with
| [] -> pret [AnyCon]
| _ -> pret (List.map (fun c -> Con c) cs)
in
let pvar =
let* id = pid in
if not (List.mem id vs) then
pfail ("unbound variable: " ^ id)
else
pret id
in
pws |>> (
( let* c = pint64 in pret [Atm (Con c)] )
|||
( pre "(con)" |>> pret [Atm AnyCon] ) |||
( let* cs = pre "(con" |>> pcons_tail in
pret (List.map (fun c -> Atm c) cs) ) |||
( let* v = pre "(con" |>> pws1 |>> pvar in
let* cs = pcons_tail in
pret (List.map (fun c -> Var (v, c)) cs) )
|||
( pre "(tmp)" |>> pret [Atm Tmp] ) |||
( let* v = pre "(tmp" |>> pws1 |>> pvar in
pws |>> pre ")" |>> pret [Var (v, Tmp)] )
|||
( let* (op, rands) =
plist1p (pws |>> pop) (papp ppat vs) in
let nrands = List.length rands in
if nrands < 2 then
pfail ( "binary op requires at least"
^ " two arguments" )
else
let mk x y = Bnr (op, x, y) in
pret
(products rands []
(fun rands pats ->
(* construct a left-heavy tree *)
let r0 = List.hd rands in
let rs = List.tl rands in
List.fold_left mk r0 rs :: pats)) )
)
let pwith_vars ?(vs = []) p =
( let* vs =
pws |>> pre "(with-vars" |>> pws |>>
plist (pws |>> pid)
in pws |>> p vs |<< pws |<< pre ")" )
||| p vs
let ppats =
pwith_vars @@ fun vs ->
pre "(patterns" |>> plist_tail
(pwith_vars ~vs @@ fun vs ->
let* n, ps = ppair pid (ppat vs) in
pret (n, vs, ps))
(* ---------------------------------------- *)
(* tests *)
(* ---------------------------------------- *)
let () =
if false then
let show_patterns ps =
"[" ^ String.concat "; "
(List.map show_pattern ps) ^ "]"
in
let pat s =
Printf.printf "parse %s = " s;
let vars =
[ "foobar"; "a"; "b"; "d"
; "m"; "s"; "x" ]
in
match run_parser (ppat vars) s with
| `Ok p ->
Printf.printf "%s\n" (show_patterns p)
| `Error (_, e, _) ->
Printf.printf "ERROR: %s\n" e
in
pat "42";
pat "(tmp)";
pat "(tmp foobar)";
pat "(con)";
pat "(con 1 2 3)";
pat "(con x 1 2 3)";
pat "(add 1 2)";
pat "(add 1 2 3 4)";
pat "(sub 1 2)";
pat "(sub 1 2 3)";
pat "(tmp unbound_var)";
pat "(add 0)";
pat "(add 1 (add 2 3))";
pat "(add (tmp a) (con d))";
pat "(add (tmp b) (mul (tmp m) (con s 2 4 8)))";
pat "(add (con 1 2) (con 3 4))";
()
+134
View File
@@ -0,0 +1,134 @@
open Match
open Fuzz
open Cgen
(* unit tests *)
let test_pattern_match =
let pm = pattern_match
and nm = fun x y -> not (pattern_match x y) in
begin
assert (nm (Atm Tmp) (Atm (Con 42L)));
assert (pm (Atm AnyCon) (Atm (Con 42L)));
assert (nm (Atm (Con 42L)) (Atm AnyCon));
assert (nm (Atm (Con 42L)) (Atm Tmp));
end
let test_peel =
let o = Kw, Oadd in
let p = Bnr (o, Bnr (o, Atm Tmp, Atm Tmp),
Atm (Con 42L)) in
let l = peel p () in
let () = assert (List.length l = 3) in
let atomic_p (p, _) =
match p with Atm _ -> true | _ -> false in
let () = assert (List.for_all atomic_p l) in
let l = List.map (fun (p, c) -> fold_cursor c p) l in
let () = assert (List.for_all ((=) p) l) in
()
let test_fold_pairs =
let l = [1; 2; 3; 4; 5] in
let p = fold_pairs l l [] (fun a b -> a :: b) in
let () = assert (List.length p = 25) in
let p = sort_uniq compare p in
let () = assert (List.length p = 25) in
()
(* test pattern & state *)
let print_sm oc =
StateMap.iter (fun k s' ->
match k with
| K (o, sl, sr) ->
let top =
List.fold_left (fun top c ->
match c with
| Top r -> top ^ " " ^ r
| _ -> top) "" s'.point
in
Printf.fprintf oc
" (%s %d %d) -> %d%s\n"
(show_op o)
sl.id sr.id s'.id top)
let rules =
let oa = Kl, Oadd in
let om = Kl, Omul in
let va = Var ("a", Tmp)
and vb = Var ("b", Tmp)
and vc = Var ("c", Tmp)
and vs = Var ("s", Tmp) in
let vars = ["a"; "b"; "c"; "s"] in
let rule name pattern =
List.map
(fun pattern -> {name; vars; pattern})
(ac_equiv pattern)
in
match `X64Addr with
(* ------------------------------- *)
| `X64Addr ->
(* o + b *)
rule "ob" (Bnr (oa, Atm Tmp, Atm AnyCon))
@ (* b + s * m *)
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 2L), vs)))
@
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 4L), vs)))
@
rule "bsm" (Bnr (oa, vb, Bnr (om, Var ("m", Con 8L), vs)))
@ (* b + s *)
rule "bs1" (Bnr (oa, vb, vs))
@ (* o + s * m *)
(* rule "osm" (Bnr (oa, Atm AnyCon, Bnr (om, Atm (Con 4L), Atm Tmp))) *) []
@ (* o + b + s *)
rule "obs1" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb), vs))
@ (* o + b + s * m *)
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
Bnr (om, Var ("m", Con 2L), vs)))
@
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
Bnr (om, Var ("m", Con 4L), vs)))
@
rule "obsm" (Bnr (oa, Bnr (oa, Var ("o", AnyCon), vb),
Bnr (om, Var ("m", Con 8L), vs)))
(* ------------------------------- *)
| `Add3 ->
[ { name = "add"
; vars = []
; pattern = Bnr (oa, va, Bnr (oa, vb, vc)) } ] @
[ { name = "add"
; vars = []
; pattern = Bnr (oa, Bnr (oa, va, vb), vc) } ]
(*
let sa, am, sm = generate_table rules
let () =
Array.iteri (fun i s ->
Format.printf "@[state %d: %s@]@."
i (show_pattern s.seen))
sa
let () = print_sm stdout sm; flush stdout
let matcher = lr_matcher sm sa rules "obsm" (* XXX *)
let () = Format.printf "@[<v>%a@]@." Action.pp matcher
let () = Format.printf "@[matcher size: %d@]@." (Action.size matcher)
let numbr = make_numberer sa am sm
let () =
let opts = { pfx = ""
; static = true
; oc = stdout } in
emit_c opts numbr;
emit_matchers opts
[ ( ["b"; "o"; "s"; "m"]
, "obsm"
, matcher ) ]
(*
let tp = fuzz_numberer rules numbr
let () = test_matchers tp numbr rules
*)
*)