From dbfa9e4faf124bde3c79f2dd8b804cf48b78d6a1 Mon Sep 17 00:00:00 2001 From: Krasimir Angelov Date: Wed, 16 Apr 2025 12:29:50 +0200 Subject: [PATCH] type inference for lambda abstractions --- src/runtime/c/pgf/typechecker.cxx | 126 +++++++++++++++++++----------- src/runtime/c/pgf/typechecker.h | 6 -- 2 files changed, 82 insertions(+), 50 deletions(-) diff --git a/src/runtime/c/pgf/typechecker.cxx b/src/runtime/c/pgf/typechecker.cxx index a0e5a3973..bcb8e194e 100644 --- a/src/runtime/c/pgf/typechecker.cxx +++ b/src/runtime/c/pgf/typechecker.cxx @@ -3,6 +3,45 @@ #include "typechecker.h" class PgfTypechecker::Unmarshaller1 : public PgfUnmarshaller { + PgfTypechecker *tc; + + virtual PgfExpr eabs(PgfBindType bind_type, PgfText *name, PgfExpr body) { return 0; } + virtual PgfExpr eapp(PgfExpr fun, PgfExpr arg) { return 0; } + virtual PgfExpr elit(PgfLiteral lit) { return 0; } + virtual PgfExpr emeta(PgfMetaId meta_id) { return 0; } + virtual PgfExpr efun(PgfText *name) { return 0; } + virtual PgfExpr evar(int index) { return 0; } + virtual PgfExpr etyped(PgfExpr expr, PgfType ty) { return 0; } + virtual PgfExpr eimplarg(PgfExpr expr) { return 0; } + virtual PgfLiteral lint(size_t size, uintmax_t *val) { return 0; } + virtual PgfLiteral lflt(double val) { return 0; } + virtual PgfLiteral lstr(PgfText *val) { return 0; } + + virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos, + PgfText *name, + size_t n_exprs, PgfExpr *exprs) + { + Type *ty = new(name) Cat; + tc->temps.push_back(ty); + while (n_hypos > 0) { + PgfTypeHypo *hypo = &hypos[--n_hypos]; + ty = new(hypo->bind_type, hypo->cid, (Type*) hypo->type, ty) Pi; + tc->temps.push_back(ty); + } + return (PgfType) ty; + } + + virtual void free_ref(object x) { } + +public: + Unmarshaller1(PgfTypechecker *tc) { + this->tc = tc; + } +}; + +class PgfTypechecker::Unmarshaller2 : public PgfUnmarshaller { + PgfMarshaller *m; + virtual PgfExpr eabs(PgfBindType bind_type, PgfText *name, PgfExpr body) { return 0; } virtual PgfExpr eapp(PgfExpr fun, PgfExpr arg) { return 0; } virtual PgfExpr elit(PgfLiteral lit) { return 0; } @@ -15,29 +54,6 @@ class PgfTypechecker::Unmarshaller1 : public PgfUnmarshaller { virtual PgfLiteral lflt(double val) { return 0; } virtual PgfLiteral lstr(PgfText *val) { return 0; } - virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos, - PgfText *name, - size_t n_exprs, PgfExpr *exprs) - { - Type *ty = new(name) Cat; - while (n_hypos > 0) { - PgfTypeHypo *hypo = &hypos[--n_hypos]; - ty = new(hypo->bind_type, hypo->cid, (Type*) hypo->type, ty) Pi; - } - return (PgfType) ty; - } - - virtual void free_ref(object x) { } -}; - -class PgfTypechecker::Unmarshaller2 : public Unmarshaller1 { - PgfMarshaller *m; - -public: - Unmarshaller2(PgfMarshaller *m) { - this->m = m; - } - virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos, PgfText *name, size_t n_exprs, PgfExpr *exprs) @@ -52,6 +68,11 @@ public: } virtual void free_ref(object x) { } + +public: + Unmarshaller2(PgfMarshaller *m) { + this->m = m; + } }; PgfType PgfTypechecker::marshall_type(Type *ty, PgfUnmarshaller *u) @@ -102,24 +123,37 @@ PgfExpr PgfTypechecker::Context::eabs(PgfBindType btype, PgfText *name, PgfExpr if (!checkImplArgument()) return 0; - if (exp_type == NULL) { - return tc->type_error("Cannot infer the type of a lambda abstraction"); - } - - Pi *pi = exp_type->is_pi(); - if (!pi) { - return tc->type_error("A lambda abstraction must have a function type"); - } - + Type *res; Scope new_scope; new_scope.tail=scope; new_scope.var=name; - new_scope.ty=pi->arg; - Context body_ctxt(tc,&new_scope,pi->res); + if (exp_type != NULL) { + Pi *pi = exp_type->is_pi(); + if (!pi) { + return tc->type_error("A lambda abstraction must have a function type"); + } + res = pi->res; + new_scope.ty=pi->arg; + } else { + res = NULL; + new_scope.ty=NULL; + } + + Context body_ctxt(tc,&new_scope,res); body = tc->m->match_expr(&body_ctxt, body); if (body == 0) return 0; + if (new_scope.ty == NULL) { + tc->u->free_ref(body); + return tc->type_error("Cannot infer the type of a lambda variable"); + } + + PgfText *wild = string2text("_"); + inf_type = new(btype, wild, new_scope.ty, body_ctxt.inf_type) Pi; + free(wild); + tc->temps.push_back(inf_type); + return tc->u->eabs(btype,name,body); } @@ -223,9 +257,8 @@ PgfExpr PgfTypechecker::Context::efun(PgfText *name) if (absfun == 0) return tc->type_error("Function %s is not defined", name->text); - Unmarshaller1 tu; + Unmarshaller1 tu(tc); inf_type = (Type*) tc->db_m.match_type(&tu, absfun->type.as_object()); - tc->temps.push_back(inf_type); PgfExpr e = tc->u->efun(name); @@ -252,15 +285,21 @@ PgfExpr PgfTypechecker::Context::evar(int index) return tc->type_error("Cannot type check an open expression (de Bruijn index %d)", index); } - inf_type = s->ty; - PgfExpr e = tc->u->evar(index); - if (!unifyTypes(&e)) { - tc->u->free_ref(e); - return 0; + inf_type = s->ty; + if (inf_type == NULL) { + if (exp_type == NULL) { + tc->u->free_ref(e); + return tc->type_error("Cannot infer the type of a lambda variable"); + } + s->ty = exp_type; + } else { + if (!unifyTypes(&e)) { + tc->u->free_ref(e); + return 0; + } } - return e; } @@ -365,10 +404,9 @@ PgfType PgfTypechecker::Context::dtyp(size_t n_hypos, PgfTypeHypo *hypos, size_t i, j; for (i = 0, j = 0; i < n_new_exprs && j < n_exprs; i++) { - Unmarshaller1 tu; + Unmarshaller1 tu(tc); ref hypo = abscat->context.elem(i); Type *ty = (Type *) tc->db_m.match_type(&tu,hypo->type.as_object()); - tc->temps.push_back(ty); Context expr_ctxt(tc,scope,ty,hypo->bind_type); new_exprs[i] = tc->m->match_expr(&expr_ctxt, exprs[j]); if (new_exprs[i] == 0) { diff --git a/src/runtime/c/pgf/typechecker.h b/src/runtime/c/pgf/typechecker.h index dd062888c..e7db91ade 100644 --- a/src/runtime/c/pgf/typechecker.h +++ b/src/runtime/c/pgf/typechecker.h @@ -16,7 +16,6 @@ class PGF_INTERNAL_DECL PgfTypechecker { struct Type { virtual Pi *is_pi() { return NULL; } virtual Cat *is_cat() { return NULL; } - virtual ~Type() {} }; struct Pi : Type { @@ -33,11 +32,6 @@ class PGF_INTERNAL_DECL PgfTypechecker { return pi; } virtual Pi *is_pi() { return this; } - - virtual ~Pi() { - delete arg; - delete res; - } }; struct Cat : Type {