module Learning.HMM.Internal (
HMM' (..)
, LogLikelihood
, init'
, withEmission'
, viterbi'
, baumWelch'
) where
import Control.Applicative ((<$>))
import Control.DeepSeq (NFData, force, rnf)
import Control.Monad (forM_, replicateM)
import Control.Monad.ST (runST)
import qualified Data.Map.Strict as M (findWithDefault)
import Data.Random.RVar (RVar)
import Data.Random.Distribution.Simplex (stdSimplex)
import qualified Data.Vector as V (
Vector, filter, foldl1', map, unsafeFreeze, unsafeIndex, unsafeTail
, zip, zipWith3
)
import qualified Data.Vector.Generic as G (convert)
import qualified Data.Vector.Generic.Util as G (frequencies)
import qualified Data.Vector.Mutable as MV (
unsafeNew, unsafeRead, unsafeWrite
)
import qualified Data.Vector.Unboxed as U (
Vector, fromList, length, map, sum, unsafeFreeze, unsafeIndex
, unsafeTail, zip
)
import qualified Data.Vector.Unboxed.Mutable as MU (
unsafeNew, unsafeRead, unsafeWrite
)
import qualified Numeric.LinearAlgebra.Data as H (
(!), Matrix, Vector, diag, fromColumns, fromList, fromLists
, fromRows, konst, maxElement, maxIndex, toColumns, tr
)
import qualified Numeric.LinearAlgebra.HMatrix as H (
(<>), (#>), sumElements
)
type LogLikelihood = Double
data HMM' = HMM' { nStates' :: Int
, nOutputs' :: Int
, initialStateDist' :: H.Vector Double
, transitionDist' :: H.Matrix Double
, emissionDistT' :: H.Matrix Double
}
instance NFData HMM' where
rnf hmm' = rnf k `seq` rnf l `seq` rnf pi0 `seq` rnf w `seq` rnf phi'
where
k = nStates' hmm'
l = nOutputs' hmm'
pi0 = initialStateDist' hmm'
w = transitionDist' hmm'
phi' = emissionDistT' hmm'
init' :: Int -> Int -> RVar HMM'
init' k l = do
pi0 <- H.fromList <$> stdSimplex (k1)
w <- H.fromLists <$> replicateM k (stdSimplex (k1))
phi <- H.fromLists <$> replicateM k (stdSimplex (l1))
return HMM' { nStates' = k
, nOutputs' = l
, initialStateDist' = pi0
, transitionDist' = w
, emissionDistT' = H.tr phi
}
withEmission' :: HMM' -> U.Vector Int -> HMM'
withEmission' model xs = model'
where
n = U.length xs
k = nStates' model
l = nOutputs' model
ss = [0..(k1)]
os = [0..(l1)]
step m = fst $ baumWelch1' (m { emissionDistT' = H.tr phi }) n xs
where
phi :: H.Matrix Double
phi = let zs = fst $ viterbi' m xs
fs = G.frequencies $ U.zip zs xs
hs = H.fromLists $ map (\s -> map (\o ->
M.findWithDefault 0 (s, o) fs) os) ss
hs' = hs + H.konst 1e-9 (k, l)
ns = hs' H.#> H.konst 1 k
in hs' / H.fromColumns (replicate l ns)
ms = iterate step model
ms' = tail ms
ds = zipWith euclideanDistance ms ms'
model' = fst $ head $ dropWhile ((> 1e-9) . snd) $ zip ms' ds
euclideanDistance :: HMM' -> HMM' -> Double
euclideanDistance model model' =
sqrt $ (H.sumElements $ (w w') ** 2) + (H.sumElements $ (phi phi') ** 2)
where
w = transitionDist' model
w' = transitionDist' model'
phi = emissionDistT' model
phi' = emissionDistT' model'
viterbi' :: HMM' -> U.Vector Int -> (U.Vector Int, LogLikelihood)
viterbi' model xs = (path, logL)
where
n = U.length xs
deltas :: V.Vector (H.Vector Double)
psis :: V.Vector (U.Vector Int)
(deltas, psis) = runST $ do
ds <- MV.unsafeNew n
ps <- MV.unsafeNew n
let x0 = U.unsafeIndex xs 0
MV.unsafeWrite ds 0 $ log (phi' H.! x0) + log pi0
forM_ [1..(n1)] $ \t -> do
d <- MV.unsafeRead ds (t1)
let x = U.unsafeIndex xs t
dws = map (\wj -> d + log wj) w'
MV.unsafeWrite ds t $ log (phi' H.! x) + H.fromList (map H.maxElement dws)
MV.unsafeWrite ps t $ U.fromList (map H.maxIndex dws)
ds' <- V.unsafeFreeze ds
ps' <- V.unsafeFreeze ps
return (ds', ps')
where
pi0 = initialStateDist' model
w' = H.toColumns $ transitionDist' model
phi' = emissionDistT' model
deltaE = V.unsafeIndex deltas (n1)
path = runST $ do
ix <- MU.unsafeNew n
MU.unsafeWrite ix (n1) $ H.maxIndex deltaE
forM_ [nl | l <- [1..(n1)]] $ \t -> do
i <- MU.unsafeRead ix t
let psi = V.unsafeIndex psis t
MU.unsafeWrite ix (t1) $ U.unsafeIndex psi i
U.unsafeFreeze ix
logL = H.maxElement deltaE
baumWelch' :: HMM' -> U.Vector Int -> [(HMM', LogLikelihood)]
baumWelch' model xs = zip models (tail logLs)
where
n = U.length xs
step (m, _) = baumWelch1' m n xs
(models, logLs) = unzip $ iterate step (model, undefined)
baumWelch1' :: HMM' -> Int -> U.Vector Int -> (HMM', LogLikelihood)
baumWelch1' model n xs = force (model', logL)
where
k = nStates' model
l = nOutputs' model
(alphas, cs) = forward' model n xs
betas = backward' model n xs cs
(gammas, xis) = posterior' model n xs alphas betas cs
pi0 = V.unsafeIndex gammas 0
w = let ds = V.foldl1' (+) xis
ns = ds H.#> H.konst 1 k
in H.diag (H.konst 1 k / ns) H.<> ds
phi' = let gs' o = V.map snd $ V.filter ((== o) . fst) $ V.zip (G.convert xs) gammas
ds = V.foldl1' (+) . gs'
ns = V.foldl1' (+) gammas
in H.fromRows $ map (\o -> ds o / ns) [0..(l1)]
model' = model { initialStateDist' = pi0
, transitionDist' = w
, emissionDistT' = phi'
}
logL = (U.sum $ U.map log cs)
forward' :: HMM' -> Int -> U.Vector Int -> (V.Vector (H.Vector Double), U.Vector Double)
forward' model n xs = runST $ do
as <- MV.unsafeNew n
cs <- MU.unsafeNew n
let x0 = U.unsafeIndex xs 0
a0 = (phi' H.! x0) * pi0
c0 = 1 / H.sumElements a0
MV.unsafeWrite as 0 (H.konst c0 k * a0)
MU.unsafeWrite cs 0 c0
forM_ [1..(n1)] $ \t -> do
a <- MV.unsafeRead as (t1)
let x = U.unsafeIndex xs t
a' = (phi' H.! x) * (w' H.#> a)
c' = 1 / H.sumElements a'
MV.unsafeWrite as t (H.konst c' k * a')
MU.unsafeWrite cs t c'
as' <- V.unsafeFreeze as
cs' <- U.unsafeFreeze cs
return (as', cs')
where
k = nStates' model
pi0 = initialStateDist' model
w' = H.tr $ transitionDist' model
phi' = emissionDistT' model
backward' :: HMM' -> Int -> U.Vector Int -> U.Vector Double -> V.Vector (H.Vector Double)
backward' model n xs cs = runST $ do
bs <- MV.unsafeNew n
let bE = H.konst 1 k
cE = U.unsafeIndex cs (n1)
MV.unsafeWrite bs (n1) (H.konst cE k * bE)
forM_ [nl | l <- [1..(n1)]] $ \t -> do
b <- MV.unsafeRead bs t
let x = U.unsafeIndex xs t
b' = w H.#> ((phi' H.! x) * b)
c' = U.unsafeIndex cs (t1)
MV.unsafeWrite bs (t1) (H.konst c' k * b')
V.unsafeFreeze bs
where
k = nStates' model
w = transitionDist' model
phi' = emissionDistT' model
posterior' :: HMM' -> Int -> U.Vector Int -> V.Vector (H.Vector Double) -> V.Vector (H.Vector Double) -> U.Vector Double -> (V.Vector (H.Vector Double), V.Vector (H.Matrix Double))
posterior' model _ xs alphas betas cs = (gammas, xis)
where
gammas = V.zipWith3 (\a b c -> a * b / H.konst c k)
alphas betas (G.convert cs)
xis = V.zipWith3 (\a b x -> H.diag a H.<> w H.<> H.diag (b * (phi' H.! x)))
alphas (V.unsafeTail betas) (G.convert $ U.unsafeTail xs)
k = nStates' model
w = transitionDist' model
phi' = emissionDistT' model