an optimized expression extraction in the lookup

This commit is contained in:
krasimir
2017-05-23 21:06:17 +00:00
parent fc3be2b937
commit 9ec86417cb
3 changed files with 344 additions and 131 deletions

View File

@@ -70,44 +70,6 @@ typedef struct {
} PgfCnc; } PgfCnc;
//
// PgfCncTree
//
typedef enum {
PGF_CNC_TREE_APP,
PGF_CNC_TREE_CHUNKS,
PGF_CNC_TREE_LIT,
} PgfCncTreeTag;
typedef struct {
PgfCCat* ccat;
PgfCncFun* fun;
int fid;
size_t n_vars;
PgfPrintContext* context;
size_t n_args;
PgfCncTree args[];
} PgfCncTreeApp;
typedef struct {
size_t n_vars;
PgfPrintContext* context;
size_t n_args;
PgfCncTree args[];
} PgfCncTreeChunks;
typedef struct {
size_t n_vars;
PgfPrintContext* context;
int fid;
PgfLiteral lit;
} PgfCncTreeLit;
#ifdef PGF_LINEARIZER_DEBUG #ifdef PGF_LINEARIZER_DEBUG
static void static void
pgf_print_cnc_tree_vars(size_t n_vars, PgfPrintContext* context, pgf_print_cnc_tree_vars(size_t n_vars, PgfPrintContext* context,
@@ -128,7 +90,7 @@ pgf_print_cnc_tree_vars(size_t n_vars, PgfPrintContext* context,
} }
} }
static void PGF_INTERNAL void
pgf_print_cnc_tree(PgfCncTree ctree, GuOut* out, GuExn* err) pgf_print_cnc_tree(PgfCncTree ctree, GuOut* out, GuExn* err)
{ {
GuVariantInfo ti = gu_variant_open(ctree); GuVariantInfo ti = gu_variant_open(ctree);

View File

@@ -4,23 +4,48 @@
#include <gu/enum.h> #include <gu/enum.h>
/// Linearization of abstract syntax trees. /// Linearization of abstract syntax trees.
/// @file
/** @}
*
* @name Enumerating concrete syntax trees
*
* Because of the \c variants construct in GF, there may be several
* possible concrete syntax trees that correspond to a given abstract
* syntax tree. These can be enumerated with #pgf_concretize.
*
* @{
*/
//
// PgfCncTree
//
/// A concrete syntax tree /// A concrete syntax tree
typedef GuVariant PgfCncTree; typedef GuVariant PgfCncTree;
typedef enum {
PGF_CNC_TREE_APP,
PGF_CNC_TREE_CHUNKS,
PGF_CNC_TREE_LIT,
} PgfCncTreeTag;
typedef struct {
PgfCCat* ccat;
PgfCncFun* fun;
int fid;
size_t n_vars;
PgfPrintContext* context;
size_t n_args;
PgfCncTree args[];
} PgfCncTreeApp;
typedef struct {
size_t n_vars;
PgfPrintContext* context;
size_t n_args;
PgfCncTree args[];
} PgfCncTreeChunks;
typedef struct {
size_t n_vars;
PgfPrintContext* context;
int fid;
PgfLiteral lit;
} PgfCncTreeLit;
/// An enumeration of #PgfCncTree trees. /// An enumeration of #PgfCncTree trees.
typedef GuEnum PgfCncTreeEnum; typedef GuEnum PgfCncTreeEnum;

View File

@@ -5,11 +5,13 @@
#include <gu/string.h> #include <gu/string.h>
#include <gu/choice.h> #include <gu/choice.h>
#include <pgf/data.h> #include <pgf/data.h>
#include <pgf/linearizer.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <math.h> #include <math.h>
//#define PGF_LOOKUP_DEBUG //#define PGF_LOOKUP_DEBUG
//#define PGF_LINEARIZER_DEBUG
typedef struct { typedef struct {
PgfAbsFun* fun; PgfAbsFun* fun;
@@ -53,6 +55,11 @@ pgf_print_abs_productions(GuBuf* prods,
} }
#endif #endif
#ifdef PGF_LINEARIZER_DEBUG
PGF_INTERNAL_DECL void
pgf_print_cnc_tree(PgfCncTree ctree, GuOut* out, GuExn* err);
#endif
static void static void
pgf_lookup_index_syms(GuMap* lexicon_idx, PgfSymbols* syms, PgfProductionIdx* idx, GuPool* pool) { pgf_lookup_index_syms(GuMap* lexicon_idx, PgfSymbols* syms, PgfProductionIdx* idx, GuPool* pool) {
size_t n_syms = gu_seq_length(syms); size_t n_syms = gu_seq_length(syms);
@@ -344,40 +351,168 @@ pgf_lookup_merge(PgfMetaId meta_id1, GuBuf* spine1,
} }
typedef struct { typedef struct {
PgfLinFuncs* funcs;
PgfConcr* concr;
GuBuf* join; GuBuf* join;
PgfMetaId start_id; PgfMetaId start_id;
GuChoice* choice; GuChoice* choice;
GuBuf* stack; GuBuf* stack;
GuBuf* exprs; GuBuf* expr_tokens;
GuPool* out_pool; GuBuf* ctrees;
int fid;
GuPool* pool;
} PgfLookupState; } PgfLookupState;
typedef struct { typedef struct {
GuEnum en; GuEnum en;
double max; double max;
size_t index; size_t index;
GuBuf* exprs; GuBuf* ctrees;
GuPool* out_pool;
} PgfLookupEnum; } PgfLookupEnum;
static bool static PgfCncTree
pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfExprProb* ep) pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat);
static PgfCncTree
pgf_lookup_extract_app(PgfLookupState* st,
PgfCCat* ccat, GuBuf* buf,
size_t n_args, PgfMetaId* args)
{ {
GuChoiceMark mark = gu_choice_mark(st->choice);
int save_fid = st->fid;
PgfCncTree ret = gu_null_variant;
PgfCncTreeApp* capp =
gu_new_flex_variant(PGF_CNC_TREE_APP,
PgfCncTreeApp,
args, n_args, &ret, st->pool);
capp->ccat = ccat;
capp->n_vars = 0;
capp->context = NULL;
redo:;
int index = gu_choice_next(st->choice, gu_buf_length(buf));
if (index < 0) {
return gu_null_variant;
}
PgfProductionApply* papply =
gu_buf_get(buf, PgfProductionApply*, index);
gu_assert(n_args == gu_seq_length(papply->args));
capp->fun = papply->fun;
capp->fid = 0;
capp->n_args = n_args;
for (size_t i = 0; i < n_args; i++) {
PgfPArg* parg = gu_seq_index(papply->args, PgfPArg, i);
PgfMetaId meta_id = args[i];
PgfCCat* ccat = NULL;
GuBuf* coercions =
gu_map_get(st->concr->coerce_idx, parg->ccat, GuBuf*);
if (coercions == NULL) {
ccat = parg->ccat;
} else {
int index = gu_choice_next(st->choice, gu_buf_length(coercions));
if (index < 0) {
st->fid = save_fid;
gu_choice_reset(st->choice, mark);
if (!gu_choice_advance(st->choice))
return gu_null_variant;
goto redo;
}
PgfProductionCoerce* pcoerce =
gu_buf_get(coercions, PgfProductionCoerce*, index);
ccat = pcoerce->coerce;
}
capp->args[i] =
pgf_lookup_extract(st, meta_id, ccat);
if (gu_variant_is_null(capp->args[i])) {
gu_choice_reset(st->choice, mark);
if (!gu_choice_advance(st->choice))
return gu_null_variant;
goto redo;
}
}
return ret;
}
typedef struct {
GuMapItor fn;
int index;
PgfCCat* ccat;
GuBuf* buf;
} PgfCncItor;
static void
pgf_cnc_cat_resolve_itor(GuMapItor* fn, const void* key, void* value, GuExn* err)
{
PgfCncItor* clo = (PgfCncItor*) fn;
PgfCCat* ccat = (PgfCCat*) key;
GuBuf* buf = *((GuBuf**) value);
if (clo->index == 0) {
clo->ccat = ccat;
clo->buf = buf;
}
clo->index--;
}
static PgfCncTree
pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfCCat *ccat)
{
PgfCncTree ret = gu_null_variant;
GuBuf* id_prods = gu_buf_get(st->join, GuBuf*, meta_id); GuBuf* id_prods = gu_buf_get(st->join, GuBuf*, meta_id);
if (id_prods == NULL || gu_buf_length(id_prods) == 0) { if (id_prods == NULL || gu_buf_length(id_prods) == 0) {
ep->expr = gu_new_variant_i(st->out_pool, PgfCncTree chunks_tree;
PGF_EXPR_META, PgfCncTreeChunks* chunks =
PgfExprMeta, gu_new_flex_variant(PGF_CNC_TREE_CHUNKS,
meta_id); PgfCncTreeChunks,
ep->prob = 0; args, 0, &chunks_tree, st->pool);
return true; chunks->n_vars = 0;
chunks->context = NULL;
chunks->n_args = 0;
if (ccat == NULL) {
return chunks_tree;
}
if (ccat->lindefs == NULL) {
return ret;
}
int index =
gu_choice_next(st->choice, gu_seq_length(ccat->lindefs));
if (index < 0) {
return ret;
}
PgfCncTreeApp* capp =
gu_new_flex_variant(PGF_CNC_TREE_APP,
PgfCncTreeApp,
args, 1, &ret, st->pool);
capp->ccat = ccat;
capp->fun = gu_seq_get(ccat->lindefs, PgfCncFun*, index);
capp->fid = st->fid++;
capp->n_vars = 0;
capp->context = NULL;
capp->n_args = 1;
capp->args[0] = chunks_tree;
return ret;
} }
size_t n_stack = gu_buf_length(st->stack); size_t n_stack = gu_buf_length(st->stack);
for (size_t i = 0; i < n_stack; i++) { for (size_t i = 0; i < n_stack; i++) {
PgfMetaId id = gu_buf_get(st->stack, PgfMetaId, i); PgfMetaId id = gu_buf_get(st->stack, PgfMetaId, i);
if (meta_id == id) { if (meta_id == id) {
return false; return gu_null_variant;
} }
} }
gu_buf_push(st->stack, PgfMetaId, meta_id); gu_buf_push(st->stack, PgfMetaId, meta_id);
@@ -388,25 +523,47 @@ pgf_lookup_extract(PgfLookupState* st, PgfMetaId meta_id, PgfExprProb* ep)
PgfAbsProduction* prod = PgfAbsProduction* prod =
gu_buf_get(id_prods, PgfAbsProduction*, i); gu_buf_get(id_prods, PgfAbsProduction*, i);
*ep = prod->fun->ep;
bool res = true;
size_t n_args = gu_seq_length(prod->fun->type->hypos); size_t n_args = gu_seq_length(prod->fun->type->hypos);
for (size_t j = 0; j < n_args; j++) {
PgfExprProb ep_arg; PgfCncOverloadMap* overl_table =
if (!pgf_lookup_extract(st, prod->args[j], &ep_arg)) { gu_map_get(st->concr->fun_indices, prod->fun->name, PgfCncOverloadMap*);
res = false; if (overl_table == NULL) {
break; return gu_null_variant;
}
if (ccat == NULL) {
size_t n_count = gu_map_count(overl_table);
GuChoiceMark mark = gu_choice_mark(st->choice);
redo:;
int index = gu_choice_next(st->choice, n_count);
if (index < 0) {
goto done;
} }
ep->expr = gu_new_variant_i(st->out_pool, PgfCncItor clo = { { pgf_cnc_cat_resolve_itor }, index, NULL, NULL };
PGF_EXPR_APP, gu_map_iter(overl_table, &clo.fn, NULL);
PgfExprApp, assert(clo.ccat != NULL && clo.buf != NULL);
ep->expr, ep_arg.expr);
ep->prob += ep_arg.prob; ret = pgf_lookup_extract_app(st, clo.ccat, clo.buf, n_args, prod->args);
if (gu_variant_is_null(ret)) {
gu_choice_reset(st->choice, mark);
if (gu_choice_advance(st->choice))
goto redo;
}
} else {
GuBuf* buf =
gu_map_get(overl_table, ccat, GuBuf*);
if (buf == NULL) {
goto done;
}
ret = pgf_lookup_extract_app(st, ccat, buf, n_args, prod->args);
} }
done:
gu_buf_pop(st->stack, PgfMetaId); gu_buf_pop(st->stack, PgfMetaId);
return res; return ret;
} }
static GuBuf* static GuBuf*
@@ -440,17 +597,16 @@ pgf_lookup_tokenize(GuString buf, size_t len, GuPool* pool)
return tokens; return tokens;
} }
static int static long
pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens,
int* matrix, size_t i, size_t j); long* matrix, size_t i, size_t j);
static int static long
pgf_lookup_compute_kernel_helper2(GuBuf* sentence_tokens, GuBuf* expr_tokens, pgf_lookup_compute_kernel_helper2(GuBuf* sentence_tokens, GuBuf* expr_tokens,
int* matrix, size_t i, size_t j) long* matrix, size_t i, size_t j)
{ {
// size_t n_sentence_tokens = gu_buf_length(sentence_tokens);
size_t n_expr_tokens = gu_buf_length(expr_tokens); size_t n_expr_tokens = gu_buf_length(expr_tokens);
if (j >= n_expr_tokens) if (j >= n_expr_tokens)
return 0; return 0;
@@ -467,14 +623,13 @@ pgf_lookup_compute_kernel_helper2(GuBuf* sentence_tokens, GuBuf* expr_tokens,
} }
} }
static int static long
pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens,
int* matrix, size_t i, size_t j) long* matrix, size_t i, size_t j)
{ {
size_t n_sentence_tokens = gu_buf_length(sentence_tokens); size_t n_sentence_tokens = gu_buf_length(sentence_tokens);
// size_t n_expr_tokens = gu_buf_length(expr_tokens);
int score = matrix[i+n_sentence_tokens*j]; long score = matrix[i+n_sentence_tokens*j];
if (score == -1) { if (score == -1) {
if (i >= n_sentence_tokens) if (i >= n_sentence_tokens)
score = 0; score = 0;
@@ -483,40 +638,95 @@ pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens,
matrix, i+1, j) matrix, i+1, j)
+ pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, + pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens,
matrix, i, j); matrix, i, j);
matrix[n_sentence_tokens*i + j] = score; matrix[i + n_sentence_tokens*j] = score;
} }
return score; return score;
} }
static int static long
pgf_lookup_compute_kernel(GuBuf* sentence_tokens, GuBuf* expr_tokens) pgf_lookup_compute_kernel(GuBuf* sentence_tokens, GuBuf* expr_tokens)
{ {
size_t n_sentence_tokens = gu_buf_length(sentence_tokens); size_t n_sentence_tokens = gu_buf_length(sentence_tokens);
size_t n_expr_tokens = gu_buf_length(expr_tokens); size_t n_expr_tokens = gu_buf_length(expr_tokens);
size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1)*sizeof(int); size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1)*sizeof(long);
int* matrix = alloca(size); long* matrix = alloca(size);
memset(matrix, -1, size); memset(matrix, -1, size);
return pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, matrix, 0, 0); return pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, matrix, 0, 0);
} }
typedef struct { typedef struct {
PgfExprProb ep; PgfCncTree ctree;
double score; double score;
} PgfExprScore; } PgfCncTreeScore;
static void
pgf_lookup_ctree_to_expr(PgfCncTree ctree, PgfExprProb* ep,
GuPool* out_pool)
{
size_t n_args = 0;
PgfCncTree* args = NULL;
GuVariantInfo cti = gu_variant_open(ctree);
switch (cti.tag) {
case PGF_CNC_TREE_APP: {
PgfCncTreeApp* fapp = cti.data;
*ep = fapp->fun->absfun->ep;
n_args = fapp->n_args;
args = fapp->args;
break;
}
case PGF_CNC_TREE_CHUNKS: {
PgfCncTreeChunks* fchunks = cti.data;
n_args = fchunks->n_args;
args = fchunks->args;
ep->expr = gu_new_variant_i(out_pool,
PGF_EXPR_META, PgfExprMeta,
.id = 0);
ep->prob = 0;
break;
}
/* case PGF_CNC_TREE_LIT: {
PgfCncTreeLit* flit = cti.data;
break;
}*/
default:
gu_impossible();
}
if (gu_variant_is_null(ep->expr)) {
gu_assert(n_args==1);
pgf_lookup_ctree_to_expr(args[0], ep, out_pool);
} else {
for (size_t i = 0; i < n_args; i++) {
PgfExprProb ep_arg;
pgf_lookup_ctree_to_expr(args[i], &ep_arg, out_pool);
ep->expr = gu_new_variant_i(out_pool,
PGF_EXPR_APP,
PgfExprApp,
ep->expr, ep_arg.expr);
ep->prob += ep_arg.prob;
}
}
}
static void static void
pgf_lookup_enum_next(GuEnum* self, void* to, GuPool* pool) pgf_lookup_enum_next(GuEnum* self, void* to, GuPool* pool)
{ {
PgfLookupEnum* st = gu_container(self, PgfLookupEnum, en); PgfLookupEnum* st = gu_container(self, PgfLookupEnum, en);
PgfExprScore* es = NULL; PgfCncTreeScore* cts = NULL;
while (st->index < gu_buf_length(st->exprs)) { while (st->index < gu_buf_length(st->ctrees)) {
es = gu_buf_index(st->exprs, PgfExprScore, st->index); cts = gu_buf_index(st->ctrees, PgfCncTreeScore, st->index);
st->index++; st->index++;
if (fabs(es->score - st->max) < 0.00005) { if (cts->score == st->max) {
*((PgfExprProb**) to) = &es->ep; PgfExprProb* ep = gu_new(PgfExprProb, st->out_pool);
pgf_lookup_ctree_to_expr(cts->ctree, ep, st->out_pool);
*((PgfExprProb**) to) = ep;
return; return;
} }
} }
@@ -524,6 +734,22 @@ pgf_lookup_enum_next(GuEnum* self, void* to, GuPool* pool)
*((PgfExprProb**) to) = NULL; *((PgfExprProb**) to) = NULL;
} }
static void
pgf_lookup_symbol_token(PgfLinFuncs** funcs, PgfToken tok)
{
PgfLookupState* st = gu_container(funcs, PgfLookupState, funcs);
gu_buf_push(st->expr_tokens, PgfToken, tok);
}
static PgfLinFuncs pgf_lookup_lin_funcs = {
.symbol_token = pgf_lookup_symbol_token,
.begin_phrase = NULL,
.end_phrase = NULL,
.symbol_ne = NULL,
.symbol_bind = NULL,
.symbol_capit = NULL
};
PGF_API GuEnum* PGF_API GuEnum*
pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* pool, GuPool* out_pool) pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* pool, GuPool* out_pool)
{ {
@@ -593,63 +819,62 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
#endif #endif
PgfLookupState st; PgfLookupState st;
st.funcs = &pgf_lookup_lin_funcs;
st.concr = concr;
st.join = join; st.join = join;
st.start_id= meta_id1; st.start_id= meta_id1;
st.choice = gu_new_choice(work_pool); st.choice = gu_new_choice(work_pool);
st.stack = gu_new_buf(PgfMetaId, work_pool); st.stack = gu_new_buf(PgfMetaId, work_pool);
st.exprs = gu_new_buf(PgfExprScore, pool); st.expr_tokens=gu_new_buf(GuString, work_pool);
st.out_pool= out_pool; st.ctrees = gu_new_buf(PgfCncTreeScore, pool);
st.fid = 0;
st.pool = pool;
GuChoiceMark mark = gu_choice_mark(st.choice); GuChoiceMark mark = gu_choice_mark(st.choice);
long sentence_value =
pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens);
double max = 0; double max = 0;
PgfExprScore* es = gu_buf_extend(st.exprs); PgfCncTreeScore* cts = gu_buf_extend(st.ctrees);
for (;;) { for (;;) {
bool res = pgf_lookup_extract(&st, st.start_id, &es->ep); cts->ctree =
pgf_lookup_extract(&st, st.start_id, NULL);
gu_choice_reset(st.choice, mark); if (!gu_variant_is_null(cts->ctree)) {
cts->ctree = pgf_lzr_wrap_linref(cts->ctree, st.pool);
pgf_lzr_linearize(concr, cts->ctree, 0, &st.funcs, st.pool);
if (!gu_choice_advance(st.choice)) cts->score =
break; ((double) pgf_lookup_compute_kernel(sentence_tokens, st.expr_tokens)) /
sqrt(((double) sentence_value) * ((double) pgf_lookup_compute_kernel(st.expr_tokens, st.expr_tokens)));
if (res) { gu_buf_flush(st.expr_tokens);
GuExn* err = gu_exn(work_pool);
GuStringBuf* sbuf = gu_new_string_buf(work_pool);
GuOut* out = gu_string_buf_out(sbuf);
pgf_linearize(concr, es->ep.expr, out, err); #ifdef PGF_LINEARIZER_DEBUG
if (!gu_ok(err)) {
continue;
}
GuBuf* expr_tokens =
pgf_lookup_tokenize(gu_string_buf_data(sbuf),
gu_string_buf_length(sbuf),
work_pool);
es->score =
((double) pgf_lookup_compute_kernel(sentence_tokens, expr_tokens)) /
sqrt(((double) pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens)) * ((double) pgf_lookup_compute_kernel(expr_tokens, expr_tokens)));
#ifdef PGF_LOOKUP_DEBUG
{ {
GuPool* tmp_pool = gu_new_pool(); GuPool* tmp_pool = gu_new_pool();
GuOut* out = gu_file_out(stderr, tmp_pool); GuOut* out = gu_file_out(stderr, tmp_pool);
GuExn* err = gu_exn(tmp_pool); GuExn* err = gu_exn(tmp_pool);
pgf_print_expr(es->ep.expr, NULL, 0, out, err); pgf_print_cnc_tree(cts->ctree, out, err);
gu_printf(out, err, " [%f]\n", es->score); gu_printf(out, err, " [%f]\n", cts->score);
gu_pool_free(tmp_pool); gu_pool_free(tmp_pool);
} }
#endif #endif
if (es->score > max) { if (cts->score > max) {
max = es->score; max = cts->score;
} }
es = gu_buf_extend(st.exprs); cts = gu_buf_extend(st.ctrees);
} }
gu_choice_reset(st.choice, mark);
if (!gu_choice_advance(st.choice))
break;
} }
gu_buf_trim(st.exprs); gu_buf_trim(st.ctrees);
gu_pool_free(work_pool); gu_pool_free(work_pool);
@@ -657,6 +882,7 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
lenum->en.next = pgf_lookup_enum_next; lenum->en.next = pgf_lookup_enum_next;
lenum->max = max; lenum->max = max;
lenum->index = 0; lenum->index = 0;
lenum->exprs = st.exprs; lenum->ctrees = st.ctrees;
lenum->out_pool= out_pool;
return &lenum->en; return &lenum->en;
} }