{-# LANGUAGE FlexibleContexts #-}

-- |
-- Description :  Rate matrix helper functions
-- Copyright   :  (c) Dominik Schrempf 2021
-- License     :  GPLv3
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  unstable
-- Portability :  non-portable (not tested)
--
-- Some helper functions that come handy when working with rate matrices of
-- continuous-time discrete-state Markov processes.
--
-- * Changelog
--
-- To be imported qualified.
module ELynx.Data.MarkovProcess.RateMatrix
  ( RateMatrix,
    ExchangeabilityMatrix,
    StationaryDistribution,
    isValid,
    normalizeSD,
    totalRate,
    totalRateWith,
    normalize,
    normalizeWith,
    setDiagonal,
    toExchangeabilityMatrix,
    fromExchangeabilityMatrix,
    getStationaryDistribution,
    exchFromListLower,
    exchFromListUpper,
  )
where

import qualified Data.Vector.Storable as V
import Numeric.LinearAlgebra hiding (normalize)
import Numeric.SpecFunctions
import Prelude hiding ((<>))

-- | A rate matrix is just a real matrix.
type RateMatrix = Matrix R

-- | A matrix of exchangeabilities, we have q = e * pi, where q is a rate
-- matrix, e is the exchangeability matrix and pi is the diagonal matrix
-- containing the stationary frequency distribution.
type ExchangeabilityMatrix = Matrix R

-- | Stationary distribution of a rate matrix.
type StationaryDistribution = Vector R

epsRelaxed :: Double
epsRelaxed :: Double
epsRelaxed = Double
1e-5

-- | True if distribution sums to 1.0.
isValid :: StationaryDistribution -> Bool
isValid :: StationaryDistribution -> Bool
isValid StationaryDistribution
d = Double
epsRelaxed Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double -> Double
forall a. Num a => a -> a
abs (StationaryDistribution -> Double
forall a. Normed a => a -> Double
norm_1 StationaryDistribution
d Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1.0)

-- | Normalize a stationary distribution so that the elements sum to 1.0.
normalizeSD :: StationaryDistribution -> StationaryDistribution
normalizeSD :: StationaryDistribution -> StationaryDistribution
normalizeSD StationaryDistribution
d = StationaryDistribution
d StationaryDistribution
-> StationaryDistribution -> StationaryDistribution
forall a. Fractional a => a -> a -> a
/ Double -> StationaryDistribution
forall (c :: * -> *) e. Container c e => e -> c e
scalar (StationaryDistribution -> Double
forall a. Normed a => a -> Double
norm_1 StationaryDistribution
d)

matrixSetDiagToZero :: Matrix R -> Matrix R
matrixSetDiagToZero :: Matrix Double -> Matrix Double
matrixSetDiagToZero Matrix Double
m = Matrix Double
m Matrix Double -> Matrix Double -> Matrix Double
forall a. Num a => a -> a -> a
- StationaryDistribution -> Matrix Double
forall a. (Num a, Element a) => Vector a -> Matrix a
diag (Matrix Double -> StationaryDistribution
forall t. Element t => Matrix t -> Vector t
takeDiag Matrix Double
m)
{-# INLINE matrixSetDiagToZero #-}

-- | Get average number of substitutions per unit time.
totalRateWith :: StationaryDistribution -> RateMatrix -> Double
totalRateWith :: StationaryDistribution -> Matrix Double -> Double
totalRateWith StationaryDistribution
d Matrix Double
m = StationaryDistribution -> Double
forall a. Normed a => a -> Double
norm_1 (StationaryDistribution -> Double)
-> StationaryDistribution -> Double
forall a b. (a -> b) -> a -> b
$ StationaryDistribution
d StationaryDistribution -> Matrix Double -> StationaryDistribution
forall t. Numeric t => Vector t -> Matrix t -> Vector t
<# Matrix Double -> Matrix Double
matrixSetDiagToZero Matrix Double
m

-- | Get average number of substitutions per unit time.
totalRate :: RateMatrix -> Double
totalRate :: Matrix Double -> Double
totalRate Matrix Double
m = StationaryDistribution -> Matrix Double -> Double
totalRateWith (Matrix Double -> StationaryDistribution
getStationaryDistribution Matrix Double
m) Matrix Double
m

-- | Normalizes a Markov process generator such that one event happens per unit
-- time. Calculates stationary distribution from rate matrix.
normalize :: RateMatrix -> RateMatrix
normalize :: Matrix Double -> Matrix Double
normalize Matrix Double
m = StationaryDistribution -> Matrix Double -> Matrix Double
normalizeWith (Matrix Double -> StationaryDistribution
getStationaryDistribution Matrix Double
m) Matrix Double
m

-- | Normalizes a Markov process generator such that one event happens per unit
-- time. Faster, but stationary distribution has to be given.
normalizeWith :: StationaryDistribution -> RateMatrix -> RateMatrix
normalizeWith :: StationaryDistribution -> Matrix Double -> Matrix Double
normalizeWith StationaryDistribution
d Matrix Double
m = Double -> Matrix Double -> Matrix Double
forall t (c :: * -> *). Linear t c => t -> c t -> c t
scale (Double
1.0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ StationaryDistribution -> Matrix Double -> Double
totalRateWith StationaryDistribution
d Matrix Double
m) Matrix Double
m

-- | Set the diagonal entries of a matrix such that the rows sum to 0.
setDiagonal :: RateMatrix -> RateMatrix
setDiagonal :: Matrix Double -> Matrix Double
setDiagonal Matrix Double
m = Matrix Double
diagZeroes Matrix Double -> Matrix Double -> Matrix Double
forall a. Num a => a -> a -> a
- StationaryDistribution -> Matrix Double
forall a. (Num a, Element a) => Vector a -> Matrix a
diag ([Double] -> StationaryDistribution
forall a. Storable a => [a] -> Vector a
fromList [Double]
rowSums)
  where
    diagZeroes :: Matrix Double
diagZeroes = Matrix Double -> Matrix Double
matrixSetDiagToZero Matrix Double
m
    rowSums :: [Double]
rowSums = (StationaryDistribution -> Double)
-> [StationaryDistribution] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map StationaryDistribution -> Double
forall a. Normed a => a -> Double
norm_1 ([StationaryDistribution] -> [Double])
-> [StationaryDistribution] -> [Double]
forall a b. (a -> b) -> a -> b
$ Matrix Double -> [StationaryDistribution]
forall t. Element t => Matrix t -> [Vector t]
toRows Matrix Double
diagZeroes

-- | Extract the exchangeability matrix from a rate matrix.
toExchangeabilityMatrix ::
  RateMatrix -> StationaryDistribution -> ExchangeabilityMatrix
toExchangeabilityMatrix :: Matrix Double -> StationaryDistribution -> Matrix Double
toExchangeabilityMatrix Matrix Double
m StationaryDistribution
f = Matrix Double
m Matrix Double -> Matrix Double -> Matrix Double
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
<> StationaryDistribution -> Matrix Double
forall a. (Num a, Element a) => Vector a -> Matrix a
diag StationaryDistribution
oneOverF
  where
    oneOverF :: StationaryDistribution
oneOverF = (Double -> Double)
-> StationaryDistribution -> StationaryDistribution
forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
cmap (Double
1.0 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/) StationaryDistribution
f

-- | Convert exchangeability matrix to rate matrix.
fromExchangeabilityMatrix ::
  ExchangeabilityMatrix -> StationaryDistribution -> RateMatrix
fromExchangeabilityMatrix :: Matrix Double -> StationaryDistribution -> Matrix Double
fromExchangeabilityMatrix Matrix Double
em StationaryDistribution
d = Matrix Double -> Matrix Double
setDiagonal (Matrix Double -> Matrix Double) -> Matrix Double -> Matrix Double
forall a b. (a -> b) -> a -> b
$ Matrix Double
em Matrix Double -> Matrix Double -> Matrix Double
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
<> StationaryDistribution -> Matrix Double
forall a. (Num a, Element a) => Vector a -> Matrix a
diag StationaryDistribution
d

eps :: Double
eps :: Double
eps = Double
1e-12

normalizeSumVec :: V.Vector Double -> V.Vector Double
normalizeSumVec :: StationaryDistribution -> StationaryDistribution
normalizeSumVec StationaryDistribution
v = (Double -> Double)
-> StationaryDistribution -> StationaryDistribution
forall a b.
(Storable a, Storable b) =>
(a -> b) -> Vector a -> Vector b
V.map (Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
s) StationaryDistribution
v
  where
    s :: Double
s = StationaryDistribution -> Double
forall a. (Storable a, Num a) => Vector a -> a
V.sum StationaryDistribution
v
{-# INLINE normalizeSumVec #-}

-- | Get stationary distribution from 'RateMatrix'. Involves eigendecomposition.
-- If the given matrix does not satisfy the required properties of transition
-- rate matrices and no eigenvector with an eigenvalue nearly equal to 0 is
-- found, an error is thrown. Is there an easier way to calculate the stationary
-- distribution or a better way to handle errors (of course I could use the
-- Maybe monad, but then the error report is just delayed to the calling
-- function)?
getStationaryDistribution :: RateMatrix -> StationaryDistribution
getStationaryDistribution :: Matrix Double -> StationaryDistribution
getStationaryDistribution Matrix Double
m =
  if Double
eps Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double -> Double
forall a. Num a => a -> a
abs (Complex Double -> Double
forall a. RealFloat a => Complex a -> a
magnitude (Vector (Complex Double)
eVals Vector (Complex Double) -> Int -> Complex Double
forall c t. Indexable c t => c -> Int -> t
! Int
i))
    then StationaryDistribution -> StationaryDistribution
normalizeSumVec StationaryDistribution
distReal
    else [Char] -> StationaryDistribution
forall a. HasCallStack => [Char] -> a
error [Char]
"getStationaryDistribution: Could not retrieve stationary distribution."
  where
    (Vector (Complex Double)
eVals, Matrix (Complex Double)
eVecs) = Matrix Double -> (Vector (Complex Double), Matrix (Complex Double))
forall t.
Field t =>
Matrix t -> (Vector (Complex Double), Matrix (Complex Double))
eig (Matrix Double -> Matrix Double
forall m mt. Transposable m mt => m -> mt
tr Matrix Double
m)
    i :: IndexOf Vector
i = Vector (Complex Double) -> IndexOf Vector
forall (c :: * -> *) e. Container c e => c e -> IndexOf c
minIndex Vector (Complex Double)
eVals
    distComplex :: Vector (Complex Double)
distComplex = Matrix (Complex Double) -> [Vector (Complex Double)]
forall t. Element t => Matrix t -> [Vector t]
toColumns Matrix (Complex Double)
eVecs [Vector (Complex Double)] -> Int -> Vector (Complex Double)
forall a. [a] -> Int -> a
!! Int
i
    distReal :: StationaryDistribution
distReal = (Complex Double -> Double)
-> Vector (Complex Double) -> StationaryDistribution
forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
cmap Complex Double -> Double
forall a. Complex a -> a
realPart Vector (Complex Double)
distComplex

-- The next functions tackle the somewhat trivial, but not easily solvable
-- problem of converting a triangular matrix (excluding the diagonal) given as a
-- list into a symmetric matrix. The diagonal entries are set to zero.

-- Lower triangular matrix. This is how the exchangeabilities are specified in
-- PAML. Conversion from matrix indices (i,j) to list index k.
--
-- (i,j) k
--
-- (0,0) -
-- (1,0) 0  (1,1) -
-- (2,0) 1  (2,1) 2  (2,2) -
-- (3,0) 3  (3,1) 4  (3,2) 5 (3,3) -
-- (4,0) 6  (4,1) 7  (4,2) 8 (4,3) 9 (4,4) -
--   .
--   .
--   .
--
-- k = (i choose 2) + j.
ijToKLower :: Int -> Int -> Int
ijToKLower :: Int -> Int -> Int
ijToKLower Int
i Int
j
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
j = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round (Int
i Int -> Int -> Double
`choose` Int
2) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
  | Bool
otherwise = [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"ijToKLower: not defined for upper triangular matrix."

-- Upper triangular matrix. Conversion from matrix indices (i,j) to list index
-- k. Matrix is square of size n.
--
-- (i,j) k
--
-- (0,0) -  (0,1) 0  (0,2) 1    (0,3) 2     (0,4) 3     ...
--          (1,1) -  (1,2) n-1  (1,3) n     (1,4) n+1
--                   (2,2) -    (2,3) 2n-3  (2,4) 2n-2
--                              (3,3) -     (3,4) 3n-6
--                                          (4,4) -
--                                                      ...
--
-- k = i*(n-2) - (i choose 2) + (j - 1)
ijToKUpper :: Int -> Int -> Int -> Int
ijToKUpper :: Int -> Int -> Int -> Int
ijToKUpper Int
n Int
i Int
j
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
j = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round (Int
i Int -> Int -> Double
`choose` Int
2) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
  | Bool
otherwise = [Char] -> Int
forall a. HasCallStack => [Char] -> a
error [Char]
"ijToKUpper: not defined for lower triangular matrix."

-- The function is a little weird because HMatrix uses Double indices for Matrix
-- Double builders.
fromListBuilderLower :: RealFrac a => [a] -> a -> a -> a
fromListBuilderLower :: [a] -> a -> a -> a
fromListBuilderLower [a]
es a
i a
j
  | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
j = [a]
es [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int -> Int -> Int
ijToKLower Int
iI Int
jI
  | a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
j = a
0.0
  | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
j = [a]
es [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int -> Int -> Int
ijToKLower Int
jI Int
iI
  | Bool
otherwise =
    [Char] -> a
forall a. HasCallStack => [Char] -> a
error
      [Char]
"Float indices could not be compared during matrix creation."
  where
    iI :: Int
iI = a -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round a
i :: Int
    jI :: Int
jI = a -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round a
j :: Int

-- The function is a little weird because HMatrix uses Double indices for Matrix
-- Double builders.
fromListBuilderUpper :: RealFrac a => Int -> [a] -> a -> a -> a
fromListBuilderUpper :: Int -> [a] -> a -> a -> a
fromListBuilderUpper Int
n [a]
es a
i a
j
  | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
j = [a]
es [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int -> Int -> Int -> Int
ijToKUpper Int
n Int
iI Int
jI
  | a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
j = a
0.0
  | a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
j = [a]
es [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int -> Int -> Int -> Int
ijToKUpper Int
n Int
jI Int
iI
  | Bool
otherwise =
    [Char] -> a
forall a. HasCallStack => [Char] -> a
error
      [Char]
"Float indices could not be compared during matrix creation."
  where
    iI :: Int
iI = a -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round a
i :: Int
    jI :: Int
jI = a -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round a
j :: Int

checkEs :: RealFrac a => Int -> [a] -> [a]
checkEs :: Int -> [a] -> [a]
checkEs Int
n [a]
es
  | [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
es Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nExp = [a]
es
  | Bool
otherwise = [Char] -> [a]
forall a. HasCallStack => [Char] -> a
error [Char]
eStr
  where
    nExp :: Int
nExp = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round (Int
n Int -> Int -> Double
`choose` Int
2)
    eStr :: [Char]
eStr =
      [[Char]] -> [Char]
unlines
        [ [Char]
"exchFromListlower: the number of exchangeabilities does not match the matrix size",
          [Char]
"matrix size: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
n,
          [Char]
"expected number of exchangeabilities: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
nExp,
          [Char]
"received number of exchangeabilities: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
es)
        ]

-- | Build exchangeability matrix from list denoting lower triangular matrix,
-- and excluding diagonal. This is how the exchangeabilities are specified in
-- PAML.
exchFromListLower :: (RealFrac a, Container Vector a) => Int -> [a] -> Matrix a
exchFromListLower :: Int -> [a] -> Matrix a
exchFromListLower Int
n [a]
es = (Int, Int) -> (a -> a -> a) -> Matrix a
forall d f (c :: * -> *) e. Build d f c e => d -> f -> c e
build (Int
n, Int
n) ([a] -> a -> a -> a
forall a. RealFrac a => [a] -> a -> a -> a
fromListBuilderLower (Int -> [a] -> [a]
forall a. RealFrac a => Int -> [a] -> [a]
checkEs Int
n [a]
es))

-- | Build exchangeability matrix from list denoting upper triangular matrix,
-- and excluding diagonal.
exchFromListUpper :: (RealFrac a, Container Vector a) => Int -> [a] -> Matrix a
exchFromListUpper :: Int -> [a] -> Matrix a
exchFromListUpper Int
n [a]
es = (Int, Int) -> (a -> a -> a) -> Matrix a
forall d f (c :: * -> *) e. Build d f c e => d -> f -> c e
build (Int
n, Int
n) (Int -> [a] -> a -> a -> a
forall a. RealFrac a => Int -> [a] -> a -> a -> a
fromListBuilderUpper Int
n (Int -> [a] -> [a]
forall a. RealFrac a => Int -> [a] -> [a]
checkEs Int
n [a]
es))