Now there is a just-in-time compiler which generates native code for proof search. This is already used by the exhaustive generator. The time to generate 10000 abstract trees with ParseEng went down from 4.43 sec to 0.29 sec.

This commit is contained in:
kr.angelov
2013-06-25 19:22:42 +00:00
parent 6aafb6ccbb
commit 16584d4368
8 changed files with 577 additions and 188 deletions

View File

@@ -104,6 +104,7 @@ libpgf_la_SOURCES = \
pgf/expr.h \
pgf/parser.c \
pgf/parser.h \
pgf/jit.c \
pgf/parseval.c \
pgf/lexer.c \
pgf/lexer.h \

View File

@@ -76,6 +76,7 @@ typedef struct {
int arity;
PgfEquations defns; // maybe null
PgfExprProb ep;
void* predicate;
} PgfAbsFun;
extern GU_DECLARE_TYPE(PgfAbsFun, abstract);
@@ -91,7 +92,8 @@ typedef struct {
prob_t meta_token_prob;
PgfMetaChildMap* meta_child_probs;
GuBuf* functions; // -->PgfFunDecl
GuBuf* functions; // -->PgfAbsFun
void* predicate;
} PgfAbsCat;
extern GU_DECLARE_TYPE(PgfAbsCat, abstract);

290
src/runtime/c/pgf/jit.c Normal file
View File

