{-# LANGUAGE RecordWildCards #-}
module Data.CRF.Chain2.Tiers.Inference
( tag
, marginals
, expectedFeatures
, accuracy
, zx
, zx'
) where
import Data.Ord (comparing)
import Data.List (maximumBy)
import qualified Data.Array as A
import qualified Data.Vector as V
import qualified Data.Number.LogFloat as L
import Control.Parallel.Strategies (rseq, parMap)
import Control.Parallel (par, pseq)
import GHC.Conc (numCapabilities)
import Data.CRF.Chain2.Tiers.Dataset.Internal
import Data.CRF.Chain2.Tiers.Model
import Data.CRF.Chain2.Tiers.Feature
import Data.CRF.Chain2.Tiers.Util (partition)
import qualified Data.CRF.Chain2.Tiers.DP as DP
type AccF = [L.LogFloat] -> L.LogFloat
type ProbArray = CbIx -> CbIx -> CbIx -> L.LogFloat
onWord :: Model -> Xs -> Int -> CbIx -> L.LogFloat
onWord crf xs i u =
product . map (phi crf) $ obFeatsOn xs i u
{-# INLINE onWord #-}
onTransition :: Model -> Xs -> Int -> CbIx -> CbIx -> CbIx -> L.LogFloat
onTransition crf xs i u w v =
product . map (phi crf) $ trFeatsOn xs i u w v
{-# INLINE onTransition #-}
computePsi :: Model -> Xs -> Int -> CbIx -> L.LogFloat
computePsi crf xs i = (A.!) $ A.array (0, lbNum xs i - 1)
[ (k, onWord crf xs i k)
| k <- lbIxs xs i ]
forward :: AccF -> Model -> Xs -> ProbArray
forward acc crf sent = alpha where
alpha = DP.flexible3 (-1, V.length sent - 1)
(\i -> (0, lbNum sent i - 1))
(\i _ -> (0, lbNum sent (i - 1) - 1))
(\t i -> withMem (computePsi crf sent i) t i)
withMem psi _alpha i j k
| i == -1 = 1.0
| otherwise = acc
[ _alpha (i - 1) k h * psi j
* onTransition crf sent i j k h
| h <- lbIxs sent (i - 2) ]
backward :: AccF -> Model -> Xs -> ProbArray
backward acc crf sent = beta where
beta = DP.flexible3 (0, V.length sent)
(\i -> (0, lbNum sent (i - 1) - 1))
(\i _ -> (0, lbNum sent (i - 2) - 1))
(\t i -> withMem (computePsi crf sent i) t i)
withMem psi _beta i j k
| i == V.length sent = 1.0
| otherwise = acc
[ _beta (i + 1) h j * psi h
* onTransition crf sent i h j k
| h <- lbIxs sent i ]
zxBeta :: ProbArray -> L.LogFloat
zxBeta beta = beta 0 0 0
zxAlpha :: AccF -> Xs -> ProbArray -> L.LogFloat
zxAlpha acc sent alpha = acc
[ alpha (n - 1) i j
| i <- lbIxs sent (n - 1)
, j <- lbIxs sent (n - 2) ]
where n = V.length sent
zx :: Model -> Xs -> L.LogFloat
zx crf = zxBeta . backward sum crf
zx' :: Model -> Xs -> L.LogFloat
zx' crf sent = zxAlpha sum sent (forward sum crf sent)
argmax :: (Ord b) => (a -> b) -> [a] -> (a, b)
argmax f l = foldl1 choice $ map (\x -> (x, f x)) l
where choice (x1, v1) (x2, v2)
| v1 > v2 = (x1, v1)
| otherwise = (x2, v2)
tagIxs :: Model -> Xs -> [Int]
tagIxs crf sent = collectMaxArg (0, 0, 0) [] mem where
mem = DP.flexible3 (0, V.length sent)
(\i -> (0, lbNum sent (i - 1) - 1))
(\i _ -> (0, lbNum sent (i - 2) - 1))
(\t i -> withMem (computePsi crf sent i) t i)
withMem psiMem _mem i j k
| i == V.length sent = (-1, 1)
| otherwise = argmax eval $ lbIxs sent i
where eval h =
(snd $ _mem (i + 1) h j) * psiMem h
* onTransition crf sent i h j k
collectMaxArg (i, j, k) acc _mem =
collect $ _mem i j k
where collect (h, _)
| h == -1 = reverse acc
| otherwise = collectMaxArg (i + 1, h, j) (h:acc) _mem
tag :: Model -> Xs -> [Cb]
tag crf sent =
let ixs = tagIxs crf sent
in [lbAt x i | (x, i) <- zip (V.toList sent) ixs]
marginals :: Model -> Xs -> [[L.LogFloat]]
marginals crf sent =
let alpha = forward sum crf sent
beta = backward sum crf sent
in [ [ prob1 crf alpha beta sent i k
| k <- lbIxs sent i ]
| i <- [0 .. V.length sent - 1] ]
goodAndBad :: Model -> Xs -> Ys -> (Int, Int)
goodAndBad crf xs ys =
foldl gather (0, 0) $ zip labels labels'
where
labels = [ (best . unY) (ys V.! i)
| i <- [0 .. V.length ys - 1] ]
best zs
| null zs = Nothing
| otherwise = Just . fst $ maximumBy (comparing snd) zs
labels' = map Just $ tag crf xs
gather (good, bad) (x, y)
| x == y = (good + 1, bad)
| otherwise = (good, bad + 1)
goodAndBad' :: Model -> [(Xs, Ys)] -> (Int, Int)
goodAndBad' crf dataset =
let add (g, b) (g', b') = (g + g', b + b')
in foldl add (0, 0) [goodAndBad crf x y | (x, y) <- dataset]
accuracy :: Model -> [(Xs, Ys)] -> Double
accuracy crf dataset =
let k = numCapabilities
parts = partition k dataset
xs = parMap rseq (goodAndBad' crf) parts
(good, bad) = foldl add (0, 0) xs
add (g, b) (g', b') = (g + g', b + b')
in fromIntegral good / fromIntegral (good + bad)
prob3
:: Model -> ProbArray -> ProbArray -> Xs
-> Int -> (CbIx -> L.LogFloat) -> CbIx -> CbIx -> CbIx
-> L.LogFloat
prob3 crf alpha beta sent k psiMem x y z =
alpha (k - 1) y z * beta (k + 1) x y * psiMem x
* onTransition crf sent k x y z / zxBeta beta
{-# INLINE prob3 #-}
prob2
:: Model -> ProbArray -> ProbArray
-> Xs -> Int -> CbIx -> CbIx -> L.LogFloat
prob2 _ alpha beta _ k x y =
alpha k x y * beta (k + 1) x y / zxBeta beta
{-# INLINE prob2 #-}
prob1
:: Model -> ProbArray -> ProbArray
-> Xs -> Int -> CbIx -> L.LogFloat
prob1 crf alpha beta sent k x = sum
[ prob2 crf alpha beta sent k x y
| y <- lbIxs sent (k - 1) ]
expectedFeaturesOn
:: Model -> ProbArray -> ProbArray
-> Xs -> Int -> [(Feat, L.LogFloat)]
expectedFeaturesOn crf alpha beta sent k =
fs3 ++ fs1
where psi = computePsi crf sent k
pr1 = prob1 crf alpha beta sent k
pr3 = prob3 crf alpha beta sent k psi
fs1 = [ (ft, pr)
| a <- lbIxs sent k
, let pr = pr1 a
, ft <- obFs a ]
fs3 = [ (ft, pr)
| a <- lbIxs sent k
, b <- lbIxs sent $ k - 1
, c <- lbIxs sent $ k - 2
, let pr = pr3 a b c
, ft <- trFs a b c ]
obFs = obFeatsOn sent k
trFs = trFeatsOn sent k
expectedFeatures :: Model -> Xs -> [(Feat, L.LogFloat)]
expectedFeatures crf sent =
zx1 `par` zx2 `pseq` zx1 `pseq` concat
[ expectedFeaturesOn crf alpha beta sent k
| k <- [0 .. V.length sent - 1] ]
where alpha = forward sum crf sent
beta = backward sum crf sent
zx1 = zxAlpha sum sent alpha
zx2 = zxBeta beta