-- |
-- Module      : MonusWeightedSearch.Examples.Viterbi
-- Copyright   : (c) Donnacha Oisín Kidney 2021
-- Maintainer  : mail@doisinkidney.com
-- Stability   : experimental
-- Portability : non-portable
--
-- An implementation of the Viterbi algorithm using the 'Heap' monad.
--
-- This algorithm follows almost exactly the
-- <https://en.wikipedia.org/wiki/Viterbi_algorithm#Example example given on Wikipedia>.
--
-- This actually implements the /lazy/ Viterbi algorithm, since the heap
-- prioritises likely results.

module MonusWeightedSearch.Examples.Viterbi where


-- $setup
-- >>> import Data.Bifunctor (first)
-- >>> :set -XTypeApplications

import Control.Monad.Heap
import Data.Monus.Prob
import Control.Monad.Writer
import Data.Maybe

-- | A heap of probabilities; similar to a probability monad, but prioritises
-- likely outcomes.
type Viterbi = Heap Prob

-- | The possible observations.
data Obs = Normal | Cold | Dizzy deriving (Int -> Obs -> ShowS
[Obs] -> ShowS
Obs -> String
(Int -> Obs -> ShowS)
-> (Obs -> String) -> ([Obs] -> ShowS) -> Show Obs
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Obs] -> ShowS
$cshowList :: [Obs] -> ShowS
show :: Obs -> String
$cshow :: Obs -> String
showsPrec :: Int -> Obs -> ShowS
$cshowsPrec :: Int -> Obs -> ShowS
Show, Obs -> Obs -> Bool
(Obs -> Obs -> Bool) -> (Obs -> Obs -> Bool) -> Eq Obs
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Obs -> Obs -> Bool
$c/= :: Obs -> Obs -> Bool
== :: Obs -> Obs -> Bool
$c== :: Obs -> Obs -> Bool
Eq)

-- | The possible hidden states.
data States = Healthy | Fever deriving (Int -> States -> ShowS
[States] -> ShowS
States -> String
(Int -> States -> ShowS)
-> (States -> String) -> ([States] -> ShowS) -> Show States
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [States] -> ShowS
$cshowList :: [States] -> ShowS
show :: States -> String
$cshow :: States -> String
showsPrec :: Int -> States -> ShowS
$cshowsPrec :: Int -> States -> ShowS
Show, States -> States -> Bool
(States -> States -> Bool)
-> (States -> States -> Bool) -> Eq States
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: States -> States -> Bool
$c/= :: States -> States -> Bool
== :: States -> States -> Bool
$c== :: States -> States -> Bool
Eq)

-- | Then initial states (i.e. the estimated fever rate in the population).
start :: Viterbi States
start :: Viterbi States
start = [(States, Prob)] -> Viterbi States
forall (m :: * -> *) a w. Applicative m => [(a, w)] -> HeapT w m a
fromList [(States
Healthy, Prob
0.6), (States
Fever, Prob
0.4)]

-- | The transition function: how likely is a healthy person to be healthy on
-- the following day? How likely is someone with a fever today to have one
-- tomorrow?
trans :: States -> Viterbi States
trans :: States -> Viterbi States
trans States
Healthy = [(States, Prob)] -> Viterbi States
forall (m :: * -> *) a w. Applicative m => [(a, w)] -> HeapT w m a
fromList [(States
Healthy, Prob
0.7), (States
Fever, Prob
0.3)]
trans States
Fever   = [(States, Prob)] -> Viterbi States
forall (m :: * -> *) a w. Applicative m => [(a, w)] -> HeapT w m a
fromList [(States
Healthy, Prob
0.4), (States
Fever, Prob
0.6)]

-- | Given the hidden state, what is the likelihood of the various observations.
emit :: States -> Viterbi Obs
emit :: States -> Viterbi Obs
emit States
Healthy = [(Obs, Prob)] -> Viterbi Obs
forall (m :: * -> *) a w. Applicative m => [(a, w)] -> HeapT w m a
fromList [(Obs
Normal, Prob
0.5), (Obs
Cold, Prob
0.4), (Obs
Dizzy, Prob
0.1)]
emit States
Fever   = [(Obs, Prob)] -> Viterbi Obs
forall (m :: * -> *) a w. Applicative m => [(a, w)] -> HeapT w m a
fromList [(Obs
Normal, Prob
0.1), (Obs
Cold, Prob
0.3), (Obs
Dizzy, Prob
0.6)]

-- | @'iterateM' n f x@ applies @f@ to @x@ @n@ times, collecting the results.
iterateM :: Monad m => Int -> (a -> m a) -> m a -> m [a]
iterateM :: forall (m :: * -> *) a.
Monad m =>
Int -> (a -> m a) -> m a -> m [a]
iterateM Int
n a -> m a
f = Int -> ([a] -> [a]) -> m a -> m [a]
go Int
n [a] -> [a]
forall a. a -> a
id
  where
    go :: Int -> ([a] -> [a]) -> m a -> m [a]
go Int
0 [a] -> [a]
k m a
xs = [a] -> m [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([a] -> [a]
k [])
    go Int
n [a] -> [a]
k m a
xs = m a
xs m a -> (a -> m [a]) -> m [a]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
x -> Int -> ([a] -> [a]) -> m a -> m [a]
go (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) ([a] -> [a]
k ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a
xa -> [a] -> [a]
forall a. a -> [a] -> [a]
:)) (a -> m a
f a
x)

-- | Given a sequence of observations, what is the most likely sequence of
-- hidden states?
--
-- For instance, if you observe normal, then cold, then dizzy, the underlying
-- states are most likely to be healthy, then healthy, then fever, with
-- probability 0.01512.
--
-- >>> first (realToFrac @_ @Double) (likely [Normal,Cold,Dizzy])
-- (1.512e-2,[Healthy,Healthy,Fever])
likely :: [Obs] -> (Prob, [States])
likely :: [Obs] -> (Prob, [States])
likely [Obs]
obs = Maybe (Prob, [States]) -> (Prob, [States])
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Prob, [States]) -> (Prob, [States]))
-> Maybe (Prob, [States]) -> (Prob, [States])
forall a b. (a -> b) -> a -> b
$ Heap Prob [States] -> Maybe (Prob, [States])
forall w a. Monus w => Heap w a -> Maybe (w, a)
best (Heap Prob [States] -> Maybe (Prob, [States]))
-> Heap Prob [States] -> Maybe (Prob, [States])
forall a b. (a -> b) -> a -> b
$ do
  [States]
hidden <- Int
-> (States -> Viterbi States)
-> Viterbi States
-> Heap Prob [States]
forall (m :: * -> *) a.
Monad m =>
Int -> (a -> m a) -> m a -> m [a]
iterateM ([Obs] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Obs]
obs) States -> Viterbi States
trans Viterbi States
start
  [Obs]
pobs <- (States -> Viterbi Obs) -> [States] -> HeapT Prob Identity [Obs]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse States -> Viterbi Obs
emit [States]
hidden
  Bool -> HeapT Prob Identity ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ([Obs]
obs [Obs] -> [Obs] -> Bool
forall a. Eq a => a -> a -> Bool
== [Obs]
pobs)
  return [States]
hidden