{-# LANGUAGE TypeFamilies #-}
{- |
Counterparts to functions in "Math.HiddenMarkovModel.Private"
that normalize interim results.
We need to do this in order to prevent
to round very small probabilities to zero.
-}
module Math.HiddenMarkovModel.Normalized where

import qualified Math.HiddenMarkovModel.Public.Distribution as Distr
import Math.HiddenMarkovModel.Private
          (T(..), Trained(..), emission,
           biscaleTransition, revealGen, sumTransitions)
import Math.HiddenMarkovModel.Utility (normalizeFactor, normalizeProb)

import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix ((-*#), (#*|))
import Numeric.LAPACK.Vector (Vector)

import qualified Numeric.Netlib.Class as Class

import qualified Control.Functor.HT as Functor

import qualified Data.Array.Comfort.Storable as StorableArray
import qualified Data.Array.Comfort.Shape as Shape

import qualified Data.NonEmpty.Class as NonEmptyC
import qualified Data.NonEmpty as NonEmpty
import qualified Data.Foldable as Fold
import Data.Traversable (Traversable)


{- $setup
>>> import qualified Data.NonEmpty as NonEmpty
-}


{- |
Logarithm of the likelihood to observe the given sequence.
We return the logarithm because the likelihood can be so small
that it may be rounded to zero in the choosen number type.
-}
logLikelihood ::
   (Distr.EmissionProb typ, Shape.C sh, Eq sh, Floating prob,
    Class.Real prob, Distr.Emission typ prob ~ emission,
    Traversable f) =>
   T typ sh prob -> NonEmpty.T f emission -> prob
logLikelihood :: T typ sh prob -> T f emission -> prob
logLikelihood T typ sh prob
hmm = T f prob -> prob
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
Fold.sum (T f prob -> prob)
-> (T f emission -> T f prob) -> T f emission -> prob
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((prob, Vector sh prob) -> prob)
-> T f (prob, Vector sh prob) -> T f prob
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (prob -> prob
forall a. Floating a => a -> a
log (prob -> prob)
-> ((prob, Vector sh prob) -> prob)
-> (prob, Vector sh prob)
-> prob
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (prob, Vector sh prob) -> prob
forall a b. (a, b) -> a
fst) (T f (prob, Vector sh prob) -> T f prob)
-> (T f emission -> T f (prob, Vector sh prob))
-> T f emission
-> T f prob
forall b c a. (b -> c) -> (a -> b) -> a -> c
. T typ sh prob -> T f emission -> T f (prob, Vector sh prob)
forall typ sh prob emission (f :: * -> *).
(EmissionProb typ, C sh, Eq sh, Real prob,
 Emission typ prob ~ emission, Traversable f) =>
T typ sh prob -> T f emission -> T f (prob, Vector sh prob)
alpha T typ sh prob
hmm

alpha ::
   (Distr.EmissionProb typ, Shape.C sh, Eq sh,
    Class.Real prob, Distr.Emission typ prob ~ emission,
    Traversable f) =>
   T typ sh prob ->
   NonEmpty.T f emission -> NonEmpty.T f (prob, Vector sh prob)
alpha :: T typ sh prob -> T f emission -> T f (prob, Vector sh prob)
alpha T typ sh prob
hmm (NonEmpty.Cons emission
x f emission
xs) =
   let normMulEmiss :: emission -> Vector sh prob -> (prob, Vector sh prob)
normMulEmiss emission
y = Vector sh prob -> (prob, Vector sh prob)
forall sh a. (C sh, Real a) => Vector sh a -> (a, Vector sh a)
normalizeFactor (Vector sh prob -> (prob, Vector sh prob))
-> (Vector sh prob -> Vector sh prob)
-> Vector sh prob
-> (prob, Vector sh prob)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector sh prob -> Vector sh prob -> Vector sh prob
forall sh a.
(C sh, Eq sh, Floating a) =>
Vector sh a -> Vector sh a -> Vector sh a
Vector.mul (T typ sh prob -> Emission typ prob -> Vector sh prob
forall typ sh prob.
(EmissionProb typ, C sh, Eq sh, Real prob) =>
T typ sh prob -> Emission typ prob -> Vector sh prob
emission T typ sh prob
hmm emission
Emission typ prob
y)
   in  ((prob, Vector sh prob) -> emission -> (prob, Vector sh prob))
-> (prob, Vector sh prob)
-> f emission
-> T f (prob, Vector sh prob)
forall (f :: * -> *) b a.
Traversable f =>
(b -> a -> b) -> b -> f a -> T f b
NonEmpty.scanl
          (\(prob
_,Vector sh prob
alphai) emission
xi -> emission -> Vector sh prob -> (prob, Vector sh prob)
normMulEmiss emission
xi (T typ sh prob -> Square sh prob
forall typ sh prob. T typ sh prob -> Square sh prob
transition T typ sh prob
hmm Square sh prob -> Vector sh prob -> Vector sh prob
forall typ xl xu lower upper meas vert horiz height width a.
(MultiplyVector typ xl xu, Strip lower, Strip upper, Measure meas,
 C vert, C horiz, C height, C width, Eq width, Floating a) =>
Matrix typ xl xu lower upper meas vert horiz height width a
-> Vector width a -> Vector height a
#*| Vector sh prob
alphai))
          (emission -> Vector sh prob -> (prob, Vector sh prob)
normMulEmiss emission
x (T typ sh prob -> Vector sh prob
forall typ sh prob. T typ sh prob -> Vector sh prob
initial T typ sh prob
hmm))
          f emission
xs

beta ::
   (Distr.EmissionProb typ, Shape.C sh, Eq sh,
    Class.Real prob, Distr.Emission typ prob ~ emission,
    Traversable f, NonEmptyC.Reverse f) =>
   T typ sh prob ->
   f (prob, emission) -> NonEmpty.T f (Vector sh prob)
beta :: T typ sh prob -> f (prob, emission) -> T f (Vector sh prob)
beta T typ sh prob
hmm =
   ((prob, emission) -> Vector sh prob -> Vector sh prob)
-> Vector sh prob -> f (prob, emission) -> T f (Vector sh prob)
forall (f :: * -> *) a b.
(Traversable f, Reverse f) =>
(a -> b -> b) -> b -> f a -> T f b
nonEmptyScanr
      (\(prob
ci,emission
xi) Vector sh prob
betai ->
         prob -> Vector sh prob -> Vector sh prob
forall sh a. (C sh, Floating a) => a -> Vector sh a -> Vector sh a
Vector.scale (prob -> prob
forall a. Fractional a => a -> a
recip prob
ci) (Vector sh prob -> Vector sh prob)
-> Vector sh prob -> Vector sh prob
forall a b. (a -> b) -> a -> b
$
         Vector sh prob -> Vector sh prob -> Vector sh prob
forall sh a.
(C sh, Eq sh, Floating a) =>
Vector sh a -> Vector sh a -> Vector sh a
Vector.mul (T typ sh prob -> Emission typ prob -> Vector sh prob
forall typ sh prob.
(EmissionProb typ, C sh, Eq sh, Real prob) =>
T typ sh prob -> Emission typ prob -> Vector sh prob
emission T typ sh prob
hmm emission
Emission typ prob
xi) Vector sh prob
betai Vector sh prob
-> Matrix
     (Array Unpacked Arbitrary)
     ()
     ()
     Filled
     Filled
     Shape
     Small
     Small
     sh
     sh
     prob
-> Vector sh prob
forall typ xl xu lower upper meas vert horiz height width a.
(MultiplyVector typ xl xu, Strip lower, Strip upper, Measure meas,
 C vert, C horiz, C height, C width, Eq height, Floating a) =>
Vector height a
-> Matrix typ xl xu lower upper meas vert horiz height width a
-> Vector width a
-*# T typ sh prob
-> Matrix
     (Array Unpacked Arbitrary)
     ()
     ()
     Filled
     Filled
     Shape
     Small
     Small
     sh
     sh
     prob
forall typ sh prob. T typ sh prob -> Square sh prob
transition T typ sh prob
hmm)
      (sh -> Vector sh prob
forall sh a. (C sh, Floating a) => sh -> Vector sh a
Vector.one (sh -> Vector sh prob) -> sh -> Vector sh prob
forall a b. (a -> b) -> a -> b
$ Vector sh prob -> sh
forall sh a. Array sh a -> sh
StorableArray.shape (Vector sh prob -> sh) -> Vector sh prob -> sh
forall a b. (a -> b) -> a -> b
$ T typ sh prob -> Vector sh prob
forall typ sh prob. T typ sh prob -> Vector sh prob
initial T typ sh prob
hmm)

alphaBeta ::
   (Distr.EmissionProb typ, Shape.C sh, Eq sh,
    Class.Real prob, Distr.Emission typ prob ~ emission,
    Traversable f, NonEmptyC.Zip f, NonEmptyC.Reverse f) =>
   T typ sh prob ->
   NonEmpty.T f emission ->
   (NonEmpty.T f (prob, Vector sh prob), NonEmpty.T f (Vector sh prob))
alphaBeta :: T typ sh prob
-> T f emission
-> (T f (prob, Vector sh prob), T f (Vector sh prob))
alphaBeta T typ sh prob
hmm T f emission
xs =
   let calphas :: T f (prob, Vector sh prob)
calphas = T typ sh prob -> T f emission -> T f (prob, Vector sh prob)
forall typ sh prob emission (f :: * -> *).
(EmissionProb typ, C sh, Eq sh, Real prob,
 Emission typ prob ~ emission, Traversable f) =>
T typ sh prob -> T f emission -> T f (prob, Vector sh prob)
alpha T typ sh prob
hmm T f emission
xs
   in  (T f (prob, Vector sh prob)
calphas,
        T typ sh prob -> f (prob, emission) -> T f (Vector sh prob)
forall typ sh prob emission (f :: * -> *).
(EmissionProb typ, C sh, Eq sh, Real prob,
 Emission typ prob ~ emission, Traversable f, Reverse f) =>
T typ sh prob -> f (prob, emission) -> T f (Vector sh prob)
beta T typ sh prob
hmm (f (prob, emission) -> T f (Vector sh prob))
-> f (prob, emission) -> T f (Vector sh prob)
forall a b. (a -> b) -> a -> b
$ T f (prob, emission) -> f (prob, emission)
forall (f :: * -> *) a. T f a -> f a
NonEmpty.tail (T f (prob, emission) -> f (prob, emission))
-> T f (prob, emission) -> f (prob, emission)
forall a b. (a -> b) -> a -> b
$ T f prob -> T f emission -> T f (prob, emission)
forall (f :: * -> *) a b. Zip f => f a -> f b -> f (a, b)
NonEmptyC.zip (((prob, Vector sh prob) -> prob)
-> T f (prob, Vector sh prob) -> T f prob
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (prob, Vector sh prob) -> prob
forall a b. (a, b) -> a
fst T f (prob, Vector sh prob)
calphas) T f emission
xs)


xiFromAlphaBeta ::
   (Distr.EmissionProb typ, Shape.C sh, Eq sh,
    Class.Real prob, Distr.Emission typ prob ~ emission,
    Traversable f, NonEmptyC.Zip f) =>
   T typ sh prob ->
   NonEmpty.T f emission ->
   NonEmpty.T f (prob, Vector sh prob) ->
   NonEmpty.T f (Vector sh prob) ->
   f (Matrix.Square sh prob)
xiFromAlphaBeta :: T typ sh prob
-> T f emission
-> T f (prob, Vector sh prob)
-> T f (Vector sh prob)
-> f (Square sh prob)
xiFromAlphaBeta T typ sh prob
hmm T f emission
xs T f (prob, Vector sh prob)
calphas T f (Vector sh prob)
betas =
   let (T f prob
cs,T f (Vector sh prob)
alphas) = T f (prob, Vector sh prob) -> (T f prob, T f (Vector sh prob))
forall (f :: * -> *) a b. Functor f => f (a, b) -> (f a, f b)
Functor.unzip T f (prob, Vector sh prob)
calphas
   in  (emission
 -> Vector sh prob -> prob -> Vector sh prob -> Square sh prob)
-> f emission
-> f (Vector sh prob)
-> f prob
-> f (Vector sh prob)
-> f (Square sh prob)
forall (f :: * -> *) a b c d e.
Zip f =>
(a -> b -> c -> d -> e) -> f a -> f b -> f c -> f d -> f e
NonEmptyC.zipWith4
          (\emission
x Vector sh prob
alpha0 prob
c1 Vector sh prob
beta1 ->
             prob -> Square sh prob -> Square sh prob
forall meas vert horiz property height width a pack lower upper.
(Measure meas, C vert, C horiz, Scale property, C height, C width,
 Floating a) =>
a
-> ArrayMatrix
     pack property lower upper meas vert horiz height width a
-> ArrayMatrix
     pack property lower upper meas vert horiz height width a
Matrix.scale (prob -> prob
forall a. Fractional a => a -> a
recip prob
c1) (Square sh prob -> Square sh prob)
-> Square sh prob -> Square sh prob
forall a b. (a -> b) -> a -> b
$ T typ sh prob
-> Emission typ prob
-> Vector sh prob
-> Vector sh prob
-> Square sh prob
forall typ sh prob.
(EmissionProb typ, C sh, Eq sh, Real prob) =>
T typ sh prob
-> Emission typ prob
-> Vector sh prob
-> Vector sh prob
-> Square sh prob
biscaleTransition T typ sh prob
hmm emission
Emission typ prob
x Vector sh prob
alpha0 Vector sh prob
beta1)
          (T f emission -> f emission
forall (f :: * -> *) a. T f a -> f a
NonEmpty.tail T f emission
xs)
          (T f (Vector sh prob) -> f (Vector sh prob)
forall (f :: * -> *) a. Traversable f => T f a -> f a
NonEmpty.init T f (Vector sh prob)
alphas)
          (T f prob -> f prob
forall (f :: * -> *) a. T f a -> f a
NonEmpty.tail T f prob
cs)
          (T f (Vector sh prob) -> f (Vector sh prob)
forall (f :: * -> *) a. T f a -> f a
NonEmpty.tail T f (Vector sh prob)
betas)

zetaFromAlphaBeta ::
   (Shape.C sh, Eq sh, Class.Real prob, NonEmptyC.Zip f) =>
   NonEmpty.T f (prob, Vector sh prob) ->
   NonEmpty.T f (Vector sh prob) ->
   NonEmpty.T f (Vector sh prob)
zetaFromAlphaBeta :: T f (prob, Vector sh prob)
-> T f (Vector sh prob) -> T f (Vector sh prob)
zetaFromAlphaBeta T f (prob, Vector sh prob)
calphas T f (Vector sh prob)
betas =
   ((prob, Vector sh prob) -> Vector sh prob -> Vector sh prob)
-> T f (prob, Vector sh prob)
-> T f (Vector sh prob)
-> T f (Vector sh prob)
forall (f :: * -> *) a b c.
Zip f =>
(a -> b -> c) -> f a -> f b -> f c
NonEmptyC.zipWith (Vector sh prob -> Vector sh prob -> Vector sh prob
forall sh a.
(C sh, Eq sh, Floating a) =>
Vector sh a -> Vector sh a -> Vector sh a
Vector.mul (Vector sh prob -> Vector sh prob -> Vector sh prob)
-> ((prob, Vector sh prob) -> Vector sh prob)
-> (prob, Vector sh prob)
-> Vector sh prob
-> Vector sh prob
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (prob, Vector sh prob) -> Vector sh prob
forall a b. (a, b) -> b
snd) T f (prob, Vector sh prob)
calphas T f (Vector sh prob)
betas


{- |
Reveal the state sequence
that led most likely to the observed sequence of emissions.
It is found using the Viterbi algorithm.
-}
reveal ::
   (Distr.EmissionProb typ, Shape.InvIndexed sh, Eq sh, Shape.Index sh ~ state,
    Distr.Emission typ prob ~ emission, Class.Real prob, Traversable f) =>
   T typ sh prob -> NonEmpty.T f emission -> NonEmpty.T f state
reveal :: T typ sh prob -> T f emission -> T f state
reveal = (Vector (Deferred sh) prob -> Vector (Deferred sh) prob)
-> T typ sh prob -> T f emission -> T f state
forall typ sh state prob emission (f :: * -> *).
(EmissionProb typ, InvIndexed sh, Eq sh, Index sh ~ state,
 Emission typ prob ~ emission, Real prob, Traversable f) =>
(Vector (Deferred sh) prob -> Vector (Deferred sh) prob)
-> T typ sh prob -> T f emission -> T f state
revealGen Vector (Deferred sh) prob -> Vector (Deferred sh) prob
forall sh a. (C sh, Real a) => Vector sh a -> Vector sh a
normalizeProb


{- |
Variant of NonEmpty.scanr with less stack consumption.

prop> \x xs -> nonEmptyScanr (-) x xs == NonEmpty.scanr (-) x (xs::[Int])
-}
nonEmptyScanr ::
   (Traversable f, NonEmptyC.Reverse f) =>
   (a -> b -> b) -> b -> f a -> NonEmpty.T f b
nonEmptyScanr :: (a -> b -> b) -> b -> f a -> T f b
nonEmptyScanr a -> b -> b
f b
x =
   T f b -> T f b
forall (f :: * -> *) a. Reverse f => f a -> f a
NonEmptyC.reverse (T f b -> T f b) -> (f a -> T f b) -> f a -> T f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (b -> a -> b) -> b -> f a -> T f b
forall (f :: * -> *) b a.
Traversable f =>
(b -> a -> b) -> b -> f a -> T f b
NonEmpty.scanl ((a -> b -> b) -> b -> a -> b
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> b -> b
f) b
x (f a -> T f b) -> (f a -> f a) -> f a -> T f b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f a -> f a
forall (f :: * -> *) a. Reverse f => f a -> f a
NonEmptyC.reverse


{- |
Consider a superposition of all possible state sequences
weighted by the likelihood to produce the observed emission sequence.
Now train the model with respect to all of these sequences
with respect to the weights.
This is done by the Baum-Welch algorithm.
-}
trainUnsupervised ::
   (Distr.Estimate typ, Shape.C sh, Eq sh,
    Class.Real prob, Distr.Emission typ prob ~ emission) =>
   T typ sh prob -> NonEmpty.T [] emission -> Trained typ sh prob
trainUnsupervised :: T typ sh prob -> T [] emission -> Trained typ sh prob
trainUnsupervised T typ sh prob
hmm T [] emission
xs =
   let (T [] (prob, Vector sh prob)
alphas, T [] (Vector sh prob)
betas) = T typ sh prob
-> T [] emission
-> (T [] (prob, Vector sh prob), T [] (Vector sh prob))
forall typ sh prob emission (f :: * -> *).
(EmissionProb typ, C sh, Eq sh, Real prob,
 Emission typ prob ~ emission, Traversable f, Zip f, Reverse f) =>
T typ sh prob
-> T f emission
-> (T f (prob, Vector sh prob), T f (Vector sh prob))
alphaBeta T typ sh prob
hmm T [] emission
xs
       zetas :: T [] (Vector sh prob)
zetas = T [] (prob, Vector sh prob)
-> T [] (Vector sh prob) -> T [] (Vector sh prob)
forall sh prob (f :: * -> *).
(C sh, Eq sh, Real prob, Zip f) =>
T f (prob, Vector sh prob)
-> T f (Vector sh prob) -> T f (Vector sh prob)
zetaFromAlphaBeta T [] (prob, Vector sh prob)
alphas T [] (Vector sh prob)
betas
       zeta0 :: Vector sh prob
zeta0 = T [] (Vector sh prob) -> Vector sh prob
forall (f :: * -> *) a. T f a -> a
NonEmpty.head T [] (Vector sh prob)
zetas

   in  Trained :: forall typ sh prob.
Vector sh prob
-> Square sh prob -> Trained typ sh prob -> Trained typ sh prob
Trained {
          trainedInitial :: Vector sh prob
trainedInitial = Vector sh prob
zeta0,
          trainedTransition :: Square sh prob
trainedTransition =
             T typ sh prob -> [Square sh prob] -> Square sh prob
forall sh e typ.
(C sh, Eq sh, Real e) =>
T typ sh e -> [Square sh e] -> Square sh e
sumTransitions T typ sh prob
hmm ([Square sh prob] -> Square sh prob)
-> [Square sh prob] -> Square sh prob
forall a b. (a -> b) -> a -> b
$ T typ sh prob
-> T [] emission
-> T [] (prob, Vector sh prob)
-> T [] (Vector sh prob)
-> [Square sh prob]
forall typ sh prob emission (f :: * -> *).
(EmissionProb typ, C sh, Eq sh, Real prob,
 Emission typ prob ~ emission, Traversable f, Zip f) =>
T typ sh prob
-> T f emission
-> T f (prob, Vector sh prob)
-> T f (Vector sh prob)
-> f (Square sh prob)
xiFromAlphaBeta T typ sh prob
hmm T [] emission
xs T [] (prob, Vector sh prob)
alphas T [] (Vector sh prob)
betas,
          trainedDistribution :: Trained typ sh prob
trainedDistribution =
             T [] (Emission typ prob, Vector sh prob) -> Trained typ sh prob
forall typ sh prob.
(Estimate typ, C sh, Eq sh, Real prob) =>
T [] (Emission typ prob, Vector sh prob) -> Trained typ sh prob
Distr.accumulateEmissionVectors (T [] (Emission typ prob, Vector sh prob) -> Trained typ sh prob)
-> T [] (Emission typ prob, Vector sh prob) -> Trained typ sh prob
forall a b. (a -> b) -> a -> b
$ T [] emission
-> T [] (Vector sh prob) -> T [] (emission, Vector sh prob)
forall (f :: * -> *) a b. Zip f => f a -> f b -> f (a, b)
NonEmptyC.zip T [] emission
xs T [] (Vector sh prob)
zetas
       }