module Data.Random.Distribution.MultiNormal
(
MultiNormal(..)
, inv
)
where
import Control.Monad (replicateM, when)
import Data.Maybe (fromMaybe, fromJust)
import Data.Random
import GHC.TypeLits
import Numeric.LinearAlgebra.Static
import qualified Numeric.LinearAlgebra as LA
inv :: KnownNat n => Sq n -> Sq n
inv = fromMaybe (error "Failed attempting to invert non-invertible matrix.") .
flip linSolve eye
data family MultiNormal k :: *
data instance KnownNat n => MultiNormal (R n) =
MultiNormal { mu :: (R n), cov :: (Sym n) }
instance KnownNat n => Show (MultiNormal (R n)) where
show MultiNormal {..} = "Normal " ++ show mu ++ " " ++ show cov
instance KnownNat n => Distribution MultiNormal (R n) where
rvar = multiNormalRV
multiNormalRV :: KnownNat n => MultiNormal (R n) -> RVarT m (R n)
multiNormalRV MultiNormal {..} = do
let (vals, vecs) = eigensystem cov
when (any (<0) (LA.toList $ unwrap vals))
(error "Covariance matrix is not positive semi-definite.")
let lSqrt = diag (fromJust . create $ LA.cmap sqrt (extract vals))
bigA = tr vecs <> lSqrt
gnoise <- replicateM (size mu) (rvarT StdNormal)
return $ mu + bigA #> (vector gnoise)
instance KnownNat n => PDF MultiNormal (R n) where
pdf = multiNormalPDF
logPdf = multiNormalLogPDF
multiNormalPDF :: KnownNat n => MultiNormal (R n) -> R n -> Double
multiNormalPDF mn pt =
multiNormalConstant mn * exp (multiNormalQuadraticForm mn pt)
multiNormalLogPDF :: KnownNat n => MultiNormal (R n) -> R n -> Double
multiNormalLogPDF mn pt =
multiNormalConstant mn + multiNormalQuadraticForm mn pt
multiNormalConstant :: KnownNat n => MultiNormal (R n) -> Double
multiNormalConstant MultiNormal {..} = recip . sqrt $ (2*pi)^n * detCov
where
n = size mu
detCov = LA.det . extract . unSym $ cov
multiNormalQuadraticForm :: KnownNat n => MultiNormal (R n) -> R n -> Double
multiNormalQuadraticForm MultiNormal {..} pt = (diff LA.<.> invCov LA.#> diff) / (2)
where
diff = extract (mu pt)
invCov = extract . inv . unSym $ cov