diff --git a/src/compiler/GF/Command/Importing.hs b/src/compiler/GF/Command/Importing.hs index ba8d51a6a..5d1f88ee7 100644 --- a/src/compiler/GF/Command/Importing.hs +++ b/src/compiler/GF/Command/Importing.hs @@ -64,8 +64,14 @@ importPGF opts Nothing f then removeFile f' else return () putStr ("(Boot image "++f'++") ") - fmap Just (bootNGF f f') - | otherwise = fmap Just (readPGF f) + mb_probs <- case flag optProbsFile opts of + Nothing -> return Nothing + Just file -> fmap Just (readProbabilitiesFromFile file) + fmap Just (bootNGFWithProbs f mb_probs f') + | otherwise = do mb_probs <- case flag optProbsFile opts of + Nothing -> return Nothing + Just file -> fmap Just (readProbabilitiesFromFile file) + fmap Just (readPGFWithProbs f mb_probs) importPGF opts (Just pgf) f = fmap Just (modifyPGF pgf (mergePGF f) `catch` (\e@(PGFError loc msg) -> if msg == "The abstract syntax names doesn't match" @@ -73,7 +79,6 @@ importPGF opts (Just pgf) f = fmap Just (modifyPGF pgf (mergePGF f) `catc readPGF f else throwIO e)) - importSource :: Options -> [FilePath] -> IO (ModuleName,SourceGrammar) importSource opts files = fmap snd (batchCompile opts files) diff --git a/src/compiler/GF/Compiler.hs b/src/compiler/GF/Compiler.hs index e5b14117b..92b207900 100644 --- a/src/compiler/GF/Compiler.hs +++ b/src/compiler/GF/Compiler.hs @@ -140,7 +140,10 @@ unionPGFFiles opts fs = doIt = case fs of [] -> return () - (f:fs) -> do pgf <- if snd (flag optLinkTargets opts) + (f:fs) -> do mb_probs <- case flag optProbsFile opts of + Nothing -> return Nothing + Just file -> fmap Just (readProbabilitiesFromFile file) + pgf <- if snd (flag optLinkTargets opts) then case flag optName opts of Just name -> do let fname = maybe id () (flag optOutputDir opts) (name<.>"ngf") putStrLnE ("(Boot image "++fname++")") @@ -148,10 +151,10 @@ unionPGFFiles opts fs = if exists then removeFile fname else return () - echo (\f -> bootNGF f fname) f + echo (\f -> bootNGFWithProbs f mb_probs fname) f Nothing -> do putStrLnE $ "To boot from a list of .pgf files add option -name" - echo readPGF f - else echo readPGF f + echo (\f -> readPGFWithProbs f mb_probs) f + else echo (\f -> readPGFWithProbs f mb_probs) f pgf <- foldM (\pgf -> echo (modifyPGF pgf . mergePGF)) pgf fs let pgfFile = outputPath opts (grammarName opts pgf <.> "pgf") if pgfFile `elem` fs diff --git a/src/runtime/c/pgf/namespace.h b/src/runtime/c/pgf/namespace.h index 4e3362968..e331a47e4 100644 --- a/src/runtime/c/pgf/namespace.h +++ b/src/runtime/c/pgf/namespace.h @@ -407,6 +407,33 @@ void namespace_iter(Namespace map, PgfItor* itor, PgfExn *err) return; } +template +void namespace_vec_fill_names(Namespace node, size_t offs, Vector *vec) +{ + if (node == 0) + return; + + namespace_vec_fill_names(node->left, offs, vec); + + offs += namespace_size(node->left); + vector_elem(vec, offs++)->name = &node->value->name; + + namespace_vec_fill_names(node->right, offs, vec); +} + +template +Vector *namespace_to_sorted_names(Namespace node) +{ + Vector *vec = (Vector *) + malloc(sizeof(Vector)+node->sz*sizeof(A)); + if (errno != 0) + throw pgf_systemerror(errno); + vec->len = node->sz; + memset(vec->data, 0, node->sz*sizeof(A)); + namespace_vec_fill_names(node, 0, vec); + return vec; +} + template void namespace_release(Namespace node) { diff --git a/src/runtime/c/pgf/pgf.cxx b/src/runtime/c/pgf/pgf.cxx index 0aaff467b..1ab3a1256 100644 --- a/src/runtime/c/pgf/pgf.cxx +++ b/src/runtime/c/pgf/pgf.cxx @@ -37,8 +37,8 @@ pgf_exn_clear(PgfExn* err) } PGF_API -PgfDB *pgf_read_pgf(const char* fpath, - PgfRevision *revision, +PgfDB *pgf_read_pgf(const char* fpath, PgfRevision *revision, + PgfProbsCallback *probs_callback, PgfExn* err) { PgfDB *db = NULL; @@ -56,7 +56,7 @@ PgfDB *pgf_read_pgf(const char* fpath, db->start_transaction(); - PgfReader rdr(in); + PgfReader rdr(in,probs_callback); ref pgf = rdr.read_pgf(); *revision = db->register_revision(pgf.tagged(), PgfDB::get_txn_id()); @@ -79,6 +79,7 @@ PgfDB *pgf_read_pgf(const char* fpath, PGF_API PgfDB *pgf_boot_ngf(const char* pgf_path, const char* ngf_path, PgfRevision *revision, + PgfProbsCallback *probs_callback, PgfExn* err) { PgfDB *db = NULL; @@ -103,7 +104,7 @@ PgfDB *pgf_boot_ngf(const char* pgf_path, const char* ngf_path, db->start_transaction(); - PgfReader rdr(in); + PgfReader rdr(in,probs_callback); ref pgf = rdr.read_pgf(); *revision = db->register_revision(pgf.tagged(), PgfDB::get_txn_id()); @@ -220,7 +221,7 @@ void pgf_merge_pgf(PgfDB *db, PgfRevision revision, DB_scope scope(db, WRITER_SCOPE); ref pgf = db->revision2pgf(revision); - PgfReader rdr(in); + PgfReader rdr(in,NULL); rdr.merge_pgf(pgf); } } PGF_API_END diff --git a/src/runtime/c/pgf/pgf.h b/src/runtime/c/pgf/pgf.h index 81f8f9188..93f5c30ed 100644 --- a/src/runtime/c/pgf/pgf.h +++ b/src/runtime/c/pgf/pgf.h @@ -226,11 +226,17 @@ typedef struct PgfDB PgfDB; typedef object PgfRevision; typedef object PgfConcrRevision; +typedef struct PgfProbsCallback PgfProbsCallback; +struct PgfProbsCallback { + double (*fn)(PgfProbsCallback* self, PgfText *name); +}; + /* Reads a PGF file and builds the database in memory. * If successful, *revision will contain the initial revision of * the grammar. */ PGF_API_DECL PgfDB *pgf_read_pgf(const char* fpath, PgfRevision *revision, + PgfProbsCallback *probs_callback, PgfExn* err); /* Reads a PGF file and stores the unpacked data in an NGF file @@ -240,6 +246,7 @@ PgfDB *pgf_read_pgf(const char* fpath, PgfRevision *revision, PGF_API_DECL PgfDB *pgf_boot_ngf(const char* pgf_path, const char* ngf_path, PgfRevision *revision, + PgfProbsCallback *probs_callback, PgfExn* err); /* Tries to read the grammar from an already booted NGF file. diff --git a/src/runtime/c/pgf/reader.cxx b/src/runtime/c/pgf/reader.cxx index ef5843899..f875ba980 100644 --- a/src/runtime/c/pgf/reader.cxx +++ b/src/runtime/c/pgf/reader.cxx @@ -3,9 +3,10 @@ #include #include -PgfReader::PgfReader(FILE *in) +PgfReader::PgfReader(FILE *in,PgfProbsCallback *probs_callback) { this->in = in; + this->probs_callback = probs_callback; this->abstract = 0; this->concrete = 0; } @@ -71,6 +72,15 @@ double PgfReader::read_double() return sign ? copysign(ret, -1.0) : ret; } +prob_t PgfReader::read_prob(PgfText *name) +{ + double d = read_double(); + if (probs_callback != NULL) { + d = probs_callback->fn(probs_callback, name); + } + return - log(d); +} + uint64_t PgfReader::read_uint() { uint64_t u = 0; @@ -318,7 +328,7 @@ ref PgfReader::read_absfun() default: throw pgf_error("Unknown tag, 0 or 1 expected"); } - absfun->prob = - log(read_double()); + absfun->prob = read_prob(&absfun->name); return absfun; } @@ -326,10 +336,76 @@ ref PgfReader::read_abscat() { ref abscat = read_name(&PgfAbsCat::name); abscat->context = read_vector(&PgfReader::read_hypo); - abscat->prob = - log(read_double()); + abscat->prob = read_prob(&abscat->name); return abscat; } +struct PGF_INTERNAL_DECL PgfAbsCatCounts +{ + PgfText *name; + size_t n_nan_probs; + double probs_sum; + prob_t prob; +}; + +struct PGF_INTERNAL_DECL PgfProbItor : PgfItor +{ + Vector *cats; +}; + +static +PgfAbsCatCounts *find_counts(Vector *cats, PgfText *name) +{ + size_t i = 0; + size_t j = cats->len-1; + while (i <= j) { + size_t k = (i+j)/2; + PgfAbsCatCounts *counts = &cats->data[k]; + int cmp = textcmp(name, counts->name); + if (cmp < 0) { + j = k-1; + } else if (cmp > 0) { + i = k+1; + } else { + return counts; + } + } + + return NULL; +} + +static +void collect_counts(PgfItor *itor, PgfText *key, object value, PgfExn *err) +{ + PgfProbItor* prob_itor = (PgfProbItor*) itor; + ref absfun = value; + + PgfAbsCatCounts *counts = + find_counts(prob_itor->cats, &absfun->type->name); + if (counts != NULL) { + if (isnan(absfun->prob)) { + counts->n_nan_probs++; + } else { + counts->probs_sum += exp(-absfun->prob); + } + } +} + +static +void pad_probs(PgfItor *itor, PgfText *key, object value, PgfExn *err) +{ + PgfProbItor* prob_itor = (PgfProbItor*) itor; + ref absfun = value; + + if (isnan(absfun->prob)) { + PgfAbsCatCounts *counts = + find_counts(prob_itor->cats, &absfun->type->name); + if (counts != NULL) { + absfun->prob = counts->prob; + } + } +} + void PgfReader::read_abstract(ref abstract) { this->abstract = abstract; @@ -338,6 +414,27 @@ void PgfReader::read_abstract(ref abstract) abstract->aflags = read_namespace(&PgfReader::read_flag); abstract->funs = read_namespace(&PgfReader::read_absfun); abstract->cats = read_namespace(&PgfReader::read_abscat); + + if (probs_callback != NULL) { + PgfExn err; + err.type = PGF_EXN_NONE; + + PgfProbItor itor; + itor.cats = namespace_to_sorted_names(abstract->cats); + + itor.fn = collect_counts; + namespace_iter(abstract->funs, &itor, &err); + + for (size_t i = 0; i < itor.cats->len; i++) { + PgfAbsCatCounts *counts = &itor.cats->data[i]; + counts->prob = - log((1-counts->probs_sum) / counts->n_nan_probs); + } + + itor.fn = pad_probs; + namespace_iter(abstract->funs, &itor, &err); + + free(itor.cats); + } } void PgfReader::merge_abstract(ref abstract) diff --git a/src/runtime/c/pgf/reader.h b/src/runtime/c/pgf/reader.h index 6e1c5a554..609b9b8bc 100644 --- a/src/runtime/c/pgf/reader.h +++ b/src/runtime/c/pgf/reader.h @@ -8,12 +8,13 @@ class PGF_INTERNAL_DECL PgfReader { public: - PgfReader(FILE *in); + PgfReader(FILE *in,PgfProbsCallback *probs_callback); uint8_t read_uint8(); uint16_t read_u16be(); uint64_t read_u64be(); double read_double(); + prob_t read_prob(PgfText *name); uint64_t read_uint(); int64_t read_int() { return (int64_t) read_uint(); }; size_t read_len() { return (size_t) read_uint(); }; @@ -87,6 +88,7 @@ public: private: FILE *in; + PgfProbsCallback *probs_callback; ref abstract; ref concrete; diff --git a/src/runtime/haskell/PGF2.hsc b/src/runtime/haskell/PGF2.hsc index 6af03cefc..4bf67937e 100644 --- a/src/runtime/haskell/PGF2.hsc +++ b/src/runtime/haskell/PGF2.hsc @@ -15,6 +15,7 @@ module PGF2 (-- * PGF PGF,readPGF,bootNGF,readNGF,newNGF,writePGF,showPGF, + readPGFWithProbs, bootNGFWithProbs, -- * Abstract syntax AbsName,abstractName,globalFlag,abstractFlag, @@ -109,11 +110,15 @@ import Text.PrettyPrint -- | Reads a PGF file and keeps it in memory. readPGF :: FilePath -> IO PGF -readPGF fpath = +readPGF fpath = readPGFWithProbs fpath Nothing + +readPGFWithProbs :: FilePath -> Maybe (Map.Map String Double) -> IO PGF +readPGFWithProbs fpath mb_probs = withCString fpath $ \c_fpath -> alloca $ \p_revision -> + withProbsCallback mb_probs $ \c_pcallback -> mask_ $ do - c_db <- withPgfExn "readPGF" (pgf_read_pgf c_fpath p_revision) + c_db <- withPgfExn "readPGF" (pgf_read_pgf c_fpath p_revision c_pcallback) c_revision <- peek p_revision fptr <- newForeignPtrEnv pgf_free_revision c_db c_revision langs <- getConcretes c_db fptr @@ -124,17 +129,37 @@ readPGF fpath = -- The NGF file is platform dependent and should not be copied -- between machines. bootNGF :: FilePath -> FilePath -> IO PGF -bootNGF pgf_path ngf_path = +bootNGF pgf_path ngf_path = bootNGFWithProbs pgf_path Nothing ngf_path + +bootNGFWithProbs :: FilePath -> Maybe (Map.Map String Double) -> FilePath -> IO PGF +bootNGFWithProbs pgf_path mb_probs ngf_path = withCString pgf_path $ \c_pgf_path -> withCString ngf_path $ \c_ngf_path -> alloca $ \p_revision -> + withProbsCallback mb_probs $ \c_pcallback -> mask_ $ do - c_db <- withPgfExn "bootNGF" (pgf_boot_ngf c_pgf_path c_ngf_path p_revision) + c_db <- withPgfExn "bootNGF" (pgf_boot_ngf c_pgf_path c_ngf_path p_revision c_pcallback) c_revision <- peek p_revision fptr <- newForeignPtrEnv pgf_free_revision c_db c_revision langs <- getConcretes c_db fptr return (PGF c_db fptr langs) +withProbsCallback :: Maybe (Map.Map String Double) -> (Ptr PgfProbsCallback -> IO a) -> IO a +withProbsCallback Nothing f = f nullPtr +withProbsCallback (Just probs) f = + allocaBytes (#size PgfProbsCallback) $ \callback -> + bracket (wrapProbsCallback getProb) freeHaskellFunPtr $ \fptr -> do + (#poke PgfProbsCallback, fn) callback fptr + f callback + where + getProb _ c_name = do + name <- peekText c_name + case Map.lookup name probs of + Nothing -> return nan + Just p -> return p + + nan = log (-1) + -- | Reads the grammar from an already booted NGF file. -- The function fails if the file does not exist. readNGF :: FilePath -> IO PGF diff --git a/src/runtime/haskell/PGF2/FFI.hsc b/src/runtime/haskell/PGF2/FFI.hsc index 2d588b786..bfefaabd4 100644 --- a/src/runtime/haskell/PGF2/FFI.hsc +++ b/src/runtime/haskell/PGF2/FFI.hsc @@ -46,6 +46,7 @@ data PgfLinBuilderIface data PgfLinearizationOutputIface data PgfGraphvizOptions data PgfSequenceItor +data PgfProbsCallback data PgfMorphoCallback data PgfCohortsCallback data PgfPhrasetableIds @@ -60,10 +61,14 @@ foreign import ccall unsafe "pgf_utf8_encode" pgf_utf8_encode :: Word32 -> Ptr CString -> IO () foreign import ccall "pgf_read_pgf" - pgf_read_pgf :: CString -> Ptr (Ptr PGF) -> Ptr PgfExn -> IO (Ptr PgfDB) + pgf_read_pgf :: CString -> Ptr (Ptr PGF) -> Ptr PgfProbsCallback -> Ptr PgfExn -> IO (Ptr PgfDB) foreign import ccall "pgf_boot_ngf" - pgf_boot_ngf :: CString -> CString -> Ptr (Ptr PGF) -> Ptr PgfExn -> IO (Ptr PgfDB) + pgf_boot_ngf :: CString -> CString -> Ptr (Ptr PGF) -> Ptr PgfProbsCallback -> Ptr PgfExn -> IO (Ptr PgfDB) + +type ProbsCallback = Ptr PgfProbsCallback -> Ptr PgfText -> IO Double + +foreign import ccall "wrapper" wrapProbsCallback :: Wrapper ProbsCallback foreign import ccall "pgf_read_ngf" pgf_read_ngf :: CString -> Ptr (Ptr PGF) -> Ptr PgfExn -> IO (Ptr PgfDB) diff --git a/src/runtime/python/pypgf.c b/src/runtime/python/pypgf.c index 679bf29e9..fb592792c 100644 --- a/src/runtime/python/pypgf.c +++ b/src/runtime/python/pypgf.c @@ -515,7 +515,7 @@ pgf_readPGF(PyObject *self, PyObject *args) PGFObject *py_pgf = (PGFObject *)pgf_PGFType.tp_alloc(&pgf_PGFType, 0); PgfExn err; - py_pgf->db = pgf_read_pgf(fpath, &py_pgf->revision, &err); + py_pgf->db = pgf_read_pgf(fpath, &py_pgf->revision, NULL, &err); if (handleError(err) != PGF_EXN_NONE) { Py_DECREF(py_pgf); return NULL; @@ -535,7 +535,7 @@ pgf_bootNGF(PyObject *self, PyObject *args) PGFObject *py_pgf = (PGFObject *)pgf_PGFType.tp_alloc(&pgf_PGFType, 0); PgfExn err; - py_pgf->db = pgf_boot_ngf(fpath, npath, &py_pgf->revision, &err); + py_pgf->db = pgf_boot_ngf(fpath, npath, &py_pgf->revision, NULL, &err); if (handleError(err) != PGF_EXN_NONE) { Py_DECREF(py_pgf); return NULL;