diff --git a/src/runtime/c/pgf/expr.cxx b/src/runtime/c/pgf/expr.cxx index c3ee2e0cc..6692ee50f 100644 --- a/src/runtime/c/pgf/expr.cxx +++ b/src/runtime/c/pgf/expr.cxx @@ -211,7 +211,7 @@ PgfType PgfDBUnmarshaller::dtyp(int n_hypos, PgfTypeHypo *hypos, } ty->exprs = vector_new(n_exprs); for (size_t i = 0; i < n_exprs; i++) { - vector_elem(ty->exprs,i) = m->match_expr(this, exprs[i]); + *vector_elem(ty->exprs,i) = m->match_expr(this, exprs[i]); } return ty.as_object(); diff --git a/src/runtime/c/pgf/pgf.cxx b/src/runtime/c/pgf/pgf.cxx index e03ebf7a4..a08ce83a1 100644 --- a/src/runtime/c/pgf/pgf.cxx +++ b/src/runtime/c/pgf/pgf.cxx @@ -179,7 +179,7 @@ void pgf_iter_categories(PgfDB *db, PgfRevision revision, DB_scope scope(db, READER_SCOPE); ref pgf = revision; - err->type = PGF_EXN_NONE; + pgf_exn_clear(err); namespace_iter(pgf->abstract.cats, itor, err); } @@ -270,7 +270,7 @@ void pgf_iter_functions(PgfDB *db, PgfRevision revision, DB_scope scope(db, READER_SCOPE); ref pgf = revision; - err->type = PGF_EXN_NONE; + pgf_exn_clear(err); namespace_iter(pgf->abstract.funs, itor, err); } @@ -302,7 +302,7 @@ void pgf_iter_functions_by_cat(PgfDB *db, PgfRevision revision, helper.cat = cat; helper.itor = itor; - err->type = PGF_EXN_NONE; + pgf_exn_clear(err); namespace_iter(pgf->abstract.funs, &helper, err); } @@ -395,6 +395,53 @@ PgfType pgf_read_type(PgfText *input, PgfUnmarshaller *u) return res; } +PGF_API_DECL +PgfRevision pgf_clone_revision(PgfDB *db, PgfRevision revision, + PgfExn *err) +{ + DB_scope scope(db, WRITER_SCOPE); + + pgf_exn_clear(err); + + try { + ref pgf = revision; + + ref new_pgf = PgfDB::malloc(); + new_pgf->major_version = pgf->major_version; + new_pgf->minor_version = pgf->minor_version; + + new_pgf->gflags = pgf->gflags; + if (pgf->gflags != 0) + pgf->gflags->ref_count++; + + new_pgf->abstract.name = + PgfDB::malloc(sizeof(PgfText)+pgf->abstract.name->size+1); + memcpy(new_pgf->abstract.name, pgf->abstract.name, sizeof(PgfText)+pgf->abstract.name->size+1); + + new_pgf->abstract.aflags = pgf->abstract.aflags; + if (pgf->abstract.aflags != 0) + pgf->abstract.aflags->ref_count++; + + new_pgf->abstract.funs = pgf->abstract.funs; + if (pgf->abstract.funs != 0) + pgf->abstract.funs->ref_count++; + + new_pgf->abstract.cats = pgf->abstract.cats; + if (pgf->abstract.cats != 0) + pgf->abstract.cats->ref_count++; + + return new_pgf.as_object(); + } catch (std::system_error& e) { + err->type = PGF_EXN_SYSTEM_ERROR; + err->code = e.code().value(); + } catch (pgf_error& e) { + err->type = PGF_EXN_PGF_ERROR; + err->msg = strdup(e.what()); + } + + return 0; +} + PGF_API void pgf_create_function(PgfDB *db, PgfRevision revision, PgfText *name, @@ -404,21 +451,31 @@ void pgf_create_function(PgfDB *db, PgfRevision revision, { DB_scope scope(db, WRITER_SCOPE); - PgfDBUnmarshaller u(m); + pgf_exn_clear(err); - ref pgf = revision; - ref absfun = PgfDB::malloc(sizeof(PgfAbsFun)+name->size+1); - absfun->type = m->match_type(&u, ty); - absfun->arity = 0; - absfun->defns = 0; - absfun->ep.prob = prob; - ref efun = - ref::from_ptr((PgfExprFun*) &absfun->name); - absfun->ep.expr = ref::tagged(efun); - memcpy(&absfun->name, name, sizeof(PgfText)+name->size+1); - - Namespace nmsp = - namespace_insert(pgf->abstract.funs, absfun); - namespace_release(pgf->abstract.funs); - pgf->abstract.funs = nmsp; + try { + PgfDBUnmarshaller u(m); + + ref pgf = revision; + ref absfun = PgfDB::malloc(sizeof(PgfAbsFun)+name->size+1); + absfun->type = m->match_type(&u, ty); + absfun->arity = 0; + absfun->defns = 0; + absfun->ep.prob = prob; + ref efun = + ref::from_ptr((PgfExprFun*) &absfun->name); + absfun->ep.expr = ref::tagged(efun); + memcpy(&absfun->name, name, sizeof(PgfText)+name->size+1); + + Namespace nmsp = + namespace_insert(pgf->abstract.funs, absfun); + namespace_release(pgf->abstract.funs); + pgf->abstract.funs = nmsp; + } catch (std::system_error& e) { + err->type = PGF_EXN_SYSTEM_ERROR; + err->code = e.code().value(); + } catch (pgf_error& e) { + err->type = PGF_EXN_PGF_ERROR; + err->msg = strdup(e.what()); + } } diff --git a/src/runtime/c/pgf/pgf.h b/src/runtime/c/pgf/pgf.h index f64d8b1fc..8405d8a8f 100644 --- a/src/runtime/c/pgf/pgf.h +++ b/src/runtime/c/pgf/pgf.h @@ -308,6 +308,10 @@ PgfText *pgf_print_type(PgfType ty, PGF_API_DECL PgfType pgf_read_type(PgfText *input, PgfUnmarshaller *u); +PGF_API_DECL +PgfRevision pgf_clone_revision(PgfDB *db, PgfRevision revision, + PgfExn *err); + PGF_API_DECL void pgf_create_function(PgfDB *db, PgfRevision revision, PgfText *name, diff --git a/src/runtime/haskell/PGF2.hsc b/src/runtime/haskell/PGF2.hsc index 34962fb38..b5ac81c06 100644 --- a/src/runtime/haskell/PGF2.hsc +++ b/src/runtime/haskell/PGF2.hsc @@ -40,8 +40,6 @@ module PGF2 (-- * PGF mkType, unType, mkHypo, mkDepHypo, mkImplHypo, - createFunction, - -- * Concrete syntax ConcName, @@ -49,13 +47,13 @@ module PGF2 (-- * PGF PGFError(..) ) where -import Control.Exception(mask_,bracket) -import System.IO.Unsafe(unsafePerformIO) import PGF2.Expr import PGF2.FFI import Foreign import Foreign.C +import Control.Exception(mask_,bracket) +import System.IO.Unsafe(unsafePerformIO) import qualified Foreign.Concurrent as C import qualified Data.Map as Map import Data.IORef @@ -336,12 +334,3 @@ readType str = else do ty <- deRefStablePtr c_ty freeStablePtr c_ty return (Just ty) - -createFunction :: PGF -> Fun -> Type -> Float -> IO () -createFunction p name ty prob = - withForeignPtr (a_db p) $ \c_db -> - withForeignPtr (revision p) $ \c_revision -> - withText name $ \c_name -> - bracket (newStablePtr ty) freeStablePtr $ \c_ty -> - withForeignPtr marshaller $ \m -> do - pgf_create_function c_db c_revision c_name c_ty prob m diff --git a/src/runtime/haskell/PGF2/FFI.hsc b/src/runtime/haskell/PGF2/FFI.hsc index 52f275227..450500097 100644 --- a/src/runtime/haskell/PGF2/FFI.hsc +++ b/src/runtime/haskell/PGF2/FFI.hsc @@ -108,8 +108,11 @@ foreign import ccall "pgf/expr.h pgf_function_is_constructor" foreign import ccall "pgf_function_prob" pgf_function_prob :: Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfText -> IO (#type prob_t) +foreign import ccall "pgf_clone_revision" + pgf_clone_revision :: Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfExn -> IO (Ptr PgfRevision) + foreign import ccall "pgf_create_function" - pgf_create_function :: Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfText -> StablePtr Type -> (#type prob_t) -> Ptr PgfMarshaller -> IO () + pgf_create_function :: Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfText -> StablePtr Type -> (#type prob_t) -> Ptr PgfMarshaller -> Ptr PgfExn -> IO () ----------------------------------------------------------------------- diff --git a/src/runtime/haskell/PGF2/Transactions.hsc b/src/runtime/haskell/PGF2/Transactions.hsc new file mode 100644 index 000000000..2800e5fc0 --- /dev/null +++ b/src/runtime/haskell/PGF2/Transactions.hsc @@ -0,0 +1,55 @@ +module PGF2.Transactions + ( Transaction + , modifyPGF + , createFunction + ) where + +import PGF2.FFI +import PGF2.Expr + +import Foreign +import Foreign.C +import qualified Foreign.Concurrent as C +import Control.Exception(bracket) + +#include + +newtype Transaction a = + Transaction (Ptr PgfDB -> Ptr PgfRevision -> Ptr PgfExn -> IO a) + +instance Functor Transaction where + fmap f (Transaction g) = Transaction $ \c_db c_revision c_exn -> do + res <- g c_db c_revision c_exn + return (f res) + +instance Applicative Transaction where + pure x = Transaction $ \c_db c_revision c_exn -> return x + +instance Monad Transaction where + (Transaction f) >>= g = Transaction $ \c_db c_revision c_exn -> do + res <- f c_db c_revision c_exn + ex_type <- (#peek PgfExn, type) c_exn + if (ex_type :: (#type PgfExnType)) == (#const PGF_EXN_NONE) + then case g res of + Transaction g -> g c_db c_revision c_exn + else return undefined + +modifyPGF :: PGF -> Transaction a -> IO PGF +modifyPGF p (Transaction f) = + withForeignPtr (a_db p) $ \c_db -> + withForeignPtr (revision p) $ \c_revision -> + withPgfExn "" $ \c_exn -> do + c_revision <- pgf_clone_revision c_db c_revision c_exn + ex_type <- (#peek PgfExn, type) c_exn + if (ex_type :: (#type PgfExnType)) == (#const PGF_EXN_NONE) + then do f c_db c_revision c_exn + fptr2 <- C.newForeignPtr c_revision (withForeignPtr (a_db p) (\c_db -> pgf_free_revision c_db c_revision)) + return (PGF (a_db p) fptr2 (langs p)) + else return p + +createFunction :: Fun -> Type -> Float -> Transaction () +createFunction name ty prob = Transaction $ \c_db c_revision c_exn -> + withText name $ \c_name -> + bracket (newStablePtr ty) freeStablePtr $ \c_ty -> + withForeignPtr marshaller $ \m -> do + pgf_create_function c_db c_revision c_name c_ty prob m c_exn diff --git a/src/runtime/haskell/pgf2.cabal b/src/runtime/haskell/pgf2.cabal index 3364cb7b9..beb137b03 100644 --- a/src/runtime/haskell/pgf2.cabal +++ b/src/runtime/haskell/pgf2.cabal @@ -22,6 +22,7 @@ extra-source-files: library exposed-modules: PGF2, + PGF2.Transactions, PGF2.Internal, -- backwards compatibility API: PGF @@ -53,3 +54,12 @@ test-suite basic random, directory, pgf2 + +test-suite transactions + type: exitcode-stdio-1.0 + main-is: tests/transactions.hs + default-language: Haskell2010 + build-depends: + base, + HUnit, + pgf2 diff --git a/src/runtime/haskell/tests/basic.hs b/src/runtime/haskell/tests/basic.hs index f5a40d752..0a5783e67 100644 --- a/src/runtime/haskell/tests/basic.hs +++ b/src/runtime/haskell/tests/basic.hs @@ -74,11 +74,8 @@ main = do print (e :: SomeException) gr1 <- readPGF "tests/basic.pgf" - print (abstractName gr1) gr2 <- bootNGF "tests/basic.pgf" "tests/basic.ngf" - print (abstractName gr2) gr3 <- readNGF "tests/basic.ngf" - print (abstractName gr3) rp1 <- testLoadFailure (readPGF "non-existing.pgf") rp2 <- testLoadFailure (readPGF "tests/basic.gf") diff --git a/src/runtime/haskell/tests/transactions.hs b/src/runtime/haskell/tests/transactions.hs new file mode 100644 index 000000000..e48039d05 --- /dev/null +++ b/src/runtime/haskell/tests/transactions.hs @@ -0,0 +1,18 @@ +import Test.HUnit +import PGF2 +import PGF2.Transactions + +main = do + gr1 <- readPGF "tests/basic.pgf" + let Just ty = readType "(N -> N) -> P (s z)" + gr2 <- modifyPGF gr1 (createFunction "foo" ty pi) + + runTestTTAndExit $ + TestList $ + [TestCase (assertEqual "original functions" ["c","ind","s","z"] (functions gr1)) + ,TestCase (assertEqual "extended functions" ["c","foo","ind","s","z"] (functions gr2)) + ,TestCase (assertEqual "old function type" Nothing (functionType gr1 "foo")) + ,TestCase (assertEqual "new function type" (Just ty) (functionType gr2 "foo")) + ,TestCase (assertEqual "old function prob" (-log 0) (functionProb gr1 "foo")) + ,TestCase (assertEqual "new function prob" pi (functionProb gr2 "foo")) + ]