-------------------------------------------------------------- --- An implementation of multivariate normal distributions --- -------------------------------------------------------------- {- Written by: Dominic Steinitz, Jacob West Last modified: 2016-07-27 Summary: Multivariate normal distributions are necessary for Kalman filters and smoothers. However, strictly speaking, the functionality provided here should exist elsewhere, perhaps in the package: random-fu. -} --------------------------- --- File header pragmas --- --------------------------- {-# LANGUAGE RecordWildCards #-} -- Used by multiNormalRV, multiNormalConstant -- and multiNormalQuadraticForm {-# LANGUAGE MultiParamTypeClasses #-} -- Necessary for Distribution instance {-# LANGUAGE FlexibleInstances #-} -- Necessary for Show instance {-# LANGUAGE TypeFamilies #-} -- Necessary for MultiNormal definition ------------------------ --- Module / Exports --- ------------------------ module Data.Random.Distribution.MultiNormal ( MultiNormal(..) , inv ) where --------------- --- Imports --- --------------- import Control.Monad (replicateM, when) import Data.Maybe (fromMaybe, fromJust) import Data.Random import GHC.TypeLits import Numeric.LinearAlgebra.Static import qualified Numeric.LinearAlgebra as LA ------------------------ --- Helper Functions --- ------------------------ -- Matrix inverse: for some reason, this isn't built into the -- static interface; warning: no error handling -- WARNING: Needs better error handling inv :: KnownNat n => Sq n -> Sq n inv = fromMaybe (error "Failed attempting to invert non-invertible matrix.") . flip linSolve eye ---------------------------------------- --- Multivariate Normal Distrubtions --- ---------------------------------------- -- This probably belongs elsewhere, maybe Data.Random, but that would -- create a dependence on Numric.LinearAlgebra which I believe is not -- there now and may be undesirable. data family MultiNormal k :: * data instance KnownNat n => MultiNormal (R n) = MultiNormal { mu :: (R n), cov :: (Sym n) } --- Show Instance --- instance KnownNat n => Show (MultiNormal (R n)) where show MultiNormal {..} = "Normal " ++ show mu ++ " " ++ show cov --- Distribution Instance --- instance KnownNat n => Distribution MultiNormal (R n) where rvar = multiNormalRV -- WARNING: Needs better error handling multiNormalRV :: KnownNat n => MultiNormal (R n) -> RVarT m (R n) multiNormalRV MultiNormal {..} = do let (vals, vecs) = eigensystem cov when (any (<0) (LA.toList $ unwrap vals)) (error "Covariance matrix is not positive semi-definite.") let lSqrt = diag (fromJust . create $ LA.cmap sqrt (extract vals)) bigA = tr vecs <> lSqrt gnoise <- replicateM (size mu) (rvarT StdNormal) return $ mu + bigA #> (vector gnoise) --- PDF Instance --- instance KnownNat n => PDF MultiNormal (R n) where pdf = multiNormalPDF logPdf = multiNormalLogPDF multiNormalPDF :: KnownNat n => MultiNormal (R n) -> R n -> Double multiNormalPDF mn pt = multiNormalConstant mn * exp (multiNormalQuadraticForm mn pt) multiNormalLogPDF :: KnownNat n => MultiNormal (R n) -> R n -> Double multiNormalLogPDF mn pt = multiNormalConstant mn + multiNormalQuadraticForm mn pt multiNormalConstant :: KnownNat n => MultiNormal (R n) -> Double multiNormalConstant MultiNormal {..} = recip . sqrt $ (2*pi)^n * detCov where n = size mu detCov = LA.det . extract . unSym $ cov multiNormalQuadraticForm :: KnownNat n => MultiNormal (R n) -> R n -> Double multiNormalQuadraticForm MultiNormal {..} pt = (diff LA.<.> invCov LA.#> diff) / (-2) where diff = extract (mu - pt) invCov = extract . inv . unSym $ cov