From 622274180af2e3bd64ced3e191319dc5be3aa4d1 Mon Sep 17 00:00:00 2001 From: krasimir Date: Tue, 23 May 2017 14:07:55 +0000 Subject: [PATCH] a naive buf working ranking in the sentence lookup --- src/runtime/c/pgf/lookup.c | 413 +++++++++++++++++++++++-------------- 1 file changed, 253 insertions(+), 160 deletions(-) diff --git a/src/runtime/c/pgf/lookup.c b/src/runtime/c/pgf/lookup.c index 53a70f36c..8fa94e19c 100644 --- a/src/runtime/c/pgf/lookup.c +++ b/src/runtime/c/pgf/lookup.c @@ -6,6 +6,8 @@ #include #include #include +#include +#include //#define PGF_LOOKUP_DEBUG @@ -98,7 +100,7 @@ pgf_lookup_index_syms(GuMap* lexicon_idx, PgfSymbols* syms, PgfProductionIdx* id typedef struct { GuMap* function_idx; - GuMap* cat_ids; + GuMap* meta_ids; GuBuf* spine; GuPool* pool; } PgfSpineBuilder; @@ -124,15 +126,15 @@ pgf_lookup_add_production(PgfSpineBuilder* builder, PgfMetaId id, PgfAbsProducti static PgfMetaId pgf_lookup_add_spine_nodes(PgfSpineBuilder* builder, PgfCId cat) { - PgfMetaId cat_id = gu_map_get(builder->cat_ids, cat, PgfMetaId); - if (cat_id != 0) { - return cat_id; + PgfMetaId meta_id = gu_map_get(builder->meta_ids, cat, PgfMetaId); + if (meta_id != 0) { + return meta_id; } - cat_id = gu_buf_length(builder->spine); + meta_id = gu_buf_length(builder->spine); gu_buf_push(builder->spine, GuBuf*, gu_new_buf(PgfAbsProduction*, builder->pool)); - gu_map_put(builder->cat_ids, cat, PgfMetaId, cat_id); + gu_map_put(builder->meta_ids, cat, PgfMetaId, meta_id); GuBuf* entries = gu_map_get(builder->function_idx, cat, GuBuf*); if (entries != NULL) { @@ -143,13 +145,13 @@ pgf_lookup_add_spine_nodes(PgfSpineBuilder* builder, PgfCId cat) { PgfMetaId id = pgf_lookup_add_spine_nodes(builder, entry->fun->type->cid); PgfAbsProduction* prod = pgf_lookup_new_production(entry->fun, builder->pool); - prod->args[entry->arg_idx] = cat_id; + prod->args[entry->arg_idx] = meta_id; pgf_lookup_add_production(builder, id, prod); } } - return cat_id; + return meta_id; } static void @@ -164,12 +166,12 @@ pgf_lookup_add_spine_leaf(PgfSpineBuilder* builder, PgfAbsFun *fun) static GuBuf* pgf_lookup_build_spine(GuMap* lexicon_idx, GuMap* function_idx, - GuString tok, PgfType* typ, PgfMetaId* cat_id, + GuString tok, PgfType* typ, PgfMetaId* meta_id, GuPool* pool) { PgfSpineBuilder builder; builder.function_idx = function_idx; - builder.cat_ids = gu_new_string_map(PgfMetaId, &gu_null_struct, pool); + builder.meta_ids = gu_new_string_map(PgfMetaId, &gu_null_struct, pool); builder.spine = gu_new_buf(GuBuf*, pool); builder.pool = pool; @@ -185,7 +187,7 @@ pgf_lookup_build_spine(GuMap* lexicon_idx, GuMap* function_idx, } } - *cat_id = gu_map_get(builder.cat_ids, typ->cid, PgfMetaId); + *meta_id = gu_map_get(builder.meta_ids, typ->cid, PgfMetaId); return builder.spine; } @@ -219,32 +221,32 @@ pgf_pair_hasher[1] = { static PgfMetaId pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, - PgfMetaId cat_id1, GuBuf* spine1, - PgfMetaId cat_id2, GuBuf* spine2, + PgfMetaId meta_id1, GuBuf* spine1, + PgfMetaId meta_id2, GuBuf* spine2, GuPool* pool) { - if (cat_id1 == 0 && cat_id2 == 0) + if (meta_id1 == 0 && meta_id2 == 0) return 0; PgfPair pair; - pair[0] = cat_id1; - pair[1] = cat_id2; - PgfMetaId cat_id = gu_map_get(pairs, &pair, PgfMetaId); - if (cat_id != 0) - return cat_id; + pair[0] = meta_id1; + pair[1] = meta_id2; + PgfMetaId meta_id = gu_map_get(pairs, &pair, PgfMetaId); + if (meta_id != 0) + return meta_id; - cat_id = gu_buf_length(spine); + meta_id = gu_buf_length(spine); GuBuf* id_prods = gu_new_buf(PgfAbsProduction*, pool); gu_buf_push(spine, GuBuf*, id_prods); - gu_map_put(pairs, &pair, PgfMetaId, cat_id); + gu_map_put(pairs, &pair, PgfMetaId, meta_id); - GuBuf* id_prods1 = gu_buf_get(spine1, GuBuf*, cat_id1); - GuBuf* id_prods2 = gu_buf_get(spine2, GuBuf*, cat_id2); - size_t n_id_prods1 = (cat_id1 == 0) ? 0 : gu_buf_length(id_prods1); - size_t n_id_prods2 = (cat_id2 == 0) ? 0 : gu_buf_length(id_prods2); + GuBuf* id_prods1 = gu_buf_get(spine1, GuBuf*, meta_id1); + GuBuf* id_prods2 = gu_buf_get(spine2, GuBuf*, meta_id2); + size_t n_id_prods1 = (meta_id1 == 0) ? 0 : gu_buf_length(id_prods1); + size_t n_id_prods2 = (meta_id2 == 0) ? 0 : gu_buf_length(id_prods2); - if (cat_id1 != 0) { + if (meta_id1 != 0) { for (size_t i = 0; i < n_id_prods1; i++) { PgfAbsProduction* prod1 = gu_buf_get(id_prods1, PgfAbsProduction*, i); @@ -287,7 +289,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, } } - if (cat_id2 != 0) { + if (meta_id2 != 0) { for (size_t i = 0; i < n_id_prods2; i++) { PgfAbsProduction* prod2 = gu_buf_get(id_prods2, PgfAbsProduction*, i); @@ -318,13 +320,13 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, } } - return cat_id; + return meta_id; } static GuBuf* -pgf_lookup_merge(PgfMetaId cat_id1, GuBuf* spine1, - PgfMetaId cat_id2, GuBuf* spine2, - PgfMetaId* cat_id, +pgf_lookup_merge(PgfMetaId meta_id1, GuBuf* spine1, + PgfMetaId meta_id2, GuBuf* spine2, + PgfMetaId* meta_id, GuPool* pool, GuPool* out_pool) { GuBuf* spine = gu_new_buf(GuBuf*, out_pool); @@ -332,105 +334,69 @@ pgf_lookup_merge(PgfMetaId cat_id1, GuBuf* spine1, GuMap* pairs = gu_new_map(PgfPair, pgf_pair_hasher, PgfMetaId, &gu_null_struct, pool); - *cat_id = + *meta_id = pgf_lookup_merge_cats(spine, pairs, - cat_id1, spine1, - cat_id2, spine2, + meta_id1, spine1, + meta_id2, spine2, out_pool); return spine; } -static bool -pgf_lookup_filter(GuBuf* join, PgfMetaId cat_id, GuSeq* counts, GuBuf* stack) -{ - if (cat_id == 0) - return true; - - size_t count = gu_seq_get(counts, size_t, cat_id); - if (count > 0) - return true; - - size_t n_stack = gu_buf_length(stack); - for (size_t i = 0; i < n_stack; i++) { - PgfMetaId id = gu_buf_get(stack, PgfMetaId, i); - if (cat_id == id) { - return false; - } - } - gu_buf_push(stack, PgfMetaId, cat_id); - - size_t pos = 0; - size_t maximum = 0; - GuBuf* id_prods = gu_buf_get(join, GuBuf*, cat_id); - size_t n_id_prods = gu_buf_length(id_prods); - for (size_t i = 0; i < n_id_prods; i++) { - PgfAbsProduction* prod = - gu_buf_get(id_prods, PgfAbsProduction*, i); - - size_t n_args = gu_seq_length(prod->fun->type->hypos); - size_t sum = prod->count; - for (size_t j = 0; j < n_args; j++) { - if (!pgf_lookup_filter(join, prod->args[j], counts, stack)) { - sum = 0; - break; - } - sum += gu_seq_get(counts, size_t, prod->args[j]); - } - - if (sum > maximum) { - maximum = sum; - pos = 0; - } - if (sum == maximum) { - gu_buf_set(id_prods, PgfAbsProduction*, pos, prod); - pos++; - } - - prod->count = sum; - } - - gu_seq_set(counts, size_t, cat_id, maximum); - gu_buf_trim_n(id_prods, n_id_prods-pos); - - gu_buf_pop(stack, PgfMetaId); - - return true; -} - typedef struct { - GuEnum en; GuBuf* join; PgfMetaId start_id; GuChoice* choice; + GuBuf* stack; + GuBuf* exprs; GuPool* out_pool; +} PgfLookupState; + +typedef struct { + GuEnum en; + double max; + size_t index; + GuBuf* exprs; } PgfLookupEnum; -static void -pgf_lookup_extract(PgfLookupEnum* st, PgfMetaId cat_id, PgfExprProb* ep) +static bool +pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfExprProb* ep) { - GuBuf* id_prods = gu_buf_get(st->join, GuBuf*, cat_id); - + GuBuf* id_prods = gu_buf_get(st->join, GuBuf*, meta_id); + if (id_prods == NULL || gu_buf_length(id_prods) == 0) { ep->expr = gu_new_variant_i(st->out_pool, PGF_EXPR_META, PgfExprMeta, - cat_id); + meta_id); ep->prob = 0; - return; + return true; } + size_t n_stack = gu_buf_length(st->stack); + for (size_t i = 0; i < n_stack; i++) { + PgfMetaId id = gu_buf_get(st->stack, PgfMetaId, i); + if (meta_id == id) { + return false; + } + } + gu_buf_push(st->stack, PgfMetaId, meta_id); + size_t n_id_prods = gu_buf_length(id_prods); size_t i = gu_choice_next(st->choice, n_id_prods); - PgfAbsProduction* prod = + PgfAbsProduction* prod = gu_buf_get(id_prods, PgfAbsProduction*, i); *ep = prod->fun->ep; + bool res = true; size_t n_args = gu_seq_length(prod->fun->type->hypos); for (size_t j = 0; j < n_args; j++) { PgfExprProb ep_arg; - pgf_lookup_extract(st, prod->args[j], &ep_arg); + if (!pgf_lookup_extract(st, prod->args[j], &ep_arg)) { + res = false; + break; + } ep->expr = gu_new_variant_i(st->out_pool, PGF_EXPR_APP, @@ -438,29 +404,124 @@ pgf_lookup_extract(PgfLookupEnum* st, PgfMetaId cat_id, PgfExprProb* ep) ep->expr, ep_arg.expr); ep->prob += ep_arg.prob; } + + gu_buf_pop(st->stack, PgfMetaId); + return res; } +static GuBuf* +pgf_lookup_tokenize(GuString buf, size_t len, GuPool* pool) +{ + GuBuf* tokens = gu_new_buf(GuString, pool); + + GuUCS c = ' '; + const uint8_t* p = (const uint8_t*) buf; + for (;;) { + while (gu_ucs_is_space(c)) { + c = gu_utf8_decode(&p); + } + if (c == 0) + break; + + const uint8_t* start = p-1; + while (c != 0 && !gu_ucs_is_space(c)) { + c = gu_utf8_decode(&p); + } + const uint8_t* end = p-1; + + size_t len = end-start; + GuString tok = gu_malloc(pool, len+1); + memcpy((uint8_t*) tok, start, len); + ((uint8_t*) tok)[len] = 0; + + gu_buf_push(tokens, GuString, tok); + } + + return tokens; +} + +static int +pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, + int* matrix, size_t i, size_t j); + +static int +pgf_lookup_compute_kernel_helper2(GuBuf* sentence_tokens, GuBuf* expr_tokens, + int* matrix, size_t i, size_t j) +{ +// size_t n_sentence_tokens = gu_buf_length(sentence_tokens); + size_t n_expr_tokens = gu_buf_length(expr_tokens); + + if (j >= n_expr_tokens) + return 0; + + GuString sentence_token = gu_buf_get(sentence_tokens, GuString, i); + GuString expr_token = gu_buf_get(expr_tokens, GuString, j); + if (strcmp(sentence_token, expr_token) == 0) { + return 1 + + pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, + matrix, i+1, j+1) + + pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, + matrix, i, j+1); + } else { + return pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, matrix, i, j+1); + } +} + +static int +pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, + int* matrix, size_t i, size_t j) +{ + size_t n_sentence_tokens = gu_buf_length(sentence_tokens); +// size_t n_expr_tokens = gu_buf_length(expr_tokens); + + int score = matrix[i+n_sentence_tokens*j]; + if (score == -1) { + if (i >= n_sentence_tokens) + score = 0; + else + score = pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, + matrix, i+1, j) + + pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, + matrix, i, j); + matrix[n_sentence_tokens*i + j] = score; + } + + return score; +} + +static int +pgf_lookup_compute_kernel(GuBuf* sentence_tokens, GuBuf* expr_tokens) +{ + size_t n_sentence_tokens = gu_buf_length(sentence_tokens); + size_t n_expr_tokens = gu_buf_length(expr_tokens); + size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1)*sizeof(int); + int* matrix = alloca(size); + memset(matrix, -1, size); + + return pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, matrix, 0, 0); +} + +typedef struct { + PgfExprProb ep; + double score; +} PgfExprScore; + static void pgf_lookup_enum_next(GuEnum* self, void* to, GuPool* pool) { PgfLookupEnum* st = gu_container(self, PgfLookupEnum, en); - - if (st->choice == NULL) { - *((PgfExprProb**) to) = NULL; - return; + PgfExprScore* es = NULL; + + while (st->index < gu_buf_length(st->exprs)) { + es = gu_buf_index(st->exprs, PgfExprScore, st->index); + st->index++; + if (fabs(es->score - st->max) < 0.00005) { + *((PgfExprProb**) to) = &es->ep; + return; + } } - GuChoiceMark mark = gu_choice_mark(st->choice); - - PgfExprProb* ep = gu_new(PgfExprProb, pool); - pgf_lookup_extract(st, st->start_id, ep); - *((PgfExprProb**) to) = ep; - - gu_choice_reset(st->choice, mark); - - if (!gu_choice_advance(st->choice)) { - st->choice = NULL; - } + *((PgfExprProb**) to) = NULL; } PGF_API GuEnum* @@ -500,52 +561,26 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po GuPool *work_pool = gu_new_pool(); - PgfMetaId cat_id1 = 0; + GuBuf* sentence_tokens = + pgf_lookup_tokenize(sentence, + strlen(sentence), + work_pool); + + PgfMetaId meta_id1 = 0; GuBuf* join = gu_new_buf(GuBuf*, pool); gu_buf_push(join, GuBuf*, NULL); - GuUCS c = ' '; - const uint8_t* p = (const uint8_t*) sentence; - for (;;) { - while (gu_ucs_is_space(c)) { - c = gu_utf8_decode(&p); - } - if (c == 0) - break; + size_t n_tokens = gu_buf_length(sentence_tokens); + for (size_t i = 0; i < n_tokens; i++) { + GuString tok = gu_buf_get(sentence_tokens, GuString, i); - const uint8_t* start = p-1; - while (c != 0 && !gu_ucs_is_space(c)) { - c = gu_utf8_decode(&p); - } - const uint8_t* end = p-1; - - size_t len = end-start; - GuString tok = gu_malloc(work_pool, len+1); - memcpy((uint8_t*) tok, start, len); - ((uint8_t*) tok)[len] = 0; - - PgfMetaId cat_id2 = 0; + PgfMetaId meta_id2 = 0; GuBuf* spine = pgf_lookup_build_spine(lexicon_idx, function_idx, - tok, typ, &cat_id2, + tok, typ, &meta_id2, work_pool); - join = pgf_lookup_merge(cat_id1, join, cat_id2, spine, &cat_id1, work_pool, pool); - } - - - size_t n_cats = gu_buf_length(join); - GuBuf* stack = gu_new_buf(PgfMetaId, work_pool); - GuSeq* counts = gu_new_seq(size_t, n_cats, work_pool); - for (size_t i = 0; i < n_cats; i++) { - gu_seq_set(counts, size_t, i, 0); - } - pgf_lookup_filter(join, cat_id1, counts, stack); - for (size_t i = 1; i < n_cats; i++) { - if (gu_seq_get(counts, size_t, i) == 0) { - GuBuf* id_prods = gu_buf_get(join, GuBuf*, i); - gu_buf_flush(id_prods); - } + join = pgf_lookup_merge(meta_id1, join, meta_id2, spine, &meta_id1, work_pool, pool); } #ifdef PGF_LOOKUP_DEBUG @@ -557,13 +592,71 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po gu_pool_free(tmp_pool); #endif + PgfLookupState st; + st.join = join; + st.start_id= meta_id1; + st.choice = gu_new_choice(work_pool); + st.stack = gu_new_buf(PgfMetaId, work_pool); + st.exprs = gu_new_buf(PgfExprScore, pool); + st.out_pool= out_pool; + + GuChoiceMark mark = gu_choice_mark(st.choice); + + double max = 0; + PgfExprScore* es = gu_buf_extend(st.exprs); + for (;;) { + bool res = pgf_lookup_extract(&st, st.start_id, &es->ep); + + gu_choice_reset(st.choice, mark); + + if (!gu_choice_advance(st.choice)) + break; + + if (res) { + GuExn* err = gu_exn(work_pool); + GuStringBuf* sbuf = gu_new_string_buf(work_pool); + GuOut* out = gu_string_buf_out(sbuf); + + pgf_linearize(concr, es->ep.expr, out, err); + if (!gu_ok(err)) { + continue; + } + + GuBuf* expr_tokens = + pgf_lookup_tokenize(gu_string_buf_data(sbuf), + gu_string_buf_length(sbuf), + work_pool); + + es->score = + ((double) pgf_lookup_compute_kernel(sentence_tokens, expr_tokens)) / + sqrt(((double) pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens)) * ((double) pgf_lookup_compute_kernel(expr_tokens, expr_tokens))); + +#ifdef PGF_LOOKUP_DEBUG + { + GuPool* tmp_pool = gu_new_pool(); + GuOut* out = gu_file_out(stderr, tmp_pool); + GuExn* err = gu_exn(tmp_pool); + pgf_print_expr(es->ep.expr, NULL, 0, out, err); + gu_printf(out, err, " [%f]\n", es->score); + gu_pool_free(tmp_pool); + } +#endif + + if (es->score > max) { + max = es->score; + } + + es = gu_buf_extend(st.exprs); + } + } + gu_buf_trim(st.exprs); + gu_pool_free(work_pool); - PgfLookupEnum* st = gu_new(PgfLookupEnum, pool); - st->en.next = pgf_lookup_enum_next; - st->join = join; - st->start_id= cat_id1; - st->choice = gu_new_choice(pool); - st->out_pool= out_pool; - return &st->en; + PgfLookupEnum* lenum = gu_new(PgfLookupEnum, pool); + lenum->en.next = pgf_lookup_enum_next; + lenum->max = max; + lenum->index = 0; + lenum->exprs = st.exprs; + return &lenum->en; }