forked from GitHub/gf-core
the first approximation for a statistical model consistent with dependent types in the abstract syntax
This commit is contained in:
@@ -9,14 +9,17 @@ module PGF.Probabilistic
|
|||||||
|
|
||||||
, probTree
|
, probTree
|
||||||
, rankTreesByProbs
|
, rankTreesByProbs
|
||||||
|
, mkProbDefs
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import PGF.CId
|
import PGF.CId
|
||||||
import PGF.Data
|
import PGF.Data
|
||||||
import PGF.Macros
|
import PGF.Macros
|
||||||
|
import PGF.Type
|
||||||
|
import PGF.Expr
|
||||||
|
|
||||||
import qualified Data.Map as Map
|
import qualified Data.Map as Map
|
||||||
import Data.List (sortBy,partition)
|
import Data.List (sortBy,partition,nub,mapAccumL)
|
||||||
import Data.Maybe (fromMaybe, fromJust)
|
import Data.Maybe (fromMaybe, fromJust)
|
||||||
|
|
||||||
-- | An abstract data structure which represents
|
-- | An abstract data structure which represents
|
||||||
@@ -99,3 +102,179 @@ rankTreesByProbs pgf ts = sortBy (\ (_,p) (_,q) -> compare q p)
|
|||||||
[(t, probTree pgf t) | t <- ts]
|
[(t, probTree pgf t) | t <- ts]
|
||||||
|
|
||||||
|
|
||||||
|
mkProbDefs :: PGF -> ([[CId]],[(CId,Type,[Equation])])
|
||||||
|
mkProbDefs pgf =
|
||||||
|
let cs = [(c,hyps,fns) | (c,(hyps0,fs,_)) <- Map.toList (cats (abstract pgf)),
|
||||||
|
not (elem c [cidString,cidInt,cidFloat]),
|
||||||
|
let hyps = zipWith (\(bt,_,ty) n -> (bt,mkCId ('v':show n),ty))
|
||||||
|
hyps0
|
||||||
|
[1..]
|
||||||
|
fns = [(f,ty) | (_,f) <- fs,
|
||||||
|
let Just (ty,_,_,_,_) = Map.lookup f (funs (abstract pgf))]
|
||||||
|
]
|
||||||
|
((_,css),eqss) = mapAccumL (\(ngen,css) (c,hyps,fns) ->
|
||||||
|
let st0 = (1,Map.empty)
|
||||||
|
((_,eqs_map),cs) = computeConstrs pgf st0 [(fn,[],es) | (fn,(DTyp _ _ es)) <- fns]
|
||||||
|
(ngen', eqs) = mapAccumL (mkEquation eqs_map hyps) ngen fns
|
||||||
|
ceqs = [(id,DTyp [] cidFloat [],reverse eqs) | (id,eqs) <- Map.toList eqs_map, not (null eqs)]
|
||||||
|
in ((ngen',cs:css),(p_f c, mkType c hyps, eqs):ceqs)) (1,[]) cs
|
||||||
|
in (reverse (concat css),concat eqss)
|
||||||
|
where
|
||||||
|
mkEImplArg bt e
|
||||||
|
| bt == Explicit = e
|
||||||
|
| otherwise = EImplArg e
|
||||||
|
|
||||||
|
mkPImplArg bt p
|
||||||
|
| bt == Explicit = p
|
||||||
|
| otherwise = PImplArg p
|
||||||
|
|
||||||
|
mkType c hyps =
|
||||||
|
DTyp (hyps++[mkHypo (DTyp [] c es)]) cidFloat []
|
||||||
|
where
|
||||||
|
is = reverse [0..length hyps-1]
|
||||||
|
es = [mkEImplArg bt (EVar i) | (i,(bt,_,_)) <- zip is hyps]
|
||||||
|
|
||||||
|
sig = (funs (abstract pgf), \_ -> Nothing)
|
||||||
|
|
||||||
|
mkEquation ceqs hyps ngen (fn,ty@(DTyp args _ es)) =
|
||||||
|
let fs1 = case Map.lookup (p_f fn) ceqs of
|
||||||
|
Nothing -> [mkApp (k_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
|
||||||
|
Just eqs | null eqs -> []
|
||||||
|
| otherwise -> [mkApp (p_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)]
|
||||||
|
(ngen',fs2) = mapAccumL mkFactor2 ngen vs2
|
||||||
|
fs3 = map mkFactor3 vs3
|
||||||
|
eq = Equ (map mkTildeP xes++[PApp fn (zipWith mkArgP [1..] args)])
|
||||||
|
(mkMult (fs1++fs2++fs3))
|
||||||
|
in (ngen',eq)
|
||||||
|
where
|
||||||
|
xes = map (normalForm sig k env) es
|
||||||
|
|
||||||
|
mkTildeP e =
|
||||||
|
case e of
|
||||||
|
EImplArg e -> PImplArg (PTilde e)
|
||||||
|
e -> PTilde e
|
||||||
|
|
||||||
|
mkArgP n (bt,_,_) = mkPImplArg bt (PVar (mkCId ('v':show n)))
|
||||||
|
|
||||||
|
mkMult [] = ELit (LFlt 1)
|
||||||
|
mkMult [e] = e
|
||||||
|
mkMult es = mkApp (mkCId "mult") es
|
||||||
|
|
||||||
|
mkFactor2 ngen (src,dst) =
|
||||||
|
let vs = [EVar (k-i-1) | (i,ty) <- src]
|
||||||
|
in (ngen+1,mkApp (p_i ngen) vs)
|
||||||
|
|
||||||
|
mkFactor3 (i,DTyp _ c es) =
|
||||||
|
let v = EVar (k-i-1)
|
||||||
|
in mkApp (p_f c) (map (normalForm sig k env) es++[v])
|
||||||
|
|
||||||
|
(k,env,vs1,vs2,vs3) = mkDeps ty
|
||||||
|
|
||||||
|
mkDeps (DTyp args _ es) =
|
||||||
|
let (k,env,dep1) = updateArgs 0 [] [] args
|
||||||
|
dep2 = foldl (update k env) dep1 es
|
||||||
|
(vs2,vs3) = closure k dep2 [] []
|
||||||
|
vs1 = concat [src | (src,dst) <- dep2, elem k dst]
|
||||||
|
in (k,map (\k -> VGen k []) env,vs1,reverse vs2,vs3)
|
||||||
|
where
|
||||||
|
updateArgs k env dep [] = (k,env,dep)
|
||||||
|
updateArgs k env dep ((_,x,ty@(DTyp _ _ es)) : args) =
|
||||||
|
let dep1 = foldl (update k env) dep es ++ [([(k,ty)],[])]
|
||||||
|
env1 | x == wildCId = env
|
||||||
|
| otherwise = k : env
|
||||||
|
in updateArgs (k+1) env1 dep1 args
|
||||||
|
|
||||||
|
update k env dep e =
|
||||||
|
case e of
|
||||||
|
EApp e1 e2 -> update k env (update k env dep e1) e2
|
||||||
|
EFun _ -> dep
|
||||||
|
EVar i -> let (dep1,(src,dst):dep2) = splitAt (env !! i) dep
|
||||||
|
in dep1++(src,k:dst):dep2
|
||||||
|
|
||||||
|
closure k [] vs2 vs3 = (vs2,vs3)
|
||||||
|
closure k ((src,dst):deps) vs2 vs3
|
||||||
|
| null dst = closure k deps vs2 (vs3++src)
|
||||||
|
| otherwise =
|
||||||
|
let (deps1,deps2) = partition (\(src',dst') -> not (null [v1 | v1 <- dst, v2 <- dst', v1 == v2])) deps
|
||||||
|
deps3 = (src,dst):deps1
|
||||||
|
src2 = concatMap fst deps3
|
||||||
|
dst2 = [v | v <- concatMap snd deps3
|
||||||
|
, lookup v src2 == Nothing]
|
||||||
|
dep2 = (src2,dst2)
|
||||||
|
dst' = nub dst
|
||||||
|
in if null deps1
|
||||||
|
then if dst' == [k]
|
||||||
|
then closure k deps2 vs2 vs3
|
||||||
|
else closure k deps2 ((src,dst') : vs2) vs3
|
||||||
|
else closure k (dep2 : deps2) vs2 vs3
|
||||||
|
|
||||||
|
mkNewSig src =
|
||||||
|
DTyp (mkArgs 0 0 [] src) cidFloat []
|
||||||
|
where
|
||||||
|
mkArgs k l env [] = []
|
||||||
|
mkArgs k l env ((i,DTyp _ c es) : src)
|
||||||
|
| i == k = let ty = DTyp [] c (map (normalForm sig k env) es)
|
||||||
|
in (Explicit,wildCId,ty) : mkArgs (k+1) (l+1) (VGen l [] : env) src
|
||||||
|
| otherwise = mkArgs (k+1) l (VMeta 0 env [] : env) src
|
||||||
|
|
||||||
|
type CState = (Int,Map.Map CId [Equation])
|
||||||
|
|
||||||
|
computeConstrs :: PGF -> CState -> [(CId,[Patt],[Expr])] -> (CState,[[CId]])
|
||||||
|
computeConstrs pgf (ngen,eqs_map) fns@((id,pts,[]):rest)
|
||||||
|
| null rest =
|
||||||
|
let eqs_map' =
|
||||||
|
Map.insertWith (++) (p_f id)
|
||||||
|
(if null pts
|
||||||
|
then []
|
||||||
|
else [Equ pts (ELit (LFlt 1.0))])
|
||||||
|
eqs_map
|
||||||
|
in ((ngen,eqs_map'),[])
|
||||||
|
| otherwise =
|
||||||
|
let (st,ks) = mapAccumL mk_k (ngen,eqs_map) fns
|
||||||
|
|
||||||
|
mk_k (ngen,eqs_map) (id,pts,[])
|
||||||
|
| null pts = ((ngen,eqs_map),k_f id)
|
||||||
|
| otherwise = let eqs_map' =
|
||||||
|
Map.insertWith (++)
|
||||||
|
(p_f id)
|
||||||
|
[Equ pts (EFun (k_i ngen))]
|
||||||
|
eqs_map
|
||||||
|
in ((ngen+1,eqs_map'),k_i ngen)
|
||||||
|
|
||||||
|
in (st,[ks])
|
||||||
|
computeConstrs pgf st fns =
|
||||||
|
let (st',res) = mapAccumL (\st (p,fns) -> computeConstrs pgf st fns)
|
||||||
|
st
|
||||||
|
(computeConstr fns)
|
||||||
|
in (st',concat res)
|
||||||
|
where
|
||||||
|
computeConstr fns = merge (split fns (Map.empty,[]))
|
||||||
|
|
||||||
|
merge (cns,vrs) =
|
||||||
|
[(p,fns++[(id,ps++[p],es) | (id,ps,es) <- vrs])
|
||||||
|
| (p,fns) <- concatMap addArgs (Map.toList cns)]
|
||||||
|
++
|
||||||
|
if null vrs
|
||||||
|
then []
|
||||||
|
else [(PWild,[(id,ps++[PWild],es) | (id,ps,es) <- vrs])]
|
||||||
|
where
|
||||||
|
addArgs (cn,fns) = addArg (length args) cn [] fns
|
||||||
|
where
|
||||||
|
Just (ty@(DTyp args _ es),_,_,_,_) = Map.lookup cn (funs (abstract pgf))
|
||||||
|
|
||||||
|
addArg 0 cn ps fns = [(PApp cn (reverse ps),fns)]
|
||||||
|
addArg n cn ps fns = concat [addArg (n-1) cn (arg:ps) fns' | (arg,fns') <- computeConstr fns]
|
||||||
|
|
||||||
|
split [] (cns,vrs) = (cns,vrs)
|
||||||
|
split ((id, ps, e:es):fns) (cns,vrs) = split fns (extract e [])
|
||||||
|
where
|
||||||
|
extract (EFun cn) args = (Map.insertWith (++) cn [(id,ps,args++es)] cns, vrs)
|
||||||
|
extract (EVar i) args = (cns, (id,ps,es):vrs)
|
||||||
|
extract (EApp e1 e2) args = extract e1 (e2:args)
|
||||||
|
extract (ETyped e ty) args = extract e args
|
||||||
|
extract (EImplArg e) args = extract e args
|
||||||
|
|
||||||
|
p_f c = mkCId ("p_"++showCId c)
|
||||||
|
p_i i = mkCId ("p_"++show i)
|
||||||
|
k_f f = mkCId ("k_"++showCId f)
|
||||||
|
k_i i = mkCId ("k_"++show i)
|
||||||
|
|||||||
Reference in New Issue
Block a user