@@ -0,0 +1,290 @@
#include <gu/seq.h>
#include <gu/file.h>
#include <pgf/data.h>
#include <pgf/jit.h>
#include <pgf/reasoner.h>
#include "sys/mman.h"
#include "lightning.h"
//#define PGF_JIT_DEBUG
struct PgfJitState {
GuPool* tmp_pool;
GuPool* pool;
jit_state jit;
jit_insn *buf;
GuBuf* patches;
};
#define _jit (state->jit)
typedef struct {
PgfCId cid;
jit_insn *ref;
} PgfCallPatch;
// Between two calls to pgf_jit_make_space we are not allowed
// to emit more that JIT_CODE_WINDOW bytes. This is not quite
// safe but this is how GNU lightning is designed.
#define JIT_CODE_WINDOW 128
typedef struct {
GuFinalizer fin;
void *page;
} PgfPageFinalizer;
static void
pgf_jit_finalize_page(GuFinalizer* self)
{
PgfPageFinalizer* fin = gu_container(self, PgfPageFinalizer, fin);
free(fin->page);
}
static size_t total_size = 0;
static void
pgf_jit_alloc_page(PgfJitState* state)
{
void *page;
size_t page_size = sysconf(_SC_PAGESIZE);
total_size += page_size;
if (posix_memalign(&page, page_size, page_size) != 0) {
gu_fatal("Memory allocation failed");
}
if (mprotect(page, page_size,
PROT_READ | PROT_WRITE | PROT_EXEC) != 0) {
gu_fatal("mprotect failed");
}
PgfPageFinalizer* fin = gu_new(PgfPageFinalizer, state->pool);
fin->fin.fn = pgf_jit_finalize_page;
fin->page = page;
gu_pool_finally(state->pool, &fin->fin);
state->buf = page;
jit_set_ip(state->buf);
}
PgfJitState*
pgf_jit_init(GuPool* tmp_pool, GuPool* pool)
{
PgfJitState* state = gu_new(PgfJitState, tmp_pool);
state->tmp_pool = tmp_pool;
state->pool = pool;
state->buf = NULL;
state->patches = gu_new_buf(PgfCallPatch, tmp_pool);
pgf_jit_alloc_page(state);
return state;
}
static void
pgf_jit_make_space(PgfJitState* state)
{
size_t page_size = sysconf(_SC_PAGESIZE);
if (jit_get_ip().ptr + JIT_CODE_WINDOW > ((char*) state->buf) + page_size) {
jit_flush_code(state->buf, jit_get_ip().ptr);
pgf_jit_alloc_page(state);
}
}
void
pgf_jit_predicate(PgfJitState* state,
PgfCIdMap* abscats, PgfAbsCat* abscat)
{
#ifdef PGF_JIT_DEBUG
GuPool* tmp_pool = gu_new_pool();
GuOut* out = gu_file_out(stderr, tmp_pool);
GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool);
GuExn* err = gu_exn(NULL, type, tmp_pool);
gu_string_write(abscat->name, wtr, err);
gu_puts(":\n", wtr, err);
int label = 0;
#endif
size_t n_funs = gu_buf_length(abscat->functions);
pgf_jit_make_space(state);
abscat->predicate = (PgfPredicate) jit_get_ip().ptr;
jit_prolog(2);
if (n_funs > 0) {
PgfAbsFun* absfun =
gu_buf_get(abscat->functions, PgfAbsFun*, 0);
#ifdef PGF_JIT_DEBUG
gu_puts(" TRY_FIRST ", wtr, err);
gu_string_write(absfun->name, wtr, err);
gu_puts("\n", wtr, err);
#endif
int rs_arg = jit_arg_p();
int parent_arg = jit_arg_p();
jit_getarg_p(JIT_V1, rs_arg);
jit_getarg_p(JIT_V2, parent_arg);
// compile TRY_FIRST
jit_prepare(3);
jit_movi_p(JIT_V0,absfun);
jit_pusharg_p(JIT_V0);
jit_pusharg_p(JIT_V2);
jit_pusharg_p(JIT_V1);
jit_finish(pgf_try_first);
}
#ifdef PGF_JIT_DEBUG
gu_puts(" RET\n", wtr, err);
#endif
// compile RET
jit_ret();
#ifdef PGF_JIT_DEBUG
if (n_funs > 0) {
PgfAbsFun* absfun =
gu_buf_get(abscat->functions, PgfAbsFun*, 0);
gu_string_write(absfun->name, wtr, err);
gu_puts(":\n", wtr, err);
}
#endif
for (size_t i = 0; i < n_funs; i++) {
PgfAbsFun* absfun =
gu_buf_get(abscat->functions, PgfAbsFun*, i);
pgf_jit_make_space(state);
absfun->predicate = (PgfPredicate) jit_get_ip().ptr;
jit_prolog(2);
int rs_arg = jit_arg_p();
int st_arg = jit_arg_p();
jit_getarg_p(JIT_V1, rs_arg);
jit_getarg_p(JIT_V2, st_arg);
if (i+1 < n_funs) {
PgfAbsFun* absfun =
gu_buf_get(abscat->functions, PgfAbsFun*, i+1);
#ifdef PGF_JIT_DEBUG
gu_puts(" TRY_ELSE ", wtr, err);
gu_string_write(absfun->name, wtr, err);
gu_puts("\n", wtr, err);
#endif
// compile TRY_ELSE
jit_prepare(3);
jit_movi_p(JIT_V0, absfun);
jit_pusharg_p(JIT_V0);
jit_pusharg_p(JIT_V2);
jit_pusharg_p(JIT_V1);
jit_finish(pgf_try_else);
}
size_t n_hypos = gu_seq_length(absfun->type->hypos);
for (size_t i = 0; i < n_hypos; i++) {
PgfHypo* hypo = gu_seq_index(absfun->type->hypos, PgfHypo, i);
jit_insn *ref;
// call the predicate for the category in hypo->type->cid
PgfAbsCat* arg =
gu_map_get(abscats, &hypo->type->cid, PgfAbsCat*);
#ifdef PGF_JIT_DEBUG
gu_puts(" CALL ", wtr, err);
gu_string_write(hypo->type->cid, wtr, err);
gu_printf(wtr, err, " L%d\n", label);
#endif
// compile CALL
ref = jit_movi_p(JIT_V0, jit_forward());
jit_str_p(JIT_V2, JIT_V0);
jit_prepare(2);
jit_pusharg_p(JIT_V2);
jit_pusharg_p(JIT_V1);
if (arg != NULL) {
jit_finish(arg->predicate);
} else {
PgfCallPatch patch;
patch.cid = hypo->type->cid;
patch.ref = jit_finish(jit_forward());
gu_buf_push(state->patches, PgfCallPatch, patch);
}
#ifdef PGF_JIT_DEBUG
gu_puts(" RET\n", wtr, err);
gu_printf(wtr, err, "L%d:\n", label++);
#endif
// compile RET
jit_ret();
pgf_jit_make_space(state);
jit_patch_movi(ref,jit_get_label());
jit_prolog(2);
rs_arg = jit_arg_p();
st_arg = jit_arg_p();
jit_getarg_p(JIT_V1, rs_arg);
jit_getarg_p(JIT_V2, st_arg);
}
#ifdef PGF_JIT_DEBUG
gu_puts(" COMPLETE\n", wtr, err);
#endif
// compile COMPLETE
jit_prepare(2);
jit_pusharg_p(JIT_V2);
jit_pusharg_p(JIT_V1);
jit_finish(pgf_complete);
#ifdef PGF_JIT_DEBUG
gu_puts(" RET\n", wtr, err);
#endif
// compile RET
jit_ret();
#ifdef PGF_JIT_DEBUG
if (i+1 < n_funs) {
PgfAbsFun* absfun =
gu_buf_get(abscat->functions, PgfAbsFun*, i+1);
gu_string_write(absfun->name, wtr, err);
gu_puts(":\n", wtr, err);
}
#endif
}
#ifdef PGF_JIT_DEBUG
gu_pool_free(tmp_pool);
#endif
}
void
pgf_jit_done(PgfJitState* state, PgfAbstr* abstr)
{
size_t n_patches = gu_buf_length(state->patches);
for (size_t i = 0; i < n_patches; i++) {
PgfCallPatch* patch =
gu_buf_index(state->patches, PgfCallPatch, i);
PgfAbsCat* arg =
gu_map_get(abstr->cats, &patch->cid, PgfAbsCat*);
gu_assert(arg != NULL);
jit_patch_at(patch->ref,(jit_insn*) arg->predicate);
}
jit_flush_code(state->buf, jit_get_ip().ptr);
}

