From ed45bf9ebd615f2ef598ce64dfc5629a35f3da16 Mon Sep 17 00:00:00 2001 From: Krasimir Angelov Date: Mon, 13 Mar 2023 13:30:17 +0100 Subject: [PATCH] HOAS in exhaustive generation --- src/runtime/c/pgf/data.h | 2 +- src/runtime/c/pgf/generator.cxx | 171 ++++++++++++++++++++++++-------- src/runtime/c/pgf/generator.h | 89 ++++++++++++----- 3 files changed, 196 insertions(+), 66 deletions(-) diff --git a/src/runtime/c/pgf/data.h b/src/runtime/c/pgf/data.h index a16612345..292aa9aab 100644 --- a/src/runtime/c/pgf/data.h +++ b/src/runtime/c/pgf/data.h @@ -7,7 +7,7 @@ #include #include #include -#include +#include #include "pgf.h" diff --git a/src/runtime/c/pgf/generator.cxx b/src/runtime/c/pgf/generator.cxx index 887e6df51..5072538da 100644 --- a/src/runtime/c/pgf/generator.cxx +++ b/src/runtime/c/pgf/generator.cxx @@ -2,6 +2,7 @@ #include "data.h" #include "generator.h" + PgfRandomGenerator::PgfRandomGenerator(ref pgf, size_t depth, uint64_t *seed, PgfMarshaller *m, PgfUnmarshaller *u) @@ -159,11 +160,11 @@ again: { PgfVarGenerator v_gen(this, index, cat, n_exprs, exprs); expr = m->match_type(&v_gen, sc->type); if (expr != 0) { - if (rand() < VAR_PROB) { - prob += -log(VAR_PROB); + if (rand() < Scope::VAR_PROB) { + prob += -log(Scope::VAR_PROB); break; } else { - prob += -log(1-VAR_PROB); + prob += -log(1-Scope::VAR_PROB); if (var_expr != 0) u->free_ref(var_expr); var_expr = expr; @@ -199,7 +200,7 @@ again: { ref fun = probspace_random(pgf->abstract.funs_by_cat, cat, rand_value); if (fun == 0) { if (var_expr != 0) { - prob += -log(VAR_PROB/(1-VAR_PROB)); + prob += -log(Scope::VAR_PROB/(1-Scope::VAR_PROB)); expr = var_expr; } } else { @@ -293,9 +294,9 @@ PgfExhaustiveGenerator::PgfExhaustiveGenerator(ref pgf, PgfLiteral lint = u->lint(1,&value); PgfExpr expr = u->elit(lint); u->free_ref(lint); - Result *res = new Result(); + Result *res = new Result(ref::from_ptr(&cat_Int->name)); res->exprs.push_back(std::pair(expr,0)); - results[ref::from_ptr(&cat_Int->name)] = res; + results._M_insert_unique(res); } PgfText *text_Float = string2text("Float"); @@ -306,9 +307,9 @@ PgfExhaustiveGenerator::PgfExhaustiveGenerator(ref pgf, PgfLiteral lflt = u->lflt(3.14); PgfExpr expr = u->elit(lflt); u->free_ref(lflt); - Result *res = new Result(); + Result *res = new Result(ref::from_ptr(&cat_Float->name)); res->exprs.push_back(std::pair(expr,0)); - results[ref::from_ptr(&cat_Float->name)] = res; + results._M_insert_unique(res); } PgfText *text_String = string2text("String"); @@ -323,9 +324,9 @@ PgfExhaustiveGenerator::PgfExhaustiveGenerator(ref pgf, PgfLiteral lstr = u->lstr(value); PgfExpr expr = u->elit(lstr); u->free_ref(lstr); - Result *res = new Result(); + Result *res = new Result(ref::from_ptr(&cat_String->name)); res->exprs.push_back(std::pair(expr,0)); - results[ref::from_ptr(&cat_String->name)] = res; + results._M_insert_unique(res); } } @@ -384,7 +385,7 @@ PgfLiteral PgfExhaustiveGenerator::lstr(PgfText *v) return 0; } -void PgfExhaustiveGenerator::push_left_states(PgfProbspace space, PgfText *cat, Result *res) +void PgfExhaustiveGenerator::push_left_states(PgfProbspace space, PgfText *cat, Result *res, prob_t outside_prob) { while (space != 0) { int cmp = textcmp(cat,&(*space->value.cat)); @@ -398,12 +399,12 @@ void PgfExhaustiveGenerator::push_left_states(PgfProbspace space, PgfText *cat, State0 *state = new State0(); state->res = res; - state->prob = res->outside_prob(this) + + state->prob = outside_prob + space->value.fun->prob; state->space = space; queue.push(state); } else { - push_left_states(space->right, cat, res); + push_left_states(space->right, cat, res, outside_prob); } space = space->left; } @@ -419,13 +420,17 @@ PgfType PgfExhaustiveGenerator::dtyp(size_t n_hypos, PgfTypeHypo *hypos, if (abscat == 0) return 0; - Result *&res = results[ref::from_ptr(&abscat->name)]; - if (res == NULL) { - res = new Result(); - } - top_res = res; + Goal g(ref::from_ptr(&abscat->name)); - push_left_states(pgf->abstract.funs_by_cat, cat, top_res); + auto i = results.lower_bound(g); + if (i == results.end() || results.key_comp()(g, **i)) { + top_res = new Result(g); + results._M_emplace_hint_unique(i, top_res); + } else { + top_res = *i; + } + + push_left_states(pgf->abstract.funs_by_cat, cat, top_res, 0); return 0; } @@ -455,9 +460,11 @@ void PgfExhaustiveGenerator::State::release(State *state, PgfUnmarshaller *u) bool PgfExhaustiveGenerator::State0::process(PgfExhaustiveGenerator *gen, PgfUnmarshaller *u) { - gen->push_left_states(space->right, &(*space->value.cat), res); - ref fun = space->value.fun; + prob_t outside_prob = this->prob-fun->prob; + + gen->push_left_states(space->right, &(*space->value.cat), res, outside_prob); + PgfExpr expr = u->efun(&fun->name); res->ref_count++; @@ -483,19 +490,76 @@ bool PgfExhaustiveGenerator::State1::process(PgfExhaustiveGenerator *gen, PgfUnm return true; } - PgfDTyp *arg_type = vector_elem(type->hypos, n_args)->type; - Result *&res = gen->results[ref::from_ptr(&arg_type->name)]; - Result *tmp = res; - if (res == NULL) { - res = new Result(); + ref arg_type = vector_elem(type->hypos, n_args)->type; + Goal g(ref::from_ptr(&arg_type->name), *res); + for (size_t i = 0; i < arg_type->hypos->len; i++) { + ref hypo = vector_elem(arg_type->hypos, i); + + size_t buf_size = 16; + Scope *new_scope = (Scope *) malloc(sizeof(Scope)+buf_size); + new_scope->next = g.scope; + new_scope->type = hypo->type.as_object(); + new_scope->m = &gen->i_m; + new_scope->bind_type = hypo->bind_type; + + size_t out; +again: { + new_scope->var.size = + snprintf(new_scope->var.text, buf_size, "v%zu", g.scope_len+1); + if (new_scope->var.size >= buf_size) { + buf_size = new_scope->var.size+1; + new_scope = (Scope*) + realloc(new_scope,sizeof(Scope)+buf_size); + goto again; + } + } + + gen->scopes.push_back(new_scope); + g.scope = new_scope; + g.scope_len++; } - res->states.push_back(this); + + Result *arg_res; + Result *tmp = NULL; + auto i = gen->results.lower_bound(g); + if (i == gen->results.end() || gen->results.key_comp()(g, **i)) { + arg_res = new Result(g); + gen->results._M_emplace_hint_unique(i, res); + } else { + arg_res = *i; + tmp = res; + } + + arg_res->states.push_back(this); if (tmp == NULL) { - gen->push_left_states(gen->pgf->abstract.funs_by_cat, &arg_type->name, res); + // predict local variables + size_t index = 0; + Scope *s = g.scope; + prob_t outside_prob = this->prob; + while (s != NULL) { + ref type = s->type; + if (textcmp(&type->name, g.cat) == 0) { + State1 *var_state = new State1(); + var_state->res = arg_res; + var_state->prob = outside_prob - log(Scope::VAR_PROB); + var_state->type = type; + var_state->n_args = 0; + var_state->expr = u->evar(index); + gen->queue.push(var_state); + + outside_prob += -log(1-Scope::VAR_PROB); + } + + index++; + s = s->next; + } + + // predict global functions + gen->push_left_states(gen->pgf->abstract.funs_by_cat, g.cat, arg_res, outside_prob); } else { - for (std::pair p : res->exprs) { - this->combine(gen,p.first,p.second,u); + for (std::pair p : arg_res->exprs) { + this->combine(gen,arg_res->scope,p.first,p.second,u); } } @@ -503,18 +567,35 @@ bool PgfExhaustiveGenerator::State1::process(PgfExhaustiveGenerator *gen, PgfUnm } void PgfExhaustiveGenerator::State1::combine(PgfExhaustiveGenerator *gen, - PgfExpr expr, prob_t prob, + Scope *scope, PgfExpr expr, prob_t prob, PgfUnmarshaller *u) { + Scope *s = scope; + while (s != res->scope) { + PgfExpr abs = u->eabs(s->bind_type, &s->var, expr); + if (s != scope) { + // if expr is a lambda created in the previous iteration + u->free_ref(expr); + } + expr = abs; + s = s->next; + } + PgfBindType bind_type = vector_elem(type->hypos, n_args)->bind_type; if (bind_type == PGF_BIND_TYPE_IMPLICIT) { - expr = u->eimplarg(expr); + PgfExpr implarg = u->eimplarg(expr); + if (scope != res->scope) { + // if expr is a lambda created in the previous loop + u->free_ref(expr); + } + expr = implarg; } PgfExpr app = u->eapp(this->expr, expr); - if (bind_type == PGF_BIND_TYPE_IMPLICIT) { + if (bind_type == PGF_BIND_TYPE_IMPLICIT || scope != res->scope) { + // if expr is either a lambda or an implicit argument u->free_ref(expr); } @@ -531,10 +612,16 @@ void PgfExhaustiveGenerator::State1::combine(PgfExhaustiveGenerator *gen, void PgfExhaustiveGenerator::State1::complete(PgfExhaustiveGenerator *gen, PgfUnmarshaller *u) { - prob_t inside_prob = prob-res->outside_prob(gen); + prob_t outside_prob; + if (res == gen->top_res) + outside_prob = 0; + else + outside_prob = res->states[0]->prob; + + prob_t inside_prob = prob-outside_prob; res->exprs.push_back(std::pair(expr,inside_prob)); for (State1 *state : res->states) { - state->combine(gen,expr,inside_prob,u); + state->combine(gen,res->scope,expr,inside_prob,u); } } @@ -543,15 +630,13 @@ void PgfExhaustiveGenerator::State1::free_refs(PgfUnmarshaller *u) u->free_ref(expr); } -PgfExhaustiveGenerator::Result::Result() -{ - ref_count = 0; -} - PgfExpr PgfExhaustiveGenerator::fetch(PgfDB *db, PgfUnmarshaller *u, prob_t *prob) { DB_scope scope(db, READER_SCOPE); + if (top_res == NULL) + return 0; + for (;;) { if (top_res_index < top_res->exprs.size()) { auto pair = top_res->exprs[top_res_index++]; @@ -579,7 +664,7 @@ void PgfExhaustiveGenerator::free_refs(PgfUnmarshaller *u) } for (auto i : results) { - for (auto j : i.second->exprs) { + for (auto j : i->exprs) { free_ref(j.first); } } @@ -587,4 +672,8 @@ void PgfExhaustiveGenerator::free_refs(PgfUnmarshaller *u) PgfExhaustiveGenerator::~PgfExhaustiveGenerator() { + while (!scopes.empty()) { + Scope *scope = scopes.back(); scopes.pop_back(); + delete scope; + } } diff --git a/src/runtime/c/pgf/generator.h b/src/runtime/c/pgf/generator.h index 2a9c88116..2aee30d0d 100644 --- a/src/runtime/c/pgf/generator.h +++ b/src/runtime/c/pgf/generator.h @@ -1,10 +1,18 @@ #ifndef GENERATOR_H #define GENERATOR_H +struct PGF_INTERNAL_DECL Scope { + constexpr static prob_t VAR_PROB = 0.1; + + Scope *next; + PgfType type; + PgfMarshaller *m; + PgfBindType bind_type; + PgfText var; +}; + class PGF_INTERNAL_DECL PgfRandomGenerator : public PgfUnmarshaller { - const static int VAR_PROB = 0.1; - ref pgf; size_t depth; uint64_t *seed; @@ -13,14 +21,6 @@ class PGF_INTERNAL_DECL PgfRandomGenerator : public PgfUnmarshaller PgfInternalMarshaller i_m; PgfUnmarshaller *u; - struct Scope { - Scope *next; - PgfType type; - PgfMarshaller *m; - PgfBindType bind_type; - PgfText var; - }; - Scope *scope; size_t scope_len; @@ -61,6 +61,7 @@ class PGF_INTERNAL_DECL PgfExhaustiveGenerator : public PgfUnmarshaller, public ref pgf; size_t depth; PgfMarshaller *m; + PgfInternalMarshaller i_m; Result *top_res; size_t top_res_index; @@ -85,22 +86,50 @@ class PGF_INTERNAL_DECL PgfExhaustiveGenerator : public PgfUnmarshaller, public virtual bool process(PgfExhaustiveGenerator *gen, PgfUnmarshaller *u); virtual void free_refs(PgfUnmarshaller *u); void combine(PgfExhaustiveGenerator *gen, - PgfExpr expr, prob_t prob, + Scope *scope, PgfExpr expr, prob_t prob, PgfUnmarshaller *u); void complete(PgfExhaustiveGenerator *gen, PgfUnmarshaller *u); }; - struct Result { + struct Goal { + ref cat; + Scope *scope; + size_t scope_len; + + Goal(ref cat) { + this->cat = cat; + this->scope = NULL; + this->scope_len = 0; + } + + Goal(ref cat, Goal &g) { + this->cat = cat; + this->scope = g.scope; + this->scope_len = g.scope_len; + } + + Goal(Goal &g) { + this->cat = g.cat; + this->scope = g.scope; + this->scope_len = g.scope_len; + } + }; + + struct Result : Goal { size_t ref_count; std::vector states; std::vector> exprs; - Result(); + Result(ref cat) : Goal(cat) { + this->ref_count = 0; + } - prob_t outside_prob(PgfExhaustiveGenerator *gen) { - if (this == gen->top_res) - return 0; - return states[0]->prob; + Result(ref cat, Goal &g) : Goal(cat,g) { + this->ref_count = 0; + } + + Result(Goal &g) : Goal(g) { + this->ref_count = 0; } }; @@ -111,17 +140,29 @@ class PGF_INTERNAL_DECL PgfExhaustiveGenerator : public PgfUnmarshaller, public } }; - class CompareText : public std::less> { - public: - bool operator() (const ref t1, const ref t2) const { - return textcmp(t1, t2) < 0; + struct CompareGoal : public std::less { + bool operator() (const Goal &g1, const Goal &g2) const { + int cmp = textcmp(g1.cat, g2.cat); + if (cmp < 0) + return true; + else if (cmp > 0) + return false; + else + return (g1.scope < g2.scope); } }; - std::map, Result*, CompareText> results; - std::priority_queue, CompareState> queue; + struct Result2Goal { + Goal &operator()(Result *res) { + return *res; + } + }; - void push_left_states(PgfProbspace space, PgfText *cat, Result *res); + std::_Rb_tree results; + std::priority_queue, CompareState> queue; + std::vector scopes; + + void push_left_states(PgfProbspace space, PgfText *cat, Result *res, prob_t outside_prob); public: PgfExhaustiveGenerator(ref pgf, size_t depth,