From 383d829d5a435125a82e39d31e22cded0401e230 Mon Sep 17 00:00:00 2001 From: "kr.angelov" Date: Tue, 30 Jul 2013 07:29:11 +0000 Subject: [PATCH] the first approximation for a statistical model consistent with dependent types in the abstract syntax --- src/runtime/haskell/PGF/Probabilistic.hs | 181 ++++++++++++++++++++++- 1 file changed, 180 insertions(+), 1 deletion(-) diff --git a/src/runtime/haskell/PGF/Probabilistic.hs b/src/runtime/haskell/PGF/Probabilistic.hs index d0206af26..c1fd58fc2 100644 --- a/src/runtime/haskell/PGF/Probabilistic.hs +++ b/src/runtime/haskell/PGF/Probabilistic.hs @@ -9,14 +9,17 @@ module PGF.Probabilistic , probTree , rankTreesByProbs + , mkProbDefs ) where import PGF.CId import PGF.Data import PGF.Macros +import PGF.Type +import PGF.Expr import qualified Data.Map as Map -import Data.List (sortBy,partition) +import Data.List (sortBy,partition,nub,mapAccumL) import Data.Maybe (fromMaybe, fromJust) -- | An abstract data structure which represents @@ -99,3 +102,179 @@ rankTreesByProbs pgf ts = sortBy (\ (_,p) (_,q) -> compare q p) [(t, probTree pgf t) | t <- ts] +mkProbDefs :: PGF -> ([[CId]],[(CId,Type,[Equation])]) +mkProbDefs pgf = + let cs = [(c,hyps,fns) | (c,(hyps0,fs,_)) <- Map.toList (cats (abstract pgf)), + not (elem c [cidString,cidInt,cidFloat]), + let hyps = zipWith (\(bt,_,ty) n -> (bt,mkCId ('v':show n),ty)) + hyps0 + [1..] + fns = [(f,ty) | (_,f) <- fs, + let Just (ty,_,_,_,_) = Map.lookup f (funs (abstract pgf))] + ] + ((_,css),eqss) = mapAccumL (\(ngen,css) (c,hyps,fns) -> + let st0 = (1,Map.empty) + ((_,eqs_map),cs) = computeConstrs pgf st0 [(fn,[],es) | (fn,(DTyp _ _ es)) <- fns] + (ngen', eqs) = mapAccumL (mkEquation eqs_map hyps) ngen fns + ceqs = [(id,DTyp [] cidFloat [],reverse eqs) | (id,eqs) <- Map.toList eqs_map, not (null eqs)] + in ((ngen',cs:css),(p_f c, mkType c hyps, eqs):ceqs)) (1,[]) cs + in (reverse (concat css),concat eqss) + where + mkEImplArg bt e + | bt == Explicit = e + | otherwise = EImplArg e + + mkPImplArg bt p + | bt == Explicit = p + | otherwise = PImplArg p + + mkType c hyps = + DTyp (hyps++[mkHypo (DTyp [] c es)]) cidFloat [] + where + is = reverse [0..length hyps-1] + es = [mkEImplArg bt (EVar i) | (i,(bt,_,_)) <- zip is hyps] + + sig = (funs (abstract pgf), \_ -> Nothing) + + mkEquation ceqs hyps ngen (fn,ty@(DTyp args _ es)) = + let fs1 = case Map.lookup (p_f fn) ceqs of + Nothing -> [mkApp (k_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)] + Just eqs | null eqs -> [] + | otherwise -> [mkApp (p_f fn) (map (\(i,_) -> EVar (k-i-1)) vs1)] + (ngen',fs2) = mapAccumL mkFactor2 ngen vs2 + fs3 = map mkFactor3 vs3 + eq = Equ (map mkTildeP xes++[PApp fn (zipWith mkArgP [1..] args)]) + (mkMult (fs1++fs2++fs3)) + in (ngen',eq) + where + xes = map (normalForm sig k env) es + + mkTildeP e = + case e of + EImplArg e -> PImplArg (PTilde e) + e -> PTilde e + + mkArgP n (bt,_,_) = mkPImplArg bt (PVar (mkCId ('v':show n))) + + mkMult [] = ELit (LFlt 1) + mkMult [e] = e + mkMult es = mkApp (mkCId "mult") es + + mkFactor2 ngen (src,dst) = + let vs = [EVar (k-i-1) | (i,ty) <- src] + in (ngen+1,mkApp (p_i ngen) vs) + + mkFactor3 (i,DTyp _ c es) = + let v = EVar (k-i-1) + in mkApp (p_f c) (map (normalForm sig k env) es++[v]) + + (k,env,vs1,vs2,vs3) = mkDeps ty + + mkDeps (DTyp args _ es) = + let (k,env,dep1) = updateArgs 0 [] [] args + dep2 = foldl (update k env) dep1 es + (vs2,vs3) = closure k dep2 [] [] + vs1 = concat [src | (src,dst) <- dep2, elem k dst] + in (k,map (\k -> VGen k []) env,vs1,reverse vs2,vs3) + where + updateArgs k env dep [] = (k,env,dep) + updateArgs k env dep ((_,x,ty@(DTyp _ _ es)) : args) = + let dep1 = foldl (update k env) dep es ++ [([(k,ty)],[])] + env1 | x == wildCId = env + | otherwise = k : env + in updateArgs (k+1) env1 dep1 args + + update k env dep e = + case e of + EApp e1 e2 -> update k env (update k env dep e1) e2 + EFun _ -> dep + EVar i -> let (dep1,(src,dst):dep2) = splitAt (env !! i) dep + in dep1++(src,k:dst):dep2 + + closure k [] vs2 vs3 = (vs2,vs3) + closure k ((src,dst):deps) vs2 vs3 + | null dst = closure k deps vs2 (vs3++src) + | otherwise = + let (deps1,deps2) = partition (\(src',dst') -> not (null [v1 | v1 <- dst, v2 <- dst', v1 == v2])) deps + deps3 = (src,dst):deps1 + src2 = concatMap fst deps3 + dst2 = [v | v <- concatMap snd deps3 + , lookup v src2 == Nothing] + dep2 = (src2,dst2) + dst' = nub dst + in if null deps1 + then if dst' == [k] + then closure k deps2 vs2 vs3 + else closure k deps2 ((src,dst') : vs2) vs3 + else closure k (dep2 : deps2) vs2 vs3 + + mkNewSig src = + DTyp (mkArgs 0 0 [] src) cidFloat [] + where + mkArgs k l env [] = [] + mkArgs k l env ((i,DTyp _ c es) : src) + | i == k = let ty = DTyp [] c (map (normalForm sig k env) es) + in (Explicit,wildCId,ty) : mkArgs (k+1) (l+1) (VGen l [] : env) src + | otherwise = mkArgs (k+1) l (VMeta 0 env [] : env) src + +type CState = (Int,Map.Map CId [Equation]) + +computeConstrs :: PGF -> CState -> [(CId,[Patt],[Expr])] -> (CState,[[CId]]) +computeConstrs pgf (ngen,eqs_map) fns@((id,pts,[]):rest) + | null rest = + let eqs_map' = + Map.insertWith (++) (p_f id) + (if null pts + then [] + else [Equ pts (ELit (LFlt 1.0))]) + eqs_map + in ((ngen,eqs_map'),[]) + | otherwise = + let (st,ks) = mapAccumL mk_k (ngen,eqs_map) fns + + mk_k (ngen,eqs_map) (id,pts,[]) + | null pts = ((ngen,eqs_map),k_f id) + | otherwise = let eqs_map' = + Map.insertWith (++) + (p_f id) + [Equ pts (EFun (k_i ngen))] + eqs_map + in ((ngen+1,eqs_map'),k_i ngen) + + in (st,[ks]) +computeConstrs pgf st fns = + let (st',res) = mapAccumL (\st (p,fns) -> computeConstrs pgf st fns) + st + (computeConstr fns) + in (st',concat res) + where + computeConstr fns = merge (split fns (Map.empty,[])) + + merge (cns,vrs) = + [(p,fns++[(id,ps++[p],es) | (id,ps,es) <- vrs]) + | (p,fns) <- concatMap addArgs (Map.toList cns)] + ++ + if null vrs + then [] + else [(PWild,[(id,ps++[PWild],es) | (id,ps,es) <- vrs])] + where + addArgs (cn,fns) = addArg (length args) cn [] fns + where + Just (ty@(DTyp args _ es),_,_,_,_) = Map.lookup cn (funs (abstract pgf)) + + addArg 0 cn ps fns = [(PApp cn (reverse ps),fns)] + addArg n cn ps fns = concat [addArg (n-1) cn (arg:ps) fns' | (arg,fns') <- computeConstr fns] + + split [] (cns,vrs) = (cns,vrs) + split ((id, ps, e:es):fns) (cns,vrs) = split fns (extract e []) + where + extract (EFun cn) args = (Map.insertWith (++) cn [(id,ps,args++es)] cns, vrs) + extract (EVar i) args = (cns, (id,ps,es):vrs) + extract (EApp e1 e2) args = extract e1 (e2:args) + extract (ETyped e ty) args = extract e args + extract (EImplArg e) args = extract e args + +p_f c = mkCId ("p_"++showCId c) +p_i i = mkCId ("p_"++show i) +k_f f = mkCId ("k_"++showCId f) +k_i i = mkCId ("k_"++show i)