high-order syntax in PMCFG

This commit is contained in:
krasimir
2008-10-15 14:58:00 +00:00
parent 063b82cf6c
commit bb6623f6e7
2 changed files with 113 additions and 85 deletions

View File

@@ -37,76 +37,22 @@ import Debug.Trace
-- main conversion function
convertConcrete :: Abstr -> Concr -> ParserInfo
convertConcrete abs cnc = fixHoasFuns $ convert abs_defs' conc' cats'
convertConcrete abs cnc = convert abs_defs conc cats
where abs_defs = Map.assocs (funs abs)
conc = Map.union (opers cnc) (lins cnc) -- "union big+small most efficient"
cats = lincats cnc
(abs_defs',conc',cats') = expandHOAS abs_defs conc cats
expandHOAS :: [(CId,(Type,Expr))] -> TermMap -> TermMap -> ([(CId,(Type,Expr))],TermMap,TermMap)
expandHOAS funs lins lincats = (funs' ++ hoFuns ++ varFuns,
Map.unions [lins, hoLins, varLins],
Map.unions [lincats, hoLincats, varLincat])
where
-- replace higher-order fun argument types with new categories
funs' = [(f,(fixType ty,e)) | (f,(ty,e)) <- funs]
where
fixType :: Type -> Type
fixType ty = let (ats,rt) = typeSkeleton ty in cftype (map catName ats) rt
hoTypes :: [(Int,CId)]
hoTypes = sortNub [(n,c) | (_,(ty,_)) <- funs, (n,c) <- fst (typeSkeleton ty), n > 0]
hoCats = sortNub (map snd hoTypes)
-- for each Cat with N bindings, we add a new category _NCat
-- each new category contains a single function __NCat : Cat -> _Var -> ... -> _Var -> _NCat
hoFuns = [(funName ty,(cftype (c : replicate n varCat) (catName ty),EEq [])) | ty@(n,c) <- hoTypes]
-- lincats for the new categories
hoLincats = Map.fromList [(catName ty, modifyRec (++ replicate n (S [])) (lincatOf c)) | ty@(n,c) <- hoTypes]
-- linearizations of the new functions, lin __NCat v_0 ... v_n-1 x = { s1 = x.s1; ...; sk = x.sk; $0 = v_0.s ...
hoLins = Map.fromList [ (funName ty, mkLin c n) | ty@(n,c) <- hoTypes]
where mkLin c n = modifyRec (\fs -> [P (V 0) (C j) | j <- [0..length fs-1]] ++ [P (V i) (C 0) | i <- [1..n]]) (lincatOf c)
-- for each Cat, we a add a fun _Var_Cat : _Var -> Cat
varFuns = [(varFunName cat, (cftype [varCat] cat,EEq [])) | cat <- hoCats]
-- linearizations of the _Var_Cat functions
varLins = Map.fromList [(varFunName cat, R [P (V 0) (C 0)]) | cat <- hoCats]
-- lincat for the _Var category
varLincat = Map.singleton varCat (R [S []])
lincatOf c = fromMaybe (error $ "No lincat for " ++ prCId c) $ Map.lookup c lincats
modifyRec :: ([Term] -> [Term]) -> Term -> Term
modifyRec f (R xs) = R (f xs)
modifyRec _ t = error $ "Not a record: " ++ show t
varCat = mkCId "_Var"
catName :: (Int,CId) -> CId
catName (0,c) = c
catName (n,c) = mkCId ("_" ++ show n ++ prCId c)
funName :: (Int,CId) -> CId
funName (n,c) = mkCId ("__" ++ show n ++ prCId c)
varFunName :: CId -> CId
varFunName c = mkCId ("_Var_" ++ prCId c)
-- replaces __NCat with _B and _Var_Cat with _.
-- the temporary names are just there to avoid name collisions.
fixHoasFuns :: ParserInfo -> ParserInfo
fixHoasFuns pinfo = pinfo{functions=mkArray [FFun (fixName n) prof lins | FFun n prof lins <- elems (functions pinfo)]}
where fixName (CId n) | BS.pack "__" `BS.isPrefixOf` n = (mkCId "_B")
| BS.pack "_Var_" `BS.isPrefixOf` n = wildCId
fixName n = n
convert :: [(CId,(Type,Expr))] -> TermMap -> TermMap -> ParserInfo
convert abs_defs cnc_defs cat_defs = getParserInfo (List.foldl' (convertRule cnc_defs) (emptyGrammarEnv cnc_defs cat_defs) xrules)
convert abs_defs cnc_defs cat_defs =
let env = expandHOAS abs_defs cnc_defs cat_defs (emptyGrammarEnv cnc_defs cat_defs)
in getParserInfo (List.foldl' (convertRule cnc_defs) env xrules)
where
xrules = [
(XRule id args res (map findLinType args) (findLinType res) term) |
(id, (ty,_)) <- abs_defs, let (args,res) = catSkeleton ty,
(XRule id args (0,res) (map findLinType args) (findLinType (0,res)) term) |
(id, (ty,_)) <- abs_defs, let (args,res) = typeSkeleton ty,
term <- Map.lookup id cnc_defs]
findLinType id = fromMaybe (error $ "No lincat for " ++ show id) (Map.lookup id cat_defs)
findLinType (_,id) = fromMaybe (error $ "No lincat for " ++ show id) (Map.lookup id cat_defs)
brk :: (GrammarEnv -> GrammarEnv) -> (GrammarEnv -> GrammarEnv)
brk f (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) =
@@ -149,21 +95,27 @@ convertRule cnc_defs grammarEnv (XRule fun args res ctypes ctype term) =
type CnvMonad a = BacktrackM Env a
type FPath = [FIndex]
data ProtoFCat = PFCat CId [FPath] [(FPath,[FIndex])]
data ProtoFCat = PFCat Int CId [FPath] [(FPath,[FIndex])]
type Env = (ProtoFCat, [ProtoFCat])
type LinRec = [(FPath, [FSymbol])]
data XRule = XRule CId {- function -}
[CId] {- argument types -}
CId {- result type -}
[Term] {- argument lin-types representation -}
Term {- result lin-type representation -}
Term {- body -}
data XRule = XRule CId {- function -}
[(Int,CId)] {- argument types: context size and category -}
(Int,CId) {- result type : context size (always 0) and category -}
[Term] {- argument lin-types representation -}
Term {- result lin-type representation -}
Term {- body -}
protoFCat :: TermMap -> CId -> Term -> ProtoFCat
protoFCat cnc_defs cat ctype =
let (rcs,tcs) = loop [] [] [] ctype
in PFCat cat rcs tcs
protoFCat :: TermMap -> (Int,CId) -> Term -> ProtoFCat
protoFCat cnc_defs (n,cat) ctype =
let (rcs,tcs) = loop [] [] [] ctype'
in PFCat n cat rcs tcs
where
ctype' -- extend the high-order linearization type
| n > 0 = case ctype of
R xs -> R (xs ++ replicate n (S []))
_ -> error $ "Not a record: " ++ show ctype
| otherwise = ctype
loop path rcs tcs (R record) = List.foldl' (\(rcs,tcs) (index,term) -> loop (index:path) rcs tcs term) (rcs,tcs) (zip [0..] record)
loop path rcs tcs (C i) = ( rcs,(path,[0..i]):tcs)
loop path rcs tcs (S _) = (path:rcs, tcs)
@@ -209,7 +161,7 @@ convertArg (C max) nr path lbl_path lin lins = do
return lins
convertArg (S _) nr path lbl_path lin lins = do
(_, args) <- readState
let PFCat cat rcs tcs = args !! nr
let PFCat _ cat rcs tcs = args !! nr
return ((lbl_path, FSymCat nr (index path rcs 0) : lin) : lins)
where
index lbl' (lbl:lbls) idx
@@ -236,7 +188,7 @@ convertRec cnc_defs (index:sub_sel) ctype record lbl_path lin lins = do
evalTerm :: TermMap -> FPath -> Term -> CnvMonad FIndex
evalTerm cnc_defs path (V nr) = do (_, args) <- readState
let PFCat _ _ tcs = args !! nr
let PFCat _ _ _ tcs = args !! nr
rpath = reverse path
index <- member (fromMaybe (error "evalTerm: wrong path") (lookup rpath tcs))
restrictArg nr rpath index
@@ -256,14 +208,14 @@ evalTerm cnc_defs path x = error ("evalTerm ("++show x++")")
-- GrammarEnv
data GrammarEnv = GrammarEnv {-# UNPACK #-} !Int CatSet SeqSet FunSet CoerceSet (IntMap.IntMap (Set.Set Production))
type CatSet = Map.Map CId (FCat,FCat,[Int])
type CatSet = IntMap.IntMap (Map.Map CId (FCat,FCat,[Int]))
type SeqSet = Map.Map FSeq SeqId
type FunSet = Map.Map FFun FunId
type CoerceSet= Map.Map [FCat] FCat
emptyGrammarEnv cnc_defs lincats =
let (last_id,catSet) = Map.mapAccumWithKey computeCatRange 0 lincats
in GrammarEnv last_id catSet Map.empty Map.empty Map.empty IntMap.empty
in GrammarEnv last_id (IntMap.singleton 0 catSet) Map.empty Map.empty Map.empty IntMap.empty
where
cidString = mkCId "String"
cidInt = mkCId "Int"
@@ -286,6 +238,64 @@ emptyGrammarEnv cnc_defs lincats =
Just term -> getMultipliers m ms term
Nothing -> error ("unknown identifier: "++prCId id)
expandHOAS abs_defs cnc_defs lincats env =
foldl add_varFun (foldl (\env ncat -> add_hoFun (add_hoCat env ncat) ncat) env hoTypes) hoCats
where
hoTypes :: [(Int,CId)]
hoTypes = sortNub [(n,c) | (_,(ty,_)) <- abs_defs
, (n,c) <- fst (typeSkeleton ty), n > 0]
hoCats :: [CId]
hoCats = sortNub [c | (_,(ty,_)) <- abs_defs
, Hyp _ ty <- case ty of {DTyp hyps val _ -> hyps}
, c <- fst (catSkeleton ty)]
-- add a range of PMCFG categories for each GF high-order category
add_hoCat env@(GrammarEnv last_id catSet seqSet funSet crcSet prodSet) (n,cat) =
case IntMap.lookup 0 catSet >>= Map.lookup cat of
Just (start,end,ms) -> let !catSet' = IntMap.insertWith Map.union n (Map.singleton cat (last_id,last_id+(end-start),ms)) catSet
!last_id' = last_id+(end-start)+1
in (GrammarEnv last_id' catSet' seqSet funSet crcSet prodSet)
Nothing -> env
-- add one PMCFG function for each high-order type: _B : Cat -> Var -> ... -> Var -> HoCat
add_hoFun env (n,cat) =
let linRec = reverse $
[(l ,[FSymCat 0 i]) | (l,i) <- case arg of {PFCat _ _ rcs _ -> zip rcs [0..]}] ++
[([],[FSymCat i 0]) | i <- [1..n]]
(env1,lins) = List.mapAccumL addFSeq env linRec
newLinRec = mkArray lins
(env2,funid) = addFFun env1 (FFun _B [[i] | i <- [0..n]] newLinRec)
env3 = foldl (\env (arg,res) -> addProduction env res (FApply funid (arg : replicate n fcatVar)))
env2
(zip (getFCats env2 arg) (getFCats env2 res))
in env3
where
(arg,res) = case Map.lookup cat lincats of
Nothing -> error $ "No lincat for " ++ prCId cat
Just ctype -> (protoFCat cnc_defs (0,cat) ctype, protoFCat cnc_defs (n,cat) ctype)
-- add one PMCFG function for each high-order category: _V : Var -> Cat
add_varFun env cat =
let (env1,seqid) = addFSeq env ([],[FSymCat 0 0])
lins = replicate (case res of {PFCat _ _ rcs _ -> length rcs}) seqid
(env2,funid) = addFFun env1 (FFun _V [[0]] (mkArray lins))
env3 = foldl (\env res -> addProduction env2 res (FApply funid [fcatVar]))
env2
(getFCats env2 res)
in env3
where
res = case Map.lookup cat lincats of
Nothing -> error $ "No lincat for " ++ prCId cat
Just ctype -> protoFCat cnc_defs (0,cat) ctype
_B = mkCId "_B"
_V = mkCId "_V"
addProduction :: GrammarEnv -> FCat -> Production -> GrammarEnv
addProduction (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) cat p =
GrammarEnv last_id catSet seqSet funSet crcSet (IntMap.insertWith Set.union cat (Set.singleton p) prodSet)
@@ -320,7 +330,7 @@ getParserInfo (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) =
ParserInfo { functions = mkArray funSet
, sequences = mkArray seqSet
, productions = IntMap.union prodSet coercions
, startCats = Map.map (\(start,end,_) -> range (start,end)) catSet
, startCats = maybe Map.empty (Map.map (\(start,end,_) -> range (start,end))) (IntMap.lookup 0 catSet)
, totalCats = last_id+1
}
where
@@ -329,8 +339,8 @@ getParserInfo (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) =
coercions = IntMap.fromList [(fcat,Set.fromList (map FCoerce sub_fcats)) | (sub_fcats,fcat) <- Map.toList crcSet]
getFCats :: GrammarEnv -> ProtoFCat -> [FCat]
getFCats (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) (PFCat cat rcs tcs) =
case Map.lookup cat catSet of
getFCats (GrammarEnv last_id catSet seqSet funSet crcSet prodSet) (PFCat n cat rcs tcs) =
case IntMap.lookup n catSet >>= Map.lookup cat of
Just (start,end,ms) -> reverse (solutions (variants ms tcs start) ())
where
variants _ [] fcat = return fcat
@@ -353,9 +363,9 @@ restrictHead path term
writeState (head', args)
restrictProtoFCat :: FPath -> FIndex -> ProtoFCat -> CnvMonad ProtoFCat
restrictProtoFCat path0 index0 (PFCat cat rcs tcs) = do
restrictProtoFCat path0 index0 (PFCat n cat rcs tcs) = do
tcs <- addConstraint tcs
return (PFCat cat rcs tcs)
return (PFCat n cat rcs tcs)
where
addConstraint [] = error "restrictProtoFCat: unknown path"
addConstraint (c@(path,indices) : tcs)

View File

@@ -109,20 +109,38 @@ extractExps (State pinfo chart items) start = exps
let FFun fn _ lins = functions pinfo ! funid
lbl <- indices lins
Just fid <- [lookupPC (PK cat lbl 0) (passive st)]
go Set.empty 0 (0,fid)
(fvs,tree) <- go Set.empty 0 (0,fid)
guard (Set.null fvs)
return tree
go rec fcat' (d,fcat)
| fcat < totalCats pinfo = [Meta (fcat'*10+d)] -- FIXME: here we assume that every rule has at most 10 arguments
| fcat < totalCats pinfo = return (Set.empty,Meta (fcat'*10+d)) -- FIXME: here we assume that every rule has at most 10 arguments
| Set.member fcat rec = mzero
| otherwise = foldForest (\funid args trees ->
do let FFun fn _ lins = functions pinfo ! funid
args <- mapM (go (Set.insert fcat rec) fcat) (zip [0..] args)
return (Fun fn args)
check_ho_fun fn args
`mplus`
trees)
(\const _ trees ->
return (freeVar const,const)
`mplus`
trees)
(\const _ trees -> const : trees)
[] fcat (forest st)
check_ho_fun fun args
| fun == _V = return (head args)
| fun == _B = return (foldl1 Set.difference (map fst args),Abs [mkVar (snd e) | e <- tail args] (snd (head args)))
| otherwise = return (Set.unions (map fst args),Fun fun (map snd args))
mkVar (Var v) = v
mkVar (Meta _) = wildCId
freeVar (Var v) = Set.singleton v
freeVar _ = Set.empty
_B = mkCId "_B"
_V = mkCId "_V"
process fn !seqs !funs [] acc chart = (acc,chart)
process fn !seqs !funs (item@(Active j ppos funid seqid args key0):items) acc chart