diff --git a/src/runtime/c/pgf/data.c b/src/runtime/c/pgf/data.c index 36729b23f..74dba9cb8 100644 --- a/src/runtime/c/pgf/data.c +++ b/src/runtime/c/pgf/data.c @@ -3,6 +3,7 @@ #include #include #include +#include bool pgf_tokens_equal(PgfTokens t1, PgfTokens t2) @@ -184,6 +185,12 @@ GU_DEFINE_TYPE( GU_MEMBER(PgfCatFun, prob, double), GU_MEMBER(PgfCatFun, fun, PgfCId)); +static float inf_float = INFINITY; + +GU_DEFINE_TYPE(PgfMetaChildMap, GuMap, + gu_type(PgfCat), NULL, + gu_type(float), &inf_float); + GU_DEFINE_TYPE( PgfCat, struct, GU_MEMBER(PgfCat, context, PgfHypos), diff --git a/src/runtime/c/pgf/data.h b/src/runtime/c/pgf/data.h index 63c26d318..7fe2fc7d3 100644 --- a/src/runtime/c/pgf/data.h +++ b/src/runtime/c/pgf/data.h @@ -145,11 +145,16 @@ struct PgfCatFun { PgfCId fun; }; +typedef GuMap PgfMetaChildMap; +extern GU_DECLARE_TYPE(PgfMetaChildMap, GuMap); + struct PgfCat { // TODO: Add cid here PgfHypos context; float meta_prob; + float meta_token_prob; + PgfMetaChildMap* meta_child_probs; GuLength n_functions; PgfCatFun functions[]; // XXX: resolve to PgfFunDecl*? diff --git a/src/runtime/c/pgf/parser.c b/src/runtime/c/pgf/parser.c index 3d97b5a39..a05600884 100644 --- a/src/runtime/c/pgf/parser.c +++ b/src/runtime/c/pgf/parser.c @@ -492,7 +492,9 @@ pgf_item_set_curr_symbol(PgfItem* item, GuPool* pool) static PgfItem* pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx, - PgfProduction prod, PgfItemBuf* conts, GuPool* pool) + PgfProduction prod, PgfItemBuf* conts, + float delta_prob, + GuPool* pool) { PgfItemBase* base = gu_new(PgfItemBase, pool); base->ccat = ccat; @@ -557,6 +559,7 @@ pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx, best_cont->inside_prob-ccat->viterbi_prob+ best_cont->outside_prob; } + item->outside_prob += delta_prob; pgf_item_set_curr_symbol(item, pool); return item; @@ -650,7 +653,12 @@ pgf_parsing_combine(PgfParseState* before, PgfParseState* after, nargs * sizeof(PgfPArg)); gu_seq_set(item->args, PgfPArg, nargs, ((PgfPArg) { .hypos = NULL, .ccat = cat })); - item->inside_prob += cat->viterbi_prob; + + PgfCIdMap* meta_child_probs = + item->base->ccat->cnccat->abscat->meta_child_probs; + item->inside_prob += + cat->viterbi_prob+ + gu_map_get(meta_child_probs, cat->cnccat->abscat, float); PgfSymbol prev = item->curr_sym; PgfSymbolCat* scat = (PgfSymbolCat*) @@ -673,10 +681,11 @@ pgf_parsing_combine(PgfParseState* before, PgfParseState* after, static void pgf_parsing_production(PgfParseState* state, PgfCCat* ccat, size_t lin_idx, - PgfProduction prod, PgfItemBuf* conts) + PgfProduction prod, PgfItemBuf* conts, + float delta_prob) { PgfItem* item = - pgf_new_item(state->offset, ccat, lin_idx, prod, conts, state->pool); + pgf_new_item(state->offset, ccat, lin_idx, prod, conts, delta_prob, state->pool); gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item); } @@ -798,7 +807,7 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after, * i.e. process it. */ if (conts2) { pgf_parsing_production(before, cat, i, - prod, conts2); + prod, conts2, 0); } } @@ -818,7 +827,7 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after, * i.e. process it. */ if (conts2) { pgf_parsing_production(state, cat, i, - prod, conts2); + prod, conts2, 0); } } @@ -835,14 +844,15 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after, static void pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after, - PgfItem* item, PgfCCat* ccat, size_t lin_idx) + PgfItem* item, PgfCCat* ccat, size_t lin_idx, + float delta_prob) { gu_enter("-> cat: %d", ccat->fid); if (gu_seq_is_null(ccat->prods)) { // Empty category return; } - + PgfItemBuf* conts = pgf_parsing_get_conts(before->conts_map, ccat, lin_idx, before->pool, before->pool); @@ -856,17 +866,17 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after, PgfProductionSeq prods = ccat->prods; for (size_t i = 0; i < ccat->n_synprods; i++) { PgfProduction prod = - gu_seq_get(prods, PgfProduction, i); - pgf_parsing_production(before, ccat, lin_idx, prod, conts); + gu_seq_get(prods, PgfProduction, i); + pgf_parsing_production(before, ccat, lin_idx, prod, conts, delta_prob); } - + if (ccat->cnccat->abscat->meta_prob != INFINITY && ccat->fid < before->ps->concr->total_cats) { // Top-down prediction for meta rules PgfItem *item = - pgf_new_item(before->offset, ccat, lin_idx, before->ps->meta_prod, conts, before->pool); + pgf_new_item(before->offset, ccat, lin_idx, before->ps->meta_prod, conts, 0, before->pool); item->inside_prob = - 1000000 + ccat->cnccat->abscat->meta_prob * 1000; + ccat->cnccat->abscat->meta_prob; gu_buf_heap_push(before->agenda, &pgf_item_prob_order, &item); } @@ -880,7 +890,7 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after, new_item->base->lin_idx == lin_idx && gu_seq_length(new_item->args) == 0) { pgf_parsing_production(before, ccat, lin_idx, - new_item->base->prod, conts); + new_item->base->prod, conts, 0); } } } @@ -901,7 +911,7 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after, PgfProductionApply* papp = i.data; if (gu_seq_length(papp->args) == 0) { pgf_parsing_production(before, ccat, lin_idx, - prod, conts); + prod, conts, 0); } break; } @@ -931,67 +941,107 @@ pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after, gu_exit("<-"); } -static void +static float pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after, - PgfItem* item, PgfItem* meta_item, PgfItemBuf* agenda, - bool print) + PgfItemBuf* index, PgfItem* meta_item, + PgfItemBuf* agenda) { - PgfItemBuf* conts = - pgf_parsing_get_conts(before->conts_map, - item->base->ccat, item->base->lin_idx, - before->pool, before->pool); - gu_buf_push(conts, PgfItem*, meta_item); - if (gu_buf_length(conts) == 1) { - PgfItem* copy = pgf_item_copy(item, after->pool); - copy->base = pgf_item_base_copy(item->base, after->pool); - copy->base->conts = conts; - copy->outside_prob = - meta_item->inside_prob+meta_item->outside_prob; - -#ifdef PGF_PARSER_DEBUG - copy->start = before->offset; - copy->end = before->offset; + float prob = INFINITY; - if (print) { - GuPool* tmp_pool = gu_new_pool(); - GuOut* out = gu_file_out(stderr, tmp_pool); - GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool); - GuExn* err = gu_exn(NULL, type, tmp_pool); - pgf_print_item(copy, wtr, err, tmp_pool); - gu_pool_free(tmp_pool); - } else { - copy->end = after->offset; - } + PgfMetaChildMap* meta_child_probs = + meta_item->base->ccat->cnccat->abscat->meta_child_probs; + if (meta_child_probs == NULL) + return prob; + + if (!gu_map_has(before->generated_cats, index)) { + gu_map_put(before->generated_cats, index, PgfCCat*, NULL); + + size_t n_items = gu_buf_length(index); + for (size_t i = 0; i < n_items; i++) { + PgfItem *item = gu_buf_get(index, PgfItem*, i); + + float meta_prob = + meta_item->inside_prob+ + meta_item->outside_prob+ + gu_map_get(meta_child_probs, item->base->ccat->cnccat->abscat, float); + + PgfItemBuf* conts = + pgf_parsing_get_conts(before->conts_map, + item->base->ccat, item->base->lin_idx, + before->pool, before->pool); + if (gu_buf_length(conts) == 0) { + float outside_prob = + pgf_parsing_bu_predict(before, after, + item->base->conts, meta_item, + conts); + + if (outside_prob > meta_prob) + outside_prob = meta_prob; + + for (size_t j = i; j < n_items; j++) { + PgfItem *item_ = gu_buf_get(index, PgfItem*, j); + + if (item->base->conts == item_->base->conts) { + PgfItem* copy = pgf_item_copy(item_, after->pool); + copy->base = pgf_item_base_copy(item_->base, after->pool); + copy->base->conts = conts; + copy->outside_prob = outside_prob; +#ifdef PGF_PARSER_DEBUG + copy->start = before->offset; + copy->end = (agenda == NULL) + ? after->offset + : before->offset; #endif - gu_buf_push(agenda, PgfItem*, copy); + if (agenda == NULL) + pgf_parsing_add_transition(before, after, after->ts->tok, copy); + else + gu_buf_push(agenda, PgfItem*, copy); - size_t n_items = gu_buf_length(item->base->conts); - for (size_t i = 0; i < n_items; i++) { - PgfItem *item_ = gu_buf_get(item->base->conts, PgfItem*, i); - pgf_parsing_bu_predict(before, after, item_, meta_item, conts, true); - } - } else { - /* If it has already been completed, combine. */ + float item_prob = + copy->inside_prob+copy->outside_prob; + if (prob > item_prob) + prob = item_prob; + } + } + } else { + size_t n_items = gu_buf_length(conts); + for (size_t i = 0; i < n_items; i++) { + PgfItem *item = gu_buf_get(conts, PgfItem*, i); + + float item_prob = + item->inside_prob+item->outside_prob; + if (prob > item_prob) + prob = item_prob; + } + prob += item->inside_prob; - /*PgfCCat* completed = - pgf_parsing_get_completed(before, conts); - if (completed) { - pgf_parsing_combine(before, after, meta_item, completed, item->base->lin_idx); - }*/ + /* If it has already been completed, combine. */ - PgfParseState* state = after; - while (state != NULL) { - PgfCCat* completed = - pgf_parsing_get_completed(state, conts); - if (completed) { - pgf_parsing_combine(state, state->next, meta_item, completed, item->base->lin_idx); + /*PgfCCat* completed = + pgf_parsing_get_completed(before, conts); + if (completed) { + pgf_parsing_combine(before, after, meta_item, completed, item->base->lin_idx); + }*/ + + PgfParseState* state = after; + while (state != NULL) { + PgfCCat* completed = + pgf_parsing_get_completed(state, conts); + if (completed) { + pgf_parsing_combine(state, state->next, meta_item, completed, item->base->lin_idx); + } + + state = state->next; + } } - - state = state->next; + + if (meta_prob != INFINITY) + gu_buf_push(conts, PgfItem*, meta_item); } } + return prob; } static void @@ -1002,7 +1052,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after, PgfSymbolCat* scat = gu_variant_data(sym); PgfPArg* parg = gu_seq_index(item->args, PgfPArg, scat->d); gu_assert(!parg->hypos || !parg->hypos->len); - pgf_parsing_td_predict(before, after, item, parg->ccat, scat->r); + pgf_parsing_td_predict(before, after, item, parg->ccat, scat->r, 0); break; } case PGF_SYMBOL_KS: { @@ -1105,7 +1155,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after, if (parg->ccat->fid > 0 && parg->ccat->fid >= before->ps->concr->total_cats) - pgf_parsing_td_predict(before, after, item, parg->ccat, slit->r); + pgf_parsing_td_predict(before, after, item, parg->ccat, slit->r, 0); else { PgfItemBuf* conts = pgf_parsing_get_conts(before->conts_map, @@ -1133,7 +1183,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after, pext->callback = callback; pgf_parsing_production(before, parg->ccat, slit->r, - prod, conts); + prod, conts, 0); } } else { /* If it has already been completed, combine. */ @@ -1168,6 +1218,7 @@ pgf_parsing_symbol(PgfParseState* before, PgfParseState* after, } PgfParseState *meta_after = NULL; +static PgfLiteralCallback pgf_meta_callback; static void pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item) @@ -1202,7 +1253,7 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item) case 0: pgf_parsing_td_predict(before, after, item, pcoerce->coerce, - item->base->lin_idx); + item->base->lin_idx, 0); break; case 1: pgf_parsing_complete(before, after, item, NULL); @@ -1241,6 +1292,14 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item) pgf_parsing_complete(before, after, item, before->meta_ep); if (accepted && after != NULL) { + if (pext->callback == &pgf_meta_callback) { + float meta_token_prob = + item->base->ccat->cnccat->abscat->meta_token_prob; + if (meta_token_prob == INFINITY) + break; + item->inside_prob += meta_token_prob; + } + PgfSymbol prev = item->curr_sym; PgfSymbolKS* sks = (PgfSymbolKS*) gu_alloc_variant(PGF_SYMBOL_KS, @@ -1265,6 +1324,48 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item) } } +typedef struct { + GuMapItor fn; + PgfParseState* before; + PgfParseState* after; + PgfItem* meta_item; +} PgfMetaPredictFn; + +static void +pgf_parsing_meta_predict(GuMapItor* fn, const void* key, void* value, GuExn* err) +{ + (void) (err); + + PgfCId abscat = *((PgfCId*) key); + float meta_prob = *((float*) value); + PgfMetaPredictFn* clo = (PgfMetaPredictFn*) fn; + PgfParseState* before = clo->before; + PgfParseState* after = clo->after; + PgfItem* meta_item = clo->meta_item; +{ + GuPool* tmp_pool = gu_new_pool(); + GuOut* out = gu_file_out(stdout, tmp_pool); + GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool); + GuExn* err = gu_exn(NULL, type, tmp_pool); + gu_string_write(abscat, wtr, err); + gu_pool_free(tmp_pool); +} + PgfCncCat* cnccat = + gu_map_get(before->ps->concr->cnccats, &abscat, PgfCncCat*); + if (cnccat == NULL) + return; + + size_t n_cats = gu_list_length(cnccat->cats); + for (size_t i = 0; i < n_cats; i++) { + PgfCCat* ccat = gu_list_index(cnccat->cats, i); + + for (size_t lin_idx = 0; lin_idx < cnccat->n_lins; lin_idx++) { + pgf_parsing_td_predict(before, after, + meta_item, ccat, lin_idx, meta_prob); + } + } +} + static bool pgf_match_meta(PgfConcr* concr, PgfItem *item, PgfToken tok, PgfExprProb** out_ep, GuPool *pool) @@ -1298,14 +1399,18 @@ pgf_match_meta(PgfConcr* concr, PgfItem *item, PgfToken tok, PgfParseState* before = gu_container(out_ep, PgfParseState, meta_ep); - size_t n_items = gu_buf_length(after->ts->lexicon_idx); - for (size_t i = 0; i < n_items; i++) { - PgfItem* item_ = - gu_buf_get(after->ts->lexicon_idx, PgfItem*, i); - pgf_parsing_bu_predict(before, after, - item_, item, after->agenda, false); - after->ps->target = item_; + PgfCIdMap* meta_child_probs = + item->base->ccat->cnccat->abscat->meta_child_probs; + if (meta_child_probs != NULL) { + PgfMetaPredictFn clo = { { pgf_parsing_meta_predict }, before, after, item }; + gu_map_iter(meta_child_probs, &clo.fn, NULL); } +/* + fprintf(stderr, "------------------------------------\n"); + pgf_parsing_bu_predict(before, after, + after->ts->lexicon_idx, item, + NULL); + fprintf(stderr, "------------------------------------\n");*/ return false; } } @@ -1651,14 +1756,14 @@ pgf_parser_init_state(PgfConcr* concr, PgfCId cat, size_t lin_idx, GuPool* pool) PgfProduction prod = gu_seq_get(prods, PgfProduction, i); PgfItem* item = - pgf_new_item(0, ccat, lin_idx, prod, conts, pool); + pgf_new_item(0, ccat, lin_idx, prod, conts, 0, pool); gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item); } PgfItem *item = - pgf_new_item(0, ccat, lin_idx, ps->meta_prod, conts, pool); + pgf_new_item(0, ccat, lin_idx, ps->meta_prod, conts, 0, pool); item->inside_prob = - 1000000 + ccat->cnccat->abscat->meta_prob * 1000; + ccat->cnccat->abscat->meta_prob; gu_buf_heap_push(state->agenda, &pgf_item_prob_order, &item); } } @@ -1896,7 +2001,7 @@ pgf_parser_bu_index(PgfConcr* concr, PgfCCat* ccat, PgfProduction prod, pgf_parsing_get_conts(conts_map, ccat, lin_idx, pool, tmp_pool); PgfItem* item = - pgf_new_item(0, ccat, lin_idx, prod, conts, pool); + pgf_new_item(0, ccat, lin_idx, prod, conts, 0, pool); pgf_parser_bu_item(concr, item, conts_map, pool, tmp_pool); } diff --git a/src/runtime/c/pgf/pgf.h b/src/runtime/c/pgf/pgf.h index 91659d95e..e14b4c8c8 100644 --- a/src/runtime/c/pgf/pgf.h +++ b/src/runtime/c/pgf/pgf.h @@ -69,6 +69,9 @@ pgf_read(GuIn* in, GuPool* pool, GuExn* err); */ +bool +pgf_load_meta_child_probs(PgfPGF*, const char* fpath, GuPool* pool); + #include extern GU_DECLARE_TYPE(PgfPGF, struct); diff --git a/src/runtime/c/pgf/reader.c b/src/runtime/c/pgf/reader.c index 1fee45f83..08cc16096 100644 --- a/src/runtime/c/pgf/reader.c +++ b/src/runtime/c/pgf/reader.c @@ -30,6 +30,7 @@ #include #include #include +#include #define GU_LOG_ENABLE #include @@ -656,6 +657,8 @@ pgf_compute_meta_probs(GuMapItor* fn, const void* key, void* value, GuExn* err) mass += cat->functions[i].prob; } cat->meta_prob = - log(fabs(1 - mass)); + cat->meta_token_prob = INFINITY; + cat->meta_child_probs = NULL; } static void @@ -936,3 +939,51 @@ pgf_read(GuIn* in, GuPool* pool, GuExn* err) gu_return_on_exn(err, NULL); return pgf; } + +bool +pgf_load_meta_child_probs(PgfPGF* pgf, const char* fpath, GuPool* pool) +{ + FILE *fp = fopen(fpath, "r"); + if (!fp) + return false; + + GuPool* tmp_pool = gu_new_pool(); + + for (;;) { + char cat1_s[21]; + char cat2_s[21]; + float prob; + + if (fscanf(fp, "%20s\t%20s\t%f", cat1_s, cat2_s, &prob) < 3) + break; + + prob = - log(prob); + + GuString cat1 = gu_str_string(cat1_s, tmp_pool); + PgfCat* abscat1 = + gu_map_get(pgf->abstract.cats, &cat1, PgfCat*); + if (abscat1 == NULL) + return false; + + if (strcmp(cat2_s, "_") == 0) { + abscat1->meta_token_prob = prob; + } else { + GuString cat2 = gu_str_string(cat2_s, tmp_pool); + PgfCat* abscat2 = gu_map_get(pgf->abstract.cats, &cat2, PgfCat*); + if (abscat2 == NULL) + return false; + + if (abscat1->meta_child_probs == NULL) { + abscat1->meta_child_probs = + gu_map_type_new(PgfMetaChildMap, pool); + } + + gu_map_put(abscat1->meta_child_probs, abscat2, float, prob); + } + } + + gu_pool_free(tmp_pool); + + fclose(fp); + return true; +} diff --git a/src/runtime/c/utils/pgf-translate.c b/src/runtime/c/utils/pgf-translate.c index aae09e70d..a740d3204 100644 --- a/src/runtime/c/utils/pgf-translate.c +++ b/src/runtime/c/utils/pgf-translate.c @@ -87,6 +87,12 @@ int main(int argc, char* argv[]) { goto fail_read; } + if (!pgf_load_meta_child_probs(pgf, "../../../examples/PennTreebank/test2.probs", pool)) { + fprintf(stderr, "Loading meta child probs failed\n"); + status = EXIT_FAILURE; + goto fail_read; + } + // Look up the source and destination concrete categories PgfConcr* from_concr = gu_map_get(pgf->concretes, &from_lang, PgfConcr*);