reintroduce counts as a prefilter before applying cosine similarity

This commit is contained in:
krasimir
2017-06-01 09:06:35 +00:00
parent 234a0d5e99
commit 8176470e2a

View File

@@ -20,6 +20,7 @@ typedef struct {
typedef struct {
PgfAbsFun* fun;
size_t count;
PgfMetaId args[0];
} PgfAbsProduction;
@@ -34,6 +35,7 @@ pgf_print_abs_production(PgfMetaId id,
for (size_t i = 0; i < n_hypos; i++) {
gu_printf(out,err," ?%d", prod->args[i]);
}
gu_printf(out,err," [%d]\n",prod->count);
gu_putc('\n',out,err);
}
@@ -118,6 +120,7 @@ pgf_lookup_new_production(PgfAbsFun* fun, GuPool *pool)
size_t n_hypos = gu_seq_length(fun->type->hypos);
PgfAbsProduction* prod = gu_new_flex(pool, PgfAbsProduction, args, n_hypos);
prod->fun = fun;
prod->count = 0;
for (size_t i = 0; i < n_hypos; i++) {
prod->args[i] = 0;
}
@@ -166,6 +169,7 @@ pgf_lookup_add_spine_leaf(PgfSpineBuilder* builder, PgfAbsFun *fun)
{
PgfMetaId id = pgf_lookup_add_spine_nodes(builder, fun->type->cid);
PgfAbsProduction* prod = pgf_lookup_new_production(fun, builder->pool);
prod->count = 1;
pgf_lookup_add_production(builder, id, prod);
}
@@ -257,6 +261,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs,
if (prod1->fun == prod2->fun) {
PgfAbsProduction* prod =
pgf_lookup_new_production(prod1->fun, pool);
prod->count = prod1->count+prod2->count;
size_t n_hypos = gu_seq_length(prod->fun->type->hypos);
for (size_t l = 0; l < n_hypos; l++) {
prod->args[l] =
@@ -274,6 +279,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs,
if (count == 0) {
PgfAbsProduction* prod =
pgf_lookup_new_production(prod1->fun, pool);
prod->count = prod1->count;
size_t n_hypos = gu_seq_length(prod->fun->type->hypos);
for (size_t l = 0; l < n_hypos; l++) {
prod->args[l] =
@@ -304,6 +310,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs,
if (!found) {
PgfAbsProduction* prod =
pgf_lookup_new_production(prod2->fun, pool);
prod->count = prod2->count;
size_t n_hypos = gu_seq_length(prod->fun->type->hypos);
for (size_t l = 0; l < n_hypos; l++) {
prod->args[l] =
@@ -346,7 +353,6 @@ typedef struct {
GuBuf* join;
PgfMetaId start_id;
GuChoice* choice;
GuBuf* stack;
GuBuf* expr_tokens;
GuBuf* ctrees;
PgfAbsFun** curr_absfun;
@@ -430,6 +436,63 @@ redo:;
return ret;
}
static bool
pgf_lookup_filter(GuBuf* join, PgfMetaId meta_id, GuSeq* counts, GuBuf* stack)
{
if (meta_id == 0)
return true;
size_t count = gu_seq_get(counts, size_t, meta_id);
if (count > 0)
return true;
size_t n_stack = gu_buf_length(stack);
for (size_t i = 0; i < n_stack; i++) {
PgfMetaId id = gu_buf_get(stack, PgfMetaId, i);
if (meta_id == id) {
return false;
}
}
gu_buf_push(stack, PgfMetaId, meta_id);
size_t pos = 0;
size_t maximum = 0;
GuBuf* id_prods = gu_buf_get(join, GuBuf*, meta_id);
size_t n_id_prods = gu_buf_length(id_prods);
for (size_t i = 0; i < n_id_prods; i++) {
PgfAbsProduction* prod =
gu_buf_get(id_prods, PgfAbsProduction*, i);
size_t n_args = gu_seq_length(prod->fun->type->hypos);
size_t sum = prod->count;
for (size_t j = 0; j < n_args; j++) {
if (!pgf_lookup_filter(join, prod->args[j], counts, stack)) {
sum = 0;
break;
}
sum += gu_seq_get(counts, size_t, prod->args[j]);
}
if (sum > maximum) {
maximum = sum;
pos = 0;
}
if (sum == maximum) {
gu_buf_set(id_prods, PgfAbsProduction*, pos, prod);
pos++;
}
prod->count = sum;
}
gu_seq_set(counts, size_t, meta_id, maximum);
gu_buf_trim_n(id_prods, n_id_prods-pos);
gu_buf_pop(stack, PgfMetaId);
return true;
}
typedef struct {
GuMapItor fn;
int index;
@@ -496,15 +559,6 @@ pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat)
return ret;
}
size_t n_stack = gu_buf_length(st->stack);
for (size_t i = 0; i < n_stack; i++) {
PgfMetaId id = gu_buf_get(st->stack, PgfMetaId, i);
if (meta_id == id) {
return gu_null_variant;
}
}
gu_buf_push(st->stack, PgfMetaId, meta_id);
size_t n_id_prods = gu_buf_length(id_prods);
size_t i = gu_choice_next(st->choice, n_id_prods);
@@ -550,7 +604,6 @@ redo:;
}
done:
gu_buf_pop(st->stack, PgfMetaId);
return ret;
}
@@ -827,6 +880,20 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
join = pgf_lookup_merge(meta_id1, join, meta_id2, spine, &meta_id1, work_pool, pool);
}
size_t n_cats = gu_buf_length(join);
GuBuf* stack = gu_new_buf(PgfMetaId, work_pool);
GuSeq* counts = gu_new_seq(size_t, n_cats, work_pool);
for (size_t i = 0; i < n_cats; i++) {
gu_seq_set(counts, size_t, i, 0);
}
pgf_lookup_filter(join, meta_id1, counts, stack);
for (size_t i = 1; i < n_cats; i++) {
if (gu_seq_get(counts, size_t, i) == 0) {
GuBuf* id_prods = gu_buf_get(join, GuBuf*, i);
gu_buf_flush(id_prods);
}
}
#ifdef PGF_LOOKUP_DEBUG
GuPool* tmp_pool = gu_new_pool();
GuOut* out = gu_file_out(stderr, tmp_pool);
@@ -842,7 +909,6 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
st.join = join;
st.start_id= meta_id1;
st.choice = gu_new_choice(work_pool);
st.stack = gu_new_buf(PgfMetaId, work_pool);
st.expr_tokens=gu_new_buf(PgfInputToken, work_pool);
st.ctrees = gu_new_buf(PgfCncTreeScore, pool);
st.curr_absfun= NULL;