From bb6623f6e70c347a862368efa138d34c2dcb7818 Mon Sep 17 00:00:00 2001 From: krasimir Date: Wed, 15 Oct 2008 14:58:00 +0000 Subject: [PATCH] high-order syntax in PMCFG --- src/GF/Compile/GeneratePMCFG.hs | 172 +++++++++++++++------------- src/PGF/Parsing/FCFG/Incremental.hs | 26 ++++- 2 files changed, 113 insertions(+), 85 deletions(-) diff --git a/src/GF/Compile/GeneratePMCFG.hs b/src/GF/Compile/GeneratePMCFG.hs index ac81279b5..619e5088b 100644 --- a/src/GF/Compile/GeneratePMCFG.hs +++ b/src/GF/Compile/GeneratePMCFG.hs @@ -37,76 +37,22 @@ import Debug.Trace -- main conversion function convertConcrete :: Abstr -> Concr -> ParserInfo -convertConcrete abs cnc = fixHoasFuns $ convert abs_defs' conc' cats' +convertConcrete abs cnc = convert abs_defs conc cats where abs_defs = Map.assocs (funs abs) conc = Map.union (opers cnc) (lins cnc) -- "union big+small most efficient" cats = lincats cnc - (abs_defs',conc',cats') = expandHOAS abs_defs conc cats - -expandHOAS :: [(CId,(Type,Expr))] -> TermMap -> TermMap -> ([(CId,(Type,Expr))],TermMap,TermMap) -expandHOAS funs lins lincats = (funs' ++ hoFuns ++ varFuns, - Map.unions [lins, hoLins, varLins], - Map.unions [lincats, hoLincats, varLincat]) - where - -- replace higher-order fun argument types with new categories - funs' = [(f,(fixType ty,e)) | (f,(ty,e)) <- funs] - where - fixType :: Type -> Type - fixType ty = let (ats,rt) = typeSkeleton ty in cftype (map catName ats) rt - - hoTypes :: [(Int,CId)] - hoTypes = sortNub [(n,c) | (_,(ty,_)) <- funs, (n,c) <- fst (typeSkeleton ty), n > 0] - hoCats = sortNub (map snd hoTypes) - -- for each Cat with N bindings, we add a new category _NCat - -- each new category contains a single function __NCat : Cat -> _Var -> ... -> _Var -> _NCat - hoFuns = [(funName ty,(cftype (c : replicate n varCat) (catName ty),EEq [])) | ty@(n,c) <- hoTypes] - -- lincats for the new categories - hoLincats = Map.fromList [(catName ty, modifyRec (++ replicate n (S [])) (lincatOf c)) | ty@(n,c) <- hoTypes] - -- linearizations of the new functions, lin __NCat v_0 ... v_n-1 x = { s1 = x.s1; ...; sk = x.sk; $0 = v_0.s ... - hoLins = Map.fromList [ (funName ty, mkLin c n) | ty@(n,c) <- hoTypes] - where mkLin c n = modifyRec (\fs -> [P (V 0) (C j) | j <- [0..length fs-1]] ++ [P (V i) (C 0) | i <- [1..n]]) (lincatOf c) - -- for each Cat, we a add a fun _Var_Cat : _Var -> Cat - varFuns = [(varFunName cat, (cftype [varCat] cat,EEq [])) | cat <- hoCats] - -- linearizations of the _Var_Cat functions - varLins = Map.fromList [(varFunName cat, R [P (V 0) (C 0)]) | cat <- hoCats] - -- lincat for the _Var category - varLincat = Map.singleton varCat (R [S []]) - - lincatOf c = fromMaybe (error $ "No lincat for " ++ prCId c) $ Map.lookup c lincats - - modifyRec :: ([Term] -> [Term]) -> Term -> Term - modifyRec f (R xs) = R (f xs) - modifyRec _ t = error $ "Not a record: " ++ show t - - varCat = mkCId "_Var" - - catName :: (Int,CId) -> CId - catName (0,c) = c - catName (n,c) = mkCId ("_" ++ show n ++ prCId c) - - funName :: (Int,CId) -> CId - funName (n,c) = mkCId ("__" ++ show n ++ prCId c) - - varFunName :: CId -> CId - varFunName c = mkCId ("_Var_" ++ prCId c) - --- replaces __NCat with _B and _Var_Cat with _. --- the temporary names are just there to avoid name collisions. -fixHoasFuns :: ParserInfo -> ParserInfo -fixHoasFuns pinfo = pinfo{functions=mkArray [FFun (fixName n) prof lins | FFun n prof lins <- elems (functions pinfo)]} - where fixName (CId n) | BS.pack "__" `BS.isPrefixOf` n = (mkCId "_B") - | BS.pack "_Var_" `BS.isPrefixOf` n = wildCId - fixName n = n convert :: [(CId,(Type,Expr))] -> TermMap -> TermMap -> ParserInfo -convert abs_defs cnc_defs cat_defs = getParserInfo (List.foldl' (convertRule cnc_defs) (emptyGrammarEnv cnc_defs cat_defs) xrules) +convert abs_defs cnc_defs cat_defs = + let env = expandHOAS abs_defs cnc_defs cat_defs (emptyGrammarEnv cnc_defs cat_defs) + in getParserInfo (List.foldl' (convertRule cnc_defs) env xrules) where xrules = [ - (XRule id args res (map findLinType args) (findLinType res) term) | - (id, (ty,_)) <- abs_defs, let (args,res) = catSkeleton ty, + (XRule id args (0,res) (map findLinType args) (findLinType (0,res)) term) | + (id, (ty,_)) <- abs_defs, let (args,res) = typeSkeleton ty, term <- Map.lookup id cnc_defs] - findLinType id = fromMaybe (error $ "No lincat for " ++ show id) (Map.lookup id cat_defs) + findLinType (_,id) = fromMaybe (error $ "No lincat for " ++ show id) (Map.lookup id cat_defs) brk :: (GrammarEnv -> GrammarEnv) -> (GrammarEnv -> GrammarEnv) brk f (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) = @@ -149,21 +95,27 @@ convertRule cnc_defs grammarEnv (XRule fun args res ctypes ctype term) = type CnvMonad a = BacktrackM Env a type FPath = [FIndex] -data ProtoFCat = PFCat CId [FPath] [(FPath,[FIndex])] +data ProtoFCat = PFCat Int CId [FPath] [(FPath,[FIndex])] type Env = (ProtoFCat, [ProtoFCat]) type LinRec = [(FPath, [FSymbol])] -data XRule = XRule CId {- function -} - [CId] {- argument types -} - CId {- result type -} - [Term] {- argument lin-types representation -} - Term {- result lin-type representation -} - Term {- body -} +data XRule = XRule CId {- function -} + [(Int,CId)] {- argument types: context size and category -} + (Int,CId) {- result type : context size (always 0) and category -} + [Term] {- argument lin-types representation -} + Term {- result lin-type representation -} + Term {- body -} -protoFCat :: TermMap -> CId -> Term -> ProtoFCat -protoFCat cnc_defs cat ctype = - let (rcs,tcs) = loop [] [] [] ctype - in PFCat cat rcs tcs +protoFCat :: TermMap -> (Int,CId) -> Term -> ProtoFCat +protoFCat cnc_defs (n,cat) ctype = + let (rcs,tcs) = loop [] [] [] ctype' + in PFCat n cat rcs tcs where + ctype' -- extend the high-order linearization type + | n > 0 = case ctype of + R xs -> R (xs ++ replicate n (S [])) + _ -> error $ "Not a record: " ++ show ctype + | otherwise = ctype + loop path rcs tcs (R record) = List.foldl' (\(rcs,tcs) (index,term) -> loop (index:path) rcs tcs term) (rcs,tcs) (zip [0..] record) loop path rcs tcs (C i) = ( rcs,(path,[0..i]):tcs) loop path rcs tcs (S _) = (path:rcs, tcs) @@ -209,7 +161,7 @@ convertArg (C max) nr path lbl_path lin lins = do return lins convertArg (S _) nr path lbl_path lin lins = do (_, args) <- readState - let PFCat cat rcs tcs = args !! nr + let PFCat _ cat rcs tcs = args !! nr return ((lbl_path, FSymCat nr (index path rcs 0) : lin) : lins) where index lbl' (lbl:lbls) idx @@ -236,7 +188,7 @@ convertRec cnc_defs (index:sub_sel) ctype record lbl_path lin lins = do evalTerm :: TermMap -> FPath -> Term -> CnvMonad FIndex evalTerm cnc_defs path (V nr) = do (_, args) <- readState - let PFCat _ _ tcs = args !! nr + let PFCat _ _ _ tcs = args !! nr rpath = reverse path index <- member (fromMaybe (error "evalTerm: wrong path") (lookup rpath tcs)) restrictArg nr rpath index @@ -256,14 +208,14 @@ evalTerm cnc_defs path x = error ("evalTerm ("++show x++")") -- GrammarEnv data GrammarEnv = GrammarEnv {-# UNPACK #-} !Int CatSet SeqSet FunSet CoerceSet (IntMap.IntMap (Set.Set Production)) -type CatSet = Map.Map CId (FCat,FCat,[Int]) +type CatSet = IntMap.IntMap (Map.Map CId (FCat,FCat,[Int])) type SeqSet = Map.Map FSeq SeqId type FunSet = Map.Map FFun FunId type CoerceSet= Map.Map [FCat] FCat emptyGrammarEnv cnc_defs lincats = let (last_id,catSet) = Map.mapAccumWithKey computeCatRange 0 lincats - in GrammarEnv last_id catSet Map.empty Map.empty Map.empty IntMap.empty + in GrammarEnv last_id (IntMap.singleton 0 catSet) Map.empty Map.empty Map.empty IntMap.empty where cidString = mkCId "String" cidInt = mkCId "Int" @@ -286,6 +238,64 @@ emptyGrammarEnv cnc_defs lincats = Just term -> getMultipliers m ms term Nothing -> error ("unknown identifier: "++prCId id) + +expandHOAS abs_defs cnc_defs lincats env = + foldl add_varFun (foldl (\env ncat -> add_hoFun (add_hoCat env ncat) ncat) env hoTypes) hoCats + where + hoTypes :: [(Int,CId)] + hoTypes = sortNub [(n,c) | (_,(ty,_)) <- abs_defs + , (n,c) <- fst (typeSkeleton ty), n > 0] + + hoCats :: [CId] + hoCats = sortNub [c | (_,(ty,_)) <- abs_defs + , Hyp _ ty <- case ty of {DTyp hyps val _ -> hyps} + , c <- fst (catSkeleton ty)] + + -- add a range of PMCFG categories for each GF high-order category + add_hoCat env@(GrammarEnv last_id catSet seqSet funSet crcSet prodSet) (n,cat) = + case IntMap.lookup 0 catSet >>= Map.lookup cat of + Just (start,end,ms) -> let !catSet' = IntMap.insertWith Map.union n (Map.singleton cat (last_id,last_id+(end-start),ms)) catSet + !last_id' = last_id+(end-start)+1 + in (GrammarEnv last_id' catSet' seqSet funSet crcSet prodSet) + Nothing -> env + + -- add one PMCFG function for each high-order type: _B : Cat -> Var -> ... -> Var -> HoCat + add_hoFun env (n,cat) = + let linRec = reverse $ + [(l ,[FSymCat 0 i]) | (l,i) <- case arg of {PFCat _ _ rcs _ -> zip rcs [0..]}] ++ + [([],[FSymCat i 0]) | i <- [1..n]] + (env1,lins) = List.mapAccumL addFSeq env linRec + newLinRec = mkArray lins + + (env2,funid) = addFFun env1 (FFun _B [[i] | i <- [0..n]] newLinRec) + + env3 = foldl (\env (arg,res) -> addProduction env res (FApply funid (arg : replicate n fcatVar))) + env2 + (zip (getFCats env2 arg) (getFCats env2 res)) + in env3 + where + (arg,res) = case Map.lookup cat lincats of + Nothing -> error $ "No lincat for " ++ prCId cat + Just ctype -> (protoFCat cnc_defs (0,cat) ctype, protoFCat cnc_defs (n,cat) ctype) + + -- add one PMCFG function for each high-order category: _V : Var -> Cat + add_varFun env cat = + let (env1,seqid) = addFSeq env ([],[FSymCat 0 0]) + lins = replicate (case res of {PFCat _ _ rcs _ -> length rcs}) seqid + (env2,funid) = addFFun env1 (FFun _V [[0]] (mkArray lins)) + env3 = foldl (\env res -> addProduction env2 res (FApply funid [fcatVar])) + env2 + (getFCats env2 res) + in env3 + where + res = case Map.lookup cat lincats of + Nothing -> error $ "No lincat for " ++ prCId cat + Just ctype -> protoFCat cnc_defs (0,cat) ctype + + _B = mkCId "_B" + _V = mkCId "_V" + + addProduction :: GrammarEnv -> FCat -> Production -> GrammarEnv addProduction (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) cat p = GrammarEnv last_id catSet seqSet funSet crcSet (IntMap.insertWith Set.union cat (Set.singleton p) prodSet) @@ -320,7 +330,7 @@ getParserInfo (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) = ParserInfo { functions = mkArray funSet , sequences = mkArray seqSet , productions = IntMap.union prodSet coercions - , startCats = Map.map (\(start,end,_) -> range (start,end)) catSet + , startCats = maybe Map.empty (Map.map (\(start,end,_) -> range (start,end))) (IntMap.lookup 0 catSet) , totalCats = last_id+1 } where @@ -329,8 +339,8 @@ getParserInfo (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) = coercions = IntMap.fromList [(fcat,Set.fromList (map FCoerce sub_fcats)) | (sub_fcats,fcat) <- Map.toList crcSet] getFCats :: GrammarEnv -> ProtoFCat -> [FCat] -getFCats (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) (PFCat cat rcs tcs) = - case Map.lookup cat catSet of +getFCats (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) (PFCat n cat rcs tcs) = + case IntMap.lookup n catSet >>= Map.lookup cat of Just (start,end,ms) -> reverse (solutions (variants ms tcs start) ()) where variants _ [] fcat = return fcat @@ -353,9 +363,9 @@ restrictHead path term writeState (head', args) restrictProtoFCat :: FPath -> FIndex -> ProtoFCat -> CnvMonad ProtoFCat -restrictProtoFCat path0 index0 (PFCat cat rcs tcs) = do +restrictProtoFCat path0 index0 (PFCat n cat rcs tcs) = do tcs <- addConstraint tcs - return (PFCat cat rcs tcs) + return (PFCat n cat rcs tcs) where addConstraint [] = error "restrictProtoFCat: unknown path" addConstraint (c@(path,indices) : tcs) diff --git a/src/PGF/Parsing/FCFG/Incremental.hs b/src/PGF/Parsing/FCFG/Incremental.hs index 2ab04acf2..e5f64365f 100644 --- a/src/PGF/Parsing/FCFG/Incremental.hs +++ b/src/PGF/Parsing/FCFG/Incremental.hs @@ -109,20 +109,38 @@ extractExps (State pinfo chart items) start = exps let FFun fn _ lins = functions pinfo ! funid lbl <- indices lins Just fid <- [lookupPC (PK cat lbl 0) (passive st)] - go Set.empty 0 (0,fid) + (fvs,tree) <- go Set.empty 0 (0,fid) + guard (Set.null fvs) + return tree go rec fcat' (d,fcat) - | fcat < totalCats pinfo = [Meta (fcat'*10+d)] -- FIXME: here we assume that every rule has at most 10 arguments + | fcat < totalCats pinfo = return (Set.empty,Meta (fcat'*10+d)) -- FIXME: here we assume that every rule has at most 10 arguments | Set.member fcat rec = mzero | otherwise = foldForest (\funid args trees -> do let FFun fn _ lins = functions pinfo ! funid args <- mapM (go (Set.insert fcat rec) fcat) (zip [0..] args) - return (Fun fn args) + check_ho_fun fn args + `mplus` + trees) + (\const _ trees -> + return (freeVar const,const) `mplus` trees) - (\const _ trees -> const : trees) [] fcat (forest st) + check_ho_fun fun args + | fun == _V = return (head args) + | fun == _B = return (foldl1 Set.difference (map fst args),Abs [mkVar (snd e) | e <- tail args] (snd (head args))) + | otherwise = return (Set.unions (map fst args),Fun fun (map snd args)) + + mkVar (Var v) = v + mkVar (Meta _) = wildCId + + freeVar (Var v) = Set.singleton v + freeVar _ = Set.empty + +_B = mkCId "_B" +_V = mkCId "_V" process fn !seqs !funs [] acc chart = (acc,chart) process fn !seqs !funs (item@(Active j ppos funid seqid args key0):items) acc chart