an optimized string kernel

This commit is contained in:
krasimir
2017-05-24 19:44:32 +00:00
parent aa836aa86d
commit 9f44fa89f2

View File

@@ -597,63 +597,48 @@ pgf_lookup_tokenize(GuString buf, size_t len, GuPool* pool)
return tokens;
}
static long
pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens,
long* matrix, size_t i, size_t j);
static long
pgf_lookup_compute_kernel_helper2(GuBuf* sentence_tokens, GuBuf* expr_tokens,
long* matrix, size_t i, size_t j)
static double
pgf_lookup_compute_kernel_helper(GuBuf* sentence_tokens, GuBuf* expr_tokens,
double* matrix, size_t i, size_t j)
{
size_t n_expr_tokens = gu_buf_length(expr_tokens);
size_t dim = gu_buf_length(sentence_tokens)+1;
if (j >= n_expr_tokens)
return 0;
GuString sentence_token = gu_buf_get(sentence_tokens, GuString, i);
GuString expr_token = gu_buf_get(expr_tokens, GuString, j);
if (strcmp(sentence_token, expr_token) == 0) {
return 1 +
pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens,
matrix, i+1, j+1) +
pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens,
matrix, i, j+1);
} else {
return pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, matrix, i, j+1);
}
}
double score = matrix[i + dim*j];
static long
pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens,
long* matrix, size_t i, size_t j)
{
size_t n_sentence_tokens = gu_buf_length(sentence_tokens);
if (score < 0) {
score = 0;
for (size_t l = 0; l < i; l++) {
matrix[l + dim*j] = score;
for (size_t k = j; k > 0; k--) {
GuString sentence_token = gu_buf_get(sentence_tokens, GuString, l);
GuString expr_token = gu_buf_get(expr_tokens, GuString, k-1);
long score = matrix[i+n_sentence_tokens*j];
if (score == -1) {
if (i >= n_sentence_tokens)
score = 0;
else
score = pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens,
matrix, i+1, j)
+ pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens,
matrix, i, j);
matrix[i + n_sentence_tokens*j] = score;
if (strcmp(sentence_token, expr_token) == 0) {
score += 1 + pgf_lookup_compute_kernel_helper(sentence_tokens, expr_tokens, matrix, l, k-1);
}
}
}
matrix[i + dim*j] = score;
}
return score;
}
static long
static double
pgf_lookup_compute_kernel(GuBuf* sentence_tokens, GuBuf* expr_tokens)
{
size_t n_sentence_tokens = gu_buf_length(sentence_tokens);
size_t n_expr_tokens = gu_buf_length(expr_tokens);
size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1)*sizeof(long);
long* matrix = alloca(size);
memset(matrix, -1, size);
size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1);
double* matrix = alloca(size*sizeof(double));
return pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, matrix, 0, 0);
for (size_t i = 0; i < size; i++) {
matrix[i] = -1;
}
return
pgf_lookup_compute_kernel_helper(sentence_tokens,expr_tokens,matrix,
n_sentence_tokens,n_expr_tokens);
}
typedef struct {
@@ -832,7 +817,7 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
GuChoiceMark mark = gu_choice_mark(st.choice);
long sentence_value =
double sentence_value =
pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens);
double max = 0;
@@ -846,8 +831,8 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
pgf_lzr_linearize(concr, cts->ctree, 0, &st.funcs, st.pool);
cts->score =
((double) pgf_lookup_compute_kernel(sentence_tokens, st.expr_tokens)) /
sqrt(((double) sentence_value) * ((double) pgf_lookup_compute_kernel(st.expr_tokens, st.expr_tokens)));
pgf_lookup_compute_kernel(sentence_tokens, st.expr_tokens) /
sqrt(sentence_value * pgf_lookup_compute_kernel(st.expr_tokens, st.expr_tokens));
gu_buf_flush(st.expr_tokens);