From 8176470e2ae1b3e546804a77e1f1ae14f4b06f96 Mon Sep 17 00:00:00 2001 From: krasimir Date: Thu, 1 Jun 2017 09:06:35 +0000 Subject: [PATCH] reintroduce counts as a prefilter before applying cosine similarity --- src/runtime/c/pgf/lookup.c | 90 +++++++++++++++++++++++++++++++++----- 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/src/runtime/c/pgf/lookup.c b/src/runtime/c/pgf/lookup.c index 9da522587..3e37bb2c9 100644 --- a/src/runtime/c/pgf/lookup.c +++ b/src/runtime/c/pgf/lookup.c @@ -20,6 +20,7 @@ typedef struct { typedef struct { PgfAbsFun* fun; + size_t count; PgfMetaId args[0]; } PgfAbsProduction; @@ -34,6 +35,7 @@ pgf_print_abs_production(PgfMetaId id, for (size_t i = 0; i < n_hypos; i++) { gu_printf(out,err," ?%d", prod->args[i]); } + gu_printf(out,err," [%d]\n",prod->count); gu_putc('\n',out,err); } @@ -118,6 +120,7 @@ pgf_lookup_new_production(PgfAbsFun* fun, GuPool *pool) size_t n_hypos = gu_seq_length(fun->type->hypos); PgfAbsProduction* prod = gu_new_flex(pool, PgfAbsProduction, args, n_hypos); prod->fun = fun; + prod->count = 0; for (size_t i = 0; i < n_hypos; i++) { prod->args[i] = 0; } @@ -166,6 +169,7 @@ pgf_lookup_add_spine_leaf(PgfSpineBuilder* builder, PgfAbsFun *fun) { PgfMetaId id = pgf_lookup_add_spine_nodes(builder, fun->type->cid); PgfAbsProduction* prod = pgf_lookup_new_production(fun, builder->pool); + prod->count = 1; pgf_lookup_add_production(builder, id, prod); } @@ -257,6 +261,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, if (prod1->fun == prod2->fun) { PgfAbsProduction* prod = pgf_lookup_new_production(prod1->fun, pool); + prod->count = prod1->count+prod2->count; size_t n_hypos = gu_seq_length(prod->fun->type->hypos); for (size_t l = 0; l < n_hypos; l++) { prod->args[l] = @@ -274,6 +279,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, if (count == 0) { PgfAbsProduction* prod = pgf_lookup_new_production(prod1->fun, pool); + prod->count = prod1->count; size_t n_hypos = gu_seq_length(prod->fun->type->hypos); for (size_t l = 0; l < n_hypos; l++) { prod->args[l] = @@ -304,6 +310,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, if (!found) { PgfAbsProduction* prod = pgf_lookup_new_production(prod2->fun, pool); + prod->count = prod2->count; size_t n_hypos = gu_seq_length(prod->fun->type->hypos); for (size_t l = 0; l < n_hypos; l++) { prod->args[l] = @@ -346,7 +353,6 @@ typedef struct { GuBuf* join; PgfMetaId start_id; GuChoice* choice; - GuBuf* stack; GuBuf* expr_tokens; GuBuf* ctrees; PgfAbsFun** curr_absfun; @@ -430,6 +436,63 @@ redo:; return ret; } +static bool +pgf_lookup_filter(GuBuf* join, PgfMetaId meta_id, GuSeq* counts, GuBuf* stack) +{ + if (meta_id == 0) + return true; + + size_t count = gu_seq_get(counts, size_t, meta_id); + if (count > 0) + return true; + + size_t n_stack = gu_buf_length(stack); + for (size_t i = 0; i < n_stack; i++) { + PgfMetaId id = gu_buf_get(stack, PgfMetaId, i); + if (meta_id == id) { + return false; + } + } + gu_buf_push(stack, PgfMetaId, meta_id); + + size_t pos = 0; + size_t maximum = 0; + GuBuf* id_prods = gu_buf_get(join, GuBuf*, meta_id); + size_t n_id_prods = gu_buf_length(id_prods); + for (size_t i = 0; i < n_id_prods; i++) { + PgfAbsProduction* prod = + gu_buf_get(id_prods, PgfAbsProduction*, i); + + size_t n_args = gu_seq_length(prod->fun->type->hypos); + size_t sum = prod->count; + for (size_t j = 0; j < n_args; j++) { + if (!pgf_lookup_filter(join, prod->args[j], counts, stack)) { + sum = 0; + break; + } + sum += gu_seq_get(counts, size_t, prod->args[j]); + } + + if (sum > maximum) { + maximum = sum; + pos = 0; + } + if (sum == maximum) { + gu_buf_set(id_prods, PgfAbsProduction*, pos, prod); + pos++; + } + + prod->count = sum; + } + + gu_seq_set(counts, size_t, meta_id, maximum); + gu_buf_trim_n(id_prods, n_id_prods-pos); + + gu_buf_pop(stack, PgfMetaId); + + return true; +} + typedef struct { GuMapItor fn; int index; @@ -496,15 +559,6 @@ pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat) return ret; } - size_t n_stack = gu_buf_length(st->stack); - for (size_t i = 0; i < n_stack; i++) { - PgfMetaId id = gu_buf_get(st->stack, PgfMetaId, i); - if (meta_id == id) { - return gu_null_variant; - } - } - gu_buf_push(st->stack, PgfMetaId, meta_id); - size_t n_id_prods = gu_buf_length(id_prods); size_t i = gu_choice_next(st->choice, n_id_prods); @@ -550,7 +604,6 @@ redo:; } done: - gu_buf_pop(st->stack, PgfMetaId); return ret; } @@ -827,6 +880,20 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po join = pgf_lookup_merge(meta_id1, join, meta_id2, spine, &meta_id1, work_pool, pool); } + size_t n_cats = gu_buf_length(join); + GuBuf* stack = gu_new_buf(PgfMetaId, work_pool); + GuSeq* counts = gu_new_seq(size_t, n_cats, work_pool); + for (size_t i = 0; i < n_cats; i++) { + gu_seq_set(counts, size_t, i, 0); + } + pgf_lookup_filter(join, meta_id1, counts, stack); + for (size_t i = 1; i < n_cats; i++) { + if (gu_seq_get(counts, size_t, i) == 0) { + GuBuf* id_prods = gu_buf_get(join, GuBuf*, i); + gu_buf_flush(id_prods); + } + } + #ifdef PGF_LOOKUP_DEBUG GuPool* tmp_pool = gu_new_pool(); GuOut* out = gu_file_out(stderr, tmp_pool); @@ -842,7 +909,6 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po st.join = join; st.start_id= meta_id1; st.choice = gu_new_choice(work_pool); - st.stack = gu_new_buf(PgfMetaId, work_pool); st.expr_tokens=gu_new_buf(PgfInputToken, work_pool); st.ctrees = gu_new_buf(PgfCncTreeScore, pool); st.curr_absfun= NULL;