faster expression extraction

This commit is contained in:
Krasimir Angelov
2022-09-30 11:34:04 +02:00
parent 106d963d39
commit 6b63c2f779
2 changed files with 34 additions and 30 deletions

View File

@@ -1,6 +1,7 @@
#include "data.h" #include "data.h"
#include "printer.h" #include "printer.h"
#include "parser.h" #include "parser.h"
#include "math.h"
#include <type_traits> #include <type_traits>
#include <map> #include <map>
#include <vector> #include <vector>
@@ -140,7 +141,7 @@ public:
size_t d, Choice *choice) size_t d, Choice *choice)
{ {
this->outside_prob = item->outside_prob; this->outside_prob = item->outside_prob;
this->inside_prob = item->inside_prob; this->inside_prob = item->inside_prob+choice->viterbi_prob;
this->conts = item->conts; this->conts = item->conts;
this->lin = item->lin; this->lin = item->lin;
this->seq_index = item->seq_index; this->seq_index = item->seq_index;
@@ -264,7 +265,7 @@ public:
} }
} }
virtual bool proceed(PgfParser *parser, PgfUnmarshaller *u) virtual State *proceed(PgfParser *parser, PgfUnmarshaller *u)
{ {
ref<PgfSequence> seq = lin->seqs->data[seq_index]; ref<PgfSequence> seq = lin->seqs->data[seq_index];
@@ -275,11 +276,12 @@ public:
symbol(parser, sym); symbol(parser, sym);
} }
return true; return NULL;
} }
virtual void combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t prob, PgfUnmarshaller *u) virtual bool combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t prob, PgfUnmarshaller *u)
{ {
return false;
} }
virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m) virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m)
@@ -379,7 +381,7 @@ public:
this->inside_prob += prob; this->inside_prob += prob;
} }
virtual bool proceed(PgfParser *parser, PgfUnmarshaller *u) virtual State *proceed(PgfParser *parser, PgfUnmarshaller *u)
{ {
size_t n_args = prod->lin->absfun->type->hypos->len; size_t n_args = prod->lin->absfun->type->hypos->len;
while (arg_index < n_args) { while (arg_index < n_args) {
@@ -398,7 +400,7 @@ public:
combine(parser,choice->conts,ep.first,ep.second,u); combine(parser,choice->conts,ep.first,ep.second,u);
} }
} }
return true; return parser->fetch_state;
} }
PgfExpr arg = u->emeta(0); PgfExpr arg = u->emeta(0);
@@ -407,17 +409,21 @@ public:
arg_index++; arg_index++;
} }
State *prev = parser->fetch_state;
parent->exprs.push_back(std::pair<PgfExpr,prob_t>(expr,inside_prob)); parent->exprs.push_back(std::pair<PgfExpr,prob_t>(expr,inside_prob));
for (auto item : parent->items) { for (auto item : parent->items) {
item->combine(parser,parent->conts,expr,inside_prob,u); if (item->combine(parser,parent->conts,expr,inside_prob,u)) {
prev = parent->conts->state;
}
} }
return true; return prev;
} }
virtual void combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t prob, PgfUnmarshaller *u) virtual bool combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t prob, PgfUnmarshaller *u)
{ {
parser->fetch_state->queue.push(new ExprItem(this,expr,prob,u)); parser->fetch_state->queue.push(new ExprItem(this,expr,prob,u));
return false;
} }
virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m) virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m)
@@ -478,10 +484,10 @@ public:
this->next = next; this->next = next;
} }
virtual bool proceed(PgfParser *parser, PgfUnmarshaller *u) virtual State *proceed(PgfParser *parser, PgfUnmarshaller *u)
{ {
if (state->prev == NULL) if (state->prev == NULL)
return false; return NULL;
if (state->choices.size() == 0) { if (state->choices.size() == 0) {
State *prev = state; State *prev = state;
@@ -499,6 +505,7 @@ public:
prev->queue.push(new MetaItem(prev, expr, prev->queue.push(new MetaItem(prev, expr,
inside_prob, inside_prob,
this)); this));
return prev;
} else { } else {
for (auto it : state->choices) { for (auto it : state->choices) {
ParseItemConts *conts = it.first; ParseItemConts *conts = it.first;
@@ -512,8 +519,8 @@ public:
if (choice->items.size() == 1) { if (choice->items.size() == 1) {
prob_t prob = conts->state->viterbi_prob+inside_prob; prob_t prob = conts->state->viterbi_prob+inside_prob;
for (Production *prod : choice->prods) { for (Production *prod : choice->prods) {
conts->state->queue.push(new ExprItem(choice, parser->fetch_state->queue.push(new ExprItem(choice,
prod, prob+prod->lin->lincat->abscat->prob, u)); prod, prob+prod->lin->lincat->abscat->prob, u));
} }
} else { } else {
for (auto ep : choice->exprs) { for (auto ep : choice->exprs) {
@@ -521,16 +528,17 @@ public:
} }
} }
} }
return parser->fetch_state;
} }
return false;
} }
virtual void combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t prob, PgfUnmarshaller *u) virtual bool combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t prob, PgfUnmarshaller *u)
{ {
conts->state->queue.push(new MetaItem(conts->state, conts->state->queue.push(new MetaItem(conts->state,
expr, expr,
this->inside_prob+conts->field->lincat->abscat->prob+prob, this->inside_prob+conts->field->lincat->abscat->prob+prob,
this)); this));
return true;
} }
virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m) virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m)
@@ -713,22 +721,18 @@ PgfExpr PgfParser::fetch(PgfDB *db, PgfUnmarshaller *u, prob_t *prob)
fetch_state = fetch_state->next; fetch_state = fetch_state->next;
} }
while (fetch_state != NULL) { while (!fetch_state->queue.empty()) {
while (!fetch_state->queue.empty()) { Item *item = fetch_state->queue.top();
Item *item = fetch_state->queue.top(); fetch_state->queue.pop();
fetch_state->queue.pop();
item->trace(after,m); item->trace(after,m);
if (!item->proceed(this,u)) {
if (fetch_state->prev == NULL) { if (fetch_state->prev == NULL) {
*prob = item->get_prob(); *prob = item->get_prob();
return item->get_expr(u); return item->get_expr(u);
}
break;
}
} }
fetch_state = fetch_state->prev; fetch_state = item->proceed(this,u);
} }
return 0; return 0;

View File

@@ -27,8 +27,8 @@ private:
public: public:
prob_t get_prob() { return inside_prob + outside_prob; }; prob_t get_prob() { return inside_prob + outside_prob; };
virtual bool proceed(PgfParser *parser, PgfUnmarshaller *u) = 0; virtual State *proceed(PgfParser *parser, PgfUnmarshaller *u) = 0;
virtual void combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t inside_prob, PgfUnmarshaller *u) = 0; virtual bool combine(PgfParser *parser, ParseItemConts *conts, PgfExpr expr, prob_t inside_prob, PgfUnmarshaller *u) = 0;
virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m) = 0; virtual void print1(PgfPrinter *printer, State *state, PgfMarshaller *m) = 0;
virtual void print2(PgfPrinter *printer, State *state, int x, PgfMarshaller *m) = 0; virtual void print2(PgfPrinter *printer, State *state, int x, PgfMarshaller *m) = 0;
virtual PgfExpr get_expr(PgfUnmarshaller *u) = 0; virtual PgfExpr get_expr(PgfUnmarshaller *u) = 0;