{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Math.HiddenMarkovModel.Distribution (
   Emission, Probability, StateShape,
   Info(..), Generate(..), EmissionProb(..), Estimate(..),

   Discrete(..), DiscreteTrained(..),
   Gaussian(..), GaussianTrained(..), gaussian,

   ToCSV(..), FromCSV(..), HMMCSV.CSVParser, CSVSymbol(..),
   ) where

import qualified Math.HiddenMarkovModel.CSV as HMMCSV
import Math.HiddenMarkovModel.Utility (SquareMatrix, randomItemProp, vectorDim)

import qualified Numeric.LAPACK.Matrix.HermitianPositiveDefinite as HermitianPD
import qualified Numeric.LAPACK.Matrix.Hermitian as Hermitian
import qualified Numeric.LAPACK.Matrix.Triangular as Triangular
import qualified Numeric.LAPACK.Matrix.Shape as MatrixShape
import qualified Numeric.LAPACK.Matrix as Matrix
import qualified Numeric.LAPACK.Vector as Vector
import Numeric.LAPACK.Matrix ((<#))
import Numeric.LAPACK.Vector (Vector)
import Numeric.LAPACK.Format (FormatArray, Format(format))

import qualified Numeric.Netlib.Class as Class
import Foreign.Storable (Storable)

import qualified Data.Array.Comfort.Storable as StorableArray
import qualified Data.Array.Comfort.Shape as Shape
import qualified Data.Array.Comfort.Boxed as Array
import Data.Array.Comfort.Boxed (Array, (!))

import qualified System.Random as Rnd

import qualified Text.CSV.Lazy.String as CSV
import qualified Text.PrettyPrint.Boxes as TextBox
import Text.PrettyPrint.Boxes ((<>), (<+>))
import Text.Read.HT (maybeRead)
import Text.Printf (printf)

import qualified Control.Monad.Exception.Synchronous as ME
import qualified Control.Monad.Trans.Class as MT
import qualified Control.Monad.Trans.State as MS
import Control.DeepSeq (NFData, rnf)
import Control.Monad (liftM2)
import Control.Applicative (liftA2, (<|>))

import qualified Data.NonEmpty as NonEmpty
import qualified Data.Foldable as Fold
import qualified Data.Map as Map
import qualified Data.Set as Set
import qualified Data.List.HT as ListHT
import qualified Data.List as List
import Data.Functor.Identity (Identity(Identity), runIdentity)
import Data.Foldable (Foldable, foldMap)
import Data.Tuple.HT (mapFst, fst3, swap)
import Data.Monoid (Endo(Endo, appEndo))
import Data.Map (Map)
import Data.Maybe (fromMaybe, listToMaybe)

import Prelude ()
import Prelude2010


type HermitianMatrix sh = Hermitian.Hermitian sh
type UpperTriangular sh = Triangular.Upper sh


type family Probability distr
type family Emission distr
type family StateShape distr


class (Class.Real (Probability distr)) => Info distr where
   statesShape :: distr -> StateShape distr

class (Class.Real (Probability distr)) => Generate distr where
   generate ::
      (Rnd.RandomGen g, Emission distr ~ emission, StateShape distr ~ sh) =>
      distr -> Shape.Index sh -> MS.State g emission

class
   (Shape.Indexed (StateShape distr), Class.Real (Probability distr)) =>
      EmissionProb distr where
   {-
   This function could be implemented generically in terms of emissionStateProb
   but that would require an Info constraint.
   -}
   emissionProb ::
      distr -> Emission distr -> Vector (StateShape distr) (Probability distr)
   emissionStateProb ::
      distr -> Emission distr -> Shape.Index (StateShape distr) -> Probability distr
   emissionStateProb distr e s = emissionProb distr e StorableArray.! s

class
   (Distribution tdistr ~ distr, Trained distr ~ tdistr, EmissionProb distr) =>
      Estimate tdistr distr where
   type Distribution tdistr
   type Trained distr
   accumulateEmissions ::
      (Probability distr ~ prob, StateShape distr ~ sh) =>
      Array sh [(Emission distr, prob)] -> tdistr
   -- could as well be in Semigroup class
   combine :: tdistr -> tdistr -> tdistr
   normalize :: tdistr -> distr


newtype Discrete symbol sh prob = Discrete (Map symbol (Vector sh prob))
   deriving (Show)

newtype
   DiscreteTrained symbol sh prob =
      DiscreteTrained (Map symbol (Vector sh prob))
   deriving (Show)

type instance Probability (Discrete symbol sh prob) = prob
type instance Emission (Discrete symbol sh prob) = symbol
type instance StateShape (Discrete symbol sh prob) = sh


instance
   (NFData sh, NFData prob, NFData symbol) =>
      NFData (Discrete symbol sh prob) where
   rnf (Discrete m) = rnf m

instance
   (NFData sh, NFData prob, NFData symbol) =>
      NFData (DiscreteTrained symbol sh prob) where
   rnf (DiscreteTrained m) = rnf m

instance
   (FormatArray sh, Class.Real prob, Format symbol) =>
      Format (Discrete symbol sh prob) where
   format fmt (Discrete m) =
      TextBox.vsep 1 TextBox.left $
      map (\(sym,v) -> format fmt sym <> TextBox.char ':' <+> format fmt v) $
      Map.toAscList m

instance
   (Shape.C sh, Class.Real prob, Ord symbol) =>
      Info (Discrete symbol sh prob) where
   statesShape (Discrete m) = StorableArray.shape $ snd $ Map.findMin m

instance
   (Shape.Indexed sh, Class.Real prob, Ord symbol, Ord prob, Rnd.Random prob) =>
      Generate (Discrete symbol sh prob) where
   generate (Discrete m) state =
      randomItemProp $ Map.toAscList $ fmap (StorableArray.! state) m

instance
   (Shape.Indexed sh, Class.Real prob, Ord symbol) =>
      EmissionProb (Discrete symbol sh prob) where
   emissionProb (Discrete m) =
      mapLookup "emitDiscrete: unknown emission symbol" m

instance
   (Shape.Indexed sh, Eq sh, Class.Real prob, Ord symbol) =>
      Estimate (DiscreteTrained symbol sh prob) (Discrete symbol sh prob) where
   type Distribution (DiscreteTrained symbol sh prob) = Discrete symbol sh prob
   type Trained (Discrete symbol sh prob) = DiscreteTrained symbol sh prob
   accumulateEmissions grouped =
      let set = Set.toAscList $ foldMap (Set.fromList . map fst) grouped
          emi = Map.fromAscList $ zip set [0..]
      in  DiscreteTrained $ Map.fromAscList $ zip set $
          transposeVectorList $
          Array.map
             (StorableArray.accumulate (+)
                 (Vector.constant (Shape.ZeroBased $ length set) 0) .
              map (mapFst
                 (mapLookup "estimateDiscrete: unknown emission symbol" emi)))
             grouped
   combine (DiscreteTrained distr0) (DiscreteTrained distr1) =
      DiscreteTrained $ Map.unionWith Vector.add distr0 distr1
   normalize (DiscreteTrained distr) =
      Discrete $ if Map.null distr then distr else normalizeProbVecs distr

transposeVectorList ::
   (Shape.C sh, Eq sh, Class.Real a) =>
   Array sh (Vector Matrix.ZeroInt a) -> [Vector sh a]
transposeVectorList xs =
   case Array.toList xs of
      [] -> []
      x:_ -> Matrix.toRows $ Matrix.fromColumnArray (StorableArray.shape x) xs

normalizeProbVecs ::
   (Shape.C sh, Eq sh, Foldable f, Functor f, Class.Real a) =>
   f (Vector sh a) -> f (Vector sh a)
normalizeProbVecs vs =
   let factors =
         StorableArray.map recip $ List.foldl1' Vector.add $ Fold.toList vs
   in fmap (Vector.mul factors) vs

mapLookup :: (Ord k) => String -> Map.Map k a -> k -> a
mapLookup msg dict x = Map.findWithDefault (error msg) x dict


newtype Gaussian emiSh stateSh a =
      Gaussian (Array stateSh (Vector emiSh a, UpperTriangular emiSh a, a))
   deriving (Show)

newtype GaussianTrained emiSh stateSh a =
   GaussianTrained
      (Array stateSh
         (Maybe (Vector emiSh a, HermitianMatrix emiSh a, a)))
   deriving (Show)

type instance Probability (Gaussian emiSh stateSh a) = a
type instance Emission (Gaussian emiSh stateSh a) = Vector emiSh a
type instance StateShape (Gaussian emiSh stateSh a) = stateSh


instance
   (NFData emiSh, NFData stateSh, Shape.C stateSh, NFData a, Storable a) =>
      NFData (Gaussian emiSh stateSh a) where
   rnf (Gaussian params) = rnf params

instance
   (NFData emiSh, NFData stateSh, Shape.C stateSh, NFData a, Storable a) =>
      NFData (GaussianTrained emiSh stateSh a) where
   rnf (GaussianTrained params) = rnf params


instance
   (FormatArray emiSh, Shape.C stateSh, Class.Real a) =>
      Format (Gaussian emiSh stateSh a) where
   format = runFormatGaussian $ Class.switchReal formatGaussian formatGaussian

newtype FormatGaussian emiSh stateSh a =
   FormatGaussian
      {runFormatGaussian :: String -> Gaussian emiSh stateSh a -> TextBox.Box}

formatGaussian ::
   (FormatArray emiSh, Shape.C stateSh, Class.Real a, Format a) =>
   FormatGaussian emiSh stateSh a
formatGaussian =
   FormatGaussian $ \fmt (Gaussian params) -> format fmt $ Array.toList params


instance
   (Shape.Indexed stateSh, Eq stateSh, Class.Real a) =>
      Info (Gaussian emiSh stateSh a) where
   statesShape (Gaussian params) = Array.shape params

instance
   (Shape.C emiSh, Eq emiSh, Shape.Indexed stateSh, Eq stateSh, Class.Real a) =>
      Generate (Gaussian emiSh stateSh a) where
   generate (Gaussian allParams) state = do
      let (center, covarianceChol, _c) = allParams ! state
      seed <- MS.state Rnd.random
      return $
         Vector.add center $
         Vector.random Vector.Normal (StorableArray.shape center) seed
            <# covarianceChol

instance
   (Shape.C emiSh, Eq emiSh, Shape.Indexed stateSh, Eq stateSh, Class.Real a) =>
      EmissionProb (Gaussian emiSh stateSh a) where
   emissionProb (Gaussian allParams) x =
      Vector.fromList (Array.shape allParams) $
      map (emissionProbGen x) $ Array.toList allParams
   emissionStateProb (Gaussian allParams) x s =
      emissionProbGen x $ allParams ! s

emissionProbGen ::
   (Shape.C emiSh, Eq emiSh, Class.Real a) =>
   Vector emiSh a -> (Vector emiSh a, UpperTriangular emiSh a, a) -> a
emissionProbGen x (center, covarianceChol, c) =
   let x0 =
         Matrix.solveVector (Triangular.transpose covarianceChol) $
         Vector.sub x center
   in  c * cexp ((-1/2) * Vector.inner x0 x0)


instance
   (Shape.C emiSh, Eq emiSh, Shape.Indexed stateSh, Eq stateSh, Class.Real a) =>
      Estimate
         (GaussianTrained emiSh stateSh a)
         (Gaussian emiSh stateSh a) where
   type Distribution (GaussianTrained emiSh stateSh a) =
            Gaussian emiSh stateSh a
   type Trained (Gaussian emiSh stateSh a) = GaussianTrained emiSh stateSh a
   accumulateEmissions =
      let params xs =
            (NonEmpty.foldl1Map Vector.add (uncurry $ flip Vector.scale) xs,
             covarianceReal $ fmap swap xs,
             Fold.sum $ fmap snd xs)
      in  GaussianTrained . fmap (fmap params . NonEmpty.fetch)
   combine (GaussianTrained distr0) (GaussianTrained distr1) =
      let comb (center0, covariance0, weight0)
               (center1, covariance1, weight1) =
             (Vector.add center0 center1,
              Vector.add covariance0 covariance1,
              weight0 + weight1)
      in  GaussianTrained $ Array.zipWith (maybePlus comb) distr0 distr1
   {-
     Sum_i (xi-m) * (xi-m)^T
   = Sum_i xi*xi^T + Sum_i m*m^T - Sum_i xi*m^T - Sum_i m*xi^T
   = Sum_i xi*xi^T - Sum_i m*m^T
   = Sum_i xi*xi^T - n * m*m^T
   -}
   normalize (GaussianTrained distr) =
      let params (centerSum, covarianceSum, weight) =
             let c = recip weight
                 center = Vector.scale c centerSum
             in  (center,
                  Vector.sub (Vector.scale c covarianceSum)
                     (Hermitian.outer MatrixShape.RowMajor center))
      in  Gaussian $
          fmap
             (gaussianParameters . params .
              fromMaybe
                (error "Distribution.normalize: undefined array element")) $
          distr

-- ToDo: could be managed by semigroup
maybePlus :: (a -> a -> a) -> Maybe a -> Maybe a -> Maybe a
maybePlus f mx my = liftA2 f mx my <|> mx <|> my


newtype CovarianceReal f emiSh a =
   CovarianceReal
      {getCovarianceReal :: f (a, Vector emiSh a) -> HermitianMatrix emiSh a}

covarianceReal ::
   (Shape.C emiSh, Eq emiSh, Class.Real a) =>
   NonEmpty.T [] (a, Vector emiSh a) -> HermitianMatrix emiSh a
covarianceReal =
   getCovarianceReal $
   Class.switchReal
      (CovarianceReal $ Hermitian.sumRank1NonEmpty MatrixShape.RowMajor)
      (CovarianceReal $ Hermitian.sumRank1NonEmpty MatrixShape.RowMajor)

gaussian ::
   (Shape.C emiSh, Shape.C stateSh, Class.Real prob) =>
   Array stateSh (Vector emiSh prob, HermitianMatrix emiSh prob) ->
   Gaussian emiSh stateSh prob
gaussian = consGaussian . fmap gaussianParameters

gaussianParameters ::
   (Shape.C emiSh, Class.Real prob) =>
   (Vector emiSh prob, HermitianMatrix emiSh prob) ->
   (Vector emiSh prob, UpperTriangular emiSh prob, prob)
gaussianParameters (center, covariance) =
   gaussianFromCholesky center $ HermitianPD.decompose covariance

consGaussian ::
   (Shape.C stateSh) =>
   Array stateSh (Vector emiSh a, UpperTriangular emiSh a, a) ->
   Gaussian emiSh stateSh a
consGaussian = Gaussian

gaussianFromCholesky ::
   (Shape.C emiSh, Class.Real prob) =>
   Vector emiSh prob -> UpperTriangular emiSh prob ->
   (Vector emiSh prob, UpperTriangular emiSh prob, prob)
gaussianFromCholesky center covarianceChol =
   let covarianceSqrtDet =
         Vector.product $ Triangular.takeDiagonal covarianceChol
   in  (center, covarianceChol,
        recip (sqrt2pi ^ vectorDim center * covarianceSqrtDet))

sqrt2pi :: (Class.Real a) => a
sqrt2pi = runIdentity $ Class.switchReal sqrt2piAux sqrt2piAux

sqrt2piAux :: (Floating a) => Identity a
sqrt2piAux = Identity $ sqrt (2*pi)

cexp :: (Class.Real a) => a -> a
cexp = appEndo $ Class.switchReal (Endo exp) (Endo exp)



class ToCSV distr where
   toCells :: distr -> [[String]]

class FromCSV distr where
   parseCells :: StateShape distr -> HMMCSV.CSVParser distr

class (Ord symbol) => CSVSymbol symbol where
   cellFromSymbol :: symbol -> String
   symbolFromCell :: String -> Maybe symbol

instance CSVSymbol Char where
   cellFromSymbol = (:[])
   symbolFromCell = listToMaybe

instance CSVSymbol Int where
   cellFromSymbol = show
   symbolFromCell = maybeRead


instance
   (Shape.C sh, Class.Real prob, Show prob, Read prob, CSVSymbol symbol) =>
      ToCSV (Discrete symbol sh prob) where
   toCells (Discrete m) =
      map
         (\(symbol, probs) ->
            cellFromSymbol symbol : HMMCSV.cellsFromVector probs) $
      Map.toAscList m

instance
   (Shape.C sh, Class.Real prob, Show prob, Read prob, CSVSymbol symbol) =>
      FromCSV (Discrete symbol sh prob) where
   parseCells n =
      fmap (Discrete . Map.fromList) $
      HMMCSV.manyRowsUntilEnd $ parseSymbolProb n

parseSymbolProb ::
   (Shape.C sh, Class.Real prob, Read prob, CSVSymbol symbol) =>
   sh -> CSV.CSVRow -> HMMCSV.CSVParser (symbol, Vector sh prob)
parseSymbolProb sh row =
   case row of
      [] -> MT.lift $ ME.throw "missing symbol"
      c:cs ->
         liftM2 (,)
            (let str = CSV.csvFieldContent c
             in  MT.lift $ ME.fromMaybe (printf "unknown symbol %s" str) $
                 symbolFromCell str)
            (do v <- HMMCSV.parseVectorFields cs
                let n = Shape.size sh
                let m = vectorDim v
                HMMCSV.assert (n == m)
                   (printf "number of states (%d) and size of probability vector (%d) mismatch"
                      n m)
                return $ StorableArray.reshape sh v)


instance
   (Shape.Indexed emiSh, Shape.Indexed stateSh,
    Class.Real a, Eq a, Show a, Read a) =>
      ToCSV (Gaussian emiSh stateSh a) where
   toCells (Gaussian params) =
      List.intercalate [[]] $
      map
         (\(center, covarianceChol, _) ->
            HMMCSV.cellsFromVector center :
            HMMCSV.cellsFromSquare (Triangular.toSquare covarianceChol)) $
      Array.toList params

instance
   (emiSh ~ Matrix.ZeroInt, Shape.Indexed stateSh,
    Class.Real a, Eq a, Show a, Read a) =>
      FromCSV (Gaussian emiSh stateSh a) where
   parseCells sh = do
      let n = Shape.size sh
      gs <- HMMCSV.manySepUntilEnd parseSingleGaussian
      HMMCSV.assert (length gs == n) $
         printf "number of states (%d) and number of Gaussians (%d) mismatch"
            n (length gs)
      let sizes = map (vectorDim . fst3) gs
      HMMCSV.assert (ListHT.allEqual sizes) $
         printf "dimensions of emissions mismatch: %s" (show sizes)
      return $ consGaussian $ Array.fromList sh gs

parseSingleGaussian ::
   (emiSh ~ Matrix.ZeroInt, Class.Real prob, Eq prob, Read prob) =>
   HMMCSV.CSVParser (Vector emiSh prob, UpperTriangular emiSh prob, prob)
parseSingleGaussian = do
   center <- HMMCSV.parseNonEmptyVectorCells
   covarianceCholSquare <-
      HMMCSV.parseSquareMatrixCells $ StorableArray.shape center
   let covarianceChol = Triangular.takeUpper covarianceCholSquare
   HMMCSV.assert
      (isUpperTriang covarianceCholSquare covarianceChol)
      "matrices must be upper triangular"
   return $ gaussianFromCholesky center covarianceChol


{-
Maybe this test is too strict.
It would also be ok, and certainly more intuitive
to use an orthogonal but not normalized matrix.
We could get such a matrix from the eigensystem.
-}
isUpperTriang ::
   (Shape.C sh, Class.Real a, Eq a) =>
   SquareMatrix sh a -> UpperTriangular sh a -> Bool
isUpperTriang m mt =
   Vector.toList m == Vector.toList (Triangular.toSquare mt)