diff --git a/src/runtime/python/ffi.c b/src/runtime/python/ffi.c index 944a615b9..d308850a0 100644 --- a/src/runtime/python/ffi.c +++ b/src/runtime/python/ffi.c @@ -326,7 +326,7 @@ match_type(PgfMarshaller *this, PgfUnmarshaller *u, PgfType ty) PgfTypeHypo hypos[n_hypos]; // PgfTypeHypo *hypos = alloca(sizeof(PgfTypeHypo)*n_hypos); for (Py_ssize_t i = 0; i < n_hypos; i++) { - PyObject *hytup = (PyObject *)PyList_GetItem(type->hypos, i); + PyObject *hytup = PyList_GetItem(type->hypos, i); hypos[i].bind_type = PyLong_AsLong(PyTuple_GetItem(hytup, 0)); hypos[i].cid = PyUnicode_AsPgfText(PyTuple_GetItem(hytup, 1)); hypos[i].type = (PgfType) PyTuple_GetItem(hytup, 2); diff --git a/src/runtime/python/tests/test_transactions.py b/src/runtime/python/tests/test_transactions.py index 20cf806ad..92808420e 100644 --- a/src/runtime/python/tests/test_transactions.py +++ b/src/runtime/python/tests/test_transactions.py @@ -7,24 +7,25 @@ prob = math.pi @pytest.fixture(scope="module") def gr1(): - return readPGF("../haskell/tests/basic.pgf") + gr = readPGF("../haskell/tests/basic.pgf") + yield gr @pytest.fixture(scope="module") -def gr2(gr1): - t = gr1.newTransaction() - t.createFunction("foo", ty, 0, prob), +def gr2(): + gr = readPGF("../haskell/tests/basic.pgf") + t = gr.newTransaction() + t.createFunction("foo", ty, 0, prob) t.createCategory("Q", [(BIND_TYPE_EXPLICIT, "x", ty)], prob) - assert t.commit() - return gr1 + t.commit() + yield gr @pytest.fixture(scope="module") def gr3(): - # TODO how to avoid reloading from file? - gr1 = readPGF("../haskell/tests/basic.pgf") - with gr1.newTransaction("bar_branch") as t: - t.createFunction("bar", ty, 0, prob), + gr = readPGF("../haskell/tests/basic.pgf") + with gr.newTransaction("bar_branch") as t: + t.createFunction("bar", ty, 0, prob) t.createCategory("R", [(BIND_TYPE_EXPLICIT, "x", ty)], prob) - return gr1 + yield gr # gr1 @@ -67,5 +68,11 @@ def test_extended_function_prob(gr2): def test_branched_functions(gr3): assert gr3.functions == ["bar", "c", "ind", "s", "z"] +def test_branched_categories(gr3): + assert gr3.categories == ["Float","Int","N","P","R","S","String"] + +def test_extended_category_context(gr3): + assert gr3.categoryContext("R") == [(BIND_TYPE_EXPLICIT, "x", ty)] + def test_branched_function_type(gr3): assert gr3.functionType("bar") == ty diff --git a/src/runtime/python/transactions.c b/src/runtime/python/transactions.c index 83fbda43b..30517d33a 100644 --- a/src/runtime/python/transactions.c +++ b/src/runtime/python/transactions.c @@ -49,8 +49,9 @@ Transaction_commit(TransactionObject *self, PyObject *args) pgf_free_revision(self->pgf->db, self->pgf->revision); self->pgf->revision = self->revision; + Py_INCREF(self->pgf->db); - Py_RETURN_TRUE; + Py_RETURN_NONE; } static PyObject * @@ -107,18 +108,24 @@ Transaction_createCategory(TransactionObject *self, PyObject *args) Py_ssize_t size; PyObject *hypos; float prob = 0.0; - if (!PyArg_ParseTuple(args, "s#O!f", &s, &size, &PyList_Type, &hypos, prob)) + // if (!PyArg_ParseTuple(args, "s#O!f", &s, &size, &PyList_Type, &hypos, prob)) // segfaults in Python 3.8 but not 3.7 + // return NULL; + if (!PyArg_ParseTuple(args, "s#Of", &s, &size, &hypos, prob)) return NULL; + if (!PyObject_TypeCheck(hypos, &PyList_Type)) { + PyErr_SetString(PyExc_TypeError, "hypos must be a list"); + return NULL; + } PgfText *catname = (PgfText *)PyMem_Malloc(sizeof(PgfText)+size+1); memcpy(catname->text, s, size+1); catname->size = size; Py_ssize_t n_hypos = PyList_Size(hypos); - // PgfTypeHypo context[n_hypos]; - PgfTypeHypo *context = alloca(sizeof(PgfTypeHypo)*n_hypos); + PgfTypeHypo context[n_hypos]; + // PgfTypeHypo *context = alloca(sizeof(PgfTypeHypo)*n_hypos); for (Py_ssize_t i = 0; i < n_hypos; i++) { - PyObject *hytup = (PyObject *)PyList_GetItem(hypos, i); + PyObject *hytup = PyList_GetItem(hypos, i); context[i].bind_type = PyLong_AsLong(PyTuple_GetItem(hytup, 0)); context[i].cid = PyUnicode_AsPgfText(PyTuple_GetItem(hytup, 1)); context[i].type = (PgfType) PyTuple_GetItem(hytup, 2); @@ -160,35 +167,51 @@ Transaction_dropCategory(TransactionObject *self, PyObject *args) static TransactionObject * Transaction_enter(TransactionObject *self, PyObject *Py_UNUSED(ignored)) { + Py_INCREF(self); return self; } +static PyObject * +Transaction_exit_impl(TransactionObject *self, PyObject *exc_type, PyObject *exc_value, PyObject *exc_tb) +{ + if (exc_type == Py_None && exc_value == Py_None && exc_tb == Py_None) { + return Transaction_commit(self, NULL); + } else { + PyErr_SetObject(exc_type, exc_value); + return NULL; + } +} + +// cpython/Modules/_multiprocessing/clinic/semaphore.c.h +// cpython/Modules/_sqlite/connection.c static PyObject * Transaction_exit(TransactionObject *self, PyObject *const *args, Py_ssize_t nargs) { - // PyObject *exc_type = Py_None; - // PyObject *exc_value = Py_None; - // PyObject *exc_tb = Py_None; + PyObject *return_value = NULL; + PyObject *exc_type = Py_None; + PyObject *exc_value = Py_None; + PyObject *exc_tb = Py_None; - // if (!_PyArg_CheckPositional("__exit__", nargs, 0, 3)) { - // Py_RETURN_FALSE; - // } + if (nargs < 0 || nargs > 3) { + goto exit; + } if (nargs < 1) { goto skip_optional; } - // exc_type = args[0]; - // if (nargs < 2) { - // goto skip_optional; - // } - // exc_value = args[1]; - // if (nargs < 3) { - // goto skip_optional; - // } - // exc_tb = args[2]; + exc_type = args[0]; + if (nargs < 2) { + goto skip_optional; + } + exc_value = args[1]; + if (nargs < 3) { + goto skip_optional; + } + exc_tb = args[2]; skip_optional: - // TODO check exception + return_value = Transaction_exit_impl(self, exc_type, exc_value, exc_tb); - return Transaction_commit(self, NULL); +exit: + return return_value; } // static void