respect the depth in the exhaustive generator

This commit is contained in:
Krasimir Angelov
2023-03-17 11:38:13 +01:00
parent 6c3a4f5dcd
commit 8bda030854
2 changed files with 35 additions and 14 deletions

View File

@@ -391,10 +391,11 @@ void PgfExhaustiveGenerator::predict_literal(ref<PgfText> cat, Result *res)
return;
}
res->exprs.push_back(std::pair<PgfExpr,prob_t>(expr,0));
ExprInstance p(expr,0,1);
res->exprs.push_back(p);
for (State1 *state : res->states) {
state->combine(this,res->scope,expr,0);
state->combine(this,res->scope,p);
}
}
@@ -438,6 +439,7 @@ bool PgfExhaustiveGenerator::State0::process(PgfExhaustiveGenerator *gen)
state->type = fun->type;
state->n_args = 0;
state->expr = expr;
state->depth = 1;
if (state->process(gen)) {
State::release(state,gen->u);
@@ -520,18 +522,22 @@ again: {
// predict global functions
gen->push_left_states(gen->pgf->abstract.funs_by_cat, g.first, arg_res, outside_prob);
} else {
for (std::pair<PgfExpr,prob_t> p : arg_res->exprs) {
this->combine(gen,arg_res->scope,p.first,p.second);
for (ExprInstance p : arg_res->exprs) {
this->combine(gen,arg_res->scope,p);
}
}
return false;
}
void PgfExhaustiveGenerator::State1::combine(PgfExhaustiveGenerator *gen,
Scope *scope, PgfExpr expr, prob_t prob)
void PgfExhaustiveGenerator::State1::combine(PgfExhaustiveGenerator *gen,
Scope *scope, ExprInstance &p)
{
if (p.depth+1 > gen->depth)
return;
Scope *s = scope;
PgfExpr expr = p.expr;
while (s != res->scope) {
PgfExpr abs = gen->u->eabs(s->bind_type, &s->var, expr);
if (s != scope) {
@@ -564,10 +570,11 @@ void PgfExhaustiveGenerator::State1::combine(PgfExhaustiveGenerator *gen,
State1 *app_state = new State1();
app_state->res = res;
app_state->prob = this->prob + prob;
app_state->prob = this->prob + p.prob;
app_state->type = type;
app_state->n_args = n_args+1;
app_state->expr = app;
app_state->depth = std::max(this->depth,p.depth+1);
gen->queue.push(app_state);
}
@@ -580,9 +587,10 @@ void PgfExhaustiveGenerator::State1::complete(PgfExhaustiveGenerator *gen)
outside_prob = res->states[0]->prob;
prob_t inside_prob = prob-outside_prob;
res->exprs.push_back(std::pair<PgfExpr,prob_t>(expr,inside_prob));
ExprInstance p(expr,inside_prob,depth);
res->exprs.push_back(p);
for (State1 *state : res->states) {
state->combine(gen,res->scope,expr,inside_prob);
state->combine(gen,res->scope,p);
}
}
@@ -601,8 +609,8 @@ PgfExpr PgfExhaustiveGenerator::fetch(PgfDB *db, prob_t *prob)
for (;;) {
if (top_res_index < top_res->exprs.size()) {
auto pair = top_res->exprs[top_res_index++];
*prob = pair.second;
return pair.first;
*prob = pair.prob;
return pair.expr;
}
if (queue.empty())
@@ -626,7 +634,7 @@ PgfExhaustiveGenerator::~PgfExhaustiveGenerator()
for (auto i : results) {
for (auto j : i.second->exprs) {
free_ref(j.first);
free_ref(j.expr);
}
}

View File

@@ -100,15 +100,28 @@ class PGF_INTERNAL_DECL PgfExhaustiveGenerator : public PgfGenerator, public Pgf
virtual bool process(PgfExhaustiveGenerator *gen);
};
struct ExprInstance {
PgfExpr expr;
prob_t prob;
size_t depth;
ExprInstance(PgfExpr expr, prob_t prob, size_t depth) {
this->expr = expr;
this->prob = prob;
this->depth = depth;
}
};
struct State1 : State {
ref<PgfDTyp> type;
size_t n_args;
PgfExpr expr;
size_t depth;
virtual bool process(PgfExhaustiveGenerator *gen);
virtual void free_refs(PgfUnmarshaller *u);
void combine(PgfExhaustiveGenerator *gen,
Scope *scope, PgfExpr expr, prob_t prob);
Scope *scope, ExprInstance &p);
void complete(PgfExhaustiveGenerator *gen);
};
@@ -119,7 +132,7 @@ class PGF_INTERNAL_DECL PgfExhaustiveGenerator : public PgfGenerator, public Pgf
Scope *scope;
size_t scope_len;
std::vector<State1*> states;
std::vector<std::pair<PgfExpr,prob_t>> exprs;
std::vector<ExprInstance> exprs;
Result() {
this->ref_count = 0;