type inference for lambda abstractions

This commit is contained in:
Krasimir Angelov
2025-04-16 12:29:50 +02:00
parent a9d4fecd33
commit dbfa9e4faf
2 changed files with 82 additions and 50 deletions

View File

@@ -3,6 +3,45 @@
#include "typechecker.h"
class PgfTypechecker::Unmarshaller1 : public PgfUnmarshaller {
PgfTypechecker *tc;
virtual PgfExpr eabs(PgfBindType bind_type, PgfText *name, PgfExpr body) { return 0; }
virtual PgfExpr eapp(PgfExpr fun, PgfExpr arg) { return 0; }
virtual PgfExpr elit(PgfLiteral lit) { return 0; }
virtual PgfExpr emeta(PgfMetaId meta_id) { return 0; }
virtual PgfExpr efun(PgfText *name) { return 0; }
virtual PgfExpr evar(int index) { return 0; }
virtual PgfExpr etyped(PgfExpr expr, PgfType ty) { return 0; }
virtual PgfExpr eimplarg(PgfExpr expr) { return 0; }
virtual PgfLiteral lint(size_t size, uintmax_t *val) { return 0; }
virtual PgfLiteral lflt(double val) { return 0; }
virtual PgfLiteral lstr(PgfText *val) { return 0; }
virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos,
PgfText *name,
size_t n_exprs, PgfExpr *exprs)
{
Type *ty = new(name) Cat;
tc->temps.push_back(ty);
while (n_hypos > 0) {
PgfTypeHypo *hypo = &hypos[--n_hypos];
ty = new(hypo->bind_type, hypo->cid, (Type*) hypo->type, ty) Pi;
tc->temps.push_back(ty);
}
return (PgfType) ty;
}
virtual void free_ref(object x) { }
public:
Unmarshaller1(PgfTypechecker *tc) {
this->tc = tc;
}
};
class PgfTypechecker::Unmarshaller2 : public PgfUnmarshaller {
PgfMarshaller *m;
virtual PgfExpr eabs(PgfBindType bind_type, PgfText *name, PgfExpr body) { return 0; }
virtual PgfExpr eapp(PgfExpr fun, PgfExpr arg) { return 0; }
virtual PgfExpr elit(PgfLiteral lit) { return 0; }
@@ -15,29 +54,6 @@ class PgfTypechecker::Unmarshaller1 : public PgfUnmarshaller {
virtual PgfLiteral lflt(double val) { return 0; }
virtual PgfLiteral lstr(PgfText *val) { return 0; }
virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos,
PgfText *name,
size_t n_exprs, PgfExpr *exprs)
{
Type *ty = new(name) Cat;
while (n_hypos > 0) {
PgfTypeHypo *hypo = &hypos[--n_hypos];
ty = new(hypo->bind_type, hypo->cid, (Type*) hypo->type, ty) Pi;
}
return (PgfType) ty;
}
virtual void free_ref(object x) { }
};
class PgfTypechecker::Unmarshaller2 : public Unmarshaller1 {
PgfMarshaller *m;
public:
Unmarshaller2(PgfMarshaller *m) {
this->m = m;
}
virtual PgfType dtyp(size_t n_hypos, PgfTypeHypo *hypos,
PgfText *name,
size_t n_exprs, PgfExpr *exprs)
@@ -52,6 +68,11 @@ public:
}
virtual void free_ref(object x) { }
public:
Unmarshaller2(PgfMarshaller *m) {
this->m = m;
}
};
PgfType PgfTypechecker::marshall_type(Type *ty, PgfUnmarshaller *u)
@@ -102,24 +123,37 @@ PgfExpr PgfTypechecker::Context::eabs(PgfBindType btype, PgfText *name, PgfExpr
if (!checkImplArgument())
return 0;
if (exp_type == NULL) {
return tc->type_error("Cannot infer the type of a lambda abstraction");
}
Pi *pi = exp_type->is_pi();
if (!pi) {
return tc->type_error("A lambda abstraction must have a function type");
}
Type *res;
Scope new_scope;
new_scope.tail=scope;
new_scope.var=name;
new_scope.ty=pi->arg;
Context body_ctxt(tc,&new_scope,pi->res);
if (exp_type != NULL) {
Pi *pi = exp_type->is_pi();
if (!pi) {
return tc->type_error("A lambda abstraction must have a function type");
}
res = pi->res;
new_scope.ty=pi->arg;
} else {
res = NULL;
new_scope.ty=NULL;
}
Context body_ctxt(tc,&new_scope,res);
body = tc->m->match_expr(&body_ctxt, body);
if (body == 0)
return 0;
if (new_scope.ty == NULL) {
tc->u->free_ref(body);
return tc->type_error("Cannot infer the type of a lambda variable");
}
PgfText *wild = string2text("_");
inf_type = new(btype, wild, new_scope.ty, body_ctxt.inf_type) Pi;
free(wild);
tc->temps.push_back(inf_type);
return tc->u->eabs(btype,name,body);
}
@@ -223,9 +257,8 @@ PgfExpr PgfTypechecker::Context::efun(PgfText *name)
if (absfun == 0)
return tc->type_error("Function %s is not defined", name->text);
Unmarshaller1 tu;
Unmarshaller1 tu(tc);
inf_type = (Type*) tc->db_m.match_type(&tu, absfun->type.as_object());
tc->temps.push_back(inf_type);
PgfExpr e = tc->u->efun(name);
@@ -252,15 +285,21 @@ PgfExpr PgfTypechecker::Context::evar(int index)
return tc->type_error("Cannot type check an open expression (de Bruijn index %d)", index);
}
inf_type = s->ty;
PgfExpr e = tc->u->evar(index);
if (!unifyTypes(&e)) {
tc->u->free_ref(e);
return 0;
inf_type = s->ty;
if (inf_type == NULL) {
if (exp_type == NULL) {
tc->u->free_ref(e);
return tc->type_error("Cannot infer the type of a lambda variable");
}
s->ty = exp_type;
} else {
if (!unifyTypes(&e)) {
tc->u->free_ref(e);
return 0;
}
}
return e;
}
@@ -365,10 +404,9 @@ PgfType PgfTypechecker::Context::dtyp(size_t n_hypos, PgfTypeHypo *hypos,
size_t i, j;
for (i = 0, j = 0; i < n_new_exprs && j < n_exprs; i++) {
Unmarshaller1 tu;
Unmarshaller1 tu(tc);
ref<PgfHypo> hypo = abscat->context.elem(i);
Type *ty = (Type *) tc->db_m.match_type(&tu,hypo->type.as_object());
tc->temps.push_back(ty);
Context expr_ctxt(tc,scope,ty,hypo->bind_type);
new_exprs[i] = tc->m->match_expr(&expr_ctxt, exprs[j]);
if (new_exprs[i] == 0) {

View File

@@ -16,7 +16,6 @@ class PGF_INTERNAL_DECL PgfTypechecker {
struct Type {
virtual Pi *is_pi() { return NULL; }
virtual Cat *is_cat() { return NULL; }
virtual ~Type() {}
};
struct Pi : Type {
@@ -33,11 +32,6 @@ class PGF_INTERNAL_DECL PgfTypechecker {
return pi;
}
virtual Pi *is_pi() { return this; }
virtual ~Pi() {
delete arg;
delete res;
}
};
struct Cat : Type {