diff --git a/src/runtime/c/pgf/parser.c b/src/runtime/c/pgf/parser.c index 4dfef0ee5..7b0dda0be 100644 --- a/src/runtime/c/pgf/parser.c +++ b/src/runtime/c/pgf/parser.c @@ -57,6 +57,7 @@ typedef struct { prob_t heuristic_factor; PgfCallbacksMap* callbacks; + PgfOracleCallback* oracle; } PgfParsing; typedef enum { BIND_NONE, BIND_HARD, BIND_SOFT } BIND_TYPE; @@ -938,6 +939,15 @@ pgf_parsing_new_production(PgfItem* item, PgfExprProb *ep, GuPool *pool) static void pgf_parsing_complete(PgfParsing* ps, PgfItem* item, PgfExprProb *ep) { + if (ps->oracle && ps->oracle->complete) { + // ask the oracle whether to complete + if (!ps->oracle->complete(ps->oracle, + item->conts->ccat->cnccat->abscat->name, + item->conts->ccat->cnccat->labels[item->conts->lin_idx], + ps->before->end_offset)) + return; + } + PgfProduction prod = pgf_parsing_new_production(item, ep, ps->pool); #ifdef PGF_COUNTS_DEBUG @@ -1269,6 +1279,15 @@ pgf_parsing_td_predict(PgfParsing* ps, * of this category at the current position, * so predict it. */ + if (ps->oracle != NULL && ps->oracle->predict) { + // if there is an oracle ask him if this prediction is appropriate + if (!ps->oracle->predict(ps->oracle, + ccat->cnccat->abscat->name, + ccat->cnccat->labels[lin_idx], + ps->before->end_offset)) + return; + } + conts->outside_prob = item->inside_prob-conts->ccat->viterbi_prob+ item->conts->outside_prob; @@ -1451,40 +1470,49 @@ pgf_parsing_symbol(PgfParsing* ps, PgfItem* item, PgfSymbol sym) bool match = false; if (!ps->before->needs_bind) { - PgfLiteralCallback* callback = - gu_map_get(ps->callbacks, - parg->ccat->cnccat, - PgfLiteralCallback*); + size_t start = ps->before->end_offset; + size_t offset = start; + PgfExprProb *ep = NULL; - if (callback != NULL) { - size_t start = ps->before->end_offset; - size_t offset = start; - PgfExprProb *ep = - callback->match(callback, ps->concr, - slit->r, - ps->sentence, &offset, - ps->out_pool); + if (ps->oracle != NULL && ps->oracle->literal) { + ep = ps->oracle->literal(ps->oracle, + parg->ccat->cnccat->abscat->name, + parg->ccat->cnccat->labels[slit->r], + &offset, + ps->out_pool); + } else { + PgfLiteralCallback* callback = + gu_map_get(ps->callbacks, + parg->ccat->cnccat, + PgfLiteralCallback*); - if (ep != NULL) { - PgfProduction prod; - PgfProductionExtern* pext = - gu_new_variant(PGF_PRODUCTION_EXTERN, - PgfProductionExtern, - &prod, ps->pool); - pext->ep = ep; - pext->lins = NULL; - - PgfItem* item = - pgf_new_item(ps, conts, prod); - item->curr_sym = pgf_collect_extern_tok(ps,start,offset); - item->sym_idx = pgf_item_symbols_length(item); - PgfParseState* state = - pgf_new_parse_state(ps, offset, BIND_NONE, - item->inside_prob+item->conts->outside_prob); - gu_buf_heap_push(state->agenda, pgf_item_prob_order, &item); - match = true; + if (callback != NULL) { + ep = callback->match(callback, ps->concr, + slit->r, + ps->sentence, &offset, + ps->out_pool); } } + + if (ep != NULL) { + PgfProduction prod; + PgfProductionExtern* pext = + gu_new_variant(PGF_PRODUCTION_EXTERN, + PgfProductionExtern, + &prod, ps->pool); + pext->ep = ep; + pext->lins = NULL; + + PgfItem* item = + pgf_new_item(ps, conts, prod); + item->curr_sym = pgf_collect_extern_tok(ps,start,offset); + item->sym_idx = pgf_item_symbols_length(item); + PgfParseState* state = + pgf_new_parse_state(ps, offset, BIND_NONE, + item->inside_prob+item->conts->outside_prob); + gu_buf_heap_push(state->agenda, pgf_item_prob_order, &item); + match = true; + } } if (!match) { @@ -1659,7 +1687,8 @@ pgf_parsing_set_default_factors(PgfParsing* ps, PgfAbstr* abstr) } static PgfParsing* -pgf_new_parsing(PgfConcr* concr, GuString sentence, PgfCallbacksMap* callbacks, +pgf_new_parsing(PgfConcr* concr, GuString sentence, + PgfCallbacksMap* callbacks, PgfOracleCallback* oracle, GuPool* pool, GuPool* out_pool) { PgfParsing* ps = gu_new(PgfParsing, pool); @@ -1685,6 +1714,7 @@ pgf_new_parsing(PgfConcr* concr, GuString sentence, PgfCallbacksMap* callbacks, ps->free_item = NULL; ps->heuristic_factor = 0; ps->callbacks = callbacks; + ps->oracle = oracle; pgf_parsing_set_default_factors(ps, concr->abstr); @@ -1857,7 +1887,8 @@ pgf_parse_result_is_new(PgfExprState* st) static PgfParsing* pgf_parsing_init(PgfConcr* concr, PgfCId cat, size_t lin_idx, GuString sentence, - double heuristic_factor, PgfCallbacksMap* callbacks, + double heuristic_factor, + PgfCallbacksMap* callbacks, PgfOracleCallback* oracle, GuExn* err, GuPool* pool, GuPool* out_pool) { PgfCncCat* cnccat = @@ -1871,7 +1902,7 @@ pgf_parsing_init(PgfConcr* concr, PgfCId cat, size_t lin_idx, gu_assert(lin_idx < cnccat->n_lins); PgfParsing* ps = - pgf_new_parsing(concr, sentence, callbacks, pool, out_pool); + pgf_new_parsing(concr, sentence, callbacks, oracle, pool, out_pool); if (heuristic_factor >= 0) { ps->heuristic_factor = heuristic_factor; @@ -2100,7 +2131,52 @@ pgf_parse_with_heuristics(PgfConcr* concr, PgfCId cat, GuString sentence, // Begin parsing a sentence with the specified category PgfParsing* ps = - pgf_parsing_init(concr, cat, 0, sentence, heuristics, callbacks, err, pool, out_pool); + pgf_parsing_init(concr, cat, 0, sentence, heuristics, callbacks, NULL, err, pool, out_pool); + if (ps == NULL) { + return NULL; + } + +#ifdef PGF_COUNTS_DEBUG + pgf_parsing_print_counts(ps); +#endif + + while (gu_buf_length(ps->expr_queue) == 0) { + if (!pgf_parsing_proceed(ps)) { + GuExnData* exn = gu_raise(err, PgfParseError); + exn->data = (void*) pgf_parsing_last_token(ps, exn->pool); + return NULL; + } + +#ifdef PGF_COUNTS_DEBUG + pgf_parsing_print_counts(ps); +#endif + } + + // Now begin enumerating the resulting syntax trees + ps->en.next = pgf_parse_result_enum_next; + return &ps->en; +} + +PgfExprEnum* +pgf_parse_with_oracle(PgfConcr* concr, PgfCId cat, + GuString sentence, + PgfOracleCallback* oracle, + GuExn* err, + GuPool* pool, GuPool* out_pool) +{ + if (concr->sequences == NULL || + concr->cnccats == NULL) { + GuExnData* err_data = gu_raise(err, PgfExn); + if (err_data) { + err_data->data = "The concrete syntax is not loaded"; + return NULL; + } + } + + // Begin parsing a sentence with the specified category + PgfCallbacksMap* callbacks = pgf_new_callbacks_map(concr, out_pool); + PgfParsing* ps = + pgf_parsing_init(concr, cat, 0, sentence, -1, callbacks, oracle, err, pool, out_pool); if (ps == NULL) { return NULL; } @@ -2162,7 +2238,7 @@ pgf_complete(PgfConcr* concr, PgfCId cat, GuString sentence, PgfCallbacksMap* callbacks = pgf_new_callbacks_map(concr, pool); PgfParsing* ps = - pgf_parsing_init(concr, cat, 0, sentence, -1.0, callbacks, err, pool, pool); + pgf_parsing_init(concr, cat, 0, sentence, -1.0, callbacks, NULL, err, pool, pool); if (ps == NULL) { return NULL; } diff --git a/src/runtime/c/pgf/pgf.h b/src/runtime/c/pgf/pgf.h index 80cbe29f2..4a6199c6f 100644 --- a/src/runtime/c/pgf/pgf.h +++ b/src/runtime/c/pgf/pgf.h @@ -137,6 +137,31 @@ pgf_parse_with_heuristics(PgfConcr* concr, PgfCId cat, GuExn* err, GuPool* pool, GuPool* out_pool); +typedef struct PgfOracleCallback PgfOracleCallback; + +struct PgfOracleCallback { + bool (*predict) (PgfOracleCallback* self, + PgfCId cat, + GuString label, + size_t offset); + bool (*complete)(PgfOracleCallback* self, + PgfCId cat, + GuString label, + size_t offset); + PgfExprProb* (*literal)(PgfOracleCallback* self, + PgfCId cat, + GuString label, + size_t* poffset, + GuPool *out_pool); +}; + +PgfExprEnum* +pgf_parse_with_oracle(PgfConcr* concr, PgfCId cat, + GuString sentence, + PgfOracleCallback* oracle, + GuExn* err, + GuPool* pool, GuPool* out_pool); + typedef struct { PgfToken tok; PgfCId cat; diff --git a/src/runtime/haskell-bind/PGF2.hsc b/src/runtime/haskell-bind/PGF2.hsc index e88dcc9ce..80abc3775 100644 --- a/src/runtime/haskell-bind/PGF2.hsc +++ b/src/runtime/haskell-bind/PGF2.hsc @@ -17,7 +17,8 @@ module PGF2 (-- * CId -- * PGF PGF,readPGF,AbsName,abstractName,Cat,startCat,categories, -- * Concrete syntax - ConcName,Concr,languages,parse,parseWithHeuristics, + ConcName,Concr,languages,parse, + parseWithHeuristics, parseWithOracle, hasLinearization,linearize,linearizeAll,alignWords, -- * Types Type(..), Hypo, BindType(..), showType, functionType, @@ -340,6 +341,88 @@ mkCallbacksMap concr callbacks pool = do predict_callback _ _ _ _ = return nullPtr +-- | The oracle is a triple of functions. +-- The first two take a category name and a linearization field name +-- and they should return True/False when the corresponding +-- prediction or completion is appropriate. The third function +-- is the oracle for literals. +type Oracle = (Maybe (Cat -> String -> Int -> Bool) + ,Maybe (Cat -> String -> Int -> Bool) + ,Maybe (Cat -> String -> Int -> Maybe (Expr,Float,Int)) + ) + +parseWithOracle :: Concr -- ^ the language with which we parse + -> Cat -- ^ the start category + -> String -- ^ the input sentence + -> Oracle + -> Either String [(Expr,Float)] +parseWithOracle lang cat sent (predict,complete,literal) = + unsafePerformIO $ + do parsePl <- gu_new_pool + exprPl <- gu_new_pool + exn <- gu_new_exn parsePl + enum <- withCString cat $ \cat -> + withCString sent $ \sent -> do + predictPtr <- maybe (return nullFunPtr) (wrapOracleCallback . oracleWrapper) predict + completePtr <- maybe (return nullFunPtr) (wrapOracleCallback . oracleWrapper) complete + literalPtr <- maybe (return nullFunPtr) (wrapOracleLiteralCallback . oracleLiteralWrapper) literal + cback <- hspgf_new_oracle_callback predictPtr completePtr literalPtr parsePl + pgf_parse_with_oracle (concr lang) cat sent cback exn parsePl exprPl + failed <- gu_exn_is_raised exn + if failed + then do is_parse_error <- gu_exn_caught exn gu_exn_type_PgfParseError + if is_parse_error + then do c_tok <- (#peek GuExn, data.data) exn + tok <- peekCString c_tok + gu_pool_free parsePl + gu_pool_free exprPl + return (Left tok) + else do is_exn <- gu_exn_caught exn gu_exn_type_PgfExn + if is_exn + then do c_msg <- (#peek GuExn, data.data) exn + msg <- peekCString c_msg + gu_pool_free parsePl + gu_pool_free exprPl + throwIO (PGFError msg) + else do gu_pool_free parsePl + gu_pool_free exprPl + throwIO (PGFError "Parsing failed") + else do parseFPl <- newForeignPtr gu_pool_finalizer parsePl + exprFPl <- newForeignPtr gu_pool_finalizer exprPl + exprs <- fromPgfExprEnum enum parseFPl (lang,exprFPl) + return (Right exprs) + where + oracleWrapper oracle _ catPtr lblPtr offset = do + cat <- peekCString catPtr + lbl <- peekCString lblPtr + return (oracle cat lbl (fromIntegral offset)) + + oracleLiteralWrapper oracle _ catPtr lblPtr poffset out_pool = do + cat <- peekCString catPtr + lbl <- peekCString lblPtr + offset <- peek poffset + case oracle cat lbl (fromIntegral offset) of + Just (e,prob,offset) -> + do poke poffset (fromIntegral offset) + + -- here we copy the expression to out_pool + c_e <- withGuPool $ \tmpPl -> do + exn <- gu_new_exn tmpPl + + (sb,out) <- newOut tmpPl + let printCtxt = nullPtr + pgf_print_expr (expr e) printCtxt 1 out exn + c_str <- gu_string_buf_freeze sb tmpPl + + guin <- gu_string_in c_str tmpPl + pgf_read_expr guin out_pool exn + + ep <- gu_malloc out_pool (#size PgfExprProb) + (#poke PgfExprProb, expr) ep c_e + (#poke PgfExprProb, prob) ep prob + return ep + Nothing -> do return nullPtr + hasLinearization :: Concr -> Fun -> Bool hasLinearization lang id = unsafePerformIO $ withCString id (pgf_has_linearization (concr lang)) diff --git a/src/runtime/haskell-bind/PGF2/FFI.hs b/src/runtime/haskell-bind/PGF2/FFI.hs index fc658d83d..67830e890 100644 --- a/src/runtime/haskell-bind/PGF2/FFI.hs +++ b/src/runtime/haskell-bind/PGF2/FFI.hs @@ -98,6 +98,7 @@ data PgfMorphoCallback data PgfPrintContext data PgfType data PgfCallbacksMap +data PgfOracleCallback data PgfCncTree foreign import ccall "pgf/pgf.h pgf_read" @@ -179,6 +180,21 @@ foreign import ccall "pgf/pgf.h pgf_new_callbacks_map" foreign import ccall hspgf_callbacks_map_add_literal :: Ptr PgfConcr -> Ptr PgfCallbacksMap -> CString -> FunPtr LiteralMatchCallback -> FunPtr LiteralPredictCallback -> Ptr GuPool -> IO () +type OracleCallback = Ptr PgfOracleCallback -> CString -> CString -> CInt -> IO Bool +type OracleLiteralCallback = Ptr PgfOracleCallback -> CString -> CString -> Ptr CInt -> Ptr GuPool -> IO (Ptr PgfExprProb) + +foreign import ccall "wrapper" + wrapOracleCallback :: OracleCallback -> IO (FunPtr OracleCallback) + +foreign import ccall "wrapper" + wrapOracleLiteralCallback :: OracleLiteralCallback -> IO (FunPtr OracleLiteralCallback) + +foreign import ccall + hspgf_new_oracle_callback :: FunPtr OracleCallback -> FunPtr OracleCallback -> FunPtr OracleLiteralCallback -> Ptr GuPool -> IO (Ptr PgfOracleCallback) + +foreign import ccall "pgf/pgf.h pgf_parse_with_oracle" + pgf_parse_with_oracle :: Ptr PgfConcr -> CString -> CString -> Ptr PgfOracleCallback -> Ptr GuExn -> Ptr GuPool -> Ptr GuPool -> IO (Ptr GuEnum) + foreign import ccall "pgf/pgf.h pgf_lookup_morpho" pgf_lookup_morpho :: Ptr PgfConcr -> CString -> Ptr PgfMorphoCallback -> Ptr GuExn -> IO () diff --git a/src/runtime/haskell-bind/utils.c b/src/runtime/haskell-bind/utils.c index a00527df5..0dd9ae03b 100644 --- a/src/runtime/haskell-bind/utils.c +++ b/src/runtime/haskell-bind/utils.c @@ -67,3 +67,33 @@ hspgf_callbacks_map_add_literal(PgfConcr* concr, PgfCallbacksMap* callbacks, gu_pool_finally(pool, &callback->fin); pgf_callbacks_map_add_literal(concr, callbacks, cat, &callback->callback); } + +typedef struct { + PgfOracleCallback oracle; + GuFinalizer fin; +} HSPgfOracleCallback; + +static void +hspgf_oracle_callback_fin(GuFinalizer* self) +{ + HSPgfOracleCallback* oracle = gu_container(self, HSPgfOracleCallback, fin); + + if (oracle->oracle.predict != NULL) + hs_free_fun_ptr((HsFunPtr) oracle->oracle.predict); + if (oracle->oracle.complete != NULL) + hs_free_fun_ptr((HsFunPtr) oracle->oracle.complete); + if (oracle->oracle.literal != NULL) + hs_free_fun_ptr((HsFunPtr) oracle->oracle.literal); +} + +PgfOracleCallback* +hspgf_new_oracle_callback(HsFunPtr predict, HsFunPtr complete, HsFunPtr literal, GuPool* pool) +{ + HSPgfOracleCallback* oracle = gu_new(HSPgfOracleCallback, pool); + oracle->oracle.predict = (void*) predict; + oracle->oracle.complete = (void*) complete; + oracle->oracle.literal = (void*) literal; + oracle->fin.fn = hspgf_oracle_callback_fin; + gu_pool_finally(pool, &oracle->fin); + return &oracle->oracle; +}