diff --git a/src/runtime/c/pgf/db.cxx b/src/runtime/c/pgf/db.cxx index bb3b189a4..6a26b602a 100644 --- a/src/runtime/c/pgf/db.cxx +++ b/src/runtime/c/pgf/db.cxx @@ -988,7 +988,7 @@ PGF_INTERNAL void PgfDB::link_transient_revision(ref pgf) { pgf->next = current_db->ms->transient_revisions; - if (current_db->ms->transient_revisions == 0) + if (current_db->ms->transient_revisions != 0) current_db->ms->transient_revisions->prev = pgf; current_db->ms->transient_revisions = pgf; } diff --git a/src/runtime/c/pgf/expr.cxx b/src/runtime/c/pgf/expr.cxx index 79244642b..d94153d73 100644 --- a/src/runtime/c/pgf/expr.cxx +++ b/src/runtime/c/pgf/expr.cxx @@ -194,9 +194,9 @@ PgfLiteral PgfDBUnmarshaller::lstr(PgfText *val) return ref::tagged(lit_str); } -PgfType PgfDBUnmarshaller::dtyp(int n_hypos, PgfTypeHypo *hypos, +PgfType PgfDBUnmarshaller::dtyp(size_t n_hypos, PgfTypeHypo *hypos, PgfText *cat, - int n_exprs, PgfExpr *exprs) + size_t n_exprs, PgfExpr *exprs) { ref ty = PgfDB::malloc(cat->size+1); @@ -205,8 +205,7 @@ PgfType PgfDBUnmarshaller::dtyp(int n_hypos, PgfTypeHypo *hypos, for (size_t i = 0; i < n_hypos; i++) { ref hypo = vector_elem(ty->hypos,i); hypo->bind_type = hypos[i].bind_type; - hypo->cid = PgfDB::malloc(hypos[i].cid->size+1); - memcpy(hypo->cid, hypos[i].cid, sizeof(PgfText)+hypos[i].cid->size+1); + hypo->cid = textdup_db(hypos[i].cid); hypo->type = m->match_type(this, hypos[i].type); } ty->exprs = vector_new(n_exprs); diff --git a/src/runtime/c/pgf/expr.h b/src/runtime/c/pgf/expr.h index ca035f5a1..1ac3c80ac 100644 --- a/src/runtime/c/pgf/expr.h +++ b/src/runtime/c/pgf/expr.h @@ -116,9 +116,9 @@ struct PGF_INTERNAL_DECL PgfDBUnmarshaller : public PgfUnmarshaller { virtual PgfLiteral lint(size_t size, uintmax_t *val); virtual PgfLiteral lflt(double val); virtual PgfLiteral lstr(PgfText *val); - virtual PgfType dtyp(int n_hypos, PgfTypeHypo *hypos, + virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos, PgfText *cat, - int n_exprs, PgfExpr *exprs); + size_t n_exprs, PgfExpr *exprs); virtual void free_ref(object x); }; diff --git a/src/runtime/c/pgf/pgf.cxx b/src/runtime/c/pgf/pgf.cxx index ece0580e8..b6737c4b4 100644 --- a/src/runtime/c/pgf/pgf.cxx +++ b/src/runtime/c/pgf/pgf.cxx @@ -453,9 +453,7 @@ PgfRevision pgf_clone_revision(PgfDB *db, PgfRevision revision, if (pgf->gflags != 0) pgf->gflags->ref_count++; - new_pgf->abstract.name = - PgfDB::malloc(pgf->abstract.name->size+1); - memcpy(new_pgf->abstract.name, pgf->abstract.name, sizeof(PgfText)+pgf->abstract.name->size+1); + new_pgf->abstract.name = textdup_db(&(*pgf->abstract.name)); new_pgf->abstract.aflags = pgf->abstract.aflags; if (pgf->abstract.aflags != 0) @@ -495,7 +493,8 @@ void pgf_commit_revision(PgfDB *db, PgfRevision revision, PgfDB::unlink_transient_revision(new_pgf); PgfDB::set_revision(new_pgf); - PgfDB::link_transient_revision(old_pgf); + if (old_pgf != 0) + PgfDB::link_transient_revision(old_pgf); } PGF_API_END } @@ -557,3 +556,51 @@ void pgf_drop_function(PgfDB *db, PgfRevision revision, pgf->abstract.funs = funs; } PGF_API_END } + +PGF_API +void pgf_create_category(PgfDB *db, PgfRevision revision, + PgfText *name, + size_t n_hypos, PgfTypeHypo *context, prob_t prob, + PgfMarshaller *m, + PgfExn *err) +{ + PGF_API_BEGIN { + DB_scope scope(db, WRITER_SCOPE); + + PgfDBUnmarshaller u(m); + + ref pgf = PgfDB::revision2pgf(revision); + ref abscat = PgfDB::malloc(name->size+1); + abscat->context = vector_new(n_hypos); + abscat->prob = prob; + memcpy(&abscat->name, name, sizeof(PgfText)+name->size+1); + + for (size_t i = 0; i < n_hypos; i++) { + vector_elem(abscat->context, i)->bind_type = context[i].bind_type; + vector_elem(abscat->context, i)->cid = textdup_db(context[i].cid); + vector_elem(abscat->context, i)->type = m->match_type(&u, context[i].type); + } + + Namespace cats = + namespace_insert(pgf->abstract.cats, abscat); + namespace_release(pgf->abstract.cats); + pgf->abstract.cats = cats; + } PGF_API_END +} + +PGF_API +void pgf_drop_category(PgfDB *db, PgfRevision revision, + PgfText *name, + PgfExn *err) +{ + PGF_API_BEGIN { + DB_scope scope(db, WRITER_SCOPE); + + ref pgf = PgfDB::revision2pgf(revision); + + Namespace cats = + namespace_delete(pgf->abstract.cats, name); + namespace_release(pgf->abstract.cats); + pgf->abstract.cats = cats; + } PGF_API_END +} diff --git a/src/runtime/c/pgf/pgf.h b/src/runtime/c/pgf/pgf.h index fc90f1429..fbb26d187 100644 --- a/src/runtime/c/pgf/pgf.h +++ b/src/runtime/c/pgf/pgf.h @@ -164,9 +164,9 @@ struct PgfUnmarshaller { virtual PgfLiteral lint(size_t size, uintmax_t *v)=0; virtual PgfLiteral lflt(double v)=0; virtual PgfLiteral lstr(PgfText *v)=0; - virtual PgfType dtyp(int n_hypos, PgfTypeHypo *hypos, + virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos, PgfText *cat, - int n_exprs, PgfExpr *exprs)=0; + size_t n_exprs, PgfExpr *exprs)=0; virtual void free_ref(object x)=0; }; @@ -203,9 +203,9 @@ struct PgfUnmarshaller { typedef struct PgfMarshaller PgfMarshaller; typedef struct PgfMarshallerVtbl PgfMarshallerVtbl; struct PgfMarshallerVtbl { - object (*match_lit)(PgfUnmarshaller *u, PgfLiteral lit); - object (*match_expr)(PgfUnmarshaller *u, PgfExpr expr); - object (*match_type)(PgfUnmarshaller *u, PgfType ty); + object (*match_lit)(PgfMarshaller *this, PgfUnmarshaller *u, PgfLiteral lit); + object (*match_expr)(PgfMarshaller *this, PgfUnmarshaller *u, PgfExpr expr); + object (*match_type)(PgfMarshaller *this, PgfUnmarshaller *u, PgfType ty); }; struct PgfMarshaller { PgfMarshallerVtbl *vtbl; @@ -347,4 +347,16 @@ void pgf_drop_function(PgfDB *db, PgfRevision revision, PgfText *name, PgfExn *err); +PGF_API_DECL +void pgf_create_category(PgfDB *db, PgfRevision revision, + PgfText *name, + size_t n_hypos, PgfTypeHypo *context, prob_t prob, + PgfMarshaller *m, + PgfExn *err); + +PGF_API_DECL +void pgf_drop_category(PgfDB *db, PgfRevision revision, + PgfText *name, + PgfExn *err); + #endif // PGF_H_ diff --git a/src/runtime/c/pgf/printer.cxx b/src/runtime/c/pgf/printer.cxx index 4cb3b6738..633c4db14 100644 --- a/src/runtime/c/pgf/printer.cxx +++ b/src/runtime/c/pgf/printer.cxx @@ -380,9 +380,9 @@ PgfLiteral PgfPrinter::lstr(PgfText *v) return 0; } -PgfType PgfPrinter::dtyp(int n_hypos, PgfTypeHypo *hypos, +PgfType PgfPrinter::dtyp(size_t n_hypos, PgfTypeHypo *hypos, PgfText *cat, - int n_exprs, PgfExpr *exprs) + size_t n_exprs, PgfExpr *exprs) { bool p = (prio > 0 && n_hypos > 0) || (prio > 3 && n_exprs > 0); diff --git a/src/runtime/c/pgf/printer.h b/src/runtime/c/pgf/printer.h index 7a7e67987..6bf39537b 100644 --- a/src/runtime/c/pgf/printer.h +++ b/src/runtime/c/pgf/printer.h @@ -61,9 +61,9 @@ public: virtual PgfLiteral lint(size_t size, uintmax_t *v); virtual PgfLiteral lflt(double v); virtual PgfLiteral lstr(PgfText *v); - virtual PgfType dtyp(int n_hypos, PgfTypeHypo *hypos, + virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos, PgfText *cat, - int n_exprs, PgfExpr *exprs); + size_t n_exprs, PgfExpr *exprs); virtual void free_ref(object x); }; diff --git a/src/runtime/c/pgf/text.cxx b/src/runtime/c/pgf/text.cxx index 9b4178582..e0378d728 100644 --- a/src/runtime/c/pgf/text.cxx +++ b/src/runtime/c/pgf/text.cxx @@ -26,6 +26,14 @@ PgfText* textdup(PgfText *t1) return t2; } +PGF_INTERNAL +ref textdup_db(PgfText *t1) +{ + ref t2 = PgfDB::malloc(t1->size+1); + memcpy(&(*t2), t1, sizeof(PgfText)+t1->size+1); + return t2; +} + PGF_API uint32_t pgf_utf8_decode(const uint8_t** src_inout) { diff --git a/src/runtime/c/pgf/text.h b/src/runtime/c/pgf/text.h index 67005d363..004b8313b 100644 --- a/src/runtime/c/pgf/text.h +++ b/src/runtime/c/pgf/text.h @@ -7,6 +7,9 @@ int textcmp(PgfText *t1, PgfText *t2); PGF_INTERNAL_DECL PgfText* textdup(PgfText *t1); +PGF_INTERNAL_DECL +ref textdup_db(PgfText *t1); + PGF_API uint32_t pgf_utf8_decode(const uint8_t** src_inout); diff --git a/src/runtime/haskell/PGF2/FFI.hsc b/src/runtime/haskell/PGF2/FFI.hsc index 1aa8d6e0e..abf5052a6 100644 --- a/src/runtime/haskell/PGF2/FFI.hsc +++ b/src/runtime/haskell/PGF2/FFI.hsc @@ -118,6 +118,10 @@ foreign import ccall pgf_create_function :: Ptr PgfDB -> Ptr PgfRevision -> Ptr foreign import ccall pgf_drop_function :: Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfText -> Ptr PgfExn -> IO () +foreign import ccall pgf_create_category :: Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfText -> CSize -> Ptr PgfTypeHypo -> (#type prob_t) -> Ptr PgfMarshaller -> Ptr PgfExn -> IO () + +foreign import ccall pgf_drop_category :: Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfText -> Ptr PgfExn -> IO () + ----------------------------------------------------------------------- -- Texts @@ -310,7 +314,7 @@ foreign import ccall "dynamic" foreign import ccall "wrapper" wrapLStrFun :: LStrFun -> IO (FunPtr LStrFun) -type DTypFun = Ptr PgfUnmarshaller -> CInt -> Ptr PgfTypeHypo -> Ptr PgfText -> CInt -> Ptr (StablePtr Expr) -> IO (StablePtr Type) +type DTypFun = Ptr PgfUnmarshaller -> CSize -> Ptr PgfTypeHypo -> Ptr PgfText -> CSize -> Ptr (StablePtr Expr) -> IO (StablePtr Type) foreign import ccall "dynamic" callDTypFun :: FunPtr DTypFun -> DTypFun @@ -397,21 +401,19 @@ marshaller = unsafePerformIO $ do ty <- deRefStablePtr c_ty case ty of DTyp hypos cat es -> let n_hypos = length hypos - in allocaBytes (n_hypos * (#size PgfTypeHypo)) $ \c_hypos -> + in withHypos hypos $ \n_hypos c_hypos -> withText cat $ \c_cat -> mask_ $ do - marshalHypos c_hypos hypos c_es <- mapM newStablePtr es res <- withArray c_es $ \c_exprs -> do fun <- (#peek PgfUnmarshallerVtbl, dtyp) vtbl callDTypFun fun u - (fromIntegral n_hypos) + n_hypos c_hypos c_cat (fromIntegral (length es)) c_exprs mapM_ freeStablePtr c_es - freeHypos c_hypos n_hypos return res where marshalHypos _ [] = return () @@ -533,3 +535,26 @@ marshalBindType Implicit = (#const PGF_BIND_TYPE_IMPLICIT) unmarshalBindType :: (#type PgfBindType) -> BindType unmarshalBindType (#const PGF_BIND_TYPE_EXPLICIT) = Explicit unmarshalBindType (#const PGF_BIND_TYPE_IMPLICIT) = Implicit + +withHypos hypos f = + let n_hypos = length hypos + in allocaBytes (n_hypos * (#size PgfTypeHypo)) $ \c_hypos -> + mask_ $ do + marshalHypos c_hypos hypos + res <- f (fromIntegral n_hypos :: CSize) c_hypos + freeHypos n_hypos c_hypos + return res + where + marshalHypos _ [] = return () + marshalHypos ptr ((bt,var,ty):hs) = do + (#poke PgfTypeHypo, bind_type) ptr (marshalBindType bt) + newText var >>= (#poke PgfTypeHypo, cid) ptr + newStablePtr ty >>= (#poke PgfTypeHypo, type) ptr + marshalHypos (ptr `plusPtr` (#size PgfTypeHypo)) hs + + freeHypos 0 ptr = return () + freeHypos n ptr = do + (#peek PgfTypeHypo, cid) ptr >>= free + (#peek PgfTypeHypo, type) ptr >>= freeStablePtr + freeHypos (n-1) (ptr `plusPtr` (#size PgfTypeHypo)) + diff --git a/src/runtime/haskell/PGF2/Transactions.hsc b/src/runtime/haskell/PGF2/Transactions.hsc index e67027f97..230d00bed 100644 --- a/src/runtime/haskell/PGF2/Transactions.hsc +++ b/src/runtime/haskell/PGF2/Transactions.hsc @@ -5,6 +5,8 @@ module PGF2.Transactions , checkoutPGF , createFunction , dropFunction + , createCategory + , dropCategory ) where import PGF2.FFI @@ -116,3 +118,15 @@ dropFunction :: Fun -> Transaction () dropFunction name = Transaction $ \c_db c_revision c_exn -> withText name $ \c_name -> do pgf_drop_function c_db c_revision c_name c_exn + +createCategory :: Fun -> [Hypo] -> Float -> Transaction () +createCategory name hypos prob = Transaction $ \c_db c_revision c_exn -> + withText name $ \c_name -> + withHypos hypos $ \n_hypos c_hypos -> + withForeignPtr marshaller $ \m -> do + pgf_create_category c_db c_revision c_name n_hypos c_hypos prob m c_exn + +dropCategory :: Cat -> Transaction () +dropCategory name = Transaction $ \c_db c_revision c_exn -> + withText name $ \c_name -> do + pgf_drop_category c_db c_revision c_name c_exn diff --git a/src/runtime/haskell/tests/transactions.hs b/src/runtime/haskell/tests/transactions.hs index 3ae455518..1447bed7d 100644 --- a/src/runtime/haskell/tests/transactions.hs +++ b/src/runtime/haskell/tests/transactions.hs @@ -6,13 +6,15 @@ main = do gr1 <- readPGF "tests/basic.pgf" let Just ty = readType "(N -> N) -> P (s z)" - gr2 <- modifyPGF gr1 (createFunction "foo" ty pi) - gr3 <- branchPGF gr1 "bar_branch" (createFunction "bar" ty pi) + gr2 <- modifyPGF gr1 (createFunction "foo" ty pi >> + createCategory "Q" [(Explicit,"x",ty)] pi) + gr3 <- branchPGF gr1 "bar_branch" (createFunction "bar" ty pi >> + createCategory "R" [(Explicit,"x",ty)] pi) Just gr4 <- checkoutPGF gr1 "master" Just gr5 <- checkoutPGF gr1 "bar_branch" - gr6 <- modifyPGF gr1 (dropFunction "ind") + gr6 <- modifyPGF gr1 (dropFunction "ind" >> dropCategory "S") runTestTTAndExit $ TestList $ @@ -21,7 +23,13 @@ main = do ,TestCase (assertEqual "branched functions" ["bar","c","ind","s","z"] (functions gr3)) ,TestCase (assertEqual "checked-out extended functions" ["c","foo","ind","s","z"] (functions gr4)) ,TestCase (assertEqual "checked-out branched functions" ["bar","c","ind","s","z"] (functions gr5)) + ,TestCase (assertEqual "original categories" ["Float","Int","N","P","S","String"] (categories gr1)) + ,TestCase (assertEqual "extended categories" ["Float","Int","N","P","Q","S","String"] (categories gr2)) + ,TestCase (assertEqual "branched categories" ["Float","Int","N","P","R","S","String"] (categories gr3)) + ,TestCase (assertEqual "Q context" [(Explicit,"x",ty)] (categoryContext gr2 "Q")) + ,TestCase (assertEqual "R context" [(Explicit,"x",ty)] (categoryContext gr3 "R")) ,TestCase (assertEqual "reduced functions" ["c","s","z"] (functions gr6)) + ,TestCase (assertEqual "reduced categories" ["Float","Int","N","P","String"] (categories gr6)) ,TestCase (assertEqual "old function type" Nothing (functionType gr1 "foo")) ,TestCase (assertEqual "new function type" (Just ty) (functionType gr2 "foo")) ,TestCase (assertEqual "old function prob" (-log 0) (functionProb gr1 "foo")) diff --git a/src/runtime/python/marshaller.c b/src/runtime/python/marshaller.c index 891761b03..13425574f 100644 --- a/src/runtime/python/marshaller.c +++ b/src/runtime/python/marshaller.c @@ -167,7 +167,7 @@ PyString_AsPgfText(PyObject *pystr) // ---------------------------------------------------------------------------- -object match_lit(PgfUnmarshaller *u, PgfLiteral lit) +object match_lit(PgfMarshaller *this, PgfUnmarshaller *u, PgfLiteral lit) { PyObject *pyobj = (PyObject *)lit; @@ -187,13 +187,13 @@ object match_lit(PgfUnmarshaller *u, PgfLiteral lit) } } -object match_expr(PgfUnmarshaller *u, PgfExpr expr) +object match_expr(PgfMarshaller *this, PgfUnmarshaller *u, PgfExpr expr) { PyErr_SetString(PyExc_NotImplementedError, "match_expr not implemented"); Py_RETURN_NOTIMPLEMENTED; } -object match_type(PgfUnmarshaller *u, PgfType ty) +object match_type(PgfMarshaller *this, PgfUnmarshaller *u, PgfType ty) { // PySys_WriteStdout(">match_type<\n");