{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE PatternGuards #-}
module Data.CRF.Chain2.Tiers
(
CRF (..)
, size
, prune
, train
, reTrain
, tag
, marginals
, module Data.CRF.Chain2.Tiers.Dataset.External
, module Data.CRF.Chain2.Tiers.Feature
) where
import System.IO (hSetBuffering, stdout, BufferMode (..))
import Control.Applicative ((<$>), (<*>))
import Data.Maybe (maybeToList)
import qualified Data.Map as M
import Data.Binary (Binary, get, put)
import qualified Data.Number.LogFloat as LogFloat
import qualified Numeric.SGD as SGD
import qualified Numeric.SGD.LogSigned as L
import Data.CRF.Chain2.Tiers.Dataset.Internal
import Data.CRF.Chain2.Tiers.Dataset.Codec
import Data.CRF.Chain2.Tiers.Dataset.External
import Data.CRF.Chain2.Tiers.Feature
import Data.CRF.Chain2.Tiers.Model
import qualified Data.CRF.Chain2.Tiers.Inference as I
data CRF a b = CRF
{ numOfLayers :: Int
, codec :: Codec a b
, model :: Model }
instance (Ord a, Ord b, Binary a, Binary b) => Binary (CRF a b) where
put CRF{..} = put numOfLayers >> put codec >> put model
get = CRF <$> get <*> get <*> get
size :: CRF a b -> Int
size CRF{..} = M.size (toMap model)
prune :: Double -> CRF a b -> CRF a b
prune x crf = crf { model = newModel } where
newModel = fromMap . M.fromList $
[ (feat, val)
| (feat, val) <- M.toList $ toMap (model crf)
, abs (LogFloat.logFromLogFloat val) > x ]
train
:: (Ord a, Ord b)
=> Int
-> FeatSel
-> SGD.SgdArgs
-> Bool
-> IO [SentL a b]
-> IO [SentL a b]
-> IO (CRF a b)
train numOfLayers featSel sgdArgs onDisk trainIO evalIO = do
hSetBuffering stdout NoBuffering
codec <- mkCodec numOfLayers <$> trainIO
trainData_ <- encodeDataL codec <$> trainIO
SGD.withData onDisk trainData_ $ \trainData -> do
evalData_ <- encodeDataL codec <$> evalIO
SGD.withData onDisk evalData_ $ \evalData -> do
model <- mkModel featSel <$> SGD.loadData trainData
para <- SGD.sgd sgdArgs
(notify sgdArgs model trainData evalData)
(gradOn model) trainData (values model)
return $ CRF numOfLayers codec model { values = para }
reTrain
:: (Ord a, Ord b)
=> CRF a b
-> SGD.SgdArgs
-> Bool
-> IO [SentL a b]
-> IO [SentL a b]
-> IO (CRF a b)
reTrain crf sgdArgs onDisk trainIO evalIO = do
hSetBuffering stdout NoBuffering
trainData_ <- encodeDataL (codec crf) <$> trainIO
SGD.withData onDisk trainData_ $ \trainData -> do
evalData_ <- encodeDataL (codec crf) <$> evalIO
SGD.withData onDisk evalData_ $ \evalData -> do
let model' = model crf
para <- SGD.sgd sgdArgs
(notify sgdArgs model' trainData evalData)
(gradOn model') trainData (values model')
return $ crf { model = model' { values = para } }
gradOn :: Model -> SGD.Para -> (Xs, Ys) -> SGD.Grad
gradOn model para (xs, ys) = SGD.fromLogList $
[ (unFeatIx ix, L.fromPos val)
| (ft, val) <- presentFeats xs ys
, ix <- maybeToList (index curr ft) ] ++
[ (unFeatIx ix, L.fromNeg val)
| (ft, val) <- I.expectedFeatures curr xs
, ix <- maybeToList (index curr ft) ]
where
curr = model { values = para }
notify
:: SGD.SgdArgs -> Model
-> SGD.Dataset (Xs, Ys)
-> SGD.Dataset (Xs, Ys)
-> SGD.Para -> Int -> IO ()
notify SGD.SgdArgs{..} model trainData evalData para k
| doneTotal k == doneTotal (k - 1) = putStr "."
| SGD.size evalData > 0 = do
x <- I.accuracy (model { values = para }) <$> SGD.loadData evalData
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = " ++ show x)
| otherwise =
putStrLn ("\n" ++ "[" ++ show (doneTotal k) ++ "] f = #")
where
doneTotal :: Int -> Int
doneTotal = floor . done
done :: Int -> Double
done i
= fromIntegral (i * batchSize)
/ fromIntegral trainSize
trainSize = SGD.size trainData
tag :: (Ord a, Ord b) => CRF a b -> Sent a b -> [[b]]
tag CRF{..} sent
= onWords . decodeLabels codec
. I.tag model . encodeSent codec
$ sent
where
onWords xs =
[ unJust codec word x
| (word, x) <- zip sent xs ]
marginals :: (Ord a, Ord b) => CRF a b -> Sent a b -> [[Double]]
marginals CRF{..}
= map (map LogFloat.fromLogFloat)
. I.marginals model
. encodeSent codec