module Learning.HMM (
HMM (..)
, LogLikelihood
, new
, init
, withEmission
, viterbi
, baumWelch
, simulate
) where
import Prelude hiding (init)
import Control.Applicative ((<$>))
import Control.Arrow (first)
import Data.List (elemIndex, genericLength)
import Data.Maybe (fromJust)
import Data.Random.Distribution (pdf, rvar)
import Data.Random.Distribution.Categorical (Categorical)
import qualified Data.Random.Distribution.Categorical as C (
fromList, fromWeightedList, normalizeCategoricalPs
)
import Data.Random.Distribution.Categorical.Util ()
import Data.Random.RVar (RVar)
import Data.Random.Sample (sample)
import qualified Data.Vector as V (
elemIndex, fromList, map, toList, unsafeIndex
)
import qualified Data.Vector.Generic as G (convert)
import qualified Data.Vector.Unboxed as U (fromList)
import qualified Numeric.LinearAlgebra.Data as H (
(!), fromList, fromLists, toList
)
import qualified Numeric.LinearAlgebra.HMatrix as H (tr)
import Learning.HMM.Internal
data HMM s o = HMM { states :: [s]
, outputs :: [o]
, initialStateDist :: Categorical Double s
, transitionDist :: s -> Categorical Double s
, emissionDist :: s -> Categorical Double o
}
instance (Show s, Show o) => Show (HMM s o) where
show = showHMM
showHMM :: (Show s, Show o) => HMM s o -> String
showHMM hmm = "HMM {states = " ++ show ss
++ ", outputs = " ++ show os
++ ", initialStateDist = " ++ show pi0
++ ", transitionDist = " ++ show [(w s, s) | s <- ss]
++ ", emissionDist = " ++ show [(phi s, s) | s <- ss]
++ "}"
where
ss = states hmm
os = outputs hmm
pi0 = initialStateDist hmm
w = transitionDist hmm
phi = emissionDist hmm
new :: (Ord s, Ord o) => [s] -> [o] -> HMM s o
new ss os = HMM { states = ss
, outputs = os
, initialStateDist = pi0
, transitionDist = w
, emissionDist = phi
}
where
pi0 = C.fromWeightedList [(1, s) | s <- ss]
w s | s `elem` ss = C.fromList [(p s', s') | s' <- ss]
| otherwise = C.fromList []
where
k = genericLength ss
p s' | s' == s = 1/2 * (1 + 1/k)
| otherwise = 1/2 / k
phi s | s `elem` ss = C.fromWeightedList [(1, o) | o <- os]
| otherwise = C.fromList []
init :: (Eq s, Eq o) => [s] -> [o] -> RVar (HMM s o)
init ss os = fromHMM' ss os <$> init' (length ss) (length os)
withEmission :: (Eq s, Eq o) => HMM s o -> [o] -> HMM s o
withEmission model xs = fromHMM' ss os $ withEmission' model' xs'
where
ss = states model
os = outputs model
os' = V.fromList os
model' = toHMM' model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` os') xs
viterbi :: (Eq s, Eq o) => HMM s o -> [o] -> ([s], LogLikelihood)
viterbi model xs =
checkModelIn "viterbi" model `seq`
checkDataIn "viterbi" model xs `seq`
first toStates $ viterbi' model' xs'
where
ss' = V.fromList $ states model
os' = V.fromList $ outputs model
model' = toHMM' model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` os') xs
toStates = V.toList . V.map (V.unsafeIndex ss') . G.convert
baumWelch :: (Eq s, Eq o) => HMM s o -> [o] -> [(HMM s o, LogLikelihood)]
baumWelch model xs =
checkModelIn "baumWelch" model `seq`
checkDataIn "baumWelch" model xs `seq`
map (first $ fromHMM' ss os) $ baumWelch' model' xs'
where
ss = states model
os = outputs model
os' = V.fromList os
model' = toHMM' model
xs' = U.fromList $ fromJust $ mapM (`V.elemIndex` os') xs
simulate :: HMM s o -> Int -> RVar ([s], [o])
simulate model step
| step < 1 = return ([], [])
| otherwise = do s0 <- sample $ rvar pi0
x0 <- sample $ rvar $ phi s0
unzip . ((s0, x0) :) <$> sim s0 (step 1)
where
sim _ 0 = return []
sim s t = do s' <- sample $ rvar $ w s
x' <- sample $ rvar $ phi s'
((s', x') :) <$> sim s' (t 1)
pi0 = initialStateDist model
w = transitionDist model
phi = emissionDist model
checkModelIn :: String -> HMM s o -> ()
checkModelIn fun hmm
| null ss = err "empty states"
| null os = err "empty outputs"
| otherwise = ()
where
ss = states hmm
os = outputs hmm
err = errorIn fun
checkDataIn :: Eq o => String -> HMM s o -> [o] -> ()
checkDataIn fun hmm xs
| all (`elem` os) xs = ()
| otherwise = err "illegal data"
where
os = outputs hmm
err = errorIn fun
fromHMM' :: (Eq s, Eq o) => [s] -> [o] -> HMM' -> HMM s o
fromHMM' ss os hmm' = HMM { states = ss
, outputs = os
, initialStateDist = C.fromList pi0'
, transitionDist = \s -> case elemIndex s ss of
Nothing -> C.fromList []
Just i -> C.fromList $ w' i
, emissionDist = \s -> case elemIndex s ss of
Nothing -> C.fromList []
Just i -> C.fromList $ phi' i
}
where
pi0 = initialStateDist' hmm'
w = transitionDist' hmm'
phi = H.tr $ emissionDistT' hmm'
pi0' = zip (H.toList pi0) ss
w' i = zip (H.toList $ w H.! i) ss
phi' i = zip (H.toList $ phi H.! i) os
toHMM' :: (Eq s, Eq o) => HMM s o -> HMM'
toHMM' hmm = HMM' { nStates' = length ss
, nOutputs' = length os
, initialStateDist' = pi0
, transitionDist' = w
, emissionDistT' = phi'
}
where
ss = states hmm
os = outputs hmm
pi0_ = C.normalizeCategoricalPs $ initialStateDist hmm
w_ = C.normalizeCategoricalPs . transitionDist hmm
phi_ = C.normalizeCategoricalPs . emissionDist hmm
pi0 = H.fromList [pdf pi0_ s | s <- ss]
w = H.fromLists [[pdf (w_ s) s' | s' <- ss] | s <- ss]
phi' = H.fromLists [[pdf (phi_ s) o | s <- ss] | o <- os]
errorIn :: String -> String -> a
errorIn fun msg = error $ "Learning.HMM." ++ fun ++ ": " ++ msg