1
0
forked from GitHub/gf-core

now the robust parser is purely top-down and the meta rules compete on a fair basis with the grammar rules

This commit is contained in:
kr.angelov
2012-06-12 09:29:51 +00:00
parent d989005e01
commit b27a440ef3
6 changed files with 256 additions and 79 deletions

View File

@@ -3,6 +3,7 @@
#include <gu/type.h>
#include <gu/variant.h>
#include <gu/assert.h>
#include <math.h>
bool
pgf_tokens_equal(PgfTokens t1, PgfTokens t2)
@@ -184,6 +185,12 @@ GU_DEFINE_TYPE(
GU_MEMBER(PgfCatFun, prob, double),
GU_MEMBER(PgfCatFun, fun, PgfCId));
static float inf_float = INFINITY;
GU_DEFINE_TYPE(PgfMetaChildMap, GuMap,
gu_type(PgfCat), NULL,
gu_type(float), &inf_float);
GU_DEFINE_TYPE(
PgfCat, struct,
GU_MEMBER(PgfCat, context, PgfHypos),

View File

@@ -145,11 +145,16 @@ struct PgfCatFun {
PgfCId fun;
};
typedef GuMap PgfMetaChildMap;
extern GU_DECLARE_TYPE(PgfMetaChildMap, GuMap);
struct PgfCat {
// TODO: Add cid here
PgfHypos context;
float meta_prob;
float meta_token_prob;
PgfMetaChildMap* meta_child_probs;
GuLength n_functions;
PgfCatFun functions[]; // XXX: resolve to PgfFunDecl*?

View File

@@ -492,7 +492,9 @@ pgf_item_set_curr_symbol(PgfItem* item, GuPool* pool)
static PgfItem*
pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx,
PgfProduction prod, PgfItemBuf* conts, GuPool* pool)
PgfProduction prod, PgfItemBuf* conts,
float delta_prob,
GuPool* pool)
{
PgfItemBase* base = gu_new(PgfItemBase, pool);
base->ccat = ccat;
@@ -557,6 +559,7 @@ pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx,
best_cont->inside_prob-ccat->viterbi_prob+
best_cont->outside_prob;
}
item->outside_prob += delta_prob;
pgf_item_set_curr_symbol(item, pool);
return item;
@@ -650,7 +653,12 @@ pgf_parsing_combine(PgfParseState* before, PgfParseState* after,
nargs * sizeof(PgfPArg));
gu_seq_set(item->args, PgfPArg, nargs,
((PgfPArg) { .hypos = NULL, .ccat = cat }));
item->inside_prob += cat->viterbi_prob;
PgfCIdMap* meta_child_probs =
item->base->ccat->cnccat->abscat->meta_child_probs;
item->inside_prob +=
cat->viterbi_prob+
gu_map_get(meta_child_probs, cat->cnccat->abscat, float);
PgfSymbol prev = item->curr_sym;
PgfSymbolCat* scat = (PgfSymbolCat*)
@@ -673,10 +681,11 @@ pgf_parsing_combine(PgfParseState* before, PgfParseState* after,
static void
pgf_parsing_production(PgfParseState* state,
PgfCCat* ccat, size_t lin_idx,
PgfProduction prod, PgfItemBuf* conts)
PgfProduction prod, PgfItemBuf* conts,
float delta_prob)
{
PgfItem* item =
pgf_new_item(state->offset, ccat, lin_idx, prod, conts, state->pool);
pgf_new_item(state->offset, ccat, lin_idx, prod, conts, delta_prob, state->pool);
gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item);
}
@@ -798,7 +807,7 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after,
* i.e. process it. */
if (conts2) {
pgf_parsing_production(before, cat, i,
prod, conts2);
prod, conts2, 0);
}
}
@@ -818,7 +827,7 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after,
* i.e. process it. */
if (conts2) {
pgf_parsing_production(state, cat, i,
prod, conts2);
prod, conts2, 0);
}
}
@@ -835,14 +844,15 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after,
static void
pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
PgfItem* item, PgfCCat* ccat, size_t lin_idx)
PgfItem* item, PgfCCat* ccat, size_t lin_idx,
float delta_prob)
{
gu_enter("-> cat: %d", ccat->fid);
if (gu_seq_is_null(ccat->prods)) {
// Empty category
return;
}
PgfItemBuf* conts =
pgf_parsing_get_conts(before->conts_map, ccat, lin_idx,
before->pool, before->pool);
@@ -856,17 +866,17 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
PgfProductionSeq prods = ccat->prods;
for (size_t i = 0; i < ccat->n_synprods; i++) {
PgfProduction prod =
gu_seq_get(prods, PgfProduction, i);
pgf_parsing_production(before, ccat, lin_idx, prod, conts);
gu_seq_get(prods, PgfProduction, i);
pgf_parsing_production(before, ccat, lin_idx, prod, conts, delta_prob);
}
if (ccat->cnccat->abscat->meta_prob != INFINITY &&
ccat->fid < before->ps->concr->total_cats) {
// Top-down prediction for meta rules
PgfItem *item =
pgf_new_item(before->offset, ccat, lin_idx, before->ps->meta_prod, conts, before->pool);
pgf_new_item(before->offset, ccat, lin_idx, before->ps->meta_prod, conts, 0, before->pool);
item->inside_prob =
1000000 + ccat->cnccat->abscat->meta_prob * 1000;
ccat->cnccat->abscat->meta_prob;
gu_buf_heap_push(before->agenda, &pgf_item_prob_order, &item);
}
@@ -880,7 +890,7 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
new_item->base->lin_idx == lin_idx &&
gu_seq_length(new_item->args) == 0) {
pgf_parsing_production(before, ccat, lin_idx,
new_item->base->prod, conts);
new_item->base->prod, conts, 0);
}
}
}
@@ -901,7 +911,7 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
PgfProductionApply* papp = i.data;
if (gu_seq_length(papp->args) == 0) {
pgf_parsing_production(before, ccat, lin_idx,
prod, conts);
prod, conts, 0);
}
break;
}
@@ -931,67 +941,107 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
gu_exit("<-");
}
static void
static float
pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after,
PgfItem* item, PgfItem* meta_item, PgfItemBuf* agenda,
bool print)
PgfItemBuf* index, PgfItem* meta_item,
PgfItemBuf* agenda)
{
PgfItemBuf* conts =
pgf_parsing_get_conts(before->conts_map,
item->base->ccat, item->base->lin_idx,
before->pool, before->pool);
gu_buf_push(conts, PgfItem*, meta_item);
if (gu_buf_length(conts) == 1) {
PgfItem* copy = pgf_item_copy(item, after->pool);
copy->base = pgf_item_base_copy(item->base, after->pool);
copy->base->conts = conts;
copy->outside_prob =
meta_item->inside_prob+meta_item->outside_prob;
#ifdef PGF_PARSER_DEBUG
copy->start = before->offset;
copy->end = before->offset;
float prob = INFINITY;
if (print) {
GuPool* tmp_pool = gu_new_pool();
GuOut* out = gu_file_out(stderr, tmp_pool);
GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool);
GuExn* err = gu_exn(NULL, type, tmp_pool);
pgf_print_item(copy, wtr, err, tmp_pool);
gu_pool_free(tmp_pool);
} else {
copy->end = after->offset;
}
PgfMetaChildMap* meta_child_probs =
meta_item->base->ccat->cnccat->abscat->meta_child_probs;
if (meta_child_probs == NULL)
return prob;
if (!gu_map_has(before->generated_cats, index)) {
gu_map_put(before->generated_cats, index, PgfCCat*, NULL);
size_t n_items = gu_buf_length(index);
for (size_t i = 0; i < n_items; i++) {
PgfItem *item = gu_buf_get(index, PgfItem*, i);
float meta_prob =
meta_item->inside_prob+
meta_item->outside_prob+
gu_map_get(meta_child_probs, item->base->ccat->cnccat->abscat, float);
PgfItemBuf* conts =
pgf_parsing_get_conts(before->conts_map,
item->base->ccat, item->base->lin_idx,
before->pool, before->pool);
if (gu_buf_length(conts) == 0) {
float outside_prob =
pgf_parsing_bu_predict(before, after,
item->base->conts, meta_item,
conts);
if (outside_prob > meta_prob)
outside_prob = meta_prob;
for (size_t j = i; j < n_items; j++) {
PgfItem *item_ = gu_buf_get(index, PgfItem*, j);
if (item->base->conts == item_->base->conts) {
PgfItem* copy = pgf_item_copy(item_, after->pool);
copy->base = pgf_item_base_copy(item_->base, after->pool);
copy->base->conts = conts;
copy->outside_prob = outside_prob;
#ifdef PGF_PARSER_DEBUG
copy->start = before->offset;
copy->end = (agenda == NULL)
? after->offset
: before->offset;
#endif
gu_buf_push(agenda, PgfItem*, copy);
if (agenda == NULL)
pgf_parsing_add_transition(before, after, after->ts->tok, copy);
else
gu_buf_push(agenda, PgfItem*, copy);
size_t n_items = gu_buf_length(item->base->conts);
for (size_t i = 0; i < n_items; i++) {
PgfItem *item_ = gu_buf_get(item->base->conts, PgfItem*, i);
pgf_parsing_bu_predict(before, after, item_, meta_item, conts, true);
}
} else {
/* If it has already been completed, combine. */
float item_prob =
copy->inside_prob+copy->outside_prob;
if (prob > item_prob)
prob = item_prob;
}
}
} else {
size_t n_items = gu_buf_length(conts);
for (size_t i = 0; i < n_items; i++) {
PgfItem *item = gu_buf_get(conts, PgfItem*, i);
float item_prob =
item->inside_prob+item->outside_prob;
if (prob > item_prob)
prob = item_prob;
}
prob += item->inside_prob;
/*PgfCCat* completed =
pgf_parsing_get_completed(before, conts);
if (completed) {
pgf_parsing_combine(before, after, meta_item, completed, item->base->lin_idx);
}*/
/* If it has already been completed, combine. */
PgfParseState* state = after;
while (state != NULL) {
PgfCCat* completed =
pgf_parsing_get_completed(state, conts);
if (completed) {
pgf_parsing_combine(state, state->next, meta_item, completed, item->base->lin_idx);
/*PgfCCat* completed =
pgf_parsing_get_completed(before, conts);
if (completed) {
pgf_parsing_combine(before, after, meta_item, completed, item->base->lin_idx);
}*/
PgfParseState* state = after;
while (state != NULL) {
PgfCCat* completed =
pgf_parsing_get_completed(state, conts);
if (completed) {
pgf_parsing_combine(state, state->next, meta_item, completed, item->base->lin_idx);
}
state = state->next;
}
}
state = state->next;
if (meta_prob != INFINITY)
gu_buf_push(conts, PgfItem*, meta_item);
}
}
return prob;
}
static void
@@ -1002,7 +1052,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
PgfSymbolCat* scat = gu_variant_data(sym);
PgfPArg* parg = gu_seq_index(item->args, PgfPArg, scat->d);
gu_assert(!parg->hypos || !parg->hypos->len);
pgf_parsing_td_predict(before, after, item, parg->ccat, scat->r);
pgf_parsing_td_predict(before, after, item, parg->ccat, scat->r, 0);
break;
}
case PGF_SYMBOL_KS: {
@@ -1105,7 +1155,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
if (parg->ccat->fid > 0 &&
parg->ccat->fid >= before->ps->concr->total_cats)
pgf_parsing_td_predict(before, after, item, parg->ccat, slit->r);
pgf_parsing_td_predict(before, after, item, parg->ccat, slit->r, 0);
else {
PgfItemBuf* conts =
pgf_parsing_get_conts(before->conts_map,
@@ -1133,7 +1183,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
pext->callback = callback;
pgf_parsing_production(before, parg->ccat, slit->r,
prod, conts);
prod, conts, 0);
}
} else {
/* If it has already been completed, combine. */
@@ -1168,6 +1218,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after,
}
PgfParseState *meta_after = NULL;
static PgfLiteralCallback pgf_meta_callback;
static void
pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
@@ -1202,7 +1253,7 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
case 0:
pgf_parsing_td_predict(before, after, item,
pcoerce->coerce,
item->base->lin_idx);
item->base->lin_idx, 0);
break;
case 1:
pgf_parsing_complete(before, after, item, NULL);
@@ -1241,6 +1292,14 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
pgf_parsing_complete(before, after, item, before->meta_ep);
if (accepted && after != NULL) {
if (pext->callback == &pgf_meta_callback) {
float meta_token_prob =
item->base->ccat->cnccat->abscat->meta_token_prob;
if (meta_token_prob == INFINITY)
break;
item->inside_prob += meta_token_prob;
}
PgfSymbol prev = item->curr_sym;
PgfSymbolKS* sks = (PgfSymbolKS*)
gu_alloc_variant(PGF_SYMBOL_KS,
@@ -1265,6 +1324,48 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
}
}
typedef struct {
GuMapItor fn;
PgfParseState* before;
PgfParseState* after;
PgfItem* meta_item;
} PgfMetaPredictFn;
static void
pgf_parsing_meta_predict(GuMapItor* fn, const void* key, void* value, GuExn* err)
{
(void) (err);
PgfCId abscat = *((PgfCId*) key);
float meta_prob = *((float*) value);
PgfMetaPredictFn* clo = (PgfMetaPredictFn*) fn;
PgfParseState* before = clo->before;
PgfParseState* after = clo->after;
PgfItem* meta_item = clo->meta_item;
{
GuPool* tmp_pool = gu_new_pool();
GuOut* out = gu_file_out(stdout, tmp_pool);
GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool);
GuExn* err = gu_exn(NULL, type, tmp_pool);
gu_string_write(abscat, wtr, err);
gu_pool_free(tmp_pool);
}
PgfCncCat* cnccat =
gu_map_get(before->ps->concr->cnccats, &abscat, PgfCncCat*);
if (cnccat == NULL)
return;
size_t n_cats = gu_list_length(cnccat->cats);
for (size_t i = 0; i < n_cats; i++) {
PgfCCat* ccat = gu_list_index(cnccat->cats, i);
for (size_t lin_idx = 0; lin_idx < cnccat->n_lins; lin_idx++) {
pgf_parsing_td_predict(before, after,
meta_item, ccat, lin_idx, meta_prob);
}
}
}
static bool
pgf_match_meta(PgfConcr* concr, PgfItem *item, PgfToken tok,
PgfExprProb** out_ep, GuPool *pool)
@@ -1298,14 +1399,18 @@ pgf_match_meta(PgfConcr* concr, PgfItem *item, PgfToken tok,
PgfParseState* before =
gu_container(out_ep, PgfParseState, meta_ep);
size_t n_items = gu_buf_length(after->ts->lexicon_idx);
for (size_t i = 0; i < n_items; i++) {
PgfItem* item_ =
gu_buf_get(after->ts->lexicon_idx, PgfItem*, i);
pgf_parsing_bu_predict(before, after,
item_, item, after->agenda, false);
after->ps->target = item_;
PgfCIdMap* meta_child_probs =
item->base->ccat->cnccat->abscat->meta_child_probs;
if (meta_child_probs != NULL) {
PgfMetaPredictFn clo = { { pgf_parsing_meta_predict }, before, after, item };
gu_map_iter(meta_child_probs, &clo.fn, NULL);
}
/*
fprintf(stderr, "------------------------------------\n");
pgf_parsing_bu_predict(before, after,
after->ts->lexicon_idx, item,
NULL);
fprintf(stderr, "------------------------------------\n");*/
return false;
}
}
@@ -1651,14 +1756,14 @@ pgf_parser_init_state(PgfConcr* concr, PgfCId cat, size_t lin_idx, GuPool* pool)
PgfProduction prod =
gu_seq_get(prods, PgfProduction, i);
PgfItem* item =
pgf_new_item(0, ccat, lin_idx, prod, conts, pool);
pgf_new_item(0, ccat, lin_idx, prod, conts, 0, pool);
gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item);
}
PgfItem *item =
pgf_new_item(0, ccat, lin_idx, ps->meta_prod, conts, pool);
pgf_new_item(0, ccat, lin_idx, ps->meta_prod, conts, 0, pool);
item->inside_prob =
1000000 + ccat->cnccat->abscat->meta_prob * 1000;
ccat->cnccat->abscat->meta_prob;
gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item);
}
}
@@ -1896,7 +2001,7 @@ pgf_parser_bu_index(PgfConcr* concr, PgfCCat* ccat, PgfProduction prod,
pgf_parsing_get_conts(conts_map, ccat, lin_idx,
pool, tmp_pool);
PgfItem* item =
pgf_new_item(0, ccat, lin_idx, prod, conts, pool);
pgf_new_item(0, ccat, lin_idx, prod, conts, 0, pool);
pgf_parser_bu_item(concr, item, conts_map, pool, tmp_pool);
}

