introduce a version of namespace_iter with a lambda function

This commit is contained in:
Krasimir Angelov
2023-12-28 10:50:08 +01:00
parent d78aea4170
commit 87b6094ade
3 changed files with 62 additions and 92 deletions

View File

@@ -631,6 +631,24 @@ void namespace_iter(Namespace<V> map, PgfItor* itor, PgfExn *err)
return; return;
} }
template <class V>
bool namespace_iter(Namespace<V> map, std::function<bool(ref<V>)> &f)
{
if (map == 0)
return true;
if (!namespace_iter(map->left, f))
return false;
if (!f(map->value))
return false;
if (!namespace_iter(map->right, f))
return false;
return true;
}
template <class V> template <class V>
void namespace_iter_prefix(Namespace<V> map, PgfText *prefix, PgfItor* itor, PgfExn *err) void namespace_iter_prefix(Namespace<V> map, PgfText *prefix, PgfItor* itor, PgfExn *err)
{ {

View File

@@ -463,23 +463,6 @@ void pgf_iter_categories(PgfDB *db, PgfRevision revision,
} PGF_API_END } PGF_API_END
} }
struct PgfItorConcrHelper : PgfItor
{
PgfDB *db;
txn_t txn_id;
PgfItor *itor;
};
static
void iter_concretes_helper(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
PgfItorConcrHelper* helper = (PgfItorConcrHelper*) itor;
ref<PgfConcr> concr = value;
object rev = helper->db->register_revision(concr.tagged(), helper->txn_id);
helper->db->ref_count++;
helper->itor->fn(helper->itor, key, rev, err);
}
PGF_API PGF_API
void pgf_iter_concretes(PgfDB *db, PgfRevision revision, void pgf_iter_concretes(PgfDB *db, PgfRevision revision,
PgfItor *itor, PgfExn *err) PgfItor *itor, PgfExn *err)
@@ -490,13 +473,14 @@ void pgf_iter_concretes(PgfDB *db, PgfRevision revision,
DB_scope scope(db, READER_SCOPE); DB_scope scope(db, READER_SCOPE);
ref<PgfPGF> pgf = db->revision2pgf(revision, &txn_id); ref<PgfPGF> pgf = db->revision2pgf(revision, &txn_id);
PgfItorConcrHelper helper; std::function<bool(ref<PgfConcr>)> f =
helper.fn = iter_concretes_helper; [txn_id,db,itor,err](ref<PgfConcr> concr) {
helper.db = db; object rev = db->register_revision(concr.tagged(), txn_id);
helper.txn_id = txn_id; db->ref_count++;
helper.itor = itor; itor->fn(itor, &concr->name, rev, err);
return (err->type == PGF_EXN_NONE);
namespace_iter(pgf->concretes, &helper, err); };
namespace_iter(pgf->concretes, f);
} PGF_API_END } PGF_API_END
} }
@@ -1609,30 +1593,19 @@ void pgf_create_category(PgfDB *db, PgfRevision revision,
struct PGF_INTERNAL_DECL PgfDropItor : PgfItor struct PGF_INTERNAL_DECL PgfDropItor : PgfItor
{ {
ref<PgfPGF> pgf; ref<PgfPGF> pgf;
ref<PgfConcr> concrete;
PgfText *name;
}; };
static
void iter_drop_cat_helper2(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
ref<PgfConcr> concr = value;
PgfText* name = ((PgfDropItor*) itor)->name;
drop_lin(concr, name);
}
static static
void iter_drop_cat_helper(PgfItor *itor, PgfText *key, object value, PgfExn *err) void iter_drop_cat_helper(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{ {
ref<PgfPGF> pgf = ((PgfDropItor*) itor)->pgf; ref<PgfPGF> pgf = ((PgfDropItor*) itor)->pgf;
PgfDropItor itor2; std::function<bool(ref<PgfConcr>)> f =
itor2.fn = iter_drop_cat_helper2; [key,err](ref<PgfConcr> concr) {
itor2.pgf = 0; drop_lin(concr, key);
itor2.concrete = 0; return (err->type == PGF_EXN_NONE);
itor2.name = key; };
namespace_iter(pgf->concretes, &itor2, err); namespace_iter(pgf->concretes, f);
ref<PgfAbsFun> fun; ref<PgfAbsFun> fun;
Namespace<PgfAbsFun> funs = Namespace<PgfAbsFun> funs =
@@ -1672,8 +1645,6 @@ void pgf_drop_category(PgfDB *db, PgfRevision revision,
PgfDropItor itor; PgfDropItor itor;
itor.fn = iter_drop_cat_helper; itor.fn = iter_drop_cat_helper;
itor.pgf = pgf; itor.pgf = pgf;
itor.concrete = 0;
itor.name = name;
PgfProbspace funs_by_cat = PgfProbspace funs_by_cat =
probspace_delete_by_cat(pgf->abstract.funs_by_cat, &cat->name, probspace_delete_by_cat(pgf->abstract.funs_by_cat, &cat->name,
&itor, err); &itor, err);

View File

@@ -373,11 +373,6 @@ struct PGF_INTERNAL_DECL PgfAbsCatCounts
prob_t prob; prob_t prob;
}; };
struct PGF_INTERNAL_DECL PgfProbItor : PgfItor
{
Vector<PgfAbsCatCounts> *cats;
};
static static
PgfAbsCatCounts *find_counts(Vector<PgfAbsCatCounts> *cats, PgfText *name) PgfAbsCatCounts *find_counts(Vector<PgfAbsCatCounts> *cats, PgfText *name)
{ {
@@ -399,38 +394,6 @@ PgfAbsCatCounts *find_counts(Vector<PgfAbsCatCounts> *cats, PgfText *name)
return NULL; return NULL;
} }
static
void collect_counts(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
PgfProbItor* prob_itor = (PgfProbItor*) itor;
ref<PgfAbsFun> absfun = value;
PgfAbsCatCounts *counts =
find_counts(prob_itor->cats, &absfun->type->name);
if (counts != NULL) {
if (isnan(absfun->prob)) {
counts->n_nan_probs++;
} else {
counts->probs_sum += exp(-absfun->prob);
}
}
}
static
void pad_probs(PgfItor *itor, PgfText *key, object value, PgfExn *err)
{
PgfProbItor* prob_itor = (PgfProbItor*) itor;
ref<PgfAbsFun> absfun = value;
if (isnan(absfun->prob)) {
PgfAbsCatCounts *counts =
find_counts(prob_itor->cats, &absfun->type->name);
if (counts != NULL) {
absfun->prob = counts->prob;
}
}
}
void PgfReader::read_abstract(ref<PgfAbstr> abstract) void PgfReader::read_abstract(ref<PgfAbstr> abstract)
{ {
this->abstract = abstract; this->abstract = abstract;
@@ -447,24 +410,42 @@ void PgfReader::read_abstract(ref<PgfAbstr> abstract)
abstract->cats = cats; abstract->cats = cats;
if (probs_callback != NULL) { if (probs_callback != NULL) {
PgfExn err; Vector<PgfAbsCatCounts> *cats = namespace_to_sorted_names<PgfAbsCat,PgfAbsCatCounts>(abstract->cats);
err.type = PGF_EXN_NONE;
PgfProbItor itor; std::function<bool(ref<PgfAbsFun>)> collect_counts =
itor.cats = namespace_to_sorted_names<PgfAbsCat,PgfAbsCatCounts>(abstract->cats); [cats](ref<PgfAbsFun> absfun) {
PgfAbsCatCounts *counts =
find_counts(cats, &absfun->type->name);
if (counts != NULL) {
if (isnan(absfun->prob)) {
counts->n_nan_probs++;
} else {
counts->probs_sum += exp(-absfun->prob);
}
}
return true;
};
namespace_iter(abstract->funs, collect_counts);
itor.fn = collect_counts; for (size_t i = 0; i < cats->len; i++) {
namespace_iter(abstract->funs, &itor, &err); PgfAbsCatCounts *counts = &cats->data[i];
for (size_t i = 0; i < itor.cats->len; i++) {
PgfAbsCatCounts *counts = &itor.cats->data[i];
counts->prob = - logf((1-counts->probs_sum) / counts->n_nan_probs); counts->prob = - logf((1-counts->probs_sum) / counts->n_nan_probs);
} }
itor.fn = pad_probs; std::function<bool(ref<PgfAbsFun>)> pad_probs =
namespace_iter(abstract->funs, &itor, &err); [cats](ref<PgfAbsFun> absfun) {
if (isnan(absfun->prob)) {
PgfAbsCatCounts *counts =
find_counts(cats, &absfun->type->name);
if (counts != NULL) {
absfun->prob = counts->prob;
}
}
return true;
};
namespace_iter(abstract->funs, pad_probs);
free(itor.cats); free(cats);
} }
} }