an optimized string kernel

This commit is contained in:
krasimir
2017-05-24 19:44:32 +00:00
parent aa836aa86d
commit 9f44fa89f2

View File

@@ -597,63 +597,48 @@ pgf_lookup_tokenize(GuString buf, size_t len, GuPool* pool)
return tokens; return tokens;
} }
static long static double
pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, pgf_lookup_compute_kernel_helper(GuBuf* sentence_tokens, GuBuf* expr_tokens,
long* matrix, size_t i, size_t j); double* matrix, size_t i, size_t j)
static long
pgf_lookup_compute_kernel_helper2(GuBuf* sentence_tokens, GuBuf* expr_tokens,
long* matrix, size_t i, size_t j)
{ {
size_t n_expr_tokens = gu_buf_length(expr_tokens); size_t dim = gu_buf_length(sentence_tokens)+1;
if (j >= n_expr_tokens) double score = matrix[i + dim*j];
return 0;
GuString sentence_token = gu_buf_get(sentence_tokens, GuString, i);
GuString expr_token = gu_buf_get(expr_tokens, GuString, j);
if (strcmp(sentence_token, expr_token) == 0) {
return 1 +
pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens,
matrix, i+1, j+1) +
pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens,
matrix, i, j+1);
} else {
return pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, matrix, i, j+1);
}
}
static long if (score < 0) {
pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, score = 0;
long* matrix, size_t i, size_t j) for (size_t l = 0; l < i; l++) {
{ matrix[l + dim*j] = score;
size_t n_sentence_tokens = gu_buf_length(sentence_tokens); for (size_t k = j; k > 0; k--) {
GuString sentence_token = gu_buf_get(sentence_tokens, GuString, l);
GuString expr_token = gu_buf_get(expr_tokens, GuString, k-1);
long score = matrix[i+n_sentence_tokens*j]; if (strcmp(sentence_token, expr_token) == 0) {
if (score == -1) { score += 1 + pgf_lookup_compute_kernel_helper(sentence_tokens, expr_tokens, matrix, l, k-1);
if (i >= n_sentence_tokens) }
score = 0; }
else }
score = pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, matrix[i + dim*j] = score;
matrix, i+1, j)
+ pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens,
matrix, i, j);
matrix[i + n_sentence_tokens*j] = score;
} }
return score; return score;
} }
static long static double
pgf_lookup_compute_kernel(GuBuf* sentence_tokens, GuBuf* expr_tokens) pgf_lookup_compute_kernel(GuBuf* sentence_tokens, GuBuf* expr_tokens)
{ {
size_t n_sentence_tokens = gu_buf_length(sentence_tokens); size_t n_sentence_tokens = gu_buf_length(sentence_tokens);
size_t n_expr_tokens = gu_buf_length(expr_tokens); size_t n_expr_tokens = gu_buf_length(expr_tokens);
size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1)*sizeof(long); size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1);
long* matrix = alloca(size); double* matrix = alloca(size*sizeof(double));
memset(matrix, -1, size);
return pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, matrix, 0, 0); for (size_t i = 0; i < size; i++) {
matrix[i] = -1;
}
return
pgf_lookup_compute_kernel_helper(sentence_tokens,expr_tokens,matrix,
n_sentence_tokens,n_expr_tokens);
} }
typedef struct { typedef struct {
@@ -832,7 +817,7 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
GuChoiceMark mark = gu_choice_mark(st.choice); GuChoiceMark mark = gu_choice_mark(st.choice);
long sentence_value = double sentence_value =
pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens); pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens);
double max = 0; double max = 0;
@@ -846,8 +831,8 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po
pgf_lzr_linearize(concr, cts->ctree, 0, &st.funcs, st.pool); pgf_lzr_linearize(concr, cts->ctree, 0, &st.funcs, st.pool);
cts->score = cts->score =
((double) pgf_lookup_compute_kernel(sentence_tokens, st.expr_tokens)) / pgf_lookup_compute_kernel(sentence_tokens, st.expr_tokens) /
sqrt(((double) sentence_value) * ((double) pgf_lookup_compute_kernel(st.expr_tokens, st.expr_tokens))); sqrt(sentence_value * pgf_lookup_compute_kernel(st.expr_tokens, st.expr_tokens));
gu_buf_flush(st.expr_tokens); gu_buf_flush(st.expr_tokens);