From 8f53e010d1a06aef34111f0f574100eff0233768 Mon Sep 17 00:00:00 2001 From: "kr.angelov" Date: Tue, 12 Jun 2012 10:05:17 +0000 Subject: [PATCH] improved script for training from PennTreebank --- examples/PennTreebank/training.hs | 86 ++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 23 deletions(-) diff --git a/examples/PennTreebank/training.hs b/examples/PennTreebank/training.hs index 080b11a4e..5f50d6a78 100644 --- a/examples/PennTreebank/training.hs +++ b/examples/PennTreebank/training.hs @@ -4,32 +4,72 @@ import Data.Maybe import Data.List main = do - pgf <- readPGF "PennTreebank.pgf" + pgf <- readPGF "ParseEngAbs.pgf" ls <- fmap lines $ readFile "log.txt" - let stats = foldl' collectStats Map.empty [e | l <- ls, Just e <- [readExpr (map toQ l)]] - mapM_ putStrLn [show f ++ "\t" ++ show p | (f,p) <- Map.toList (probs pgf stats), f /= mkCId "Q"] + let stats = foldl' (collectStats pgf) + (initStats pgf) + [(fromMaybe (error l) (readExpr (toQ l)),Just (mkCId "Phr"),Nothing) | l <- ls] + mapM_ putStrLn [show f ++ "\t" ++ show p | (f,p) <- uprobs pgf stats] + mapM_ putStrLn [show cat1 ++ "\t" ++ show cat2 ++ "\t" ++ show p | (cat1,cat2,p) <- bprobs pgf stats] where - toQ '?' = 'Q' - toQ c = c + toQ [] = [] + toQ ('[':cs) = let (xs,']':ys) = break (==']') cs + in toQ ('?' : ys) + toQ ('?':cs) = 'Q' : toQ cs + toQ (c:cs) = c : toQ cs -collectStats stats e = + +initStats pgf = + (Map.fromListWith (+) + ([(f,1) | f <- functions pgf] ++ + [(cat pgf f,1) | f <- functions pgf]) + ,Map.empty + ) + +collectStats pgf (ustats,bstats) (e,mb_cat1,mb_cat2) = case unApp e of - Just (f,args) -> let c = fromMaybe 0 (Map.lookup f stats) - in c `seq` foldl' collectStats (Map.insert f (c+1) stats) args - Nothing -> stats - -probs pgf stats = - Map.mapWithKey toProb stats + Just (f,args) -> let fcat = fromMaybe (cat2 pgf f e) mb_cat1 + cf = fromMaybe 0 (Map.lookup f ustats) + cc = fromMaybe 0 (Map.lookup fcat ustats) + in cf `seq` cc `seq` bstats `seq` + foldl' (collectStats pgf) + (Map.insert f (cf+1) (Map.insert fcat (cc+1) ustats) + ,(if null args + then Map.insertWith (+) (fcat,wildCId) 1 + else id) + (maybe bstats (\cat2 -> Map.insertWith (+) (cat2,fcat) 1 bstats) mb_cat2) + ) + (zip3 args (argCats f) (repeat (Just fcat))) + Nothing -> (ustats,bstats) where - toProb f c - | f == mkCId "Q" = 1.0 - | otherwise = let (_,cat,_) = case functionType pgf f of - Just ty -> unType ty - Nothing -> error ("unknown: "++show f) - cat_mass = fromMaybe 0 (Map.lookup cat mass) - in (fromIntegral c / fromIntegral cat_mass :: Double) + argCats f = + case fmap unType (functionType pgf f) of + Just (arg_tys,_,_) -> let tyCat (_,_,ty) = let (_,cat,_) = unType ty in Just cat + in map tyCat arg_tys + Nothing -> repeat Nothing - mass = Map.fromListWith (+) - [(cat,c) | f <- functions pgf, - let Just (_,cat,_) = fmap unType (functionType pgf f), - let c = fromMaybe 0 (Map.lookup f stats)] +uprobs pgf (ustats,bstats) = + [toProb f (cat pgf f) | f <- functions pgf] + where + toProb f cat = + let count = fromMaybe 0 (Map.lookup f ustats) + cat_mass = fromMaybe 0 (Map.lookup cat ustats) + in (f, fromIntegral count / fromIntegral cat_mass :: Double) + +bprobs pgf (ustats,bstats) = + concat [toProb cat | cat <- categories pgf] + where + toProb cat = + let mass = sum [count | ((cat1,cat2),count) <- Map.toList bstats, cat1==cat] + in [(cat1,cat2,fromIntegral count / fromIntegral mass) + | ((cat1,cat2),count) <- Map.toList bstats, cat1==cat] + +cat pgf f = + case fmap unType (functionType pgf f) of + Just (_,cat,_) -> cat + Nothing -> error ("Unknown function "++showCId f) + +cat2 pgf f e = + case fmap unType (functionType pgf f) of + Just (_,cat,_) -> cat + Nothing -> error ("Unknown function "++showCId f++show e)