From 6c3a4f5dcd5919eb004430aa5021a99e179917ed Mon Sep 17 00:00:00 2001 From: Krasimir Angelov Date: Thu, 16 Mar 2023 17:22:42 +0100 Subject: [PATCH] random generation always produces something if possible --- src/runtime/c/pgf/data.h | 1 + src/runtime/c/pgf/generator.cxx | 106 ++++++++++++++++++-------------- src/runtime/c/pgf/pgf.cxx | 40 +++++------- src/runtime/c/pgf/probspace.cxx | 51 +++++++-------- src/runtime/c/pgf/probspace.h | 3 +- src/runtime/haskell/PGF2.hsc | 32 +++++----- 6 files changed, 123 insertions(+), 110 deletions(-) diff --git a/src/runtime/c/pgf/data.h b/src/runtime/c/pgf/data.h index a16612345..a9818df0e 100644 --- a/src/runtime/c/pgf/data.h +++ b/src/runtime/c/pgf/data.h @@ -8,6 +8,7 @@ #include #include #include +#include #include "pgf.h" diff --git a/src/runtime/c/pgf/generator.cxx b/src/runtime/c/pgf/generator.cxx index f7e81ae9d..537534eee 100644 --- a/src/runtime/c/pgf/generator.cxx +++ b/src/runtime/c/pgf/generator.cxx @@ -173,42 +173,49 @@ again: { } PgfExpr expr = 0; - PgfExpr var_expr = 0; + std::set> excluded; - int index = 0; - Scope *sc = scope; - auto tmp = m; - while (sc != NULL) { - m = sc->m; - PgfVarGenerator v_gen(this, index, cat, n_exprs, exprs); - expr = m->match_type(&v_gen, sc->type); - if (expr != 0) { - if (rand() < VAR_PROB) { - prob += -log(VAR_PROB); - break; - } else { - prob += -log(1-VAR_PROB); - if (var_expr != 0) - u->free_ref(var_expr); - var_expr = expr; - expr = 0; + prob_t save_prob = prob; + prob_t total_prob = 1; + for (;;) { + prob = save_prob; + + int index = 0; + Scope *sc = scope; + auto tmp = m; + while (sc != NULL) { + m = sc->m; + PgfVarGenerator v_gen(this, index, cat, n_exprs, exprs); + expr = m->match_type(&v_gen, sc->type); + if (expr != 0) { + if (rand() < VAR_PROB) { + prob += -log(VAR_PROB); + break; + } else { + prob += -log(1-VAR_PROB); + u->free_ref(expr); + expr = 0; + } } + sc = sc->next; + index++; } - sc = sc->next; - index++; - } - m = tmp; + m = tmp; + + if (expr != 0) + break; - if (expr == 0) { if (strcmp(cat->text, "Int") == 0) { uintmax_t value = 999; PgfLiteral lint = u->lint(1,&value); expr = u->elit(lint); u->free_ref(lint); + break; } else if (strcmp(cat->text, "Float") == 0) { PgfLiteral lflt = u->lflt(3.14); expr = u->elit(lflt); u->free_ref(lflt); + break; } else if (strcmp(cat->text, "String") == 0) { PgfText *value = (PgfText *) alloca(sizeof(PgfText)+4); value->size = 3; @@ -217,25 +224,26 @@ again: { PgfLiteral lstr = u->lstr(value); expr = u->elit(lstr); u->free_ref(lstr); + break; } else { - prob_t rand_value = rand(); + prob_t rand_value = rand() * total_prob; - ref fun = probspace_random(pgf->abstract.funs_by_cat, cat, rand_value); + ref fun = + probspace_random(pgf->abstract.funs_by_cat, cat, rand_value, excluded); + if (fun == 0) + return 0; - if (!function_has_lins(&fun->name)) - fun = 0; - - if (fun == 0) { - if (var_expr != 0) { - prob += -log(VAR_PROB/(1-VAR_PROB)); - expr = var_expr; - } + if (!function_has_lins(&fun->name)) { + excluded.insert(fun); + total_prob -= exp(-fun->prob); + if (total_prob < 0) // possible because of rounding + total_prob = 0; } else { - if (depth > 0 || fun->type->hypos->len > 0) { + ref> hypos = fun->type->hypos; + if (depth > ((hypos->len > 0) ? 1 : 0)) { prob += fun->prob; expr = u->efun(&fun->name); - ref> hypos = fun->type->hypos; PgfTypeHypo *t_hypos = (PgfTypeHypo *) alloca(hypos->len * sizeof(PgfTypeHypo)); for (size_t i = 0; i < hypos->len; i++) { @@ -248,19 +256,25 @@ again: { expr = descend(expr, hypos->len, t_hypos); this->m = tmp; } + + if (expr != 0) + break; + + excluded.insert(fun); + total_prob -= exp(-fun->prob); + if (total_prob < 0) // possible because of rounding + total_prob = 0; } } } - if (expr != 0) { - while (scope != entry_scope) { - PgfExpr abs_expr = u->eabs(scope->bind_type, &scope->var, expr); - u->free_ref(expr); - expr = abs_expr; - Scope *next = scope->next; - free(scope); - scope = next; - } + while (scope != entry_scope) { + PgfExpr abs_expr = u->eabs(scope->bind_type, &scope->var, expr); + u->free_ref(expr); + expr = abs_expr; + Scope *next = scope->next; + free(scope); + scope = next; } return expr; @@ -269,9 +283,10 @@ again: { PgfExpr PgfRandomGenerator::descend(PgfExpr expr, size_t n_hypos, PgfTypeHypo *hypos) { - depth--; for (size_t i = 0; i < n_hypos; i++) { + depth--; PgfExpr arg = m->match_type(this, hypos[i].type); + depth++; if (arg == 0) { u->free_ref(expr); return 0; @@ -287,7 +302,6 @@ PgfExpr PgfRandomGenerator::descend(PgfExpr expr, u->free_ref(expr); expr = app; } - depth++; return expr; } diff --git a/src/runtime/c/pgf/pgf.cxx b/src/runtime/c/pgf/pgf.cxx index d4d8ff130..3234358af 100644 --- a/src/runtime/c/pgf/pgf.cxx +++ b/src/runtime/c/pgf/pgf.cxx @@ -1205,18 +1205,14 @@ PgfExpr pgf_generate_random(PgfDB *db, PgfRevision revision, ref pgf = db->revision2pgf(revision); - // Generation may fail for certain random choices, but succeed - // for others. We try 10 time to increase the chance of succeess. - for (size_t i = 0; i < 10; i++) { - PgfRandomGenerator gen(pgf, depth, seed, m, u); - for (size_t i = 0; i < n_concr_revisions; i++) { - gen.addConcr(db->revision2concr(concr_revisions[i])); - } - PgfExpr expr = m->match_type(&gen, type); - if (expr != 0) { - *prob = gen.getProb(); - return expr; - } + PgfRandomGenerator gen(pgf, depth, seed, m, u); + for (size_t i = 0; i < n_concr_revisions; i++) { + gen.addConcr(db->revision2concr(concr_revisions[i])); + } + PgfExpr expr = m->match_type(&gen, type); + if (expr != 0) { + *prob = gen.getProb(); + return expr; } } PGF_API_END @@ -1237,18 +1233,14 @@ PgfExpr pgf_generate_random_from ref pgf = db->revision2pgf(revision); - // Generation may fail for certain random choices, but succeed - // for others. We try 10 time to increase the chance of succeess. - for (size_t i = 0; i < 10; i++) { - PgfRandomGenerator gen(pgf, depth, seed, m, u); - for (size_t i = 0; i < n_concr_revisions; i++) { - gen.addConcr(db->revision2concr(concr_revisions[i])); - } - PgfExpr new_expr = m->match_expr(&gen, expr); - if (new_expr != 0) { - *prob = gen.getProb(); - return new_expr; - } + PgfRandomGenerator gen(pgf, depth, seed, m, u); + for (size_t i = 0; i < n_concr_revisions; i++) { + gen.addConcr(db->revision2concr(concr_revisions[i])); + } + PgfExpr new_expr = m->match_expr(&gen, expr); + if (new_expr != 0) { + *prob = gen.getProb(); + return new_expr; } } PGF_API_END diff --git a/src/runtime/c/pgf/probspace.cxx b/src/runtime/c/pgf/probspace.cxx index 4cbcd7403..848b7212b 100644 --- a/src/runtime/c/pgf/probspace.cxx +++ b/src/runtime/c/pgf/probspace.cxx @@ -204,50 +204,51 @@ void probspace_iter(PgfProbspace space, PgfText *cat, } } +struct PGF_INTERNAL RSState { + const std::set> &excluded; + prob_t rand; + ref result; +}; + static -ref probspace_random(PgfProbspace space, - PgfText *cat, prob_t *rand, - bool is_last) +bool probspace_random(PgfProbspace space, PgfText *cat, + RSState *st) { if (space == 0) - return 0; + return false; int cmp = textcmp(cat,&(*space->value.cat)); if (cmp < 0) { - return probspace_random(space->left, cat, rand, true); + return probspace_random(space->left, cat, st); } else if (cmp > 0) { - return probspace_random(space->right, cat, rand, true); + return probspace_random(space->right, cat, st); } else { - ref fun; - - fun = probspace_random(space->left, cat, rand, false); - if (fun != 0) - return fun; + if (probspace_random(space->left, cat, st)) + return true; bool is_res = space->value.is_result(); - if (is_res) { - *rand -= exp(-space->value.fun->prob); - if (*rand <= 0) - return space->value.fun; + if (is_res && !st->excluded.count(space->value.fun)) { + st->rand -= exp(-space->value.fun->prob); + st->result = space->value.fun; + if (st->rand <= 0) + return true; } - fun = probspace_random(space->right, cat, rand, is_last); - if (fun != 0) - return fun; - if (is_last && is_res) { - // necessary due to floating point rounding - return space->value.fun; - } + if (probspace_random(space->right, cat, st)) + return true; } - return 0; + return false; } PGF_INTERNAL ref probspace_random(PgfProbspace space, - PgfText *cat, prob_t rand) + PgfText *cat, prob_t rand, + const std::set> &excluded) { - return probspace_random(space,cat,&rand,true); + RSState st = {excluded, rand, 0}; + probspace_random(space,cat,&st); + return st.result; } PGF_INTERNAL diff --git a/src/runtime/c/pgf/probspace.h b/src/runtime/c/pgf/probspace.h index 04210bf80..714a00720 100644 --- a/src/runtime/c/pgf/probspace.h +++ b/src/runtime/c/pgf/probspace.h @@ -73,7 +73,8 @@ void probspace_iter(PgfProbspace space, PgfText *cat, * the given category */ PGF_INTERNAL_DECL ref probspace_random(PgfProbspace space, - PgfText *cat, prob_t rand); + PgfText *cat, prob_t rand, + const std::set> &excluded); PGF_INTERNAL_DECL void probspace_release(PgfProbspace space); diff --git a/src/runtime/haskell/PGF2.hsc b/src/runtime/haskell/PGF2.hsc index e22a8f5f5..2d55ea381 100644 --- a/src/runtime/haskell/PGF2.hsc +++ b/src/runtime/haskell/PGF2.hsc @@ -1005,14 +1005,16 @@ generateAllDepth :: PGF -> Type -> Int -> [(Expr,Float)] generateAllDepth p ty dp = generateAllExt p ty dp [] generateAllExt :: PGF -> Type -> Int -> [Concr] -> [(Expr,Float)] -generateAllExt p ty dp cs = - unsafePerformIO $ - bracket (newStablePtr ty) freeStablePtr $ \c_ty -> - withForeignPtr (a_revision p) $ \a_revision -> - withPgfConcrs cs $ \c_db c_revisions n_revisions -> - mask_ $ do - c_enum <- withPgfExn "generateAllExt" (pgf_generate_all (a_db p) a_revision c_revisions n_revisions c_ty (fromIntegral dp) marshaller unmarshaller) - enumerateExprs (a_db p) c_enum +generateAllExt p ty dp cs + | dp <= 0 = [] + | otherwise = + unsafePerformIO $ + bracket (newStablePtr ty) freeStablePtr $ \c_ty -> + withForeignPtr (a_revision p) $ \a_revision -> + withPgfConcrs cs $ \c_db c_revisions n_revisions -> + mask_ $ do + c_enum <- withPgfExn "generateAllExt" (pgf_generate_all (a_db p) a_revision c_revisions n_revisions c_ty (fromIntegral dp) marshaller unmarshaller) + enumerateExprs (a_db p) c_enum generateAllFrom :: PGF -> Expr -> [(Expr,Float)] generateAllFrom p ty = generateAllFromExt p ty maxBound [] @@ -1033,9 +1035,10 @@ generateRandomDepth :: RandomGen g => g -> PGF -> Type -> Int -> [(Expr,Float)] generateRandomDepth g p ty dp = generateRandomExt g p ty dp [] generateRandomExt :: RandomGen g => g -> PGF -> Type -> Int -> [Concr] -> [(Expr,Float)] -generateRandomExt g p ty dp cs = - let (seed,_) = random g - in generate seed +generateRandomExt g p ty dp cs + | dp <= 0 = [] + | otherwise = let (seed,_) = random g + in generate seed where generate seed = unsafePerformIO $ @@ -1062,9 +1065,10 @@ generateRandomFromDepth :: RandomGen g => g -> PGF -> Expr -> Int -> [(Expr,Floa generateRandomFromDepth g p e dp = generateRandomFromExt g p e dp [] generateRandomFromExt :: RandomGen g => g -> PGF -> Expr -> Int -> [Concr] -> [(Expr,Float)] -generateRandomFromExt g p e dp cs = - let (seed,_) = random g - in generate seed +generateRandomFromExt g p e dp cs + | dp <= 0 = [] + | otherwise = let (seed,_) = random g + in generate seed where generate seed = unsafePerformIO $