faster lookup

This commit is contained in:
krasimir
2017-06-23 17:52:43 +00:00
parent fcde05237a
commit e8726b1cfa

View File

@@ -354,6 +354,7 @@ typedef struct {
GuChoice* choice;
GuBuf* expr_tokens;
GuBuf* ctrees;
int max_fid;
PgfAbsFun** curr_absfun;
GuPool* pool;
} PgfLookupState;
@@ -366,75 +367,6 @@ typedef struct {
GuPool* out_pool;
} PgfLookupEnum;
static PgfCncTree
pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat);
static PgfCncTree
pgf_lookup_extract_app(PgfLookupState* st,
PgfCCat* ccat, GuBuf* buf,
size_t n_args, PgfMetaId* args)
{
GuChoiceMark mark = gu_choice_mark(st->choice);
PgfCncTree ret = gu_null_variant;
PgfCncTreeApp* capp =
gu_new_flex_variant(PGF_CNC_TREE_APP,
PgfCncTreeApp,
args, n_args, &ret, st->pool);
capp->ccat = ccat;
capp->n_vars = 0;
capp->context = NULL;
redo:;
int index = gu_choice_next(st->choice, gu_buf_length(buf));
if (index < 0) {
return gu_null_variant;
}
PgfProductionApply* papply =
gu_buf_get(buf, PgfProductionApply*, index);
gu_assert(n_args == gu_seq_length(papply->args));
capp->fun = papply->fun;
capp->fid = 0;
capp->n_args = n_args;
for (size_t i = 0; i < n_args; i++) {
PgfPArg* parg = gu_seq_index(papply->args, PgfPArg, i);
PgfMetaId meta_id = args[i];
PgfCCat* ccat = NULL;
GuBuf* coercions =
gu_map_get(st->concr->coerce_idx, parg->ccat, GuBuf*);
if (coercions == NULL) {
ccat = parg->ccat;
} else {
int index = gu_choice_next(st->choice, gu_buf_length(coercions));
if (index < 0) {
gu_choice_reset(st->choice, mark);
if (!gu_choice_advance(st->choice))
return gu_null_variant;
goto redo;
}
PgfProductionCoerce* pcoerce =
gu_buf_get(coercions, PgfProductionCoerce*, index);
ccat = pcoerce->coerce;
}
capp->args[i] =
pgf_lookup_extract(st, meta_id, ccat);
if (gu_variant_is_null(capp->args[i])) {
gu_choice_reset(st->choice, mark);
if (!gu_choice_advance(st->choice))
return gu_null_variant;
goto redo;
}
}
return ret;
}
static bool
pgf_lookup_filter(GuBuf* join, PgfMetaId meta_id, GuSeq* counts, GuBuf* stack)
{
@@ -492,57 +424,222 @@ pgf_lookup_filter(GuBuf* join, PgfMetaId meta_id, GuSeq* counts, GuBuf* stack)
return true;
}
typedef struct {
GuMapItor fn;
int index;
PgfCCat* ccat;
GuBuf* buf;
} PgfCncItor;
static void
pgf_cnc_cat_resolve_itor(GuMapItor* fn, const void* key, void* value, GuExn* err)
gu_ccat_fini(GuFinalizer* fin)
{
PgfCncItor* clo = (PgfCncItor*) fn;
PgfCCat* ccat = (PgfCCat*) key;
GuBuf* buf = *((GuBuf**) value);
if (clo->index == 0) {
clo->ccat = ccat;
clo->buf = buf;
}
clo->index--;
PgfCCat* cat = gu_container(fin, PgfCCat, fin);
if (cat->prods != NULL)
gu_seq_free(cat->prods);
}
static PgfCncTree
pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat)
static PgfCCat*
pgf_lookup_new_ccat(PgfLookupState* st, PgfCCat* ccat)
{
PgfCncTree ret = gu_null_variant;
PgfCCat* new_ccat = gu_new_flex(st->pool, PgfCCat, fin, 1);
new_ccat->cnccat = ccat->cnccat;
new_ccat->lindefs = ccat->lindefs;
new_ccat->linrefs = ccat->linrefs;
new_ccat->viterbi_prob = 0;
new_ccat->fid = st->max_fid++;
new_ccat->conts = NULL;
new_ccat->answers = NULL;
new_ccat->prods = NULL;
new_ccat->n_synprods = 0;
new_ccat->fin[0].fn = gu_ccat_fini;
gu_pool_finally(st->pool, new_ccat->fin);
return new_ccat;
}
static PgfCCat*
pgf_lookup_concretize(PgfLookupState* st, GuMap* cache, PgfMetaId meta_id, PgfCCat *ccat);
static PgfCCat*
pgf_lookup_concretize_coercions(PgfLookupState* st, GuMap* cache,
PgfMetaId meta_id, PgfCCat* ccat,
GuBuf* coercions)
{
PgfPair pair;
pair[0] = meta_id;
pair[1] = ccat->fid;
PgfCCat** pnew_ccat = gu_map_find(cache, &pair);
if (pnew_ccat != NULL)
return *pnew_ccat;
PgfCCat* new_ccat = NULL;
size_t n_coercions = gu_buf_length(coercions);
for (size_t i = 0; i < n_coercions; i++) {
PgfProductionCoerce* pcoerce =
gu_buf_get(coercions, PgfProductionCoerce*, i);
PgfCCat* new_coerce =
pgf_lookup_concretize(st, cache, meta_id, pcoerce->coerce);
if (new_coerce == NULL)
continue;
if (new_ccat == NULL) {
new_ccat = pgf_lookup_new_ccat(st, ccat);
}
PgfProduction cnc_prod;
PgfProductionCoerce* new_pcoerce =
gu_new_variant(PGF_PRODUCTION_COERCE,
PgfProductionCoerce,
&cnc_prod, st->pool);
new_pcoerce->coerce = new_coerce;
if (new_ccat->prods == NULL || new_ccat->n_synprods >= gu_seq_length(new_ccat->prods)) {
new_ccat->prods = gu_realloc_seq(new_ccat->prods, PgfProduction, new_ccat->n_synprods+(n_coercions-i));
}
gu_seq_set(new_ccat->prods, PgfProduction, new_ccat->n_synprods++, cnc_prod);
#ifdef PGF_LOOKUP_DEBUG
{
GuPool* tmp_pool = gu_new_pool();
GuOut* out = gu_file_out(stderr, tmp_pool);
GuExn* err = gu_exn(tmp_pool);
gu_printf(out,err,"C%d -> _[C%d]\n",new_ccat->fid,new_pcoerce->coerce->fid);
gu_pool_free(tmp_pool);
}
#endif
}
gu_map_put(cache, &pair, PgfCCat*, new_ccat);
return new_ccat;
}
static PgfCCat*
pgf_lookup_concretize(PgfLookupState* st, GuMap* cache, PgfMetaId meta_id, PgfCCat *ccat)
{
if (meta_id == 0) {
if (ccat->lindefs == NULL || gu_seq_length(ccat->lindefs) == 0)
return NULL;
return ccat;
}
PgfPair pair;
pair[0] = meta_id;
pair[1] = ccat->fid;
PgfCCat** pnew_ccat = gu_map_find(cache, &pair);
if (pnew_ccat != NULL)
return *pnew_ccat;
PgfCCat* new_ccat = NULL;
GuBuf* id_prods = gu_buf_get(st->join, GuBuf*, meta_id);
if (id_prods == NULL || gu_buf_length(id_prods) == 0) {
PgfCncTree chunks_tree;
PgfCncTreeChunks* chunks =
gu_new_flex_variant(PGF_CNC_TREE_CHUNKS,
PgfCncTreeChunks,
args, 0, &chunks_tree, st->pool);
chunks->n_vars = 0;
chunks->context = NULL;
chunks->n_args = 0;
size_t n_id_prods = gu_buf_length(id_prods);
for (size_t i = 0; i < n_id_prods; i++) {
PgfAbsProduction* prod =
gu_buf_get(id_prods, PgfAbsProduction*, i);
if (ccat == NULL) {
return chunks_tree;
}
if (ccat->lindefs == NULL) {
return ret;
}
PgfCncOverloadMap* overl_table =
gu_map_get(st->concr->fun_indices, prod->fun->name, PgfCncOverloadMap*);
if (overl_table == NULL)
continue;
GuBuf* buf =
gu_map_get(overl_table, ccat, GuBuf*);
if (buf == NULL)
continue;
size_t n_prods = gu_buf_length(buf);
for (size_t j = 0; j < n_prods; j++) {
PgfProductionApply* papply =
gu_buf_get(buf, PgfProductionApply*, j);
size_t n_args = gu_seq_length(papply->args);
GuSeq* new_args = gu_new_seq(PgfPArg, n_args, st->pool);
for (size_t k = 0; k < n_args; k++) {
PgfPArg* parg = gu_seq_index(papply->args, PgfPArg, k);
PgfPArg* new_parg = gu_seq_index(new_args, PgfPArg, k);
new_parg->hypos = parg->hypos;
GuBuf* coercions =
gu_map_get(st->concr->coerce_idx, parg->ccat, GuBuf*);
if (coercions == NULL) {
new_parg->ccat =
pgf_lookup_concretize(st, cache, prod->args[k], parg->ccat);
} else {
new_parg->ccat =
pgf_lookup_concretize_coercions(st, cache, prod->args[k], parg->ccat, coercions);
}
if (new_parg->ccat == NULL)
goto skip;
}
if (new_ccat == NULL) {
new_ccat = pgf_lookup_new_ccat(st, ccat);
}
PgfProduction cnc_prod;
PgfProductionApply* new_papp =
gu_new_variant(PGF_PRODUCTION_APPLY,
PgfProductionApply,
&cnc_prod, st->pool);
new_papp->fun = papply->fun;
new_papp->args = new_args;
if (new_ccat->prods == NULL || new_ccat->n_synprods >= gu_seq_length(new_ccat->prods)) {
new_ccat->prods = gu_realloc_seq(new_ccat->prods, PgfProduction, new_ccat->n_synprods+(n_prods-j));
}
gu_seq_set(new_ccat->prods, PgfProduction, new_ccat->n_synprods++, cnc_prod);
#ifdef PGF_LOOKUP_DEBUG
{
GuPool* tmp_pool = gu_new_pool();
GuOut* out = gu_file_out(stderr, tmp_pool);
GuExn* err = gu_exn(tmp_pool);
gu_printf(out,err,"C%d -> F%d[",new_ccat->fid,new_papp->fun->funid);
size_t n_args = gu_seq_length(new_papp->args);
for (size_t l = 0; l < n_args; l++) {
if (l > 0)
gu_putc(',',out,err);
PgfPArg arg = gu_seq_get(new_papp->args, PgfPArg, l);
if (arg.hypos != NULL) {
size_t n_hypos = gu_seq_length(arg.hypos);
for (size_t r = 0; r < n_hypos; r++) {
if (r > 0)
gu_putc(' ',out,err);
PgfCCat *hypo = gu_seq_get(arg.hypos, PgfCCat*, r);
gu_printf(out,err,"C%d",hypo->fid);
}
}
gu_printf(out,err,"C%d",arg.ccat->fid);
}
gu_printf(out,err,"]\n");
gu_pool_free(tmp_pool);
}
#endif
skip:;
}
}
gu_map_put(cache, &pair, PgfCCat*, new_ccat);
return new_ccat;
}
static PgfCncTree
pgf_lookup_extract(PgfLookupState* st, PgfCCat* ccat)
{
PgfCncTree ret;
if (ccat->fid < st->concr->total_cats) {
int index =
gu_choice_next(st->choice, gu_seq_length(ccat->lindefs));
if (index < 0) {
return ret;
}
PgfCncTreeApp* capp =
gu_new_flex_variant(PGF_CNC_TREE_APP,
PgfCncTreeApp,
@@ -553,55 +650,53 @@ pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat)
capp->n_vars = 0;
capp->context = NULL;
capp->n_args = 1;
capp->args[0] = chunks_tree;
return ret;
}
size_t n_id_prods = gu_buf_length(id_prods);
size_t i = gu_choice_next(st->choice, n_id_prods);
PgfAbsProduction* prod =
gu_buf_get(id_prods, PgfAbsProduction*, i);
size_t n_args = gu_seq_length(prod->fun->type->hypos);
PgfCncOverloadMap* overl_table =
gu_map_get(st->concr->fun_indices, prod->fun->name, PgfCncOverloadMap*);
if (overl_table == NULL) {
return gu_null_variant;
}
if (ccat == NULL) {
size_t n_count = gu_map_count(overl_table);
GuChoiceMark mark = gu_choice_mark(st->choice);
redo:;
int index = gu_choice_next(st->choice, n_count);
if (index < 0) {
goto done;
}
PgfCncItor clo = { { pgf_cnc_cat_resolve_itor }, index, NULL, NULL };
gu_map_iter(overl_table, &clo.fn, NULL);
assert(clo.ccat != NULL && clo.buf != NULL);
ret = pgf_lookup_extract_app(st, clo.ccat, clo.buf, n_args, prod->args);
if (gu_variant_is_null(ret)) {
gu_choice_reset(st->choice, mark);
if (gu_choice_advance(st->choice))
goto redo;
}
PgfCncTreeChunks* chunks =
gu_new_flex_variant(PGF_CNC_TREE_CHUNKS,
PgfCncTreeChunks,
args, 0, &capp->args[0], st->pool);
chunks->n_vars = 0;
chunks->context = NULL;
chunks->n_args = 0;
} else {
GuBuf* buf =
gu_map_get(overl_table, ccat, GuBuf*);
if (buf == NULL) {
goto done;
}
int index =
gu_choice_next(st->choice, ccat->n_synprods);
ret = pgf_lookup_extract_app(st, ccat, buf, n_args, prod->args);
PgfProduction prod =
gu_seq_get(ccat->prods, PgfProduction, index);
GuVariantInfo i = gu_variant_open(prod);
switch (i.tag) {
case PGF_PRODUCTION_APPLY: {
PgfProductionApply* papply = i.data;
size_t n_args = gu_seq_length(papply->args);
PgfCncTreeApp* capp =
gu_new_flex_variant(PGF_CNC_TREE_APP,
PgfCncTreeApp,
args, n_args, &ret, st->pool);
capp->ccat = ccat;
capp->fun = papply->fun;
capp->fid = 0;
capp->n_vars = 0;
capp->context = NULL;
capp->n_args = n_args;
for (size_t i = 0; i < n_args; i++) {
PgfPArg* arg = gu_seq_index(papply->args, PgfPArg, i);
capp->args[i] = pgf_lookup_extract(st, arg->ccat);
}
break;
}
case PGF_PRODUCTION_COERCE: {
PgfProductionCoerce* pcoerce = i.data;
ret = pgf_lookup_extract(st, pcoerce->coerce);
break;
}
default:
gu_impossible();
}
}
done:
return ret;
}
@@ -911,20 +1006,32 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
st.expr_tokens=gu_new_buf(PgfInputToken, work_pool);
st.ctrees = gu_new_buf(PgfCncTreeScore, pool);
st.curr_absfun= NULL;
st.max_fid = concr->total_cats;
st.pool = pool;
GuChoiceMark mark = gu_choice_mark(st.choice);
GuMap* cache = gu_new_map(PgfPair, pgf_pair_hasher, PgfCCat*, &gu_null_struct, pool);
double sentence_value =
pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens);
double max = 0;
PgfCncTreeScore* cts = gu_buf_extend(st.ctrees);
for (;;) {
cts->ctree =
pgf_lookup_extract(&st, st.start_id, NULL);
if (!gu_variant_is_null(cts->ctree)) {
PgfCncCat* cnccat =
gu_map_get(concr->cnccats, typ->cid, PgfCncCat*);
size_t n_ccats = gu_seq_length(cnccat->cats);
for (size_t i = 0; i < n_ccats; i++) {
PgfCCat* ccat = gu_seq_get(cnccat->cats, PgfCCat*, i);
PgfCCat* new_ccat = pgf_lookup_concretize(&st, cache, st.start_id, ccat);
if (new_ccat == NULL)
continue;
GuChoiceMark mark = gu_choice_mark(st.choice);
for (;;) {
PgfCncTreeScore* cts = gu_buf_extend(st.ctrees);
cts->ctree =
pgf_lookup_extract(&st, new_ccat);
cts->ctree = pgf_lzr_wrap_linref(cts->ctree, st.pool);
pgf_lzr_linearize(concr, cts->ctree, 0, &st.funcs, st.pool);
@@ -953,18 +1060,16 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
max = cts->score;
}
cts = gu_buf_extend(st.ctrees);
}
gu_choice_reset(st.choice, mark);
gu_choice_reset(st.choice, mark);
if (!gu_choice_advance(st.choice))
break;
if (!gu_choice_advance(st.choice))
break;
}
}
gu_buf_trim(st.ctrees);
gu_pool_free(work_pool);
PgfLookupEnum* lenum = gu_new(PgfLookupEnum, pool);
lenum->en.next = pgf_lookup_enum_next;
lenum->max = max;