----------------------------------------------------------------------------- -- | -- Module : Data.Random.Distribution.MultivariateNormal -- Copyright : (c) 2016 FP Complete Corporation -- License : MIT (see LICENSE) -- Maintainer : dominic@steinitz.org -- -- Sample from the multivariate normal distribution with a given -- vector-valued \(\mu\) and covariance matrix \(\Sigma\). For example, -- the chart below shows samples from the bivariate normal -- distribution. -- -- <<diagrams/src_Data_Random_Distribution_MultivariateNormal_diagM.svg#diagram=diagM&height=600&width=500>> -- -- Example code to generate the chart: -- -- > import qualified Graphics.Rendering.Chart as C -- > import Graphics.Rendering.Chart.Backend.Diagrams -- > -- > import Data.Random.Distribution.MultivariateNormal -- > -- > import qualified Data.Random as R -- > import Data.Random.Source.PureMT -- > import Control.Monad.State -- > import qualified Numeric.LinearAlgebra.HMatrix as LA -- > -- > nSamples :: Int -- > nSamples = 10000 -- > -- > sigma1, sigma2, rho :: Double -- > sigma1 = 3.0 -- > sigma2 = 1.0 -- > rho = 0.5 -- > -- > singleSample :: R.RVarT (State PureMT) (LA.Vector Double) -- > singleSample = R.sample $ Normal (LA.fromList [0.0, 0.0]) -- > (LA.sym $ (2 LA.>< 2) [ sigma1, rho * sigma1 * sigma2 -- > , rho * sigma1 * sigma2, sigma2]) -- > -- > multiSamples :: [LA.Vector Double] -- > multiSamples = evalState (replicateM nSamples $ R.sample singleSample) (pureMT 3) -- > pts = map (f . LA.toList) multiSamples -- > where -- > f [x, y] = (x, y) -- > f _ = error "Only pairs for this chart" -- > -- > -- > chartPoint pointVals n = C.toRenderable layout -- > where -- > -- > fitted = C.plot_points_values .~ pointVals -- > $ C.plot_points_style . C.point_color .~ opaque red -- > $ C.plot_points_title .~ "Sample" -- > $ def -- > -- > layout = C.layout_title .~ "Sampling Bivariate Normal (" ++ (show n) ++ " samples)" -- > $ C.layout_y_axis . C.laxis_generate .~ C.scaledAxis def (-3,3) -- > $ C.layout_x_axis . C.laxis_generate .~ C.scaledAxis def (-3,3) -- > -- > $ C.layout_plots .~ [C.toPlot fitted] -- > $ def -- > -- > diagM = do -- > denv <- defaultEnv C.vectorAlignmentFns 600 500 -- > return $ fst $ runBackend denv (C.render (chartPoint pts nSamples) (500, 500)) -- ----------------------------------------------------------------------------- {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} module Data.Random.Distribution.MultivariateNormal ( Normal(..) ) where import Data.Random.Distribution import qualified Numeric.LinearAlgebra.HMatrix as H import Control.Monad import qualified Data.Random as R import Foreign.Storable ( Storable ) import Data.Maybe ( fromJust ) normalMultivariate :: H.Vector Double -> H.Herm Double -> R.RVarT m (H.Vector Double) normalMultivariate mu bigSigma = do z <- replicateM (H.size mu) (rvarT R.StdNormal) return $ mu + bigA H.#> (H.fromList z) where (vals, bigU) = H.eigSH bigSigma lSqrt = H.diag $ H.cmap sqrt vals bigA = bigU H.<> lSqrt data family Normal k :: * data instance Normal (H.Vector Double) = Normal (H.Vector Double) (H.Herm Double) instance Distribution Normal (H.Vector Double) where rvar (Normal m s) = normalMultivariate m s normalPdf :: (H.Numeric a, H.Field a, H.Indexable (H.Vector a) a, Num (H.Vector a)) => H.Vector a -> H.Herm a -> H.Vector a -> a normalPdf mu sigma x = exp $ normalLogPdf mu sigma x normalLogPdf :: (H.Numeric a, H.Field a, H.Indexable (H.Vector a) a, Num (H.Vector a)) => H.Vector a -> H.Herm a -> H.Vector a -> a normalLogPdf mu bigSigma x = - H.sumElements (H.cmap log (diagonals dec)) - 0.5 * (fromIntegral (H.size mu)) * log (2 * pi) - 0.5 * s where dec = fromJust $ H.mbChol bigSigma t = fromJust $ H.linearSolve (H.tr dec) (H.asColumn $ x - mu) u = H.cmap (\v -> v * v) t s = H.sumElements u diagonals :: (Storable a, H.Element t, H.Indexable (H.Vector t) a) => H.Matrix t -> H.Vector a diagonals m = H.fromList (map (\i -> m H.! i H.! i) [0..n-1]) where n = max (H.rows m) (H.cols m) instance PDF Normal (H.Vector Double) where pdf (Normal m s) = normalPdf m s logPdf (Normal m s) = normalLogPdf m s