diff --git a/src/runtime/c/pgf/lookup.c b/src/runtime/c/pgf/lookup.c index 700991222..fc5b0d367 100644 --- a/src/runtime/c/pgf/lookup.c +++ b/src/runtime/c/pgf/lookup.c @@ -354,6 +354,7 @@ typedef struct { GuChoice* choice; GuBuf* expr_tokens; GuBuf* ctrees; + int max_fid; PgfAbsFun** curr_absfun; GuPool* pool; } PgfLookupState; @@ -366,75 +367,6 @@ typedef struct { GuPool* out_pool; } PgfLookupEnum; -static PgfCncTree -pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat); - -static PgfCncTree -pgf_lookup_extract_app(PgfLookupState* st, - PgfCCat* ccat, GuBuf* buf, - size_t n_args, PgfMetaId* args) -{ - GuChoiceMark mark = gu_choice_mark(st->choice); - - PgfCncTree ret = gu_null_variant; - PgfCncTreeApp* capp = - gu_new_flex_variant(PGF_CNC_TREE_APP, - PgfCncTreeApp, - args, n_args, &ret, st->pool); - capp->ccat = ccat; - capp->n_vars = 0; - capp->context = NULL; - -redo:; - int index = gu_choice_next(st->choice, gu_buf_length(buf)); - if (index < 0) { - return gu_null_variant; - } - - PgfProductionApply* papply = - gu_buf_get(buf, PgfProductionApply*, index); - gu_assert(n_args == gu_seq_length(papply->args)); - - capp->fun = papply->fun; - capp->fid = 0; - capp->n_args = n_args; - - for (size_t i = 0; i < n_args; i++) { - PgfPArg* parg = gu_seq_index(papply->args, PgfPArg, i); - PgfMetaId meta_id = args[i]; - - PgfCCat* ccat = NULL; - GuBuf* coercions = - gu_map_get(st->concr->coerce_idx, parg->ccat, GuBuf*); - if (coercions == NULL) { - ccat = parg->ccat; - } else { - int index = gu_choice_next(st->choice, gu_buf_length(coercions)); - if (index < 0) { - gu_choice_reset(st->choice, mark); - if (!gu_choice_advance(st->choice)) - return gu_null_variant; - goto redo; - } - - PgfProductionCoerce* pcoerce = - gu_buf_get(coercions, PgfProductionCoerce*, index); - ccat = pcoerce->coerce; - } - - capp->args[i] = - pgf_lookup_extract(st, meta_id, ccat); - if (gu_variant_is_null(capp->args[i])) { - gu_choice_reset(st->choice, mark); - if (!gu_choice_advance(st->choice)) - return gu_null_variant; - goto redo; - } - } - - return ret; -} - static bool pgf_lookup_filter(GuBuf* join, PgfMetaId meta_id, GuSeq* counts, GuBuf* stack) { @@ -492,57 +424,222 @@ pgf_lookup_filter(GuBuf* join, PgfMetaId meta_id, GuSeq* counts, GuBuf* stack) return true; } -typedef struct { - GuMapItor fn; - int index; - PgfCCat* ccat; - GuBuf* buf; -} PgfCncItor; - static void -pgf_cnc_cat_resolve_itor(GuMapItor* fn, const void* key, void* value, GuExn* err) +gu_ccat_fini(GuFinalizer* fin) { - PgfCncItor* clo = (PgfCncItor*) fn; - PgfCCat* ccat = (PgfCCat*) key; - GuBuf* buf = *((GuBuf**) value); - - if (clo->index == 0) { - clo->ccat = ccat; - clo->buf = buf; - } - - clo->index--; + PgfCCat* cat = gu_container(fin, PgfCCat, fin); + if (cat->prods != NULL) + gu_seq_free(cat->prods); } -static PgfCncTree -pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat) +static PgfCCat* +pgf_lookup_new_ccat(PgfLookupState* st, PgfCCat* ccat) { - PgfCncTree ret = gu_null_variant; + PgfCCat* new_ccat = gu_new_flex(st->pool, PgfCCat, fin, 1); + new_ccat->cnccat = ccat->cnccat; + new_ccat->lindefs = ccat->lindefs; + new_ccat->linrefs = ccat->linrefs; + new_ccat->viterbi_prob = 0; + new_ccat->fid = st->max_fid++; + new_ccat->conts = NULL; + new_ccat->answers = NULL; + new_ccat->prods = NULL; + new_ccat->n_synprods = 0; + + new_ccat->fin[0].fn = gu_ccat_fini; + gu_pool_finally(st->pool, new_ccat->fin); + + return new_ccat; +} + +static PgfCCat* +pgf_lookup_concretize(PgfLookupState* st, GuMap* cache, PgfMetaId meta_id, PgfCCat *ccat); + +static PgfCCat* +pgf_lookup_concretize_coercions(PgfLookupState* st, GuMap* cache, + PgfMetaId meta_id, PgfCCat* ccat, + GuBuf* coercions) +{ + PgfPair pair; + pair[0] = meta_id; + pair[1] = ccat->fid; + PgfCCat** pnew_ccat = gu_map_find(cache, &pair); + if (pnew_ccat != NULL) + return *pnew_ccat; + + PgfCCat* new_ccat = NULL; + + size_t n_coercions = gu_buf_length(coercions); + for (size_t i = 0; i < n_coercions; i++) { + PgfProductionCoerce* pcoerce = + gu_buf_get(coercions, PgfProductionCoerce*, i); + + PgfCCat* new_coerce = + pgf_lookup_concretize(st, cache, meta_id, pcoerce->coerce); + if (new_coerce == NULL) + continue; + + if (new_ccat == NULL) { + new_ccat = pgf_lookup_new_ccat(st, ccat); + } + + PgfProduction cnc_prod; + PgfProductionCoerce* new_pcoerce = + gu_new_variant(PGF_PRODUCTION_COERCE, + PgfProductionCoerce, + &cnc_prod, st->pool); + new_pcoerce->coerce = new_coerce; + + if (new_ccat->prods == NULL || new_ccat->n_synprods >= gu_seq_length(new_ccat->prods)) { + new_ccat->prods = gu_realloc_seq(new_ccat->prods, PgfProduction, new_ccat->n_synprods+(n_coercions-i)); + } + gu_seq_set(new_ccat->prods, PgfProduction, new_ccat->n_synprods++, cnc_prod); + +#ifdef PGF_LOOKUP_DEBUG + { + GuPool* tmp_pool = gu_new_pool(); + GuOut* out = gu_file_out(stderr, tmp_pool); + GuExn* err = gu_exn(tmp_pool); + gu_printf(out,err,"C%d -> _[C%d]\n",new_ccat->fid,new_pcoerce->coerce->fid); + gu_pool_free(tmp_pool); + } +#endif + } + + gu_map_put(cache, &pair, PgfCCat*, new_ccat); + + return new_ccat; +} + +static PgfCCat* +pgf_lookup_concretize(PgfLookupState* st, GuMap* cache, PgfMetaId meta_id, PgfCCat *ccat) +{ + if (meta_id == 0) { + if (ccat->lindefs == NULL || gu_seq_length(ccat->lindefs) == 0) + return NULL; + return ccat; + } + + PgfPair pair; + pair[0] = meta_id; + pair[1] = ccat->fid; + PgfCCat** pnew_ccat = gu_map_find(cache, &pair); + if (pnew_ccat != NULL) + return *pnew_ccat; + + PgfCCat* new_ccat = NULL; GuBuf* id_prods = gu_buf_get(st->join, GuBuf*, meta_id); - if (id_prods == NULL || gu_buf_length(id_prods) == 0) { - PgfCncTree chunks_tree; - PgfCncTreeChunks* chunks = - gu_new_flex_variant(PGF_CNC_TREE_CHUNKS, - PgfCncTreeChunks, - args, 0, &chunks_tree, st->pool); - chunks->n_vars = 0; - chunks->context = NULL; - chunks->n_args = 0; + size_t n_id_prods = gu_buf_length(id_prods); + for (size_t i = 0; i < n_id_prods; i++) { + PgfAbsProduction* prod = + gu_buf_get(id_prods, PgfAbsProduction*, i); - if (ccat == NULL) { - return chunks_tree; - } - if (ccat->lindefs == NULL) { - return ret; - } + PgfCncOverloadMap* overl_table = + gu_map_get(st->concr->fun_indices, prod->fun->name, PgfCncOverloadMap*); + if (overl_table == NULL) + continue; + GuBuf* buf = + gu_map_get(overl_table, ccat, GuBuf*); + if (buf == NULL) + continue; + + size_t n_prods = gu_buf_length(buf); + for (size_t j = 0; j < n_prods; j++) { + PgfProductionApply* papply = + gu_buf_get(buf, PgfProductionApply*, j); + + size_t n_args = gu_seq_length(papply->args); + GuSeq* new_args = gu_new_seq(PgfPArg, n_args, st->pool); + for (size_t k = 0; k < n_args; k++) { + PgfPArg* parg = gu_seq_index(papply->args, PgfPArg, k); + PgfPArg* new_parg = gu_seq_index(new_args, PgfPArg, k); + + new_parg->hypos = parg->hypos; + + GuBuf* coercions = + gu_map_get(st->concr->coerce_idx, parg->ccat, GuBuf*); + if (coercions == NULL) { + new_parg->ccat = + pgf_lookup_concretize(st, cache, prod->args[k], parg->ccat); + } else { + new_parg->ccat = + pgf_lookup_concretize_coercions(st, cache, prod->args[k], parg->ccat, coercions); + } + + if (new_parg->ccat == NULL) + goto skip; + } + + if (new_ccat == NULL) { + new_ccat = pgf_lookup_new_ccat(st, ccat); + } + + PgfProduction cnc_prod; + PgfProductionApply* new_papp = + gu_new_variant(PGF_PRODUCTION_APPLY, + PgfProductionApply, + &cnc_prod, st->pool); + new_papp->fun = papply->fun; + new_papp->args = new_args; + + if (new_ccat->prods == NULL || new_ccat->n_synprods >= gu_seq_length(new_ccat->prods)) { + new_ccat->prods = gu_realloc_seq(new_ccat->prods, PgfProduction, new_ccat->n_synprods+(n_prods-j)); + } + gu_seq_set(new_ccat->prods, PgfProduction, new_ccat->n_synprods++, cnc_prod); + +#ifdef PGF_LOOKUP_DEBUG + { + GuPool* tmp_pool = gu_new_pool(); + GuOut* out = gu_file_out(stderr, tmp_pool); + GuExn* err = gu_exn(tmp_pool); + + gu_printf(out,err,"C%d -> F%d[",new_ccat->fid,new_papp->fun->funid); + + size_t n_args = gu_seq_length(new_papp->args); + for (size_t l = 0; l < n_args; l++) { + if (l > 0) + gu_putc(',',out,err); + + PgfPArg arg = gu_seq_get(new_papp->args, PgfPArg, l); + + if (arg.hypos != NULL) { + size_t n_hypos = gu_seq_length(arg.hypos); + for (size_t r = 0; r < n_hypos; r++) { + if (r > 0) + gu_putc(' ',out,err); + PgfCCat *hypo = gu_seq_get(arg.hypos, PgfCCat*, r); + gu_printf(out,err,"C%d",hypo->fid); + } + } + + gu_printf(out,err,"C%d",arg.ccat->fid); + } + gu_printf(out,err,"]\n"); + gu_pool_free(tmp_pool); + } +#endif + +skip:; + } + } + + gu_map_put(cache, &pair, PgfCCat*, new_ccat); + + return new_ccat; +} + +static PgfCncTree +pgf_lookup_extract(PgfLookupState* st, PgfCCat* ccat) +{ + PgfCncTree ret; + + if (ccat->fid < st->concr->total_cats) { int index = gu_choice_next(st->choice, gu_seq_length(ccat->lindefs)); - if (index < 0) { - return ret; - } + PgfCncTreeApp* capp = gu_new_flex_variant(PGF_CNC_TREE_APP, PgfCncTreeApp, @@ -553,55 +650,53 @@ pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat) capp->n_vars = 0; capp->context = NULL; capp->n_args = 1; - capp->args[0] = chunks_tree; - return ret; - } - - size_t n_id_prods = gu_buf_length(id_prods); - - size_t i = gu_choice_next(st->choice, n_id_prods); - PgfAbsProduction* prod = - gu_buf_get(id_prods, PgfAbsProduction*, i); - - size_t n_args = gu_seq_length(prod->fun->type->hypos); - - PgfCncOverloadMap* overl_table = - gu_map_get(st->concr->fun_indices, prod->fun->name, PgfCncOverloadMap*); - if (overl_table == NULL) { - return gu_null_variant; - } - - if (ccat == NULL) { - size_t n_count = gu_map_count(overl_table); - GuChoiceMark mark = gu_choice_mark(st->choice); - -redo:; - int index = gu_choice_next(st->choice, n_count); - if (index < 0) { - goto done; - } - - PgfCncItor clo = { { pgf_cnc_cat_resolve_itor }, index, NULL, NULL }; - gu_map_iter(overl_table, &clo.fn, NULL); - assert(clo.ccat != NULL && clo.buf != NULL); - - ret = pgf_lookup_extract_app(st, clo.ccat, clo.buf, n_args, prod->args); - if (gu_variant_is_null(ret)) { - gu_choice_reset(st->choice, mark); - if (gu_choice_advance(st->choice)) - goto redo; - } + PgfCncTreeChunks* chunks = + gu_new_flex_variant(PGF_CNC_TREE_CHUNKS, + PgfCncTreeChunks, + args, 0, &capp->args[0], st->pool); + chunks->n_vars = 0; + chunks->context = NULL; + chunks->n_args = 0; } else { - GuBuf* buf = - gu_map_get(overl_table, ccat, GuBuf*); - if (buf == NULL) { - goto done; - } + int index = + gu_choice_next(st->choice, ccat->n_synprods); - ret = pgf_lookup_extract_app(st, ccat, buf, n_args, prod->args); + PgfProduction prod = + gu_seq_get(ccat->prods, PgfProduction, index); + + GuVariantInfo i = gu_variant_open(prod); + switch (i.tag) { + case PGF_PRODUCTION_APPLY: { + PgfProductionApply* papply = i.data; + + size_t n_args = gu_seq_length(papply->args); + PgfCncTreeApp* capp = + gu_new_flex_variant(PGF_CNC_TREE_APP, + PgfCncTreeApp, + args, n_args, &ret, st->pool); + capp->ccat = ccat; + capp->fun = papply->fun; + capp->fid = 0; + capp->n_vars = 0; + capp->context = NULL; + capp->n_args = n_args; + + for (size_t i = 0; i < n_args; i++) { + PgfPArg* arg = gu_seq_index(papply->args, PgfPArg, i); + capp->args[i] = pgf_lookup_extract(st, arg->ccat); + } + break; + } + case PGF_PRODUCTION_COERCE: { + PgfProductionCoerce* pcoerce = i.data; + ret = pgf_lookup_extract(st, pcoerce->coerce); + break; + } + default: + gu_impossible(); + } } -done: return ret; } @@ -911,20 +1006,32 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po st.expr_tokens=gu_new_buf(PgfInputToken, work_pool); st.ctrees = gu_new_buf(PgfCncTreeScore, pool); st.curr_absfun= NULL; + st.max_fid = concr->total_cats; st.pool = pool; - GuChoiceMark mark = gu_choice_mark(st.choice); - + GuMap* cache = gu_new_map(PgfPair, pgf_pair_hasher, PgfCCat*, &gu_null_struct, pool); + double sentence_value = pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens); double max = 0; - PgfCncTreeScore* cts = gu_buf_extend(st.ctrees); - for (;;) { - cts->ctree = - pgf_lookup_extract(&st, st.start_id, NULL); - if (!gu_variant_is_null(cts->ctree)) { + PgfCncCat* cnccat = + gu_map_get(concr->cnccats, typ->cid, PgfCncCat*); + size_t n_ccats = gu_seq_length(cnccat->cats); + for (size_t i = 0; i < n_ccats; i++) { + PgfCCat* ccat = gu_seq_get(cnccat->cats, PgfCCat*, i); + PgfCCat* new_ccat = pgf_lookup_concretize(&st, cache, st.start_id, ccat); + if (new_ccat == NULL) + continue; + + GuChoiceMark mark = gu_choice_mark(st.choice); + + for (;;) { + PgfCncTreeScore* cts = gu_buf_extend(st.ctrees); + cts->ctree = + pgf_lookup_extract(&st, new_ccat); + cts->ctree = pgf_lzr_wrap_linref(cts->ctree, st.pool); pgf_lzr_linearize(concr, cts->ctree, 0, &st.funcs, st.pool); @@ -953,18 +1060,16 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po max = cts->score; } - cts = gu_buf_extend(st.ctrees); - } - - gu_choice_reset(st.choice, mark); + gu_choice_reset(st.choice, mark); - if (!gu_choice_advance(st.choice)) - break; + if (!gu_choice_advance(st.choice)) + break; + } } - gu_buf_trim(st.ctrees); gu_pool_free(work_pool); + PgfLookupEnum* lenum = gu_new(PgfLookupEnum, pool); lenum->en.next = pgf_lookup_enum_next; lenum->max = max;