diff --git a/src/runtime/python/expr.c b/src/runtime/python/expr.c index 1e02682ee..7fc0862f9 100644 --- a/src/runtime/python/expr.c +++ b/src/runtime/python/expr.c @@ -239,6 +239,74 @@ Expr_call(ExprObject* self, PyObject* args, PyObject* kw) return res; } +static PyObject* +Expr_visit(ExprObject* self, PyObject *args) +{ + PyObject* py_visitor = NULL; + if (!PyArg_ParseTuple(args, "O", &py_visitor)) + return NULL; + + if (PyObject_TypeCheck(self, &pgf_ExprTypedType)) { + self = ((ExprTypedObject *) self)->expr; + } + + ExprObject *o = (ExprObject *) self; + size_t arity = 0; + while (PyObject_TypeCheck(o, &pgf_ExprAppType)) { + arity++; + o = ((ExprAppObject *) o)->fun; + if (PyObject_TypeCheck(o, &pgf_ExprTypedType)) { + o = ((ExprTypedObject *) o)->expr; + } + } + if (PyObject_TypeCheck(o, &pgf_ExprFunType)) { + Py_ssize_t len; + const char *text = PyUnicode_AsUTF8AndSize(((ExprFunObject *) o)->name, &len); + if (text == NULL) + return NULL; + + char* method_name = alloca(len+4); + strcpy(method_name, "on_"); + memcpy(method_name+3, text, len+1); + if (PyObject_HasAttrString(py_visitor, method_name)) { + PyObject* method_args = PyTuple_New(arity); + if (method_args == NULL) { + return NULL; + } + + o = (ExprObject *) self; + for (size_t i = 0; i < arity; i++) { + ExprObject *arg = ((ExprAppObject *) o)->arg; + if (PyObject_TypeCheck(arg, &pgf_ExprImplArgType)) { + arg = ((ExprImplArgObject *) arg)->expr; + } + + Py_INCREF(arg); + if (PyTuple_SetItem(method_args, i, (PyObject *) arg) == -1) { + Py_DECREF(method_args); + return NULL; + } + + o = ((ExprAppObject *) o)->fun; + if (PyObject_TypeCheck(o, &pgf_ExprTypedType)) { + o = ((ExprTypedObject *) o)->expr; + } + } + + PyObject* method = + PyObject_GetAttrString(py_visitor, method_name); + if (method == NULL) { + Py_DECREF(method_args); + return NULL; + } + + return PyObject_CallObject(method, method_args); + } + } + + return PyObject_CallMethod(py_visitor, "default", "O", self); +} + static PyObject* Expr_reduce_ex(ExprObject* self, PyObject *args) { @@ -270,6 +338,13 @@ Expr_reduce_ex(ExprObject* self, PyObject *args) } static PyMethodDef Expr_methods[] = { + {"visit", (PyCFunction)Expr_visit, METH_VARARGS, + "Implementation of the visitor pattern for abstract syntax trees. " + "If e is an expression equal to f a1 .. an then " + "e.visit(self) calls method self.on_f(a1,..an). " + "If the method doesn't exist then the method self.default(e) " + "is called." + }, {"__reduce_ex__", (PyCFunction)Expr_reduce_ex, METH_VARARGS, "This method allows for transparent pickling/unpickling of expressions." },