View File

@@ -69,6 +69,9 @@ pgf_read(GuIn* in, GuPool* pool, GuExn* err);
*/
bool
pgf_load_meta_child_probs(PgfPGF*, const char* fpath, GuPool* pool);
#include <gu/type.h>
extern GU_DECLARE_TYPE(PgfPGF, struct);

View File

@@ -30,6 +30,7 @@
#include <gu/exn.h>
#include <gu/utf8.h>
#include <math.h>
#include <stdio.h>
#define GU_LOG_ENABLE
#include <gu/log.h>
@@ -656,6 +657,8 @@ pgf_compute_meta_probs(GuMapItor* fn, const void* key, void* value, GuExn* err)
mass += cat->functions[i].prob;
}
cat->meta_prob = - log(fabs(1 - mass));
cat->meta_token_prob = INFINITY;
cat->meta_child_probs = NULL;
}
static void
@@ -936,3 +939,51 @@ pgf_read(GuIn* in, GuPool* pool, GuExn* err)
gu_return_on_exn(err, NULL);
return pgf;
}
bool
pgf_load_meta_child_probs(PgfPGF* pgf, const char* fpath, GuPool* pool)
{
FILE *fp = fopen(fpath, "r");
if (!fp)
return false;
GuPool* tmp_pool = gu_new_pool();
for (;;) {
char cat1_s[21];
char cat2_s[21];
float prob;
if (fscanf(fp, "%20s\t%20s\t%f", cat1_s, cat2_s, &prob) < 3)
break;
prob = - log(prob);
GuString cat1 = gu_str_string(cat1_s, tmp_pool);
PgfCat* abscat1 =
gu_map_get(pgf->abstract.cats, &cat1, PgfCat*);
if (abscat1 == NULL)
return false;
if (strcmp(cat2_s, "_") == 0) {
abscat1->meta_token_prob = prob;
} else {
GuString cat2 = gu_str_string(cat2_s, tmp_pool);
PgfCat* abscat2 = gu_map_get(pgf->abstract.cats, &cat2, PgfCat*);
if (abscat2 == NULL)
return false;
if (abscat1->meta_child_probs == NULL) {
abscat1->meta_child_probs =
gu_map_type_new(PgfMetaChildMap, pool);
}
gu_map_put(abscat1->meta_child_probs, abscat2, float, prob);
}
}
gu_pool_free(tmp_pool);
fclose(fp);
return true;
}

View File

@@ -87,6 +87,12 @@ int main(int argc, char* argv[]) {
goto fail_read;
}
if (!pgf_load_meta_child_probs(pgf, "../../../examples/PennTreebank/test2.probs", pool)) {
fprintf(stderr, "Loading meta child probs failed\n");
status = EXIT_FAILURE;
goto fail_read;
}
// Look up the source and destination concrete categories
PgfConcr* from_concr =
gu_map_get(pgf->concretes, &from_lang, PgfConcr*);