View File

@@ -53,7 +53,7 @@ struct jit_local_state {
/* Whether a register is used for the user-accessible registers. */
#define jit_save(reg) 1
#define jit_base_prolog() (_jitl.framesize = 20, _jitl.alloca_offset = _jitl.alloca_slack = 0, \
#define jit_base_prolog() (_jitl.framesize = 20, _jitl.alloca_offset = _jitl.alloca_slack = 0, _jitl.argssize = 0, \
PUSHLr(_EBX), PUSHLr(_ESI), PUSHLr(_EDI), PUSHLr(_EBP), MOVLrr(_ESP, _EBP))
#define jit_base_ret(ofs) \
(((ofs) < 0 ? LEAVE_() : POPLr(_EBP)), \

View File

@@ -30,6 +30,7 @@ pgf_read(const char* fpath,
PgfReader* rdr = pgf_new_reader(in, pool, tmp_pool, err);
PgfPGF* pgf = pgf_read_pgf(rdr);
pgf_reader_done(rdr, pgf);
gu_pool_free(tmp_pool);
return pgf;

View File

@@ -2,6 +2,7 @@
#include "expr.h"
#include "literals.h"
#include "reader.h"
#include "jit.h"
#include <gu/defs.h>
#include <gu/map.h>
#include <gu/seq.h>
@@ -23,6 +24,7 @@ struct PgfReader {
GuExn* err;
GuPool* opool;
GuSymTable* symtab;
PgfJitState* jit_state;
};
typedef struct PgfReadTagExn PgfReadTagExn;
@@ -495,7 +497,7 @@ pgf_read_absfuns(PgfReader* rdr)
}
static PgfAbsCat*
pgf_read_abscat(PgfReader* rdr, PgfAbstr* abstr)
pgf_read_abscat(PgfReader* rdr, PgfAbstr* abstr, PgfCIdMap* abscats)
{
PgfAbsCat* abscat = gu_new(PgfAbsCat, rdr->opool);
@@ -531,6 +533,8 @@ pgf_read_abscat(PgfReader* rdr, PgfAbstr* abstr)
gu_map_get(abstr->funs, &name, PgfAbsFun*);
gu_buf_push(abscat->functions, PgfAbsFun*, absfun);
}
pgf_jit_predicate(rdr->jit_state, abscats, abscat);
return abscat;
}
@@ -548,7 +552,7 @@ pgf_read_abscats(PgfReader* rdr, PgfAbstr* abstr)
gu_return_on_exn(rdr->err, NULL);
for (size_t i = 0; i < len; i++) {
PgfAbsCat* abscat = pgf_read_abscat(rdr, abstr);
PgfAbsCat* abscat = pgf_read_abscat(rdr, abstr, abscats);
gu_return_on_exn(rdr->err, NULL);
gu_map_put(abscats, &abscat->name, PgfAbsCat*, abscat);
@@ -1187,5 +1191,12 @@ pgf_new_reader(GuIn* in, GuPool* opool, GuPool* tmp_pool, GuExn* err)
rdr->symtab = gu_new_symtable(opool, tmp_pool);
rdr->err = err;
rdr->in = in;
rdr->jit_state = pgf_jit_init(tmp_pool, rdr->opool);
return rdr;
}
void
pgf_reader_done(PgfReader* rdr, PgfPGF* pgf)
{
pgf_jit_done(rdr->jit_state, &pgf->abstract);
}

