{-# LANGUAGE
    MultiParamTypeClasses, FlexibleInstances, FlexibleContexts,
    UndecidableInstances, ForeignFunctionInterface, BangPatterns,
    RankNTypes
  #-}

{-# OPTIONS_GHC -fno-warn-type-defaults #-}

module Data.Random.Distribution.Normal
    ( Normal(..)
    , normal, normalT
    , stdNormal, stdNormalT

    , doubleStdNormal
    , floatStdNormal
    , realFloatStdNormal

    , normalTail

    , normalPair
    , boxMullerNormalPair
    , knuthPolarNormalPair
    ) where

import Data.Bits

import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Ziggurat
import Data.Random.RVar
import Data.Word

import Data.Vector.Generic (Vector)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV

import Data.Number.Erf

import qualified System.Random.Stateful as Random

-- |A random variable that produces a pair of independent
-- normally-distributed values.
normalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
normalPair :: RVar (a, a)
normalPair = RVar (a, a)
forall a. (Floating a, Distribution StdUniform a) => RVar (a, a)
boxMullerNormalPair

-- |A random variable that produces a pair of independent
-- normally-distributed values, computed using the Box-Muller method.
-- This algorithm is slightly slower than Knuth's method but using a
-- constant amount of entropy (Knuth's method is a rejection method).
-- It is also slightly more general (Knuth's method require an 'Ord'
-- instance).
{-# INLINE boxMullerNormalPair #-}
boxMullerNormalPair :: (Floating a, Distribution StdUniform a) => RVar (a,a)
boxMullerNormalPair :: RVar (a, a)
boxMullerNormalPair = do
    a
u <- RVar a
forall a. Distribution StdUniform a => RVar a
stdUniform
    a
t <- RVar a
forall a. Distribution StdUniform a => RVar a
stdUniform
    let r :: a
r = a -> a
forall a. Floating a => a -> a
sqrt (-a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
u)
        theta :: a
theta = (a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a
forall a. Floating a => a
pi) a -> a -> a
forall a. Num a => a -> a -> a
* a
t

        x :: a
x = a
r a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
cos a
theta
        y :: a
y = a
r a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sin a
theta
    (a, a) -> RVar (a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x,a
y)

-- |A random variable that produces a pair of independent
-- normally-distributed values, computed using Knuth's polar method.
-- Slightly faster than 'boxMullerNormalPair' when it accepts on the
-- first try, but does not always do so.
{-# INLINE knuthPolarNormalPair #-}
knuthPolarNormalPair :: (Floating a, Ord a, Distribution Uniform a) => RVar (a,a)
knuthPolarNormalPair :: RVar (a, a)
knuthPolarNormalPair = do
    a
v1 <- a -> a -> RVar a
forall a. Distribution Uniform a => a -> a -> RVar a
uniform (-a
1) a
1
    a
v2 <- a -> a -> RVar a
forall a. Distribution Uniform a => a -> a -> RVar a
uniform (-a
1) a
1

    let s :: a
s = a
v1a -> a -> a
forall a. Num a => a -> a -> a
*a
v1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
v2a -> a -> a
forall a. Num a => a -> a -> a
*a
v2
    if a
s a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1
        then RVar (a, a)
forall a.
(Floating a, Ord a, Distribution Uniform a) =>
RVar (a, a)
knuthPolarNormalPair
        else (a, a) -> RVar (a, a)
forall (m :: * -> *) a. Monad m => a -> m a
return ((a, a) -> RVar (a, a)) -> (a, a) -> RVar (a, a)
forall a b. (a -> b) -> a -> b
$ if a
s a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
            then (a
0,a
0)
            else let scale :: a
scale = a -> a
forall a. Floating a => a -> a
sqrt (-a
2 a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
s a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
s)
                  in (a
v1 a -> a -> a
forall a. Num a => a -> a -> a
* a
scale, a
v2 a -> a -> a
forall a. Num a => a -> a -> a
* a
scale)

-- |Draw from the tail of a normal distribution (the region beyond the provided value)
{-# INLINE normalTail #-}
normalTail :: (Distribution StdUniform a, Floating a, Ord a) =>
              a -> RVarT m a
normalTail :: a -> RVarT m a
normalTail a
r = RVarT m a
forall (m :: * -> *). RVarT m a
go
    where
        go :: RVarT m a
go = do
            !a
u <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
            let !x :: a
x = a -> a
forall a. Floating a => a -> a
log a
u a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
r
            !a
v <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
            let !y :: a
y = a -> a
forall a. Floating a => a -> a
log a
v
            if a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
ya -> a -> a
forall a. Num a => a -> a -> a
+a
y a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0
                then RVarT m a
go
                else a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a
r a -> a -> a
forall a. Num a => a -> a -> a
- a
x)

-- |Construct a 'Ziggurat' for sampling a normal distribution, given
-- @logBase 2 c@ and the 'zGetIU' implementation.
normalZ ::
  (RealFloat a, Erf a, Vector v a, Distribution Uniform a, Integral b) =>
  b -> (forall m. RVarT m (Int, a)) -> Ziggurat v a
normalZ :: b -> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat v a
normalZ b
p = Bool
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> a
-> Int
-> (forall (m :: * -> *). RVarT m (Int, a))
-> Ziggurat v a
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> Ziggurat v t
mkZigguratRec Bool
True a -> a
forall a. (Floating a, Ord a) => a -> a
normalF a -> a
forall a. Floating a => a -> a
normalFInv a -> a
forall a. (Floating a, Erf a, Ord a) => a -> a
normalFInt a
forall a. Floating a => a
normalFVol (Int
2Int -> b -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^b
p)

-- | Ziggurat target function (upper half of a non-normalized gaussian PDF)
normalF :: (Floating a, Ord a) => a -> a
normalF :: a -> a
normalF a
x
    | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0    = a
1
    | Bool
otherwise = a -> a
forall a. Floating a => a -> a
exp ((-a
0.5) a -> a -> a
forall a. Num a => a -> a -> a
* a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x)
-- | inverse of 'normalF'
normalFInv :: Floating a => a -> a
normalFInv :: a -> a
normalFInv a
y  = a -> a
forall a. Floating a => a -> a
sqrt ((-a
2) a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
log a
y)
-- | integral of 'normalF'
normalFInt :: (Floating a, Erf a, Ord a) => a -> a
normalFInt :: a -> a
normalFInt a
x
    | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0    = a
0
    | Bool
otherwise = a
forall a. Floating a => a
normalFVol a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Erf a => a -> a
erf (a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a. Floating a => a -> a
sqrt a
0.5)
-- | volume of 'normalF'
normalFVol :: Floating a => a
normalFVol :: a
normalFVol = a -> a
forall a. Floating a => a -> a
sqrt (a
0.5 a -> a -> a
forall a. Num a => a -> a -> a
* a
forall a. Floating a => a
pi)

-- |A random variable sampling from the standard normal distribution
-- over any 'RealFloat' type (subject to the rest of the constraints -
-- it builds and uses a 'Ziggurat' internally, which requires the 'Erf'
-- class).
--
-- Because it computes a 'Ziggurat', it is very expensive to use for
-- just one evaluation, or even for multiple evaluations if not used and
-- reused monomorphically (to enable the ziggurat table to be let-floated
-- out).  If you don't know whether your use case fits this description
-- then you're probably better off using a different algorithm, such as
-- 'boxMullerNormalPair' or 'knuthPolarNormalPair'.  And of course if
-- you don't need the full generality of this definition then you're much
-- better off using 'doubleStdNormal' or 'floatStdNormal'.
--
-- As far as I know, this should be safe to use in any monomorphic
-- @Distribution Normal@ instance declaration.
realFloatStdNormal :: (RealFloat a, Erf a, Distribution Uniform a) => RVarT m a
realFloatStdNormal :: RVarT m a
realFloatStdNormal = Ziggurat Vector a -> RVarT m a
forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat (Int
-> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat Vector a
forall a (v :: * -> *) b.
(RealFloat a, Erf a, Vector v a, Distribution Uniform a,
 Integral b) =>
b -> (forall (m :: * -> *). RVarT m (Int, a)) -> Ziggurat v a
normalZ Int
p forall a (m :: * -> *).
(Num a, Distribution Uniform a) =>
RVarT m (Int, a)
forall (m :: * -> *). RVarT m (Int, a)
getIU Ziggurat Vector a -> Ziggurat Vector a -> Ziggurat Vector a
forall a. a -> a -> a
`asTypeOf` (forall a. Ziggurat Vector a
forall a. HasCallStack => a
undefined :: Ziggurat V.Vector a))
    where
        p :: Int
        p :: Int
p = Int
6

        getIU :: (Num a, Distribution Uniform a) => RVarT m (Int, a)
        getIU :: RVarT m (Int, a)
getIU = do
            Word8
i <- RGen -> RVarT m Word8
forall g (m :: * -> *). StatefulGen g m => g -> m Word8
Random.uniformWord8 RGen
RGen
            a
u <- a -> a -> RVarT m a
forall a (m :: * -> *).
Distribution Uniform a =>
a -> a -> RVarT m a
uniformT (-a
1) a
1
            (Int, a) -> RVarT m (Int, a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Word8 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
2Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^Int
pInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), a
u)

-- |A random variable sampling from the standard normal distribution
-- over the 'Double' type.
doubleStdNormal :: RVarT m Double
doubleStdNormal :: RVarT m Double
doubleStdNormal = Ziggurat Vector Double -> RVarT m Double
forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat Ziggurat Vector Double
doubleStdNormalZ

-- doubleStdNormalC must not be over 2^12 if using wordToDoubleWithExcess
doubleStdNormalC :: Int
doubleStdNormalC :: Int
doubleStdNormalC = Int
512
doubleStdNormalR, doubleStdNormalV :: Double
doubleStdNormalR :: Double
doubleStdNormalR = Double
3.852046150368388
doubleStdNormalV :: Double
doubleStdNormalV = Double
2.4567663515413507e-3

{-# NOINLINE doubleStdNormalZ #-}
doubleStdNormalZ :: Ziggurat UV.Vector Double
doubleStdNormalZ :: Ziggurat Vector Double
doubleStdNormalZ = Bool
-> (Double -> Double)
-> (Double -> Double)
-> Int
-> Double
-> Double
-> (forall (m :: * -> *). RVarT m (Int, Double))
-> (forall (m :: * -> *). RVarT m Double)
-> Ziggurat Vector Double
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
True
        Double -> Double
forall a. (Floating a, Ord a) => a -> a
normalF Double -> Double
forall a. Floating a => a -> a
normalFInv
        Int
doubleStdNormalC Double
doubleStdNormalR Double
doubleStdNormalV
        forall (m :: * -> *). RVarT m (Int, Double)
getIU
        (Double -> RVarT m Double
forall a (m :: * -> *).
(Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail Double
doubleStdNormalR)
    where
        getIU :: RVarT m (Int, Double)
        getIU :: RVarT m (Int, Double)
getIU = do
            !Word64
w <- RGen -> RVarT m Word64
forall g (m :: * -> *). StatefulGen g m => g -> m Word64
Random.uniformWord64 RGen
RGen
            let (Double
u,Word64
i) = Word64 -> (Double, Word64)
wordToDoubleWithExcess Word64
w
            (Int, Double) -> RVarT m (Int, Double)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Int, Double) -> RVarT m (Int, Double))
-> (Int, Double) -> RVarT m (Int, Double)
forall a b. (a -> b) -> a -> b
$! (Word64 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
doubleStdNormalCInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), Double
uDouble -> Double -> Double
forall a. Num a => a -> a -> a
+Double
uDouble -> Double -> Double
forall a. Num a => a -> a -> a
-Double
1)

-- NOTE: inlined from random-source
{-# INLINE wordToDouble #-}
-- |Pack the low 52 bits from a 'Word64' into a 'Double' in the range [0,1).
-- Used to convert a 'stdUniform' 'Word64' to a 'stdUniform' 'Double'.
wordToDouble :: Word64 -> Double
wordToDouble :: Word64 -> Double
wordToDouble Word64
x = (Integer -> Int -> Double
forall a. RealFloat a => Integer -> Int -> a
encodeFloat (Integer -> Int -> Double) -> Integer -> Int -> Double
forall a b. (a -> b) -> a -> b
$! Word64 -> Integer
forall a. Integral a => a -> Integer
toInteger (Word64
x Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
0x000fffffffffffff {- 2^52-1 -})) (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ (-Int
52)

{-# INLINE wordToDoubleWithExcess #-}
-- |Same as wordToDouble, but also return the unused bits (as the 12
-- least significant bits of a 'Word64')
wordToDoubleWithExcess :: Word64 -> (Double, Word64)
wordToDoubleWithExcess :: Word64 -> (Double, Word64)
wordToDoubleWithExcess Word64
x = (Word64 -> Double
wordToDouble Word64
x, Word64
x Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
52)


-- |A random variable sampling from the standard normal distribution
-- over the 'Float' type.
floatStdNormal :: RVarT m Float
floatStdNormal :: RVarT m Float
floatStdNormal = Ziggurat Vector Float -> RVarT m Float
forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat Ziggurat Vector Float
floatStdNormalZ

-- floatStdNormalC must not be over 2^9 if using word32ToFloatWithExcess
floatStdNormalC :: Int
floatStdNormalC :: Int
floatStdNormalC = Int
512
floatStdNormalR, floatStdNormalV :: Float
floatStdNormalR :: Float
floatStdNormalR = Float
3.852046150368388
floatStdNormalV :: Float
floatStdNormalV = Float
2.4567663515413507e-3

{-# NOINLINE floatStdNormalZ #-}
floatStdNormalZ :: Ziggurat UV.Vector Float
floatStdNormalZ :: Ziggurat Vector Float
floatStdNormalZ = Bool
-> (Float -> Float)
-> (Float -> Float)
-> Int
-> Float
-> Float
-> (forall (m :: * -> *). RVarT m (Int, Float))
-> (forall (m :: * -> *). RVarT m Float)
-> Ziggurat Vector Float
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
True
        Float -> Float
forall a. (Floating a, Ord a) => a -> a
normalF Float -> Float
forall a. Floating a => a -> a
normalFInv
        Int
floatStdNormalC Float
floatStdNormalR Float
floatStdNormalV
        forall (m :: * -> *). RVarT m (Int, Float)
getIU
        (Float -> RVarT m Float
forall a (m :: * -> *).
(Distribution StdUniform a, Floating a, Ord a) =>
a -> RVarT m a
normalTail Float
floatStdNormalR)
    where
        getIU :: RVarT m (Int, Float)
        getIU :: RVarT m (Int, Float)
getIU = do
            !Word32
w <- RGen -> RVarT m Word32
forall g (m :: * -> *). StatefulGen g m => g -> m Word32
Random.uniformWord32 RGen
RGen
            let (Float
u,Word32
i) = Word32 -> (Float, Word32)
word32ToFloatWithExcess Word32
w
            (Int, Float) -> RVarT m (Int, Float)
forall (m :: * -> *) a. Monad m => a -> m a
return (Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
i Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
floatStdNormalCInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), Float
uFloat -> Float -> Float
forall a. Num a => a -> a -> a
+Float
uFloat -> Float -> Float
forall a. Num a => a -> a -> a
-Float
1)

-- NOTE: inlined from random-source
{-# INLINE word32ToFloat #-}
-- |Pack the low 23 bits from a 'Word32' into a 'Float' in the range [0,1).
-- Used to convert a 'stdUniform' 'Word32' to a 'stdUniform' 'Double'.
word32ToFloat :: Word32 -> Float
word32ToFloat :: Word32 -> Float
word32ToFloat Word32
x = (Integer -> Int -> Float
forall a. RealFloat a => Integer -> Int -> a
encodeFloat (Integer -> Int -> Float) -> Integer -> Int -> Float
forall a b. (a -> b) -> a -> b
$! Word32 -> Integer
forall a. Integral a => a -> Integer
toInteger (Word32
x Word32 -> Word32 -> Word32
forall a. Bits a => a -> a -> a
.&. Word32
0x007fffff {- 2^23-1 -} )) (Int -> Float) -> Int -> Float
forall a b. (a -> b) -> a -> b
$ (-Int
23)

{-# INLINE word32ToFloatWithExcess #-}
-- |Same as word32ToFloat, but also return the unused bits (as the 9
-- least significant bits of a 'Word32')
word32ToFloatWithExcess :: Word32 -> (Float, Word32)
word32ToFloatWithExcess :: Word32 -> (Float, Word32)
word32ToFloatWithExcess Word32
x = (Word32 -> Float
word32ToFloat Word32
x, Word32
x Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
23)


normalCdf :: (Real a) => a -> a -> a -> Double
normalCdf :: a -> a -> a -> Double
normalCdf a
m a
s a
x = Double -> Double
forall a. Erf a => a -> a
normcdf ((a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x Double -> Double -> Double
forall a. Num a => a -> a -> a
- a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
m) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
s)

normalPdf :: (Real a, Floating b) => a -> a -> a -> b
normalPdf :: a -> a -> a -> b
normalPdf a
mu a
sigma a
x =
  (b -> b
forall a. Fractional a => a -> a
recip (b -> b
forall a. Floating a => a -> a
sqrt (b
2 b -> b -> b
forall a. Num a => a -> a -> a
* b
forall a. Floating a => a
pi b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2))) b -> b -> b
forall a. Num a => a -> a -> a
* (b -> b
forall a. Floating a => a -> a
exp ((-((a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x) b -> b -> b
forall a. Num a => a -> a -> a
- (a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
mu))b -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) b -> b -> b
forall a. Fractional a => a -> a -> a
/ (b
2 b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2)))
  where
    sigma2 :: b
sigma2 = a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
sigmab -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2

normalLogPdf :: (Real a, Floating b) => a -> a -> a -> b
normalLogPdf :: a -> a -> a -> b
normalLogPdf a
mu a
sigma a
x =
  b -> b
forall a. Floating a => a -> a
log (b -> b
forall a. Fractional a => a -> a
recip (b -> b
forall a. Floating a => a -> a
sqrt (b
2 b -> b -> b
forall a. Num a => a -> a -> a
* b
forall a. Floating a => a
pi b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2))) b -> b -> b
forall a. Num a => a -> a -> a
+
  ((-((a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x) b -> b -> b
forall a. Num a => a -> a -> a
- (a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
mu))b -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) b -> b -> b
forall a. Fractional a => a -> a -> a
/ (b
2 b -> b -> b
forall a. Num a => a -> a -> a
* b
sigma2))
  where
    sigma2 :: b
sigma2 = a -> b
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
sigmab -> Integer -> b
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2

-- |A specification of a normal distribution over the type 'a'.
data Normal a
    -- |The \"standard\" normal distribution - mean 0, stddev 1
    = StdNormal
    -- |@Normal m s@ is a normal distribution with mean @m@ and stddev @sd@.
    | Normal a a -- mean, sd

instance Distribution Normal Double where
    rvarT :: Normal Double -> RVarT n Double
rvarT Normal Double
StdNormal = RVarT n Double
forall (m :: * -> *). RVarT m Double
doubleStdNormal
    rvarT (Normal Double
m Double
s) = do
        Double
x <- RVarT n Double
forall (m :: * -> *). RVarT m Double
doubleStdNormal
        Double -> RVarT n Double
forall (m :: * -> *) a. Monad m => a -> m a
return (Double
x Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
s Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
m)

instance Distribution Normal Float where
    rvarT :: Normal Float -> RVarT n Float
rvarT Normal Float
StdNormal = RVarT n Float
forall (m :: * -> *). RVarT m Float
floatStdNormal
    rvarT (Normal Float
m Float
s) = do
        Float
x <- RVarT n Float
forall (m :: * -> *). RVarT m Float
floatStdNormal
        Float -> RVarT n Float
forall (m :: * -> *) a. Monad m => a -> m a
return (Float
x Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
s Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float
m)

instance (Real a, Distribution Normal a) => CDF Normal a where
    cdf :: Normal a -> a -> Double
cdf Normal a
StdNormal    = a -> a -> a -> Double
forall a. Real a => a -> a -> a -> Double
normalCdf a
0 a
1
    cdf (Normal a
m a
s) = a -> a -> a -> Double
forall a. Real a => a -> a -> a -> Double
normalCdf a
m a
s

instance (Real a, Floating a, Distribution Normal a) => PDF Normal a where
  pdf :: Normal a -> a -> Double
pdf Normal a
StdNormal    = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalPdf a
0 a
1
  pdf (Normal a
m a
s) = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalPdf a
m a
s
  logPdf :: Normal a -> a -> Double
logPdf Normal a
StdNormal = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalLogPdf a
0 a
1
  logPdf (Normal a
m a
s) = a -> a -> a -> Double
forall a b. (Real a, Floating b) => a -> a -> a -> b
normalLogPdf a
m a
s

{-# SPECIALIZE stdNormal :: RVar Double #-}
{-# SPECIALIZE stdNormal :: RVar Float #-}
-- |'stdNormal' is a normal variable with distribution 'StdNormal'.
stdNormal :: Distribution Normal a => RVar a
stdNormal :: RVar a
stdNormal = Normal a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar Normal a
forall a. Normal a
StdNormal

-- |'stdNormalT' is a normal process with distribution 'StdNormal'.
stdNormalT :: Distribution Normal a => RVarT m a
stdNormalT :: RVarT m a
stdNormalT = Normal a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT Normal a
forall a. Normal a
StdNormal

-- |@normal m s@ is a random variable with distribution @'Normal' m s@.
normal :: Distribution Normal a => a -> a -> RVar a
normal :: a -> a -> RVar a
normal a
m a
s = Normal a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (a -> a -> Normal a
forall a. a -> a -> Normal a
Normal a
m a
s)

-- |@normalT m s@ is a random process with distribution @'Normal' m s@.
normalT :: Distribution Normal a => a -> a -> RVarT m a
normalT :: a -> a -> RVarT m a
normalT a
m a
s = Normal a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (a -> a -> Normal a
forall a. a -> a -> Normal a
Normal a
m a
s)