diff --git a/src/runtime/c/pgf/lookup.c b/src/runtime/c/pgf/lookup.c index 70d936f08..5e5074c4d 100644 --- a/src/runtime/c/pgf/lookup.c +++ b/src/runtime/c/pgf/lookup.c @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -15,6 +16,7 @@ typedef struct { typedef struct { PgfAbsFun* fun; + size_t count; PgfMetaId args[0]; } PgfAbsProduction; @@ -29,7 +31,7 @@ pgf_print_abs_production(PgfMetaId id, for (size_t i = 0; i < n_hypos; i++) { gu_printf(out,err," ?%d", prod->args[i]); } - gu_putc('\n',out,err); + gu_printf(out,err," (%d)\n",prod->count); } static void @@ -105,7 +107,8 @@ static PgfAbsProduction* pgf_lookup_new_production(PgfAbsFun* fun, GuPool *pool) { size_t n_hypos = gu_seq_length(fun->type->hypos); PgfAbsProduction* prod = gu_new_flex(pool, PgfAbsProduction, args, n_hypos); - prod->fun = fun; + prod->fun = fun; + prod->count = 0; for (size_t i = 0; i < n_hypos; i++) { prod->args[i] = 0; } @@ -154,7 +157,8 @@ pgf_lookup_add_spine_leaf(PgfSpineBuilder* builder, PgfAbsFun *fun) { PgfMetaId id = pgf_lookup_add_spine_nodes(builder, fun->type->cid); PgfAbsProduction* prod = pgf_lookup_new_production(fun, builder->pool); - + prod->count = 1; + pgf_lookup_add_production(builder, id, prod); } @@ -251,6 +255,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, if (prod1->fun == prod2->fun) { PgfAbsProduction* prod = pgf_lookup_new_production(prod1->fun, pool); + prod->count = prod1->count+prod2->count; size_t n_hypos = gu_seq_length(prod->fun->type->hypos); for (size_t l = 0; l < n_hypos; l++) { prod->args[l] = @@ -268,6 +273,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, if (count == 0) { PgfAbsProduction* prod = pgf_lookup_new_production(prod1->fun, pool); + prod->count = prod1->count; size_t n_hypos = gu_seq_length(prod->fun->type->hypos); for (size_t l = 0; l < n_hypos; l++) { prod->args[l] = @@ -298,6 +304,7 @@ pgf_lookup_merge_cats(GuBuf* spine, GuMap* pairs, if (!found) { PgfAbsProduction* prod = pgf_lookup_new_production(prod2->fun, pool); + prod->count = prod2->count; size_t n_hypos = gu_seq_length(prod->fun->type->hypos); for (size_t l = 0; l < n_hypos; l++) { prod->args[l] = @@ -334,6 +341,117 @@ pgf_lookup_merge(PgfMetaId cat_id1, GuBuf* spine1, return spine; } +static bool +pgf_lookup_filter(GuBuf* join, PgfMetaId cat_id, GuSeq* counts) +{ + if (cat_id == 0) + return true; + + size_t count = gu_seq_get(counts, size_t, cat_id); + if (count != 0) + return false; + gu_seq_set(counts, size_t, cat_id, 1); + + size_t pos = 0; + size_t maximum = 0; + GuBuf* id_prods = gu_buf_get(join, GuBuf*, cat_id); + size_t n_id_prods = gu_buf_length(id_prods); + for (size_t i = 0; i < n_id_prods; i++) { + PgfAbsProduction* prod = + gu_buf_get(id_prods, PgfAbsProduction*, i); + + size_t n_args = gu_seq_length(prod->fun->type->hypos); + size_t sum = prod->count; + for (size_t j = 0; j < n_args; j++) { + if (!pgf_lookup_filter(join, prod->args[j], counts)) { + sum = 0; + break; + } + sum += gu_seq_get(counts, size_t, prod->args[j]); + } + + if (sum > maximum) { + maximum = sum; + pos = 0; + } + if (sum == maximum) { + gu_buf_set(id_prods, PgfAbsProduction*, pos, prod); + pos++; + } + + prod->count = sum; + } + + gu_seq_set(counts, size_t, cat_id, maximum); + gu_buf_trim_n(id_prods, n_id_prods-pos); + return true; +} + +typedef struct { + GuEnum en; + GuBuf* join; + PgfMetaId start_id; + GuChoice* choice; + GuPool* out_pool; +} PgfLookupEnum; + +static void +pgf_lookup_extract(PgfLookupEnum* st, PgfMetaId cat_id, PgfExprProb* ep) +{ + GuBuf* id_prods = gu_buf_get(st->join, GuBuf*, cat_id); + + if (id_prods == NULL || gu_buf_length(id_prods) == 0) { + ep->expr = gu_new_variant_i(st->out_pool, + PGF_EXPR_META, + PgfExprMeta, + cat_id); + ep->prob = 0; + return; + } + + size_t n_id_prods = gu_buf_length(id_prods); + + size_t i = gu_choice_next(st->choice, n_id_prods); + PgfAbsProduction* prod = + gu_buf_get(id_prods, PgfAbsProduction*, i); + + *ep = prod->fun->ep; + size_t n_args = gu_seq_length(prod->fun->type->hypos); + for (size_t j = 0; j < n_args; j++) { + PgfExprProb ep_arg; + pgf_lookup_extract(st, prod->args[j], &ep_arg); + + ep->expr = gu_new_variant_i(st->out_pool, + PGF_EXPR_APP, + PgfExprApp, + ep->expr, ep_arg.expr); + ep->prob += ep_arg.prob; + } +} + +static void +pgf_lookup_enum_next(GuEnum* self, void* to, GuPool* pool) +{ + PgfLookupEnum* st = gu_container(self, PgfLookupEnum, en); + + if (st->choice == NULL) { + *((PgfExprProb**) to) = NULL; + return; + } + + GuChoiceMark mark = gu_choice_mark(st->choice); + + PgfExprProb* ep = gu_new(PgfExprProb, pool); + pgf_lookup_extract(st, st->start_id, ep); + *((PgfExprProb**) to) = ep; + + gu_choice_reset(st->choice, mark); + + if (!gu_choice_advance(st->choice)) { + st->choice = NULL; + } +} + PGF_API GuEnum* pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* pool, GuPool* out_pool) { @@ -361,7 +479,7 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po funs = gu_new_buf(PgfAbsBottomUpEntry, pool); gu_map_put(function_idx, hypo->type->cid, GuBuf*, funs); } - + PgfAbsBottomUpEntry* entry = gu_buf_extend(funs); entry->fun = fun; entry->arg_idx = j; @@ -372,7 +490,7 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po GuPool *work_pool = gu_new_pool(); PgfMetaId cat_id1 = 0; - GuBuf* join = gu_new_buf(GuBuf*, work_pool); + GuBuf* join = gu_new_buf(GuBuf*, pool); gu_buf_push(join, GuBuf*, NULL); GuUCS c = ' '; @@ -389,7 +507,7 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po c = gu_utf8_decode(&p); } const uint8_t* end = p-1; - + size_t len = end-start; GuString tok = gu_malloc(work_pool, len+1); memcpy((uint8_t*) tok, start, len); @@ -401,13 +519,20 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po tok, typ, &cat_id2, work_pool); - GuPool *work_pool2 = gu_new_pool(); - - join = pgf_lookup_merge(cat_id1, join, cat_id2, spine, &cat_id1, work_pool, work_pool2); - - gu_pool_free(work_pool); - - work_pool = work_pool2; + join = pgf_lookup_merge(cat_id1, join, cat_id2, spine, &cat_id1, work_pool, pool); + } + + size_t n_cats = gu_buf_length(join); + GuSeq* counts = gu_new_seq(size_t, n_cats, work_pool); + for (size_t i = 0; i < n_cats; i++) { + gu_seq_set(counts, size_t, i, 0); + } + pgf_lookup_filter(join, cat_id1, counts); + for (size_t i = 1; i < n_cats; i++) { + if (gu_seq_get(counts, size_t, i) == 0) { + GuBuf* id_prods = gu_buf_get(join, GuBuf*, i); + gu_buf_flush(id_prods); + } } #ifdef PGF_LOOKUP_DEBUG @@ -421,5 +546,11 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po gu_pool_free(work_pool); - return NULL; + PgfLookupEnum* st = gu_new(PgfLookupEnum, pool); + st->en.next = pgf_lookup_enum_next; + st->join = join; + st->start_id= cat_id1; + st->choice = gu_new_choice(pool); + st->out_pool= out_pool; + return &st->en; }