module Learning.IOHMM (
IOHMM (..)
, LogLikelihood
, init
, withEmission
, viterbi
, baumWelch
, simulate
) where
import Prelude hiding (init)
import Control.Applicative ((<$>))
import Control.Arrow (first)
import Data.List (elemIndex)
import Data.Maybe (fromJust)
import Data.Random.Distribution (pdf, rvar)
import Data.Random.Distribution.Categorical (Categorical)
import qualified Data.Random.Distribution.Categorical as C (
fromList, normalizeCategoricalPs
)
import Data.Random.Distribution.Categorical.Util ()
import Data.Random.RVar (RVar)
import Data.Random.Sample (sample)
import qualified Data.Vector as V (
elemIndex, fromList, map, toList, unsafeIndex
)
import qualified Data.Vector.Generic as G (convert)
import qualified Data.Vector.Unboxed as U (fromList, zip)
import qualified Numeric.LinearAlgebra.Data as H (
(!), fromList, fromLists, toList
)
import qualified Numeric.LinearAlgebra.HMatrix as H (tr)
import Learning.IOHMM.Internal (LogLikelihood)
import qualified Learning.IOHMM.Internal as I
data IOHMM i s o = IOHMM { inputs :: [i]
, states :: [s]
, outputs :: [o]
, initialStateDist :: Categorical Double s
, transitionDist :: i -> s -> Categorical Double s
, emissionDist :: s -> Categorical Double o
}
instance (Show i, Show s, Show o) => Show (IOHMM i s o) where
show = showIOHMM
showIOHMM :: (Show i, Show s, Show o) => IOHMM i s o -> String
showIOHMM hmm = "IOHMM {inputs = " ++ show is
++ ", states = " ++ show ss
++ ", outputs = " ++ show os
++ ", initialStateDist = " ++ show pi0
++ ", transitionDist = " ++ show [(w i s, s) | i <- is, s <- ss]
++ ", emissionDist = " ++ show [(phi s, s) | s <- ss]
++ "}"
where
is = inputs hmm
ss = states hmm
os = outputs hmm
pi0 = initialStateDist hmm
w = transitionDist hmm
phi = emissionDist hmm
init :: (Eq i, Eq s, Eq o) => [i] -> [s] -> [o] -> RVar (IOHMM i s o)
init is ss os = fromInternal is ss os <$> I.init (length is) (length ss) (length os)
withEmission :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> IOHMM i s o
withEmission model xs ys = fromInternal is ss os $ I.withEmission model' $ U.zip xs' ys'
where
is = inputs model
is' = V.fromList is
ss = states model
os = outputs model
os' = V.fromList os
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` is') xs
ys' = U.fromList $ fromJust $ mapM (`V.elemIndex` os') ys
viterbi :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> ([s], LogLikelihood)
viterbi model xs ys =
checkModelIn "viterbi" model `seq`
checkDataIn "viterbi" model xs ys `seq`
first toStates $ I.viterbi model' $ U.zip xs' ys'
where
is' = V.fromList $ inputs model
ss' = V.fromList $ states model
os' = V.fromList $ outputs model
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` is') xs
ys' = U.fromList $ fromJust $ mapM (`V.elemIndex` os') ys
toStates = V.toList . V.map (V.unsafeIndex ss') . G.convert
baumWelch :: (Eq i, Eq s, Eq o) => IOHMM i s o -> [i] -> [o] -> [(IOHMM i s o, LogLikelihood)]
baumWelch model xs ys =
checkModelIn "baumWelch" model `seq`
checkDataIn "baumWelch" model xs ys `seq`
map (first $ fromInternal is ss os) $ I.baumWelch model' $ U.zip xs' ys'
where
is = inputs model
is' = V.fromList is
ss = states model
os = outputs model
os' = V.fromList os
model' = toInternal model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` is') xs
ys' = U.fromList $ fromJust $ mapM (`V.elemIndex` os') ys
simulate :: IOHMM i s o -> [i] -> RVar ([s], [o])
simulate model xs
| null xs = return ([], [])
| otherwise = do s0 <- sample $ rvar pi0
y0 <- sample $ rvar $ phi s0
unzip . ((s0, y0) :) <$> sim s0 (tail xs)
where
sim _ [] = return []
sim s (x:xs') = do s' <- sample $ rvar $ w x s
y' <- sample $ rvar $ phi s'
((s', y') :) <$> sim s' xs'
pi0 = initialStateDist model
w = transitionDist model
phi = emissionDist model
checkModelIn :: String -> IOHMM i s o -> ()
checkModelIn fun hmm
| null is = err "empty inputs"
| null ss = err "empty states"
| null os = err "empty outputs"
| otherwise = ()
where
is = inputs hmm
ss = states hmm
os = outputs hmm
err = errorIn fun
checkDataIn :: (Eq i, Eq o) => String -> IOHMM i s o -> [i] -> [o] -> ()
checkDataIn fun hmm xs ys
| all (`elem` is) xs && all (`elem` os) ys = ()
| otherwise = err "illegal data"
where
is = inputs hmm
os = outputs hmm
err = errorIn fun
fromInternal :: (Eq i, Eq s, Eq o) => [i] -> [s] -> [o] -> I.IOHMM -> IOHMM i s o
fromInternal is ss os hmm' = IOHMM { inputs = is
, states = ss
, outputs = os
, initialStateDist = C.fromList pi0'
, transitionDist = \i s -> case (elemIndex i is, elemIndex s ss) of
(Nothing, _) -> C.fromList []
(_, Nothing) -> C.fromList []
(Just j, Just k) -> C.fromList $ w' j k
, emissionDist = \s -> case elemIndex s ss of
Nothing -> C.fromList []
Just i -> C.fromList $ phi' i
}
where
pi0 = I.initialStateDist hmm'
w = I.transitionDist hmm'
phi = H.tr $ I.emissionDistT hmm'
pi0' = zip (H.toList pi0) ss
w' j k = zip (H.toList $ V.unsafeIndex w j H.! k) ss
phi' i = zip (H.toList $ phi H.! i) os
toInternal :: (Eq i, Eq s, Eq o) => IOHMM i s o -> I.IOHMM
toInternal hmm = I.IOHMM { I.nInputs = length is
, I.nStates = length ss
, I.nOutputs = length os
, I.initialStateDist = pi0
, I.transitionDist = w
, I.emissionDistT = phi'
}
where
is = inputs hmm
ss = states hmm
os = outputs hmm
pi0_ = C.normalizeCategoricalPs $ initialStateDist hmm
w_ i = C.normalizeCategoricalPs . (transitionDist hmm) i
phi_ = C.normalizeCategoricalPs . emissionDist hmm
pi0 = H.fromList [pdf pi0_ s | s <- ss]
w = V.fromList $ map (\i -> H.fromLists [[pdf (w_ i s) s' | s' <- ss] | s <- ss]) is
phi' = H.fromLists [[pdf (phi_ s) o | s <- ss] | o <- os]
errorIn :: String -> String -> a
errorIn fun msg = error $ "Learning.IOHMM." ++ fun ++ ": " ++ msg