{-# LANGUAGE
MultiParamTypeClasses, FlexibleInstances, FlexibleContexts,
UndecidableInstances, ForeignFunctionInterface, BangPatterns,
RankNTypes
#-}
module Data.Random.Distribution.Normal
( Normal(..)
, normal, normalT
, stdNormal, stdNormalT
, doubleStdNormal
, floatStdNormal
, realFloatStdNormal
, normalTail
, normalPair
, boxMullerNormalPair
, knuthPolarNormalPair
) where
import Data.Random.Internal.Words
import Data.Bits
import Data.Random.Source
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Ziggurat
import Data.Random.RVar
import Data.Vector.Generic (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV
import Data.Number.Erf
normalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
normalPair = boxMullerNormalPair
{-# INLINE boxMullerNormalPair #-}
boxMullerNormalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
boxMullerNormalPair = do
u <- stdUniform
t <- stdUniform
let r = sqrt (-2 * log u)
theta = (2 * pi) * t
x = r * cos theta
y = r * sin theta
return (x,y)
{-# INLINE knuthPolarNormalPair #-}
knuthPolarNormalPair :: (Floating a, Ord a, Distribution Uniform a) => RVar (a,a)
knuthPolarNormalPair = do
v1 <- uniform (-1) 1
v2 <- uniform (-1) 1
let s = v1*v1 + v2*v2
if s >= 1
then knuthPolarNormalPair
else return $ if s == 0
then (0,0)
else let scale = sqrt (-2 * log s / s)
in (v1 * scale, v2 * scale)
{-# INLINE normalTail #-}
normalTail :: (Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail r = go
where
go = do
!u <- stdUniformT
let !x = log u / r
!v <- stdUniformT
let !y = log v
if x*x + y+y > 0
then go
else return (r - x)
normalZ ::
(RealFloat a, Erf a, Vector v a, Distribution Uniform a, Integral b) =>
b -> (forall m. RVarT m (Int, a)) -> Ziggurat v a
normalZ p = mkZigguratRec True normalF normalFInv normalFInt normalFVol (2^p)
normalF :: (Floating a, Ord a) => a -> a
normalF x
| x <= 0 = 1
| otherwise = exp ((-0.5) * x*x)
normalFInv :: Floating a => a -> a
normalFInv y = sqrt ((-2) * log y)
normalFInt :: (Floating a, Erf a, Ord a) => a -> a
normalFInt x
| x <= 0 = 0
| otherwise = normalFVol * erf (x * sqrt 0.5)
normalFVol :: Floating a => a
normalFVol = sqrt (0.5 * pi)
realFloatStdNormal :: (RealFloat a, Erf a, Distribution Uniform a) => RVarT m a
realFloatStdNormal = runZiggurat (normalZ p getIU `asTypeOf` (undefined :: Ziggurat V.Vector a))
where
p :: Int
p = 6
getIU :: (Num a, Distribution Uniform a) => RVarT m (Int, a)
getIU = do
i <- getRandomWord8
u <- uniformT (-1) 1
return (fromIntegral i .&. (2^p-1), u)
doubleStdNormal :: RVarT m Double
doubleStdNormal = runZiggurat doubleStdNormalZ
doubleStdNormalC :: Int
doubleStdNormalC = 512
doubleStdNormalR, doubleStdNormalV :: Double
doubleStdNormalR = 3.852046150368388
doubleStdNormalV = 2.4567663515413507e-3
{-# NOINLINE doubleStdNormalZ #-}
doubleStdNormalZ :: Ziggurat UV.Vector Double
doubleStdNormalZ = mkZiggurat_ True
normalF normalFInv
doubleStdNormalC doubleStdNormalR doubleStdNormalV
getIU
(normalTail doubleStdNormalR)
where
getIU :: RVarT m (Int, Double)
getIU = do
!w <- getRandomWord64
let (u,i) = wordToDoubleWithExcess w
return $! (fromIntegral i .&. (doubleStdNormalC-1), u+u-1)
floatStdNormal :: RVarT m Float
floatStdNormal = runZiggurat floatStdNormalZ
floatStdNormalC :: Int
floatStdNormalC = 512
floatStdNormalR, floatStdNormalV :: Float
floatStdNormalR = 3.852046150368388
floatStdNormalV = 2.4567663515413507e-3
{-# NOINLINE floatStdNormalZ #-}
floatStdNormalZ :: Ziggurat UV.Vector Float
floatStdNormalZ = mkZiggurat_ True
normalF normalFInv
floatStdNormalC floatStdNormalR floatStdNormalV
getIU
(normalTail floatStdNormalR)
where
getIU :: RVarT m (Int, Float)
getIU = do
!w <- getRandomWord32
let (u,i) = word32ToFloatWithExcess w
return (fromIntegral i .&. (floatStdNormalC-1), u+u-1)
normalCdf :: (Real a) => a -> a -> a -> Double
normalCdf m s x = normcdf ((realToFrac x - realToFrac m) / realToFrac s)
normalPdf :: (Real a, Floating b) => a -> a -> a -> b
normalPdf mu sigma x =
(recip (sqrt (2 * pi * sigma2))) * (exp ((-((realToFrac x) - (realToFrac mu))^2) / (2 * sigma2)))
where
sigma2 = realToFrac sigma^2
normalLogPdf :: (Real a, Floating b) => a -> a -> a -> b
normalLogPdf mu sigma x =
log (recip (sqrt (2 * pi * sigma2))) +
((-((realToFrac x) - (realToFrac mu))^2) / (2 * sigma2))
where
sigma2 = realToFrac sigma^2
data Normal a
= StdNormal
| Normal a a
instance Distribution Normal Double where
rvarT StdNormal = doubleStdNormal
rvarT (Normal m s) = do
x <- doubleStdNormal
return (x * s + m)
instance Distribution Normal Float where
rvarT StdNormal = floatStdNormal
rvarT (Normal m s) = do
x <- floatStdNormal
return (x * s + m)
instance (Real a, Distribution Normal a) => CDF Normal a where
cdf StdNormal = normalCdf 0 1
cdf (Normal m s) = normalCdf m s
instance (Real a, Floating a, Distribution Normal a) => PDF Normal a where
pdf StdNormal = normalPdf 0 1
pdf (Normal m s) = normalPdf m s
logPdf StdNormal = normalLogPdf 0 1
logPdf (Normal m s) = normalLogPdf m s
{-# SPECIALIZE stdNormal :: RVar Double #-}
{-# SPECIALIZE stdNormal :: RVar Float #-}
stdNormal :: Distribution Normal a => RVar a
stdNormal = rvar StdNormal
stdNormalT :: Distribution Normal a => RVarT m a
stdNormalT = rvarT StdNormal
normal :: Distribution Normal a => a -> a -> RVar a
normal m s = rvar (Normal m s)
normalT :: Distribution Normal a => a -> a -> RVarT m a
normalT m s = rvarT (Normal m s)