the C runtime now has a type prob_t which is used only for probability values

This commit is contained in:
kr.angelov
2012-09-18 09:18:48 +00:00
parent 91ca7c9a1b
commit a307ed6c75
5 changed files with 35 additions and 31 deletions

View File

@@ -188,11 +188,13 @@ GU_DEFINE_TYPE(
GU_MEMBER(PgfCatFun, prob, double), GU_MEMBER(PgfCatFun, prob, double),
GU_MEMBER(PgfCatFun, fun, PgfCId)); GU_MEMBER(PgfCatFun, fun, PgfCId));
static float inf_float = INFINITY; static prob_t inf_prob = INFINITY;
GU_DEFINE_TYPE(prob_t, GuFloating, _);
GU_DEFINE_TYPE(PgfMetaChildMap, GuMap, GU_DEFINE_TYPE(PgfMetaChildMap, GuMap,
gu_type(PgfCat), NULL, gu_type(PgfCat), NULL,
gu_type(float), &inf_float); gu_type(prob_t), &inf_prob);
GU_DEFINE_TYPE( GU_DEFINE_TYPE(
PgfCat, struct, PgfCat, struct,

View File

@@ -123,8 +123,10 @@ struct PgfPGF {
extern GU_DECLARE_TYPE(PgfPGF, struct); extern GU_DECLARE_TYPE(PgfPGF, struct);
typedef float prob_t;
typedef struct { typedef struct {
float prob; prob_t prob;
PgfExpr expr; PgfExpr expr;
} PgfExprProb; } PgfExprProb;
@@ -151,8 +153,8 @@ struct PgfCat {
PgfCId name; PgfCId name;
PgfHypos context; PgfHypos context;
float meta_prob; prob_t meta_prob;
float meta_token_prob; prob_t meta_token_prob;
PgfMetaChildMap* meta_child_probs; PgfMetaChildMap* meta_child_probs;
GuLength n_functions; GuLength n_functions;

View File

@@ -64,7 +64,7 @@ struct PgfExprState {
typedef struct { typedef struct {
PgfExprState *st; PgfExprState *st;
float prob; prob_t prob;
} PgfExprQState; } PgfExprQState;
typedef struct PgfParseResult PgfParseResult; typedef struct PgfParseResult PgfParseResult;
@@ -96,8 +96,8 @@ struct PgfItem {
uint16_t seq_idx; uint16_t seq_idx;
uint8_t tok_idx; uint8_t tok_idx;
uint8_t alt; uint8_t alt;
float inside_prob; prob_t inside_prob;
float outside_prob; prob_t outside_prob;
}; };
typedef struct { typedef struct {
@@ -415,8 +415,8 @@ cmp_item_prob(GuOrder* self, const void* a, const void* b)
PgfItem *item1 = *((PgfItem **) a); PgfItem *item1 = *((PgfItem **) a);
PgfItem *item2 = *((PgfItem **) b); PgfItem *item2 = *((PgfItem **) b);
float prob1 = item1->inside_prob + item1->outside_prob; prob_t prob1 = item1->inside_prob + item1->outside_prob;
float prob2 = item2->inside_prob + item2->outside_prob; prob_t prob2 = item2->inside_prob + item2->outside_prob;
if (prob1 < prob2) if (prob1 < prob2)
return -1; return -1;
@@ -474,7 +474,7 @@ pgf_parsing_get_conts(PgfContsMap* conts_map,
static PgfCCat* static PgfCCat*
pgf_parsing_create_completed(PgfParseState* state, PgfItemBuf* conts, pgf_parsing_create_completed(PgfParseState* state, PgfItemBuf* conts,
float viterbi_prob, PgfCncCat* cnccat) prob_t viterbi_prob, PgfCncCat* cnccat)
{ {
PgfCCat* cat = gu_new(PgfCCat, state->pool); PgfCCat* cat = gu_new(PgfCCat, state->pool);
cat->cnccat = cnccat; cat->cnccat = cnccat;
@@ -535,7 +535,7 @@ pgf_item_set_curr_symbol(PgfItem* item, GuPool* pool)
static PgfItem* static PgfItem*
pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx, pgf_new_item(int pos, PgfCCat* ccat, size_t lin_idx,
PgfProduction prod, PgfItemBuf* conts, PgfProduction prod, PgfItemBuf* conts,
float delta_prob, prob_t delta_prob,
GuPool* pool) GuPool* pool)
{ {
PgfItemBase* base = gu_new(PgfItemBase, pool); PgfItemBase* base = gu_new(PgfItemBase, pool);
@@ -712,7 +712,7 @@ pgf_parsing_combine(PgfParseState* before, PgfParseState* after,
item->base->ccat->cnccat->abscat->meta_child_probs; item->base->ccat->cnccat->abscat->meta_child_probs;
item->inside_prob += item->inside_prob +=
cat->viterbi_prob+ cat->viterbi_prob+
gu_map_get(meta_child_probs, cat->cnccat->abscat, float); gu_map_get(meta_child_probs, cat->cnccat->abscat, prob_t);
PgfSymbol prev = item->curr_sym; PgfSymbol prev = item->curr_sym;
PgfSymbolCat* scat = (PgfSymbolCat*) PgfSymbolCat* scat = (PgfSymbolCat*)
@@ -736,7 +736,7 @@ static void
pgf_parsing_production(PgfParseState* state, pgf_parsing_production(PgfParseState* state,
PgfCCat* ccat, size_t lin_idx, PgfCCat* ccat, size_t lin_idx,
PgfProduction prod, PgfItemBuf* conts, PgfProduction prod, PgfItemBuf* conts,
float delta_prob) prob_t delta_prob)
{ {
PgfItem* item = PgfItem* item =
pgf_new_item(state->offset, ccat, lin_idx, prod, conts, delta_prob, state->pool); pgf_new_item(state->offset, ccat, lin_idx, prod, conts, delta_prob, state->pool);
@@ -933,7 +933,7 @@ pgf_parsing_complete(PgfParseState* before, PgfParseState* after,
static void static void
pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after, pgf_parsing_td_predict(PgfParseState* before, PgfParseState* after,
PgfItem* item, PgfCCat* ccat, size_t lin_idx, PgfItem* item, PgfCCat* ccat, size_t lin_idx,
float delta_prob) prob_t delta_prob)
{ {
gu_enter("-> cat: %d", ccat->fid); gu_enter("-> cat: %d", ccat->fid);
if (gu_seq_is_null(ccat->prods)) { if (gu_seq_is_null(ccat->prods)) {
@@ -1043,7 +1043,7 @@ pgf_parsing_meta_predict(GuMapItor* fn, const void* key, void* value, GuExn* err
(void) (err); (void) (err);
PgfCat* abscat = (PgfCat*) key; PgfCat* abscat = (PgfCat*) key;
float meta_prob = *((float*) value); prob_t meta_prob = *((prob_t*) value);
PgfMetaPredictFn* clo = (PgfMetaPredictFn*) fn; PgfMetaPredictFn* clo = (PgfMetaPredictFn*) fn;
PgfParseState* before = clo->before; PgfParseState* before = clo->before;
PgfParseState* after = clo->after; PgfParseState* after = clo->after;
@@ -1065,12 +1065,12 @@ pgf_parsing_meta_predict(GuMapItor* fn, const void* key, void* value, GuExn* err
} }
} }
static float static prob_t
pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after, pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after,
PgfItemBuf* index, PgfItem* meta_item, PgfItemBuf* index, PgfItem* meta_item,
PgfItemBuf* agenda) PgfItemBuf* agenda)
{ {
float prob = INFINITY; prob_t prob = INFINITY;
PgfMetaChildMap* meta_child_probs = PgfMetaChildMap* meta_child_probs =
meta_item->base->ccat->cnccat->abscat->meta_child_probs; meta_item->base->ccat->cnccat->abscat->meta_child_probs;
@@ -1084,17 +1084,17 @@ pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after,
for (size_t i = 0; i < n_items; i++) { for (size_t i = 0; i < n_items; i++) {
PgfItem *item = gu_buf_get(index, PgfItem*, i); PgfItem *item = gu_buf_get(index, PgfItem*, i);
float meta_prob = prob_t meta_prob =
meta_item->inside_prob+ meta_item->inside_prob+
meta_item->outside_prob+ meta_item->outside_prob+
gu_map_get(meta_child_probs, item->base->ccat->cnccat->abscat, float); gu_map_get(meta_child_probs, item->base->ccat->cnccat->abscat, prob_t);
PgfItemBuf* conts = PgfItemBuf* conts =
pgf_parsing_get_conts(before->conts_map, pgf_parsing_get_conts(before->conts_map,
item->base->ccat, item->base->lin_idx, item->base->ccat, item->base->lin_idx,
before->pool, before->pool); before->pool, before->pool);
if (gu_buf_length(conts) == 0) { if (gu_buf_length(conts) == 0) {
float outside_prob = prob_t outside_prob =
pgf_parsing_bu_predict(before, after, pgf_parsing_bu_predict(before, after,
item->base->conts, meta_item, item->base->conts, meta_item,
conts); conts);
@@ -1122,7 +1122,7 @@ pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after,
else else
gu_buf_push(agenda, PgfItem*, copy); gu_buf_push(agenda, PgfItem*, copy);
float item_prob = prob_t item_prob =
copy->inside_prob+copy->outside_prob; copy->inside_prob+copy->outside_prob;
if (prob > item_prob) if (prob > item_prob)
prob = item_prob; prob = item_prob;
@@ -1133,7 +1133,7 @@ pgf_parsing_bu_predict(PgfParseState* before, PgfParseState* after,
for (size_t i = 0; i < n_items; i++) { for (size_t i = 0; i < n_items; i++) {
PgfItem *item = gu_buf_get(conts, PgfItem*, i); PgfItem *item = gu_buf_get(conts, PgfItem*, i);
float item_prob = prob_t item_prob =
item->inside_prob+item->outside_prob; item->inside_prob+item->outside_prob;
if (prob > item_prob) if (prob > item_prob)
prob = item_prob; prob = item_prob;
@@ -1458,7 +1458,7 @@ pgf_parsing_item(PgfParseState* before, PgfParseState* after, PgfItem* item)
if (after != NULL) { if (after != NULL) {
if (after->ts->lexicon_idx == NULL) { if (after->ts->lexicon_idx == NULL) {
float meta_token_prob = prob_t meta_token_prob =
item->base->ccat->cnccat->abscat->meta_token_prob; item->base->ccat->cnccat->abscat->meta_token_prob;
if (meta_token_prob == INFINITY) if (meta_token_prob == INFINITY)
break; break;
@@ -1508,14 +1508,14 @@ pgf_parsing_proceed(PgfParseState* state, void** output) {
if (state->ps->item_count > state->ps->concr->item_quota) if (state->ps->item_count > state->ps->concr->item_quota)
break; break;
float best_prob = INFINITY; prob_t best_prob = INFINITY;
PgfParseState* before = NULL; PgfParseState* before = NULL;
PgfParseState* st = state; PgfParseState* st = state;
while (st != NULL) { while (st != NULL) {
if (gu_buf_length(st->agenda) > 0) { if (gu_buf_length(st->agenda) > 0) {
PgfItem* item = gu_buf_get(st->agenda, PgfItem*, 0); PgfItem* item = gu_buf_get(st->agenda, PgfItem*, 0);
float item_prob = item->inside_prob+item->outside_prob; prob_t item_prob = item->inside_prob+item->outside_prob;
if (item_prob < best_prob) { if (item_prob < best_prob) {
best_prob = item_prob; best_prob = item_prob;
before = st; before = st;
@@ -1643,7 +1643,7 @@ pgf_expr_qstate_order = { cmp_expr_qstate };
static void static void
pgf_result_cat_init(PgfParseResult* pr, pgf_result_cat_init(PgfParseResult* pr,
PgfExprState* cont, float cont_prob, PgfCCat* ccat) PgfExprState* cont, prob_t cont_prob, PgfCCat* ccat)
{ {
// Checking for loops in the chart // Checking for loops in the chart
if (cont != NULL) { if (cont != NULL) {
@@ -1751,7 +1751,7 @@ pgf_parse_result_next(PgfParseResult* pr, GuPool* pool)
if (q.st->arg_idx < gu_seq_length(q.st->args)) { if (q.st->arg_idx < gu_seq_length(q.st->args)) {
PgfPArg* arg = gu_seq_index(q.st->args, PgfPArg, q.st->arg_idx); PgfPArg* arg = gu_seq_index(q.st->args, PgfPArg, q.st->arg_idx);
float cont_prob = q.prob - arg->ccat->viterbi_prob; prob_t cont_prob = q.prob - arg->ccat->viterbi_prob;
if (arg->ccat->fid < pr->state->ps->concr->total_cats) { if (arg->ccat->fid < pr->state->ps->concr->total_cats) {
q.st->expr = q.st->expr =
gu_new_variant_i(pool, PGF_EXPR_APP, gu_new_variant_i(pool, PGF_EXPR_APP,

View File

@@ -655,11 +655,11 @@ pgf_compute_meta_probs(GuMapItor* fn, const void* key, void* value, GuExn* err)
cat->name = name; cat->name = name;
double mass = 0; prob_t mass = 0;
for (size_t i = 0; i < cat->n_functions; i++) { for (size_t i = 0; i < cat->n_functions; i++) {
mass += cat->functions[i].prob; mass += cat->functions[i].prob;
} }
cat->meta_prob = - log(fabs(1 - mass)); cat->meta_prob = (mass > 1) ? INFINITY : - log(1 - mass);
cat->meta_token_prob = INFINITY; cat->meta_token_prob = INFINITY;
cat->meta_child_probs = NULL; cat->meta_child_probs = NULL;
} }

View File

@@ -87,7 +87,7 @@ int main(int argc, char* argv[]) {
goto fail_read; goto fail_read;
} }
if (!pgf_load_meta_child_probs(pgf, "../../../examples/PennTreebank/test2.probs", pool)) { if (!pgf_load_meta_child_probs(pgf, "../../../examples/PennTreebank/ParseEngAbs2.probs", pool)) {
fprintf(stderr, "Loading meta child probs failed\n"); fprintf(stderr, "Loading meta child probs failed\n");
status = EXIT_FAILURE; status = EXIT_FAILURE;
goto fail_read; goto fail_read;