diff --git a/src/runtime/c/Makefile.am b/src/runtime/c/Makefile.am index cfef382e6..e9dc866a5 100644 --- a/src/runtime/c/Makefile.am +++ b/src/runtime/c/Makefile.am @@ -104,6 +104,7 @@ libpgf_la_SOURCES = \ pgf/expr.h \ pgf/parser.c \ pgf/parser.h \ + pgf/jit.c \ pgf/parseval.c \ pgf/lexer.c \ pgf/lexer.h \ diff --git a/src/runtime/c/pgf/data.h b/src/runtime/c/pgf/data.h index 24ebbea32..1be7a3fcf 100644 --- a/src/runtime/c/pgf/data.h +++ b/src/runtime/c/pgf/data.h @@ -76,6 +76,7 @@ typedef struct { int arity; PgfEquations defns; // maybe null PgfExprProb ep; + void* predicate; } PgfAbsFun; extern GU_DECLARE_TYPE(PgfAbsFun, abstract); @@ -91,7 +92,8 @@ typedef struct { prob_t meta_token_prob; PgfMetaChildMap* meta_child_probs; - GuBuf* functions; // -->PgfFunDecl + GuBuf* functions; // -->PgfAbsFun + void* predicate; } PgfAbsCat; extern GU_DECLARE_TYPE(PgfAbsCat, abstract); diff --git a/src/runtime/c/pgf/jit.c b/src/runtime/c/pgf/jit.c new file mode 100644 index 000000000..16f72139d --- /dev/null +++ b/src/runtime/c/pgf/jit.c @@ -0,0 +1,290 @@ +#include +#include +#include +#include +#include +#include "sys/mman.h" +#include "lightning.h" + +//#define PGF_JIT_DEBUG + + +struct PgfJitState { + GuPool* tmp_pool; + GuPool* pool; + jit_state jit; + jit_insn *buf; + GuBuf* patches; +}; + +#define _jit (state->jit) + +typedef struct { + PgfCId cid; + jit_insn *ref; +} PgfCallPatch; + +// Between two calls to pgf_jit_make_space we are not allowed +// to emit more that JIT_CODE_WINDOW bytes. This is not quite +// safe but this is how GNU lightning is designed. +#define JIT_CODE_WINDOW 128 + +typedef struct { + GuFinalizer fin; + void *page; +} PgfPageFinalizer; + +static void +pgf_jit_finalize_page(GuFinalizer* self) +{ + PgfPageFinalizer* fin = gu_container(self, PgfPageFinalizer, fin); + free(fin->page); +} + +static size_t total_size = 0; + +static void +pgf_jit_alloc_page(PgfJitState* state) +{ + void *page; + + size_t page_size = sysconf(_SC_PAGESIZE); + total_size += page_size; + + if (posix_memalign(&page, page_size, page_size) != 0) { + gu_fatal("Memory allocation failed"); + } + if (mprotect(page, page_size, + PROT_READ | PROT_WRITE | PROT_EXEC) != 0) { + gu_fatal("mprotect failed"); + } + + PgfPageFinalizer* fin = gu_new(PgfPageFinalizer, state->pool); + fin->fin.fn = pgf_jit_finalize_page; + fin->page = page; + gu_pool_finally(state->pool, &fin->fin); + + state->buf = page; + jit_set_ip(state->buf); +} + +PgfJitState* +pgf_jit_init(GuPool* tmp_pool, GuPool* pool) +{ + PgfJitState* state = gu_new(PgfJitState, tmp_pool); + state->tmp_pool = tmp_pool; + state->pool = pool; + state->buf = NULL; + state->patches = gu_new_buf(PgfCallPatch, tmp_pool); + + pgf_jit_alloc_page(state); + + return state; +} + +static void +pgf_jit_make_space(PgfJitState* state) +{ + size_t page_size = sysconf(_SC_PAGESIZE); + if (jit_get_ip().ptr + JIT_CODE_WINDOW > ((char*) state->buf) + page_size) { + jit_flush_code(state->buf, jit_get_ip().ptr); + pgf_jit_alloc_page(state); + } +} + +void +pgf_jit_predicate(PgfJitState* state, + PgfCIdMap* abscats, PgfAbsCat* abscat) +{ +#ifdef PGF_JIT_DEBUG + GuPool* tmp_pool = gu_new_pool(); + GuOut* out = gu_file_out(stderr, tmp_pool); + GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool); + GuExn* err = gu_exn(NULL, type, tmp_pool); + + gu_string_write(abscat->name, wtr, err); + gu_puts(":\n", wtr, err); + + int label = 0; +#endif + + size_t n_funs = gu_buf_length(abscat->functions); + + pgf_jit_make_space(state); + + abscat->predicate = (PgfPredicate) jit_get_ip().ptr; + + jit_prolog(2); + + if (n_funs > 0) { + PgfAbsFun* absfun = + gu_buf_get(abscat->functions, PgfAbsFun*, 0); + +#ifdef PGF_JIT_DEBUG + gu_puts(" TRY_FIRST ", wtr, err); + gu_string_write(absfun->name, wtr, err); + gu_puts("\n", wtr, err); +#endif + + int rs_arg = jit_arg_p(); + int parent_arg = jit_arg_p(); + jit_getarg_p(JIT_V1, rs_arg); + jit_getarg_p(JIT_V2, parent_arg); + + // compile TRY_FIRST + jit_prepare(3); + jit_movi_p(JIT_V0,absfun); + jit_pusharg_p(JIT_V0); + jit_pusharg_p(JIT_V2); + jit_pusharg_p(JIT_V1); + jit_finish(pgf_try_first); + } + +#ifdef PGF_JIT_DEBUG + gu_puts(" RET\n", wtr, err); +#endif + // compile RET + jit_ret(); + +#ifdef PGF_JIT_DEBUG + if (n_funs > 0) { + PgfAbsFun* absfun = + gu_buf_get(abscat->functions, PgfAbsFun*, 0); + + gu_string_write(absfun->name, wtr, err); + gu_puts(":\n", wtr, err); + } +#endif + + for (size_t i = 0; i < n_funs; i++) { + PgfAbsFun* absfun = + gu_buf_get(abscat->functions, PgfAbsFun*, i); + + pgf_jit_make_space(state); + + absfun->predicate = (PgfPredicate) jit_get_ip().ptr; + + jit_prolog(2); + int rs_arg = jit_arg_p(); + int st_arg = jit_arg_p(); + jit_getarg_p(JIT_V1, rs_arg); + jit_getarg_p(JIT_V2, st_arg); + + if (i+1 < n_funs) { + PgfAbsFun* absfun = + gu_buf_get(abscat->functions, PgfAbsFun*, i+1); + +#ifdef PGF_JIT_DEBUG + gu_puts(" TRY_ELSE ", wtr, err); + gu_string_write(absfun->name, wtr, err); + gu_puts("\n", wtr, err); +#endif + + // compile TRY_ELSE + jit_prepare(3); + jit_movi_p(JIT_V0, absfun); + jit_pusharg_p(JIT_V0); + jit_pusharg_p(JIT_V2); + jit_pusharg_p(JIT_V1); + jit_finish(pgf_try_else); + } + + size_t n_hypos = gu_seq_length(absfun->type->hypos); + for (size_t i = 0; i < n_hypos; i++) { + PgfHypo* hypo = gu_seq_index(absfun->type->hypos, PgfHypo, i); + + jit_insn *ref; + + // call the predicate for the category in hypo->type->cid + PgfAbsCat* arg = + gu_map_get(abscats, &hypo->type->cid, PgfAbsCat*); + +#ifdef PGF_JIT_DEBUG + gu_puts(" CALL ", wtr, err); + gu_string_write(hypo->type->cid, wtr, err); + gu_printf(wtr, err, " L%d\n", label); +#endif + + // compile CALL + ref = jit_movi_p(JIT_V0, jit_forward()); + jit_str_p(JIT_V2, JIT_V0); + jit_prepare(2); + jit_pusharg_p(JIT_V2); + jit_pusharg_p(JIT_V1); + if (arg != NULL) { + jit_finish(arg->predicate); + } else { + PgfCallPatch patch; + patch.cid = hypo->type->cid; + patch.ref = jit_finish(jit_forward()); + gu_buf_push(state->patches, PgfCallPatch, patch); + } + +#ifdef PGF_JIT_DEBUG + gu_puts(" RET\n", wtr, err); + gu_printf(wtr, err, "L%d:\n", label++); +#endif + + // compile RET + jit_ret(); + + pgf_jit_make_space(state); + + jit_patch_movi(ref,jit_get_label()); + + jit_prolog(2); + rs_arg = jit_arg_p(); + st_arg = jit_arg_p(); + jit_getarg_p(JIT_V1, rs_arg); + jit_getarg_p(JIT_V2, st_arg); + } + +#ifdef PGF_JIT_DEBUG + gu_puts(" COMPLETE\n", wtr, err); +#endif + + // compile COMPLETE + jit_prepare(2); + jit_pusharg_p(JIT_V2); + jit_pusharg_p(JIT_V1); + jit_finish(pgf_complete); + +#ifdef PGF_JIT_DEBUG + gu_puts(" RET\n", wtr, err); +#endif + + // compile RET + jit_ret(); + +#ifdef PGF_JIT_DEBUG + if (i+1 < n_funs) { + PgfAbsFun* absfun = + gu_buf_get(abscat->functions, PgfAbsFun*, i+1); + + gu_string_write(absfun->name, wtr, err); + gu_puts(":\n", wtr, err); + } +#endif + } + +#ifdef PGF_JIT_DEBUG + gu_pool_free(tmp_pool); +#endif +} + +void +pgf_jit_done(PgfJitState* state, PgfAbstr* abstr) +{ + size_t n_patches = gu_buf_length(state->patches); + for (size_t i = 0; i < n_patches; i++) { + PgfCallPatch* patch = + gu_buf_index(state->patches, PgfCallPatch, i); + PgfAbsCat* arg = + gu_map_get(abstr->cats, &patch->cid, PgfAbsCat*); + gu_assert(arg != NULL); + + jit_patch_at(patch->ref,(jit_insn*) arg->predicate); + } + + jit_flush_code(state->buf, jit_get_ip().ptr); +} diff --git a/src/runtime/c/pgf/lightning/i386/core-32.h b/src/runtime/c/pgf/lightning/i386/core-32.h index 48117ddb9..6d5e8b6ab 100644 --- a/src/runtime/c/pgf/lightning/i386/core-32.h +++ b/src/runtime/c/pgf/lightning/i386/core-32.h @@ -53,7 +53,7 @@ struct jit_local_state { /* Whether a register is used for the user-accessible registers. */ #define jit_save(reg) 1 -#define jit_base_prolog() (_jitl.framesize = 20, _jitl.alloca_offset = _jitl.alloca_slack = 0, \ +#define jit_base_prolog() (_jitl.framesize = 20, _jitl.alloca_offset = _jitl.alloca_slack = 0, _jitl.argssize = 0, \ PUSHLr(_EBX), PUSHLr(_ESI), PUSHLr(_EDI), PUSHLr(_EBP), MOVLrr(_ESP, _EBP)) #define jit_base_ret(ofs) \ (((ofs) < 0 ? LEAVE_() : POPLr(_EBP)), \ diff --git a/src/runtime/c/pgf/pgf.c b/src/runtime/c/pgf/pgf.c index 81b1fa05c..f1b85cae3 100644 --- a/src/runtime/c/pgf/pgf.c +++ b/src/runtime/c/pgf/pgf.c @@ -30,6 +30,7 @@ pgf_read(const char* fpath, PgfReader* rdr = pgf_new_reader(in, pool, tmp_pool, err); PgfPGF* pgf = pgf_read_pgf(rdr); + pgf_reader_done(rdr, pgf); gu_pool_free(tmp_pool); return pgf; diff --git a/src/runtime/c/pgf/reader.c b/src/runtime/c/pgf/reader.c index f1b17c7f7..550cfa5d6 100644 --- a/src/runtime/c/pgf/reader.c +++ b/src/runtime/c/pgf/reader.c @@ -2,6 +2,7 @@ #include "expr.h" #include "literals.h" #include "reader.h" +#include "jit.h" #include #include #include @@ -23,6 +24,7 @@ struct PgfReader { GuExn* err; GuPool* opool; GuSymTable* symtab; + PgfJitState* jit_state; }; typedef struct PgfReadTagExn PgfReadTagExn; @@ -495,7 +497,7 @@ pgf_read_absfuns(PgfReader* rdr) } static PgfAbsCat* -pgf_read_abscat(PgfReader* rdr, PgfAbstr* abstr) +pgf_read_abscat(PgfReader* rdr, PgfAbstr* abstr, PgfCIdMap* abscats) { PgfAbsCat* abscat = gu_new(PgfAbsCat, rdr->opool); @@ -531,6 +533,8 @@ pgf_read_abscat(PgfReader* rdr, PgfAbstr* abstr) gu_map_get(abstr->funs, &name, PgfAbsFun*); gu_buf_push(abscat->functions, PgfAbsFun*, absfun); } + + pgf_jit_predicate(rdr->jit_state, abscats, abscat); return abscat; } @@ -548,7 +552,7 @@ pgf_read_abscats(PgfReader* rdr, PgfAbstr* abstr) gu_return_on_exn(rdr->err, NULL); for (size_t i = 0; i < len; i++) { - PgfAbsCat* abscat = pgf_read_abscat(rdr, abstr); + PgfAbsCat* abscat = pgf_read_abscat(rdr, abstr, abscats); gu_return_on_exn(rdr->err, NULL); gu_map_put(abscats, &abscat->name, PgfAbsCat*, abscat); @@ -1187,5 +1191,12 @@ pgf_new_reader(GuIn* in, GuPool* opool, GuPool* tmp_pool, GuExn* err) rdr->symtab = gu_new_symtable(opool, tmp_pool); rdr->err = err; rdr->in = in; + rdr->jit_state = pgf_jit_init(tmp_pool, rdr->opool); return rdr; } + +void +pgf_reader_done(PgfReader* rdr, PgfPGF* pgf) +{ + pgf_jit_done(rdr->jit_state, &pgf->abstract); +} diff --git a/src/runtime/c/pgf/reader.h b/src/runtime/c/pgf/reader.h index adfcad0ef..95dfc855f 100644 --- a/src/runtime/c/pgf/reader.h +++ b/src/runtime/c/pgf/reader.h @@ -13,4 +13,7 @@ pgf_new_reader(GuIn* in, GuPool* opool, GuPool* tmp_pool, GuExn* err); PgfPGF* pgf_read_pgf(PgfReader* rdr); +void +pgf_reader_done(PgfReader* rdr, PgfPGF* pgf); + #endif // READER_H_ diff --git a/src/runtime/c/pgf/reasoner.c b/src/runtime/c/pgf/reasoner.c index 672d4c5b2..3e5c64692 100644 --- a/src/runtime/c/pgf/reasoner.c +++ b/src/runtime/c/pgf/reasoner.c @@ -1,39 +1,62 @@ #include #include +#include #include -#include -#include //#define PGF_REASONER_DEBUG -typedef struct PgfExprState PgfExprState; - typedef struct { - GuBuf* conts; + GuBuf* parents; GuBuf* exprs; prob_t outside_prob; } PgfAnswers; -struct PgfExprState { - PgfAnswers* answers; - PgfExprProb ep; - PgfHypos hypos; - size_t arg_idx; +#ifdef PGF_REASONER_DEBUG +typedef void (*PgfStatePrinter)(PgfReasonerState* st, + GuWriter* wtr, GuExn* err, + GuPool* tmp_pool); +#endif + +struct PgfReasonerState { + // the jitter expects that continuation is the first field + PgfPredicate continuation; +#ifdef PGF_REASONER_DEBUG + PgfStatePrinter print; +#endif + prob_t prob; }; -typedef enum { - PGF_EXPR_QSTATE_PREDICT, - PGF_EXPR_QSTATE_COMBINE1, - PGF_EXPR_QSTATE_COMBINE2 -} PGF_EXPR_QSTATE_KIND; +struct PgfExprState { + // base must be the first field in order to be able to cast + // from PgfExprState to PgfReasonerState + PgfReasonerState base; + PgfAnswers* answers; + PgfExpr expr; +#ifdef PGF_REASONER_DEBUG + size_t n_args; + size_t arg_idx; +#endif +}; typedef struct { - prob_t prob; - PGF_EXPR_QSTATE_KIND kind; - void* single; - size_t choice_idx; - GuBuf* choices; -} PgfExprQState; + // base must be the first field in order to be able to cast + // from PgfCombine2State to PgfReasonerState + PgfReasonerState base; + GuBuf* exprs; + PgfExprState* parent; + size_t n_choices; + size_t choice; +} PgfCombine1State; + +typedef struct { + // base must be the first field in order to be able to cast + // from PgfCombine2State to PgfReasonerState + PgfReasonerState base; + GuBuf* parents; + PgfExprProb* ep; + size_t n_choices; + size_t choice; +} PgfCombine2State; static GU_DEFINE_TYPE(PgfAnswers, abstract); @@ -41,69 +64,68 @@ typedef GuStringMap PgfAbswersMap; static GU_DEFINE_TYPE(PgfAbswersMap, GuStringMap, gu_ptr_type(PgfAnswers), &gu_null_struct); -typedef struct { +struct PgfReasoner { + GuPool* pool; GuPool* tmp_pool; PgfAbstr* abstract; PgfAbswersMap* table; GuBuf* pqueue; + GuBuf* exprs; PgfExprEnum en; -} PgfReasoner; +}; static int -cmp_expr_qstate(GuOrder* self, const void* a, const void* b) +cmp_expr_state(GuOrder* self, const void* a, const void* b) { - PgfExprQState *q1 = *((PgfExprQState **) a); - PgfExprQState *q2 = *((PgfExprQState **) b); + PgfReasonerState *st1 = *((PgfReasonerState **) a); + PgfReasonerState *st2 = *((PgfReasonerState **) b); - if (q1->prob < q2->prob) + if (st1->prob < st2->prob) return -1; - else if (q1->prob > q2->prob) + else if (st1->prob > st2->prob) return 1; else return 0; } static GuOrder -pgf_expr_qstate_order = { cmp_expr_qstate }; +pgf_expr_state_order = { cmp_expr_state }; #ifdef PGF_REASONER_DEBUG static void -pgf_print_expr_state(PgfExprState* st, - GuWriter* wtr, GuExn* err, GuBuf* stack) +pgf_print_parent_state(PgfExprState* st, + GuWriter* wtr, GuExn* err, GuBuf* stack) { - gu_buf_push(stack, int, (gu_seq_length(st->hypos) - st->arg_idx - 1)); + gu_buf_push(stack, int, (st->n_args - st->arg_idx - 1)); - PgfExprState* cont = gu_buf_get(st->answers->conts, PgfExprState*, 0); - if (cont != NULL) - pgf_print_expr_state(cont, wtr, err, stack); + PgfExprState* parent = gu_buf_get(st->answers->parents, PgfExprState*, 0); + if (parent != NULL) + pgf_print_parent_state(parent, wtr, err, stack); gu_puts(" (", wtr, err); - pgf_print_expr(st->ep.expr, 0, wtr, err); + pgf_print_expr(st->expr, 0, wtr, err); } static void -pgf_print_expr_state0(PgfExprState* st, - GuWriter* wtr, GuExn* err, GuPool* tmp_pool) +pgf_print_expr_state(PgfExprState* st, + GuWriter* wtr, GuExn* err, GuPool* tmp_pool) { - prob_t prob = st->answers->outside_prob+st->ep.prob; - gu_printf(wtr, err, "[%f]", prob); - - size_t n_args = gu_seq_length(st->hypos); + gu_printf(wtr, err, "[%f] ", st->base.prob); GuBuf* stack = gu_new_buf(int, tmp_pool); - if (n_args > 0) - gu_buf_push(stack, int, n_args - st->arg_idx); + if (st->n_args > 0) + gu_buf_push(stack, int, st->n_args - st->arg_idx); PgfExprState* cont = - gu_buf_get(st->answers->conts, PgfExprState*, 0); + gu_buf_get(st->answers->parents, PgfExprState*, 0); if (cont != NULL) - pgf_print_expr_state(cont, wtr, err, stack); + pgf_print_parent_state(cont, wtr, err, stack); - if (n_args > 0) + if (st->n_args > 0) gu_puts(" (", wtr, err); else gu_puts(" ", wtr, err); - pgf_print_expr(st->ep.expr, 0, wtr, err); + pgf_print_expr(st->expr, 0, wtr, err); size_t n_counts = gu_buf_length(stack); for (size_t i = 0; i < n_counts; i++) { @@ -118,130 +140,208 @@ pgf_print_expr_state0(PgfExprState* st, #endif static PgfExprState* -pgf_reasoner_combine(PgfReasoner* rs, - PgfExprState* st, PgfExprProb* ep, - GuPool* pool) -{ - PgfExprState* nst = - gu_new(PgfExprState, rs->tmp_pool); - nst->answers = st->answers; - nst->ep.expr = - gu_new_variant_i(pool, PGF_EXPR_APP, +pgf_combine1_to_expr(PgfCombine1State* st, GuPool* tmp_pool) { + PgfExprProb* ep = + gu_buf_get(st->exprs, PgfExprProb*, st->choice); + + PgfExprState* nst = gu_new(PgfExprState, tmp_pool); + nst->base.continuation = st->parent->base.continuation; + nst->base.prob = st->base.prob; + nst->answers = st->parent->answers; + nst->expr = + gu_new_variant_i(tmp_pool, PGF_EXPR_APP, PgfExprApp, - .fun = st->ep.expr, + .fun = st->parent->expr, .arg = ep->expr); - nst->ep.prob = st->ep.prob+ep->prob; - nst->hypos = st->hypos; - nst->arg_idx = st->arg_idx+1; +#ifdef PGF_REASONER_DEBUG + nst->base.print = (PgfStatePrinter) pgf_print_expr_state; + nst->n_args = st->parent->n_args; + nst->arg_idx = st->parent->arg_idx+1; +#endif + return nst; } -static void -pgf_reasoner_predict(PgfReasoner* rs, PgfExprState* cont, - prob_t outside_prob, PgfCId cat, - GuPool* pool) +static PgfExprState* +pgf_combine2_to_expr(PgfCombine2State* st, GuPool* tmp_pool) { + PgfExprState* parent = + gu_buf_get(st->parents, PgfExprState*, st->choice); + if (parent == NULL) + return NULL; + + PgfExprState* nst = + gu_new(PgfExprState, tmp_pool); + nst->base.continuation = parent->base.continuation; + nst->base.prob = st->base.prob; + nst->answers = parent->answers; + nst->expr = + gu_new_variant_i(tmp_pool, PGF_EXPR_APP, + PgfExprApp, + .fun = parent->expr, + .arg = st->ep->expr); +#ifdef PGF_REASONER_DEBUG + nst->base.print = (PgfStatePrinter) pgf_print_expr_state; + nst->n_args = parent->n_args; + nst->arg_idx = parent->arg_idx+1; +#endif + + return nst; +} + +#ifdef PGF_REASONER_DEBUG +static void +pgf_print_combine1_state(PgfCombine1State* st, + GuWriter* wtr, GuExn* err, GuPool* tmp_pool) +{ + PgfExprState* nst = pgf_combine1_to_expr(st, tmp_pool); + pgf_print_expr_state(nst, wtr, err, tmp_pool); +} + +static void +pgf_print_combine2_state(PgfCombine2State* st, + GuWriter* wtr, GuExn* err, GuPool* tmp_pool) +{ + PgfExprState* nst = pgf_combine2_to_expr(st, tmp_pool); + if (nst != NULL) + pgf_print_expr_state(nst, wtr, err, tmp_pool); +} +#endif + +static void +pgf_combine1(PgfReasoner* rs, PgfCombine1State* st) +{ + PgfExprState* nst = pgf_combine1_to_expr(st, rs->tmp_pool); + nst->base.continuation(rs, &nst->base); + + st->choice++; + + if (st->choice < st->n_choices) { + PgfExprProb* ep = + gu_buf_get(st->exprs, PgfExprProb*, st->choice); + + st->base.prob = st->parent->base.prob + ep->prob; + gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st); + } +} + +void +pgf_try_first(PgfReasoner* rs, PgfExprState* parent, PgfAbsFun* absfun) +{ + PgfCId cat = absfun->type->cid; + PgfAnswers* answers = gu_map_get(rs->table, &cat, PgfAnswers*); if (answers == NULL) { answers = gu_new(PgfAnswers, rs->tmp_pool); - answers->conts = gu_new_buf(PgfExprState*, rs->tmp_pool); - answers->exprs = gu_new_buf(PgfExprProb*, rs->tmp_pool); - answers->outside_prob = outside_prob; + answers->parents = gu_new_buf(PgfExprState*, rs->tmp_pool); + answers->exprs = gu_new_buf(PgfExprProb*, rs->tmp_pool); + answers->outside_prob = parent->base.prob; gu_map_put(rs->table, &cat, PgfAnswers*, answers); } - gu_buf_push(answers->conts, PgfExprState*, cont); + gu_buf_push(answers->parents, PgfExprState*, parent); - if (gu_buf_length(answers->conts) == 1) { - PgfAbsCat* abscat = gu_map_get(rs->abstract->cats, &cat, PgfAbsCat*); - if (abscat == NULL) { - return; - } - - if (gu_buf_length(abscat->functions) > 0) { - PgfExprQState *q = gu_new(PgfExprQState, rs->tmp_pool); - q->kind = PGF_EXPR_QSTATE_PREDICT; - q->single = answers; - q->choice_idx = 0; - q->choices = abscat->functions; - - q->prob = answers->outside_prob + gu_buf_get(q->choices, PgfAbsFun*, 0)->ep.prob; - gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q); - } + if (gu_buf_length(answers->parents) == 1) { + PgfExprState* st = gu_new(PgfExprState, rs->tmp_pool); + st->base.continuation = (PgfPredicate) absfun->predicate; + st->base.prob = answers->outside_prob + absfun->ep.prob; + st->answers = answers; + st->expr = absfun->ep.expr; +#ifdef PGF_REASONER_DEBUG + st->base.print = (PgfStatePrinter) pgf_print_expr_state; + st->n_args = gu_seq_length(absfun->type->hypos); + st->arg_idx = 0; +#endif + gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st); } else { - if (gu_buf_length(answers->exprs) > 0) { - PgfExprQState *q = gu_new(PgfExprQState, rs->tmp_pool); - q->prob = cont->ep.prob + gu_buf_get(answers->exprs, PgfExprProb*, 0)->prob; - q->kind = PGF_EXPR_QSTATE_COMBINE1; - q->single = cont; - q->choice_idx = 0; - q->choices = answers->exprs; + size_t n_exprs = gu_buf_length(answers->exprs); + if (n_exprs > 0) { + PgfExprProb* ep = + gu_buf_get(answers->exprs, PgfExprProb*, 0); - gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q); + PgfCombine1State* nst = gu_new(PgfCombine1State, rs->tmp_pool); + nst->base.continuation = (PgfPredicate) pgf_combine1; + nst->base.prob = parent->base.prob + ep->prob; + nst->exprs = answers->exprs; + nst->choice = 0; + nst->n_choices = gu_buf_length(answers->exprs); + nst->parent = parent; +#ifdef PGF_REASONER_DEBUG + nst->base.print = (PgfStatePrinter) pgf_print_combine1_state; +#endif + gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &nst); } } } +void +pgf_try_else(PgfReasoner* rs, PgfExprState* prev, PgfAbsFun* absfun) +{ + PgfExprState *st = gu_new(PgfExprState, rs->tmp_pool); + st->base.continuation = (PgfPredicate) absfun->predicate; + st->base.prob = prev->answers->outside_prob + absfun->ep.prob; + st->answers = prev->answers; + st->expr = absfun->ep.expr; +#ifdef PGF_REASONER_DEBUG + st->base.print = (PgfStatePrinter) pgf_print_expr_state; + st->n_args = gu_seq_length(absfun->type->hypos); + st->arg_idx = 0; +#endif + gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st); +} + +static void +pgf_combine2(PgfReasoner* rs, PgfCombine2State* st) +{ + PgfExprState* nst = pgf_combine2_to_expr(st, rs->tmp_pool); + if (nst != NULL) { + nst->base.continuation(rs, &nst->base); + } + + st->choice++; + + if (st->choice < st->n_choices) { + PgfExprState* parent = + gu_buf_get(st->parents, PgfExprState*, st->choice); + + st->base.prob = parent->base.prob + st->ep->prob; + gu_buf_heap_push(rs->pqueue, &pgf_expr_state_order, &st); + } +} + +void +pgf_complete(PgfReasoner* rs, PgfExprState* st) +{ + PgfExprProb* ep = gu_new(PgfExprProb, rs->pool); + ep->prob = st->base.prob - st->answers->outside_prob; + ep->expr = st->expr; + gu_buf_push(st->answers->exprs, PgfExprProb*, ep); + + PgfCombine2State* nst = gu_new(PgfCombine2State, rs->tmp_pool); + nst->base.continuation = (PgfPredicate) pgf_combine2; + nst->base.prob = st->base.prob; + nst->parents = st->answers->parents; + nst->choice = 0; + nst->n_choices = gu_buf_length(st->answers->parents); + nst->ep = ep; +#ifdef PGF_REASONER_DEBUG + nst->base.print = (PgfStatePrinter) pgf_print_combine2_state; +#endif + nst->base.continuation(rs, &nst->base); +} + static PgfExprProb* -pgf_reasoner_next(PgfReasoner* rs, GuPool* pool) +pgf_reasoner_next(PgfReasoner* rs) { if (rs->tmp_pool == NULL) return NULL; + + size_t n_exprs = gu_buf_length(rs->exprs); while (gu_buf_length(rs->pqueue) > 0) { - PgfExprQState* q; - gu_buf_heap_pop(rs->pqueue, &pgf_expr_qstate_order, &q); - - PgfExprState* st = NULL; - switch (q->kind) { - case PGF_EXPR_QSTATE_PREDICT: { - PgfAbsFun* absfun = - gu_buf_get(q->choices, PgfAbsFun*, q->choice_idx); - - st = gu_new(PgfExprState, pool); - st->answers = q->single; - st->ep = absfun->ep; - st->hypos = absfun->type->hypos; - st->arg_idx = 0; - - q->choice_idx++; - if (q->choice_idx < gu_buf_length(q->choices)) { - q->prob = st->answers->outside_prob + gu_buf_get(q->choices, PgfAbsFun*, q->choice_idx)->ep.prob; - gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q); - } - break; - } - case PGF_EXPR_QSTATE_COMBINE1: { - PgfExprState* cont = q->single; - PgfExprProb* ep = - gu_buf_get(q->choices, PgfExprProb*, q->choice_idx); - st = pgf_reasoner_combine(rs, cont, ep, pool); - - q->choice_idx++; - if (q->choice_idx < gu_buf_length(q->choices)) { - q->prob = cont->ep.prob + gu_buf_get(q->choices, PgfExprProb*, q->choice_idx)->prob; - gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q); - } - break; - } - case PGF_EXPR_QSTATE_COMBINE2: { - PgfExprState* cont = - gu_buf_get(q->choices, PgfExprState*, q->choice_idx); - PgfExprProb* ep = q->single; - st = pgf_reasoner_combine(rs, cont, ep, pool); - - q->choice_idx++; - if (q->choice_idx < gu_buf_length(q->choices)) { - q->prob = ep->prob + gu_buf_get(q->choices, PgfExprState*, q->choice_idx)->ep.prob; - gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q); - } - break; - } - default: - gu_impossible(); - } - + PgfReasonerState* st; + gu_buf_heap_pop(rs->pqueue, &pgf_expr_state_order, &st); #ifdef PGF_REASONER_DEBUG { @@ -249,45 +349,15 @@ pgf_reasoner_next(PgfReasoner* rs, GuPool* pool) GuOut* out = gu_file_out(stderr, tmp_pool); GuWriter* wtr = gu_new_utf8_writer(out, tmp_pool); GuExn* err = gu_exn(NULL, type, tmp_pool); - pgf_print_expr_state0(st, wtr, err, tmp_pool); + st->print(st, wtr, err, tmp_pool); gu_pool_free(tmp_pool); } #endif - if (st->arg_idx < gu_seq_length(st->hypos)) { - PgfHypo *hypo = gu_seq_index(st->hypos, PgfHypo, st->arg_idx); - prob_t outside_prob = - st->ep.prob+st->answers->outside_prob; - pgf_reasoner_predict(rs, st, outside_prob, - hypo->type->cid, pool); - } else { - gu_buf_push(st->answers->exprs, PgfExprProb*, &st->ep); - - PgfExprProb* target = NULL; - - GuBuf* conts = st->answers->conts; - size_t choice_idx = 0; - PgfExprState* cont = - gu_buf_get(conts, PgfExprState*, 0); - if (cont == NULL) { - target = &st->ep; - cont = gu_buf_get(conts, PgfExprState*, 1); - choice_idx++; - } - - if (choice_idx < gu_buf_length(conts)) { - PgfExprQState *q = gu_new(PgfExprQState, rs->tmp_pool); - q->prob = st->ep.prob + cont->ep.prob; - q->kind = PGF_EXPR_QSTATE_COMBINE2; - q->single = &st->ep; - q->choice_idx = choice_idx; - q->choices = conts; - - gu_buf_heap_push(rs->pqueue, &pgf_expr_qstate_order, &q); - } - - if (target != NULL) - return target; + st->continuation(rs, st); + + if (n_exprs < gu_buf_length(rs->exprs)) { + return gu_buf_get(rs->exprs, PgfExprProb*, n_exprs); } } @@ -301,20 +371,31 @@ static void pgf_reasoner_enum_next(GuEnum* self, void* to, GuPool* pool) { PgfReasoner* pr = gu_container(self, PgfReasoner, en); - *(PgfExprProb**)to = pgf_reasoner_next(pr, pool); + *(PgfExprProb**)to = pgf_reasoner_next(pr); } PgfExprEnum* pgf_generate(PgfPGF* pgf, PgfCId cat, GuPool* pool) { PgfReasoner* rs = gu_new(PgfReasoner, pool); + rs->pool = pool; rs->tmp_pool = gu_new_pool(), rs->abstract = &pgf->abstract, rs->table = gu_map_type_new(PgfAbswersMap, rs->tmp_pool), - rs->pqueue = gu_new_buf(PgfExprQState*, rs->tmp_pool); + rs->pqueue = gu_new_buf(PgfReasonerState*, rs->tmp_pool); + rs->exprs = gu_new_buf(PgfExprProb*, rs->tmp_pool); rs->en.next = pgf_reasoner_enum_next; - pgf_reasoner_predict(rs, NULL, 0, cat, pool); + PgfAnswers* answers = gu_new(PgfAnswers, rs->tmp_pool); + answers->parents = gu_new_buf(PgfExprState*, rs->tmp_pool); + answers->exprs = rs->exprs; + answers->outside_prob = 0; + gu_map_put(rs->table, &cat, PgfAnswers*, answers); + + PgfAbsCat* abscat = gu_map_get(rs->abstract->cats, &cat, PgfAbsCat*); + if (abscat != NULL) { + ((PgfPredicate) abscat->predicate)(rs, NULL); + } return &rs->en; }