From 9f44fa89f28521f72fdbdc1b2a0d4a3034716470 Mon Sep 17 00:00:00 2001 From: krasimir Date: Wed, 24 May 2017 19:44:32 +0000 Subject: [PATCH] an optimized string kernel --- src/runtime/c/pgf/lookup.c | 77 +++++++++++++++----------------------- 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/src/runtime/c/pgf/lookup.c b/src/runtime/c/pgf/lookup.c index 434c0f9d8..8a101e931 100644 --- a/src/runtime/c/pgf/lookup.c +++ b/src/runtime/c/pgf/lookup.c @@ -597,63 +597,48 @@ pgf_lookup_tokenize(GuString buf, size_t len, GuPool* pool) return tokens; } -static long -pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, - long* matrix, size_t i, size_t j); - -static long -pgf_lookup_compute_kernel_helper2(GuBuf* sentence_tokens, GuBuf* expr_tokens, - long* matrix, size_t i, size_t j) +static double +pgf_lookup_compute_kernel_helper(GuBuf* sentence_tokens, GuBuf* expr_tokens, + double* matrix, size_t i, size_t j) { - size_t n_expr_tokens = gu_buf_length(expr_tokens); + size_t dim = gu_buf_length(sentence_tokens)+1; - if (j >= n_expr_tokens) - return 0; - - GuString sentence_token = gu_buf_get(sentence_tokens, GuString, i); - GuString expr_token = gu_buf_get(expr_tokens, GuString, j); - if (strcmp(sentence_token, expr_token) == 0) { - return 1 + - pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, - matrix, i+1, j+1) + - pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, - matrix, i, j+1); - } else { - return pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, matrix, i, j+1); - } -} + double score = matrix[i + dim*j]; -static long -pgf_lookup_compute_kernel_helper1(GuBuf* sentence_tokens, GuBuf* expr_tokens, - long* matrix, size_t i, size_t j) -{ - size_t n_sentence_tokens = gu_buf_length(sentence_tokens); + if (score < 0) { + score = 0; + for (size_t l = 0; l < i; l++) { + matrix[l + dim*j] = score; + for (size_t k = j; k > 0; k--) { + GuString sentence_token = gu_buf_get(sentence_tokens, GuString, l); + GuString expr_token = gu_buf_get(expr_tokens, GuString, k-1); - long score = matrix[i+n_sentence_tokens*j]; - if (score == -1) { - if (i >= n_sentence_tokens) - score = 0; - else - score = pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, - matrix, i+1, j) - + pgf_lookup_compute_kernel_helper2(sentence_tokens, expr_tokens, - matrix, i, j); - matrix[i + n_sentence_tokens*j] = score; + if (strcmp(sentence_token, expr_token) == 0) { + score += 1 + pgf_lookup_compute_kernel_helper(sentence_tokens, expr_tokens, matrix, l, k-1); + } + } + } + matrix[i + dim*j] = score; } return score; } -static long +static double pgf_lookup_compute_kernel(GuBuf* sentence_tokens, GuBuf* expr_tokens) { size_t n_sentence_tokens = gu_buf_length(sentence_tokens); size_t n_expr_tokens = gu_buf_length(expr_tokens); - size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1)*sizeof(long); - long* matrix = alloca(size); - memset(matrix, -1, size); + size_t size = (n_sentence_tokens+1)*(n_expr_tokens+1); + double* matrix = alloca(size*sizeof(double)); - return pgf_lookup_compute_kernel_helper1(sentence_tokens, expr_tokens, matrix, 0, 0); + for (size_t i = 0; i < size; i++) { + matrix[i] = -1; + } + + return + pgf_lookup_compute_kernel_helper(sentence_tokens,expr_tokens,matrix, + n_sentence_tokens,n_expr_tokens); } typedef struct { @@ -832,7 +817,7 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po GuChoiceMark mark = gu_choice_mark(st.choice); - long sentence_value = + double sentence_value = pgf_lookup_compute_kernel(sentence_tokens, sentence_tokens); double max = 0; @@ -846,8 +831,8 @@ pgf_lookup_sentence(PgfConcr* concr, PgfType* typ, GuString sentence, GuPool* po pgf_lzr_linearize(concr, cts->ctree, 0, &st.funcs, st.pool); cts->score = - ((double) pgf_lookup_compute_kernel(sentence_tokens, st.expr_tokens)) / - sqrt(((double) sentence_value) * ((double) pgf_lookup_compute_kernel(st.expr_tokens, st.expr_tokens))); + pgf_lookup_compute_kernel(sentence_tokens, st.expr_tokens) / + sqrt(sentence_value * pgf_lookup_compute_kernel(st.expr_tokens, st.expr_tokens)); gu_buf_flush(st.expr_tokens);