View File

@@ -13,4 +13,7 @@ pgf_new_reader(GuIn* in, GuPool* opool, GuPool* tmp_pool, GuExn* err);
PgfPGF*
pgf_read_pgf(PgfReader* rdr);
void
pgf_reader_done(PgfReader* rdr, PgfPGF* pgf);
#endif // READER_H_

View File

@@ -1,39 +1,62 @@
#include <pgf/pgf.h>
#include <pgf/data.h>
#include <pgf/reasoner.h>
#include <gu/file.h>
#include <math.h>
#include <stdio.h>
//#define PGF_REASONER_DEBUG
typedef struct PgfExprState PgfExprState;
typedef struct {
GuBuf* conts;
GuBuf* parents;
GuBuf* exprs;
prob_t outside_prob;
} PgfAnswers;
struct PgfExprState {
PgfAnswers* answers;
PgfExprProb ep;
PgfHypos hypos;
size_t arg_idx;
#ifdef PGF_REASONER_DEBUG
typedef void (*PgfStatePrinter)(PgfReasonerState* st,
GuWriter* wtr, GuExn* err,
GuPool* tmp_pool);
#endif
struct PgfReasonerState {
// the jitter expects that continuation is the first field
PgfPredicate continuation;
#ifdef PGF_REASONER_DEBUG
PgfStatePrinter print;
#endif
prob_t prob;
};
typedef enum {
PGF_EXPR_QSTATE_PREDICT,
PGF_EXPR_QSTATE_COMBINE1,
PGF_EXPR_QSTATE_COMBINE2
} PGF_EXPR_QSTATE_KIND;
struct PgfExprState {
// base must be the first field in order to be able to cast
// from PgfExprState to PgfReasonerState
PgfReasonerState base;
PgfAnswers* answers;
PgfExpr expr;
#ifdef PGF_REASONER_DEBUG
size_t n_args;
size_t arg_idx;
#endif
};
typedef struct {
prob_t prob;
PGF_EXPR_QSTATE_KIND kind;
void* single;
size_t choice_idx;
GuBuf* choices;
} PgfExprQState;
// base must be the first field in order to be able to cast
// from PgfCombine2State to PgfReasonerState
PgfReasonerState base;
GuBuf* exprs;
PgfExprState* parent;
size_t n_choices;
size_t choice;
} PgfCombine1State;
typedef struct {
// base must be the first field in order to be able to cast
// from PgfCombine2State to PgfReasonerState
PgfReasonerState base;
GuBuf* parents;
PgfExprProb* ep;
size_t n_choices;
size_t choice;
} PgfCombine2State;
static GU_DEFINE_TYPE(PgfAnswers, abstract);
@@ -41,69 +64,68 @@ typedef GuStringMap PgfAbswersMap;
static GU_DEFINE_TYPE(PgfAbswersMap, GuStringMap, gu_ptr_type(PgfAnswers),
&gu_null_struct);
typedef struct {
struct PgfReasoner {
GuPool* pool;
GuPool* tmp_pool;
PgfAbstr* abstract;
PgfAbswersMap* table;
GuBuf* pqueue;
GuBuf* exprs;
PgfExprEnum en;
} PgfReasoner;
};
static int
cmp_expr_qstate(GuOrder* self, const void* a, const void* b)
cmp_expr_state(GuOrder* self, const void* a, const void* b)
{
PgfExprQState *q1 = *((PgfExprQState **) a);
PgfExprQState *q2 = *((PgfExprQState **) b);
PgfReasonerState *st1 = *((PgfReasonerState **) a);
PgfReasonerState *st2 = *((PgfReasonerState **) b);
if (q1->prob < q2->prob)
if (st1->prob < st2->prob)
return -1;
else if (q1->prob > q2->prob)
else if (st1->prob > st2->prob)
return 1;
else
return 0;
}
static GuOrder
pgf_expr_qstate_order = { cmp_expr_qstate };
pgf_expr_state_order = { cmp_expr_state };
#ifdef PGF_REASONER_DEBUG
static void
pgf_print_expr_state(PgfExprState* st,
GuWriter* wtr, GuExn* err, GuBuf* stack)
pgf_print_parent_state(PgfExprState* st,
GuWriter* wtr, GuExn* err, GuBuf* stack)
{
gu_buf_push(stack, int, (gu_seq_length(st->hypos) - st->arg_idx - 1));
gu_buf_push(stack, int, (st->n_args - st->arg_idx - 1));
PgfExprState* cont = gu_buf_get(st->answers->conts, PgfExprState*, 0);
if (cont != NULL)
pgf_print_expr_state(cont, wtr, err, stack);
PgfExprState* parent = gu_buf_get(st->answers->parents, PgfExprState*, 0);
if (parent != NULL)
pgf_print_parent_state(parent, wtr, err, stack);
gu_puts(" (", wtr, err);
pgf_print_expr(st->ep.expr, 0, wtr, err);
pgf_print_expr(st->expr, 0, wtr, err);
}
static void
pgf_print_expr_state0(PgfExprState* st,
GuWriter* wtr, GuExn* err, GuPool* tmp_pool)
pgf_print_expr_state(PgfExprState* st,
GuWriter* wtr, GuExn* err, GuPool* tmp_pool)
{
prob_t prob = st->answers->outside_prob+st->ep.prob;
gu_printf(wtr, err, "[%f]", prob);
size_t n_args = gu_seq_length(st->hypos);
gu_printf(wtr, err, "[%f] ", st->base.prob);
GuBuf* stack = gu_new_buf(int, tmp_pool);
if (n_args > 0)
gu_buf_push(stack, int, n_args - st->arg_idx);
if (st->n_args > 0)
gu_buf_push(stack, int, st->n_args - st->arg_idx);
PgfExprState* cont =
gu_buf_get(st->answers->conts, PgfExprState*, 0);
gu_buf_get(st->answers->parents, PgfExprState*, 0);
if (cont != NULL)
pgf_print_expr_state(cont, wtr, err, stack);
pgf_print_parent_state(cont, wtr, err, stack);
if (n_args > 0)
if (st->n_args > 0)
gu_puts(" (", wtr, err);
else
gu_puts(" ", wtr, err);
pgf_print_expr(st->ep.expr, 0, wtr, err);
pgf_print_expr(st->expr, 0, wtr, err);
size_t n_counts = gu_buf_length(stack);
for (size_t i = 0; i < n_counts; i++) {
@@ -118,130 +140,208 @@ pgf_print_expr_state0(PgfExprState* st,
#endif
static PgfExprState*
pgf_reasoner_combine(PgfReasoner* rs,
PgfExprState* st, PgfExprProb* ep,
GuPool* pool)
{
PgfExprState* nst =
gu_new(PgfExprState, rs->tmp_pool);
nst->answers = st->answers;
nst->ep.expr =
gu_new_variant_i(pool, PGF_EXPR_APP,
pgf_combine1_to_expr(PgfCombine1State* st, GuPool* tmp_pool) {
PgfExprProb* ep =
gu_buf_get(st->exprs, PgfExprProb*, st->choice);
PgfExprState* nst = gu_new(PgfExprState, tmp_pool);
nst->base.continuation = st->parent->base.continuation;
nst->base.prob = st->base.prob;
nst->answers = st->parent->answers;
nst->expr =
gu_new_variant_i(tmp_pool, PGF_EXPR_APP,
PgfExprApp,
.fun = st->ep.expr,
.fun = st->parent->expr,
.arg = ep->expr);
nst->ep.prob = st->ep.prob+ep->prob;
nst->hypos = st->hypos;
nst->arg_idx = st->arg_idx+1;
#ifdef PGF_REASONER_DEBUG
nst->base.print = (PgfStatePrinter) pgf_print_expr_state;
nst->n_args = st->parent->n_args;
nst->arg_idx = st->parent->arg_idx+1;
#endif
return nst;
}
static void
pgf_reasoner_predict(PgfReasoner* rs, PgfExprState* cont,
prob_t outside_prob, PgfCId cat,
GuPool* pool)
static PgfExprState*
pgf_combine2_to_expr(PgfCombine2State* st, GuPool* tmp_pool)
{
PgfExprState* parent =
gu_buf_get(st->parents, PgfExprState*, st->choice);
if (parent == NULL)
return NULL;
PgfExprState* nst =
gu_new(PgfExprState, tmp_pool);
nst->base.continuation = parent->base.continuation;
nst->base.prob = st->base.prob;
nst->answers = parent->answers;
nst->expr =
gu_new_variant_i(tmp_pool, PGF_EXPR_APP,
PgfExprApp,
.fun = parent->expr,
.arg = st->ep->expr);
#ifdef PGF_REASONER_DEBUG
nst->base.print = (PgfStatePrinter) pgf_print_expr_state;
nst->n_args = parent->n_args;
nst->arg_idx = parent->arg_idx+1;
#endif
return nst;
}
#ifdef PGF_REASONER_DEBUG
static void
pgf_print_combine1_state(PgfCombine1State* st,
GuWriter* wtr, GuExn* err, GuPool* tmp_pool)
{
PgfExprState* nst = pgf_combine1_to_expr(st, tmp_pool);
pgf_print_expr_state(nst, wtr, err, tmp_pool);
}
static void
pgf_print_combine2_state(PgfCombine2State* st,
GuWriter* wtr, GuExn* err, GuPool* tmp_pool)
{
PgfExprState* nst = pgf_combine2_to_expr(st, tmp_pool);
if (nst != NULL)
pgf_print_expr_state(nst, wtr, err, tmp_pool);
}
#endif
static void
pgf_combine1(PgfReasoner* rs, PgfCombine1State* st)
{
PgfExprState* nst = pgf_combine1_to_expr(st, rs->tmp_pool);
nst->base.continuation(rs, &nst->base);
st->choice++;
if (st->choice < st->n_choices) {
PgfExprProb* ep =
gu_buf_get(st->exprs, PgfExprProb*, st->choice);
st->base.prob = st->parent->base.prob + ep->prob;
gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st);
}
}
void
pgf_try_first(PgfReasoner* rs, PgfExprState* parent, PgfAbsFun* absfun)
{
PgfCId cat = absfun->type->cid;
PgfAnswers* answers = gu_map_get(rs->table, &cat, PgfAnswers*);
if (answers == NULL) {
answers = gu_new(PgfAnswers, rs->tmp_pool);
answers->conts = gu_new_buf(PgfExprState*, rs->tmp_pool);
answers->exprs = gu_new_buf(PgfExprProb*, rs->tmp_pool);
answers->outside_prob = outside_prob;
answers->parents = gu_new_buf(PgfExprState*, rs->tmp_pool);
answers->exprs = gu_new_buf(PgfExprProb*, rs->tmp_pool);
answers->outside_prob = parent->base.prob;
gu_map_put(rs->table, &cat, PgfAnswers*, answers);
}
gu_buf_push(answers->conts, PgfExprState*, cont);
gu_buf_push(answers->parents, PgfExprState*, parent);
if (gu_buf_length(answers->conts) == 1) {
PgfAbsCat* abscat = gu_map_get(rs->abstract->cats, &cat, PgfAbsCat*);
if (abscat == NULL) {
return;
}
if (gu_buf_length(abscat->functions) > 0) {
PgfExprQState *q = gu_new(PgfExprQState, rs->tmp_pool);
q->kind = PGF_EXPR_QSTATE_PREDICT;
q->single = answers;
q->choice_idx = 0;
q->choices = abscat->functions;
q->prob = answers->outside_prob + gu_buf_get(q->choices, PgfAbsFun*, 0)->ep.prob;
gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q);
}
if (gu_buf_length(answers->parents) == 1) {
PgfExprState* st = gu_new(PgfExprState, rs->tmp_pool);
st->base.continuation = (PgfPredicate) absfun->predicate;
st->base.prob = answers->outside_prob + absfun->ep.prob;
st->answers = answers;
st->expr = absfun->ep.expr;
#ifdef PGF_REASONER_DEBUG
st->base.print = (PgfStatePrinter) pgf_print_expr_state;
st->n_args = gu_seq_length(absfun->type->hypos);
st->arg_idx = 0;
#endif
gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st);
} else {
if (gu_buf_length(answers->exprs) > 0) {
PgfExprQState *q = gu_new(PgfExprQState, rs->tmp_pool);
q->prob = cont->ep.prob + gu_buf_get(answers->exprs, PgfExprProb*, 0)->prob;
q->kind = PGF_EXPR_QSTATE_COMBINE1;
q->single = cont;
q->choice_idx = 0;
q->choices = answers->exprs;
size_t n_exprs = gu_buf_length(answers->exprs);
if (n_exprs > 0) {
PgfExprProb* ep =
gu_buf_get(answers->exprs, PgfExprProb*, 0);
gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q);
PgfCombine1State* nst = gu_new(PgfCombine1State, rs->tmp_pool);
nst->base.continuation = (PgfPredicate) pgf_combine1;
nst->base.prob = parent->base.prob + ep->prob;
nst->exprs = answers->exprs;
nst->choice = 0;
nst->n_choices = gu_buf_length(answers->exprs);
nst->parent = parent;
#ifdef PGF_REASONER_DEBUG
nst->base.print = (PgfStatePrinter) pgf_print_combine1_state;
#endif
gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &nst);
}
}
}
void
pgf_try_else(PgfReasoner* rs, PgfExprState* prev, PgfAbsFun* absfun)
{
PgfExprState *st = gu_new(PgfExprState, rs->tmp_pool);
st->base.continuation = (PgfPredicate) absfun->predicate;
st->base.prob = prev->answers->outside_prob + absfun->ep.prob;
st->answers = prev->answers;
st->expr = absfun->ep.expr;
#ifdef PGF_REASONER_DEBUG
st->base.print = (PgfStatePrinter) pgf_print_expr_state;
st->n_args = gu_seq_length(absfun->type->hypos);
st->arg_idx = 0;
#endif
gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st);
}
static void
pgf_combine2(PgfReasoner* rs, PgfCombine2State* st)
{
PgfExprState* nst = pgf_combine2_to_expr(st, rs->tmp_pool);
if (nst != NULL) {
nst->base.continuation(rs, &nst->base);
}
st->choice++;
if (st->choice < st->n_choices) {
PgfExprState* parent =
gu_buf_get(st->parents, PgfExprState*, st->choice);
st->base.prob = parent->base.prob + st->ep->prob;
gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st);
}
}
void
pgf_complete(PgfReasoner* rs, PgfExprState* st)
{
PgfExprProb* ep = gu_new(PgfExprProb, rs->pool);
ep->prob = st->base.prob - st->answers->outside_prob;
ep->expr = st->expr;
gu_buf_push(st->answers->exprs, PgfExprProb*, ep);
PgfCombine2State* nst = gu_new(PgfCombine2State, rs->tmp_pool);
nst->base.continuation = (PgfPredicate) pgf_combine2;
nst->base.prob = st->base.prob;
nst->parents = st->answers->parents;
nst->choice = 0;
nst->n_choices = gu_buf_length(st->answers->parents);
nst->ep = ep;
#ifdef PGF_REASONER_DEBUG
nst->base.print = (PgfStatePrinter) pgf_print_combine2_state;
#endif
nst->base.continuation(rs, &nst->base);
}
static PgfExprProb*
pgf_reasoner_next(PgfReasoner* rs, GuPool* pool)
pgf_reasoner_next(PgfReasoner* rs)
{
if (rs->tmp_pool == NULL)
return NULL;
size_t n_exprs = gu_buf_length(rs->exprs);
while (gu_buf_length(rs->pqueue) > 0) {
PgfExprQState* q;
gu_buf_heap_pop(rs->pqueue, &pgf_expr_qstate_order, &q);
PgfExprState* st = NULL;
switch (q->kind) {
case PGF_EXPR_QSTATE_PREDICT: {
PgfAbsFun* absfun =
gu_buf_get(q->choices, PgfAbsFun*, q->choice_idx);
st = gu_new(PgfExprState, pool);
st->answers = q->single;
st->ep = absfun->ep;
st->hypos = absfun->type->hypos;
st->arg_idx = 0;
q->choice_idx++;
if (q->choice_idx < gu_buf_length(q->choices)) {
q->prob = st->answers->outside_prob + gu_buf_get(q->choices, PgfAbsFun*, q->choice_idx)->ep.prob;
gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q);
}
break;
}
case PGF_EXPR_QSTATE_COMBINE1: {
PgfExprState* cont = q->single;
PgfExprProb* ep =
gu_buf_get(q->choices, PgfExprProb*, q->choice_idx);
st = pgf_reasoner_combine(rs, cont, ep, pool);
q->choice_idx++;
if (q->choice_idx < gu_buf_length(q->choices)) {
q->prob = cont->ep.prob + gu_buf_get(q->choices, PgfExprProb*, q->choice_idx)->prob;
gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q);
}
break;
}
case PGF_EXPR_QSTATE_COMBINE2: {
PgfExprState* cont =
gu_buf_get(q->choices, PgfExprState*, q->choice_idx);
PgfExprProb* ep = q->single;
st = pgf_reasoner_combine(rs, cont, ep, pool);
q->choice_idx++;
if (q->choice_idx < gu_buf_length(q->choices)) {
q->prob = ep->prob + gu_buf_get(q->choices, PgfExprState*, q->choice_idx)->ep.prob;
gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q);
}
break;
}
default:
gu_impossible();
}
PgfReasonerState* st;
gu_buf_heap_pop(rs->pqueue, &pgf_expr_state_order, &st);
#ifdef PGF_REASONER_DEBUG
{
@@ -249,45 +349,15 @@ pgf_reasoner_next(PgfReasoner* rs, GuPool* pool)
GuOut* out = gu_file_out(stderr, tmp_pool);
GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool);
GuExn* err = gu_exn(NULL, type, tmp_pool);
pgf_print_expr_state0(st, wtr, err, tmp_pool);
st->print(st, wtr, err, tmp_pool);
gu_pool_free(tmp_pool);
}
#endif
if (st->arg_idx < gu_seq_length(st->hypos)) {
PgfHypo *hypo = gu_seq_index(st->hypos, PgfHypo, st->arg_idx);
prob_t outside_prob =
st->ep.prob+st->answers->outside_prob;
pgf_reasoner_predict(rs, st, outside_prob,
hypo->type->cid, pool);
} else {
gu_buf_push(st->answers->exprs, PgfExprProb*, &st->ep);
PgfExprProb* target = NULL;
GuBuf* conts = st->answers->conts;
size_t choice_idx = 0;
PgfExprState* cont =
gu_buf_get(conts, PgfExprState*, 0);
if (cont == NULL) {
target = &st->ep;
cont = gu_buf_get(conts, PgfExprState*, 1);
choice_idx++;
}
if (choice_idx < gu_buf_length(conts)) {
PgfExprQState *q = gu_new(PgfExprQState, rs->tmp_pool);
q->prob = st->ep.prob + cont->ep.prob;
q->kind = PGF_EXPR_QSTATE_COMBINE2;
q->single = &st->ep;
q->choice_idx = choice_idx;
q->choices = conts;
gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q);
}
if (target != NULL)
return target;
st->continuation(rs, st);
if (n_exprs < gu_buf_length(rs->exprs)) {
return gu_buf_get(rs->exprs, PgfExprProb*, n_exprs);
}
}
@@ -301,20 +371,31 @@ static void
pgf_reasoner_enum_next(GuEnum* self, void* to, GuPool* pool)
{
PgfReasoner* pr = gu_container(self, PgfReasoner, en);
*(PgfExprProb**)to = pgf_reasoner_next(pr, pool);
*(PgfExprProb**)to = pgf_reasoner_next(pr);
}
PgfExprEnum*
pgf_generate(PgfPGF* pgf, PgfCId cat, GuPool* pool)
{
PgfReasoner* rs = gu_new(PgfReasoner, pool);
rs->pool = pool;
rs->tmp_pool = gu_new_pool(),
rs->abstract = &pgf->abstract,
rs->table = gu_map_type_new(PgfAbswersMap, rs->tmp_pool),
rs->pqueue = gu_new_buf(PgfExprQState*, rs->tmp_pool);
rs->pqueue = gu_new_buf(PgfReasonerState*, rs->tmp_pool);
rs->exprs = gu_new_buf(PgfExprProb*, rs->tmp_pool);
rs->en.next = pgf_reasoner_enum_next;
pgf_reasoner_predict(rs, NULL, 0, cat, pool);
PgfAnswers* answers = gu_new(PgfAnswers, rs->tmp_pool);
answers->parents = gu_new_buf(PgfExprState*, rs->tmp_pool);
answers->exprs = rs->exprs;
answers->outside_prob = 0;
gu_map_put(rs->table, &cat, PgfAnswers*, answers);
PgfAbsCat* abscat = gu_map_get(rs->abstract->cats, &cat, PgfAbsCat*);
if (abscat != NULL) {
((PgfPredicate) abscat->predicate)(rs, NULL);
}
return &rs->en;
}