forked from GitHub/gf-core
improved script for training from PennTreebank
This commit is contained in:
@@ -4,32 +4,72 @@ import Data.Maybe
|
||||
import Data.List
|
||||
|
||||
main = do
|
||||
pgf <- readPGF "PennTreebank.pgf"
|
||||
pgf <- readPGF "ParseEngAbs.pgf"
|
||||
ls <- fmap lines $ readFile "log.txt"
|
||||
let stats = foldl' collectStats Map.empty [e | l <- ls, Just e <- [readExpr (map toQ l)]]
|
||||
mapM_ putStrLn [show f ++ "\t" ++ show p | (f,p) <- Map.toList (probs pgf stats), f /= mkCId "Q"]
|
||||
let stats = foldl' (collectStats pgf)
|
||||
(initStats pgf)
|
||||
[(fromMaybe (error l) (readExpr (toQ l)),Just (mkCId "Phr"),Nothing) | l <- ls]
|
||||
mapM_ putStrLn [show f ++ "\t" ++ show p | (f,p) <- uprobs pgf stats]
|
||||
mapM_ putStrLn [show cat1 ++ "\t" ++ show cat2 ++ "\t" ++ show p | (cat1,cat2,p) <- bprobs pgf stats]
|
||||
where
|
||||
toQ '?' = 'Q'
|
||||
toQ c = c
|
||||
toQ [] = []
|
||||
toQ ('[':cs) = let (xs,']':ys) = break (==']') cs
|
||||
in toQ ('?' : ys)
|
||||
toQ ('?':cs) = 'Q' : toQ cs
|
||||
toQ (c:cs) = c : toQ cs
|
||||
|
||||
collectStats stats e =
|
||||
|
||||
initStats pgf =
|
||||
(Map.fromListWith (+)
|
||||
([(f,1) | f <- functions pgf] ++
|
||||
[(cat pgf f,1) | f <- functions pgf])
|
||||
,Map.empty
|
||||
)
|
||||
|
||||
collectStats pgf (ustats,bstats) (e,mb_cat1,mb_cat2) =
|
||||
case unApp e of
|
||||
Just (f,args) -> let c = fromMaybe 0 (Map.lookup f stats)
|
||||
in c `seq` foldl' collectStats (Map.insert f (c+1) stats) args
|
||||
Nothing -> stats
|
||||
|
||||
probs pgf stats =
|
||||
Map.mapWithKey toProb stats
|
||||
Just (f,args) -> let fcat = fromMaybe (cat2 pgf f e) mb_cat1
|
||||
cf = fromMaybe 0 (Map.lookup f ustats)
|
||||
cc = fromMaybe 0 (Map.lookup fcat ustats)
|
||||
in cf `seq` cc `seq` bstats `seq`
|
||||
foldl' (collectStats pgf)
|
||||
(Map.insert f (cf+1) (Map.insert fcat (cc+1) ustats)
|
||||
,(if null args
|
||||
then Map.insertWith (+) (fcat,wildCId) 1
|
||||
else id)
|
||||
(maybe bstats (\cat2 -> Map.insertWith (+) (cat2,fcat) 1 bstats) mb_cat2)
|
||||
)
|
||||
(zip3 args (argCats f) (repeat (Just fcat)))
|
||||
Nothing -> (ustats,bstats)
|
||||
where
|
||||
toProb f c
|
||||
| f == mkCId "Q" = 1.0
|
||||
| otherwise = let (_,cat,_) = case functionType pgf f of
|
||||
Just ty -> unType ty
|
||||
Nothing -> error ("unknown: "++show f)
|
||||
cat_mass = fromMaybe 0 (Map.lookup cat mass)
|
||||
in (fromIntegral c / fromIntegral cat_mass :: Double)
|
||||
argCats f =
|
||||
case fmap unType (functionType pgf f) of
|
||||
Just (arg_tys,_,_) -> let tyCat (_,_,ty) = let (_,cat,_) = unType ty in Just cat
|
||||
in map tyCat arg_tys
|
||||
Nothing -> repeat Nothing
|
||||
|
||||
mass = Map.fromListWith (+)
|
||||
[(cat,c) | f <- functions pgf,
|
||||
let Just (_,cat,_) = fmap unType (functionType pgf f),
|
||||
let c = fromMaybe 0 (Map.lookup f stats)]
|
||||
uprobs pgf (ustats,bstats) =
|
||||
[toProb f (cat pgf f) | f <- functions pgf]
|
||||
where
|
||||
toProb f cat =
|
||||
let count = fromMaybe 0 (Map.lookup f ustats)
|
||||
cat_mass = fromMaybe 0 (Map.lookup cat ustats)
|
||||
in (f, fromIntegral count / fromIntegral cat_mass :: Double)
|
||||
|
||||
bprobs pgf (ustats,bstats) =
|
||||
concat [toProb cat | cat <- categories pgf]
|
||||
where
|
||||
toProb cat =
|
||||
let mass = sum [count | ((cat1,cat2),count) <- Map.toList bstats, cat1==cat]
|
||||
in [(cat1,cat2,fromIntegral count / fromIntegral mass)
|
||||
| ((cat1,cat2),count) <- Map.toList bstats, cat1==cat]
|
||||
|
||||
cat pgf f =
|
||||
case fmap unType (functionType pgf f) of
|
||||
Just (_,cat,_) -> cat
|
||||
Nothing -> error ("Unknown function "++showCId f)
|
||||
|
||||
cat2 pgf f e =
|
||||
case fmap unType (functionType pgf f) of
|
||||
Just (_,cat,_) -> cat
|
||||
Nothing -> error ("Unknown function "++showCId f++show e)
|
||||
|
||||
Reference in New Issue
Block a user