mirror of
https://github.com/GrammaticalFramework/gf-core.git
synced 2026-04-22 11:19:32 -06:00
make it possible to replace the probabilities while reading a new .pgf
This commit is contained in:
@@ -407,6 +407,33 @@ void namespace_iter(Namespace<V> map, PgfItor* itor, PgfExn *err)
|
||||
return;
|
||||
}
|
||||
|
||||
template <class V,class A>
|
||||
void namespace_vec_fill_names(Namespace<V> node, size_t offs, Vector<A> *vec)
|
||||
{
|
||||
if (node == 0)
|
||||
return;
|
||||
|
||||
namespace_vec_fill_names(node->left, offs, vec);
|
||||
|
||||
offs += namespace_size(node->left);
|
||||
vector_elem(vec, offs++)->name = &node->value->name;
|
||||
|
||||
namespace_vec_fill_names(node->right, offs, vec);
|
||||
}
|
||||
|
||||
template <class V,class A>
|
||||
Vector<A> *namespace_to_sorted_names(Namespace<V> node)
|
||||
{
|
||||
Vector<A> *vec = (Vector<A> *)
|
||||
malloc(sizeof(Vector<A>)+node->sz*sizeof(A));
|
||||
if (errno != 0)
|
||||
throw pgf_systemerror(errno);
|
||||
vec->len = node->sz;
|
||||
memset(vec->data, 0, node->sz*sizeof(A));
|
||||
namespace_vec_fill_names(node, 0, vec);
|
||||
return vec;
|
||||
}
|
||||
|
||||
template <class V>
|
||||
void namespace_release(Namespace<V> node)
|
||||
{
|
||||
|
||||
@@ -37,8 +37,8 @@ pgf_exn_clear(PgfExn* err)
|
||||
}
|
||||
|
||||
PGF_API
|
||||
PgfDB *pgf_read_pgf(const char* fpath,
|
||||
PgfRevision *revision,
|
||||
PgfDB *pgf_read_pgf(const char* fpath, PgfRevision *revision,
|
||||
PgfProbsCallback *probs_callback,
|
||||
PgfExn* err)
|
||||
{
|
||||
PgfDB *db = NULL;
|
||||
@@ -56,7 +56,7 @@ PgfDB *pgf_read_pgf(const char* fpath,
|
||||
|
||||
db->start_transaction();
|
||||
|
||||
PgfReader rdr(in);
|
||||
PgfReader rdr(in,probs_callback);
|
||||
ref<PgfPGF> pgf = rdr.read_pgf();
|
||||
|
||||
*revision = db->register_revision(pgf.tagged(), PgfDB::get_txn_id());
|
||||
@@ -79,6 +79,7 @@ PgfDB *pgf_read_pgf(const char* fpath,
|
||||
PGF_API
|
||||
PgfDB *pgf_boot_ngf(const char* pgf_path, const char* ngf_path,
|
||||
PgfRevision *revision,
|
||||
PgfProbsCallback *probs_callback,
|
||||
PgfExn* err)
|
||||
{
|
||||
PgfDB *db = NULL;
|
||||
@@ -103,7 +104,7 @@ PgfDB *pgf_boot_ngf(const char* pgf_path, const char* ngf_path,
|
||||
|
||||
db->start_transaction();
|
||||
|
||||
PgfReader rdr(in);
|
||||
PgfReader rdr(in,probs_callback);
|
||||
ref<PgfPGF> pgf = rdr.read_pgf();
|
||||
|
||||
*revision = db->register_revision(pgf.tagged(), PgfDB::get_txn_id());
|
||||
@@ -220,7 +221,7 @@ void pgf_merge_pgf(PgfDB *db, PgfRevision revision,
|
||||
DB_scope scope(db, WRITER_SCOPE);
|
||||
ref<PgfPGF> pgf = db->revision2pgf(revision);
|
||||
|
||||
PgfReader rdr(in);
|
||||
PgfReader rdr(in,NULL);
|
||||
rdr.merge_pgf(pgf);
|
||||
}
|
||||
} PGF_API_END
|
||||
|
||||
@@ -226,11 +226,17 @@ typedef struct PgfDB PgfDB;
|
||||
typedef object PgfRevision;
|
||||
typedef object PgfConcrRevision;
|
||||
|
||||
typedef struct PgfProbsCallback PgfProbsCallback;
|
||||
struct PgfProbsCallback {
|
||||
double (*fn)(PgfProbsCallback* self, PgfText *name);
|
||||
};
|
||||
|
||||
/* Reads a PGF file and builds the database in memory.
|
||||
* If successful, *revision will contain the initial revision of
|
||||
* the grammar. */
|
||||
PGF_API_DECL
|
||||
PgfDB *pgf_read_pgf(const char* fpath, PgfRevision *revision,
|
||||
PgfProbsCallback *probs_callback,
|
||||
PgfExn* err);
|
||||
|
||||
/* Reads a PGF file and stores the unpacked data in an NGF file
|
||||
@@ -240,6 +246,7 @@ PgfDB *pgf_read_pgf(const char* fpath, PgfRevision *revision,
|
||||
PGF_API_DECL
|
||||
PgfDB *pgf_boot_ngf(const char* pgf_path, const char* ngf_path,
|
||||
PgfRevision *revision,
|
||||
PgfProbsCallback *probs_callback,
|
||||
PgfExn* err);
|
||||
|
||||
/* Tries to read the grammar from an already booted NGF file.
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
#include <math.h>
|
||||
#include <string.h>
|
||||
|
||||
PgfReader::PgfReader(FILE *in)
|
||||
PgfReader::PgfReader(FILE *in,PgfProbsCallback *probs_callback)
|
||||
{
|
||||
this->in = in;
|
||||
this->probs_callback = probs_callback;
|
||||
this->abstract = 0;
|
||||
this->concrete = 0;
|
||||
}
|
||||
@@ -71,6 +72,15 @@ double PgfReader::read_double()
|
||||
return sign ? copysign(ret, -1.0) : ret;
|
||||
}
|
||||
|
||||
prob_t PgfReader::read_prob(PgfText *name)
|
||||
{
|
||||
double d = read_double();
|
||||
if (probs_callback != NULL) {
|
||||
d = probs_callback->fn(probs_callback, name);
|
||||
}
|
||||
return - log(d);
|
||||
}
|
||||
|
||||
uint64_t PgfReader::read_uint()
|
||||
{
|
||||
uint64_t u = 0;
|
||||
@@ -318,7 +328,7 @@ ref<PgfAbsFun> PgfReader::read_absfun()
|
||||
default:
|
||||
throw pgf_error("Unknown tag, 0 or 1 expected");
|
||||
}
|
||||
absfun->prob = - log(read_double());
|
||||
absfun->prob = read_prob(&absfun->name);
|
||||
return absfun;
|
||||
}
|
||||
|
||||
@@ -326,10 +336,76 @@ ref<PgfAbsCat> PgfReader::read_abscat()
|
||||
{
|
||||
ref<PgfAbsCat> abscat = read_name<PgfAbsCat>(&PgfAbsCat::name);
|
||||
abscat->context = read_vector<PgfHypo>(&PgfReader::read_hypo);
|
||||
abscat->prob = - log(read_double());
|
||||
abscat->prob = read_prob(&abscat->name);
|
||||
return abscat;
|
||||
}
|
||||
|
||||
struct PGF_INTERNAL_DECL PgfAbsCatCounts
|
||||
{
|
||||
PgfText *name;
|
||||
size_t n_nan_probs;
|
||||
double probs_sum;
|
||||
prob_t prob;
|
||||
};
|
||||
|
||||
struct PGF_INTERNAL_DECL PgfProbItor : PgfItor
|
||||
{
|
||||
Vector<PgfAbsCatCounts> *cats;
|
||||
};
|
||||
|
||||
static
|
||||
PgfAbsCatCounts *find_counts(Vector<PgfAbsCatCounts> *cats, PgfText *name)
|
||||
{
|
||||
size_t i = 0;
|
||||
size_t j = cats->len-1;
|
||||
while (i <= j) {
|
||||
size_t k = (i+j)/2;
|
||||
PgfAbsCatCounts *counts = &cats->data[k];
|
||||
int cmp = textcmp(name, counts->name);
|
||||
if (cmp < 0) {
|
||||
j = k-1;
|
||||
} else if (cmp > 0) {
|
||||
i = k+1;
|
||||
} else {
|
||||
return counts;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static
|
||||
void collect_counts(PgfItor *itor, PgfText *key, object value, PgfExn *err)
|
||||
{
|
||||
PgfProbItor* prob_itor = (PgfProbItor*) itor;
|
||||
ref<PgfAbsFun> absfun = value;
|
||||
|
||||
PgfAbsCatCounts *counts =
|
||||
find_counts(prob_itor->cats, &absfun->type->name);
|
||||
if (counts != NULL) {
|
||||
if (isnan(absfun->prob)) {
|
||||
counts->n_nan_probs++;
|
||||
} else {
|
||||
counts->probs_sum += exp(-absfun->prob);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static
|
||||
void pad_probs(PgfItor *itor, PgfText *key, object value, PgfExn *err)
|
||||
{
|
||||
PgfProbItor* prob_itor = (PgfProbItor*) itor;
|
||||
ref<PgfAbsFun> absfun = value;
|
||||
|
||||
if (isnan(absfun->prob)) {
|
||||
PgfAbsCatCounts *counts =
|
||||
find_counts(prob_itor->cats, &absfun->type->name);
|
||||
if (counts != NULL) {
|
||||
absfun->prob = counts->prob;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PgfReader::read_abstract(ref<PgfAbstr> abstract)
|
||||
{
|
||||
this->abstract = abstract;
|
||||
@@ -338,6 +414,27 @@ void PgfReader::read_abstract(ref<PgfAbstr> abstract)
|
||||
abstract->aflags = read_namespace<PgfFlag>(&PgfReader::read_flag);
|
||||
abstract->funs = read_namespace<PgfAbsFun>(&PgfReader::read_absfun);
|
||||
abstract->cats = read_namespace<PgfAbsCat>(&PgfReader::read_abscat);
|
||||
|
||||
if (probs_callback != NULL) {
|
||||
PgfExn err;
|
||||
err.type = PGF_EXN_NONE;
|
||||
|
||||
PgfProbItor itor;
|
||||
itor.cats = namespace_to_sorted_names<PgfAbsCat,PgfAbsCatCounts>(abstract->cats);
|
||||
|
||||
itor.fn = collect_counts;
|
||||
namespace_iter(abstract->funs, &itor, &err);
|
||||
|
||||
for (size_t i = 0; i < itor.cats->len; i++) {
|
||||
PgfAbsCatCounts *counts = &itor.cats->data[i];
|
||||
counts->prob = - log((1-counts->probs_sum) / counts->n_nan_probs);
|
||||
}
|
||||
|
||||
itor.fn = pad_probs;
|
||||
namespace_iter(abstract->funs, &itor, &err);
|
||||
|
||||
free(itor.cats);
|
||||
}
|
||||
}
|
||||
|
||||
void PgfReader::merge_abstract(ref<PgfAbstr> abstract)
|
||||
|
||||
@@ -8,12 +8,13 @@
|
||||
class PGF_INTERNAL_DECL PgfReader
|
||||
{
|
||||
public:
|
||||
PgfReader(FILE *in);
|
||||
PgfReader(FILE *in,PgfProbsCallback *probs_callback);
|
||||
|
||||
uint8_t read_uint8();
|
||||
uint16_t read_u16be();
|
||||
uint64_t read_u64be();
|
||||
double read_double();
|
||||
prob_t read_prob(PgfText *name);
|
||||
uint64_t read_uint();
|
||||
int64_t read_int() { return (int64_t) read_uint(); };
|
||||
size_t read_len() { return (size_t) read_uint(); };
|
||||
@@ -87,6 +88,7 @@ public:
|
||||
|
||||
private:
|
||||
FILE *in;
|
||||
PgfProbsCallback *probs_callback;
|
||||
ref<PgfAbstr> abstract;
|
||||
ref<PgfConcr> concrete;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user