1
0
forked from GitHub/gf-core

compute the right word probability

This commit is contained in:
kr.angelov
2014-03-12 15:36:40 +00:00
parent ca8dd1e8cc
commit ae1512c926
5 changed files with 40 additions and 47 deletions

View File

@@ -2532,7 +2532,9 @@ pgf_morpho_iter(PgfProductionIdx* idx,
PgfCId lemma = entry->papp->fun->absfun->name; PgfCId lemma = entry->papp->fun->absfun->name;
GuString analysis = entry->ccat->cnccat->labels[entry->lin_idx]; GuString analysis = entry->ccat->cnccat->labels[entry->lin_idx];
prob_t prob = entry->papp->fun->absfun->ep.prob;
prob_t prob = entry->ccat->cnccat->abscat->prob +
entry->papp->fun->absfun->ep.prob;
callback->callback(callback, callback->callback(callback,
lemma, analysis, prob, err); lemma, analysis, prob, err);
if (!gu_ok(err)) if (!gu_ok(err))

View File

@@ -4,6 +4,7 @@
#include <gu/mem.h> #include <gu/mem.h>
#include <gu/exn.h> #include <gu/exn.h>
#include <gu/utf8.h> #include <gu/utf8.h>
#include <math.h>
#include <jni.h> #include <jni.h>
#ifndef __MINGW32__ #ifndef __MINGW32__
#include <alloca.h> #include <alloca.h>
@@ -504,6 +505,7 @@ Java_org_grammaticalframework_pgf_Concr_tabularLinearize(JNIEnv* env, jobject se
typedef struct { typedef struct {
PgfMorphoCallback fn; PgfMorphoCallback fn;
jobject analyses; jobject analyses;
prob_t prob;
JNIEnv* env; JNIEnv* env;
jmethodID addId; jmethodID addId;
jclass an_class; jclass an_class;
@@ -530,6 +532,8 @@ jpgf_collect_morpho(PgfMorphoCallback* self,
(*env)->DeleteLocalRef(env, jan); (*env)->DeleteLocalRef(env, jan);
(*env)->DeleteLocalRef(env, janalysis); (*env)->DeleteLocalRef(env, janalysis);
(*env)->DeleteLocalRef(env, jlemma); (*env)->DeleteLocalRef(env, jlemma);
callback->prob += exp(-prob);
} }
JNIEXPORT jobject JNICALL JNIEXPORT jobject JNICALL
@@ -548,7 +552,7 @@ Java_org_grammaticalframework_pgf_Concr_lookupMorpho(JNIEnv* env, jobject self,
GuExn* err = gu_new_exn(NULL, gu_kind(type), tmp_pool); GuExn* err = gu_new_exn(NULL, gu_kind(type), tmp_pool);
JMorphoCallback callback = { { jpgf_collect_morpho }, analyses, env, addId, an_class, an_constrId }; JMorphoCallback callback = { { jpgf_collect_morpho }, analyses, 0, env, addId, an_class, an_constrId };
pgf_lookup_morpho(get_ref(env, self), j2gu_string(env, sentence, tmp_pool), pgf_lookup_morpho(get_ref(env, self), j2gu_string(env, sentence, tmp_pool),
&callback.fn, err); &callback.fn, err);
if (!gu_ok(err)) { if (!gu_ok(err)) {
@@ -604,21 +608,10 @@ Java_org_grammaticalframework_pgf_FullFormIterator_fetchFullFormEntry
GuString form = pgf_fullform_get_string(entry); GuString form = pgf_fullform_get_string(entry);
jclass entry_class = (*env)->FindClass(env, "org/grammaticalframework/pgf/FullFormEntry");
jmethodID entry_constrId = (*env)->GetMethodID(env, entry_class, "<init>", "(Ljava/lang/String;JLorg/grammaticalframework/pgf/Concr;)V");
jobject jentry = (*env)->NewObject(env, entry_class, entry_constrId, gu2j_string(env,form), p2l(entry), jconcr);
return jentry;
}
JNIEXPORT jobject JNICALL
Java_org_grammaticalframework_pgf_FullFormEntry_getAnalyses
(JNIEnv* env, jobject self)
{
jclass list_class = (*env)->FindClass(env, "java/util/ArrayList"); jclass list_class = (*env)->FindClass(env, "java/util/ArrayList");
jmethodID list_constrId = (*env)->GetMethodID(env, list_class, "<init>", "()V"); jmethodID list_constrId = (*env)->GetMethodID(env, list_class, "<init>", "()V");
jobject analyses = (*env)->NewObject(env, list_class, list_constrId); jobject analyses = (*env)->NewObject(env, list_class, list_constrId);
jmethodID addId = (*env)->GetMethodID(env, list_class, "add", "(Ljava/lang/Object;)Z"); jmethodID addId = (*env)->GetMethodID(env, list_class, "add", "(Ljava/lang/Object;)Z");
jclass an_class = (*env)->FindClass(env, "org/grammaticalframework/pgf/MorphoAnalysis"); jclass an_class = (*env)->FindClass(env, "org/grammaticalframework/pgf/MorphoAnalysis");
@@ -627,8 +620,8 @@ Java_org_grammaticalframework_pgf_FullFormEntry_getAnalyses
GuPool* tmp_pool = gu_local_pool(); GuPool* tmp_pool = gu_local_pool();
GuExn* err = gu_new_exn(NULL, gu_kind(type), tmp_pool); GuExn* err = gu_new_exn(NULL, gu_kind(type), tmp_pool);
JMorphoCallback callback = { { jpgf_collect_morpho }, analyses, env, addId, an_class, an_constrId }; JMorphoCallback callback = { { jpgf_collect_morpho }, analyses, 0, env, addId, an_class, an_constrId };
pgf_fullform_get_analyses(get_ref(env, self), &callback.fn, err); pgf_fullform_get_analyses(entry, &callback.fn, err);
if (!gu_ok(err)) { if (!gu_ok(err)) {
if (gu_exn_caught(err) == gu_type(PgfExn)) { if (gu_exn_caught(err) == gu_type(PgfExn)) {
GuString msg = (GuString) gu_exn_caught_data(err); GuString msg = (GuString) gu_exn_caught_data(err);
@@ -641,7 +634,11 @@ Java_org_grammaticalframework_pgf_FullFormEntry_getAnalyses
gu_pool_free(tmp_pool); gu_pool_free(tmp_pool);
return analyses; jclass entry_class = (*env)->FindClass(env, "org/grammaticalframework/pgf/FullFormEntry");
jmethodID entry_constrId = (*env)->GetMethodID(env, entry_class, "<init>", "(Ljava/lang/String;DLjava/util/List;)V");
jobject jentry = (*env)->NewObject(env, entry_class, entry_constrId, gu2j_string(env,form), - log(callback.prob), analyses);
return jentry;
} }
JNIEXPORT jboolean JNICALL JNIEXPORT jboolean JNICALL

View File

@@ -4,18 +4,24 @@ import java.util.List;
public class FullFormEntry { public class FullFormEntry {
private String form; private String form;
private long ref; private double prob;
private Concr concr; private List<MorphoAnalysis> analyses;
public FullFormEntry(String form, long ref, Concr concr) { public FullFormEntry(String form, double prob, List<MorphoAnalysis> analyses) {
this.form = form; this.form = form;
this.ref = ref; this.prob = prob;
this.concr = concr; this.analyses = analyses;
} }
public String getForm() { public String getForm() {
return form; return form;
} }
public native List<MorphoAnalysis> getAnalyses(); public double getProb() {
return prob;
}
public List<MorphoAnalysis> getAnalyses() {
return analyses;
}
} }

View File

@@ -1515,7 +1515,7 @@ Concr_bracketedLinearize(ConcrObject* self, PyObject *args)
state.funcs = &pgf_bracket_lin_funcs; state.funcs = &pgf_bracket_lin_funcs;
state.stack = gu_new_buf(PyObject*, tmp_pool); state.stack = gu_new_buf(PyObject*, tmp_pool);
state.list = list; state.list = list;
pgf_lzr_linearize(self->concr, ctree, 0, &state.funcs); pgf_lzr_linearize(self->concr, ctree, 0, &state.funcs, tmp_pool);
gu_pool_free(tmp_pool); gu_pool_free(tmp_pool);

View File

@@ -324,28 +324,16 @@ public class Translator {
return getSourceConcr().lookupMorpho(sentence); return getSourceConcr().lookupMorpho(sentence);
} }
private static class WordProb implements Comparable<WordProb> {
String word;
double prob;
@Override
public int compareTo(WordProb another) {
return Double.compare(prob, another.prob);
}
}
public CompletionInfo[] lookupWordPrefix(String prefix) { public CompletionInfo[] lookupWordPrefix(String prefix) {
PriorityQueue<WordProb> queue = new PriorityQueue<WordProb>(); PriorityQueue<FullFormEntry> queue =
new PriorityQueue<FullFormEntry>(500, new Comparator<FullFormEntry>() {
@Override
public int compare(FullFormEntry lhs, FullFormEntry rhs) {
return Double.compare(lhs.getProb(), rhs.getProb());
}
});
for (FullFormEntry entry : getSourceConcr().lookupWordPrefix(prefix)) { for (FullFormEntry entry : getSourceConcr().lookupWordPrefix(prefix)) {
WordProb wp = new WordProb(); queue.add(entry);
wp.word = entry.getForm();
wp.prob = 0;
for (MorphoAnalysis an : entry.getAnalyses()) {
wp.prob += an.getProb();
}
queue.add(wp);
if (queue.size() >= 1000) if (queue.size() >= 1000)
break; break;
} }
@@ -353,7 +341,7 @@ public class Translator {
CompletionInfo[] completions = new CompletionInfo[Math.min(queue.size(), 5)+1]; CompletionInfo[] completions = new CompletionInfo[Math.min(queue.size(), 5)+1];
completions[0] = new CompletionInfo(0, 0, prefix); completions[0] = new CompletionInfo(0, 0, prefix);
for (int i = 1; i < completions.length; i++) { for (int i = 1; i < completions.length; i++) {
completions[i] = new CompletionInfo(i,i,queue.poll().word); completions[i] = new CompletionInfo(i,i,queue.poll().getForm());
} }
if (completions.length > 1) { if (completions.length > 1) {