diff --git a/src/runtime/c/pgf/parser.c b/src/runtime/c/pgf/parser.c index e69ff173e..6f02ed730 100644 --- a/src/runtime/c/pgf/parser.c +++ b/src/runtime/c/pgf/parser.c @@ -55,6 +55,7 @@ typedef struct { int prod_full_count; #endif PgfItem* free_item; + prob_t beam_size; } PgfParsing; typedef struct { @@ -75,6 +76,7 @@ GU_DEFINE_TYPE(PgfProductionIdx, GuMap, typedef struct { PgfToken tok; PgfProductionIdx* lexicon_idx; + prob_t lexical_prob; } PgfTokenState; struct PgfParseState { @@ -1483,8 +1485,12 @@ pgf_parsing_proceed(PgfParseState* state) { } } - delta_prob += - (st->viterbi_prob-(st->next ? st->next->viterbi_prob : 0))*0.95; + prob_t state_delta = + (st->viterbi_prob-(st->next ? st->next->viterbi_prob : 0))* + state->ps->beam_size; + prob_t lexical_prob = + st->ts ? st->ts->lexical_prob : 0; + delta_prob += fmax(state_delta, lexical_prob); st = st->next; } @@ -1532,6 +1538,7 @@ pgf_new_parsing(PgfConcr* concr, GuPool* pool) ps->prod_full_count = 0; #endif ps->free_item = NULL; + ps->beam_size = 0.95; PgfExprMeta *expr_meta = gu_new_variant(PGF_EXPR_META, @@ -1569,6 +1576,38 @@ pgf_new_parse_state(PgfParsing* ps, return state; } +typedef struct { + GuMapItor fn; + PgfTokenState* ts; +} PgfLexiconFn; + +static void +pgf_parser_compute_lexicon_prob(GuMapItor* fn, const void* key, void* value, GuExn* err) +{ + PgfTokenState* ts = ((PgfLexiconFn*) fn)->ts; + PgfProductionSeq prods = *((PgfProductionSeq*) value); + + if (gu_seq_is_null(prods)) + return; + + size_t n_prods = gu_seq_length(prods); + for (size_t i = 0; i < n_prods; i++) { + PgfProduction prod = + gu_seq_get(prods, PgfProduction, i); + + GuVariantInfo pi = gu_variant_open(prod); + switch (pi.tag) { + case PGF_PRODUCTION_APPLY: { + PgfProductionApply* papp = pi.data; + if (ts->lexical_prob > papp->fun->ep->prob) { + ts->lexical_prob = papp->fun->ep->prob; + } + break; + } + } + } +} + static PgfTokenState* pgf_new_token_state(PgfConcr *concr, PgfToken tok, GuPool* pool) { @@ -1576,6 +1615,11 @@ pgf_new_token_state(PgfConcr *concr, PgfToken tok, GuPool* pool) ts->tok = tok; ts->lexicon_idx = gu_map_get(concr->leftcorner_tok_idx, &tok, PgfProductionIdx*); + ts->lexical_prob = INFINITY; + PgfLexiconFn clo = { { pgf_parser_compute_lexicon_prob }, ts }; + gu_map_iter(ts->lexicon_idx, &clo.fn, NULL); + if (ts->lexical_prob == INFINITY) + ts->lexical_prob = 0; return ts; } @@ -1969,6 +2013,12 @@ pgf_parser_init_state(PgfConcr* concr, PgfCId cat, size_t lin_idx, GuPool* pool) return state; } +void +pgf_parser_set_beam_size(PgfParseState* state, double beam_size) +{ + state->ps->beam_size = beam_size; +} + void pgf_parser_add_literal(PgfConcr *concr, PgfCId cat, PgfLiteralCallback* callback) diff --git a/src/runtime/c/pgf/parser.h b/src/runtime/c/pgf/parser.h index 6952d33ab..c4c394aad 100644 --- a/src/runtime/c/pgf/parser.h +++ b/src/runtime/c/pgf/parser.h @@ -67,6 +67,9 @@ pgf_parser_next_state(PgfParseState* prev, PgfToken tok, * the pool used to create \parse. */ +void +pgf_parser_set_beam_size(PgfParseState* state, double beam_size); + void pgf_parser_add_literal(PgfConcr *concr, PgfCId cat, PgfLiteralCallback* callback);