{-# LANGUAGE
    MultiParamTypeClasses,
    FlexibleInstances, FlexibleContexts,
    UndecidableInstances, BangPatterns
  #-}

{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}

module Data.Random.Distribution.Gamma
    ( Gamma(..)
    , gamma, gammaT

    , Erlang(..)
    , erlang, erlangT

    , mtGamma
    ) where

import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Data.Random.Distribution.Normal

import Data.Ratio

import Numeric.SpecFunctions

-- |derived from  Marsaglia & Tang, "A Simple Method for generating gamma
-- variables", ACM Transactions on Mathematical Software, Vol 26, No 3 (2000), p363-372.
{-# SPECIALIZE mtGamma :: Double -> Double -> RVarT m Double #-}
{-# SPECIALIZE mtGamma :: Float  -> Float  -> RVarT m Float  #-}
mtGamma
    :: (Floating a, Ord a,
        Distribution StdUniform a,
        Distribution Normal a)
    => a -> a -> RVarT m a
mtGamma :: a -> a -> RVarT m a
mtGamma a
a a
b
    | a
a a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1     = do
        a
u <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
        a -> a -> RVarT m a
forall a (m :: * -> *).
(Floating a, Ord a, Distribution StdUniform a,
 Distribution Normal a) =>
a -> a -> RVarT m a
mtGamma (a
1a -> a -> a
forall a. Num a => a -> a -> a
+a
a) (a -> RVarT m a) -> a -> RVarT m a
forall a b. (a -> b) -> a -> b
$! (a
b a -> a -> a
forall a. Num a => a -> a -> a
* a
u a -> a -> a
forall a. Floating a => a -> a -> a
** a -> a
forall a. Fractional a => a -> a
recip a
a)
    | Bool
otherwise = RVarT m a
forall (m :: * -> *). RVarT m a
go
    where
        !d :: a
d = a
a a -> a -> a
forall a. Num a => a -> a -> a
- Rational -> a
forall a. Fractional a => Rational -> a
fromRational (Integer
1Integer -> Integer -> Rational
forall a. Integral a => a -> a -> Ratio a
%Integer
3)
        !c :: a
c = a -> a
forall a. Fractional a => a -> a
recip (a -> a
forall a. Floating a => a -> a
sqrt (a
9a -> a -> a
forall a. Num a => a -> a -> a
*a
d))

        go :: RVarT m a
go = do
            a
x <- RVarT m a
forall a (m :: * -> *). Distribution Normal a => RVarT m a
stdNormalT
            let !v :: a
v   = a
1 a -> a -> a
forall a. Num a => a -> a -> a
+ a
ca -> a -> a
forall a. Num a => a -> a -> a
*a
x

            if a
v a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0
                then RVarT m a
go
                else do
                    a
u  <- RVarT m a
forall a (m :: * -> *). Distribution StdUniform a => RVarT m a
stdUniformT
                    let !x_2 :: a
x_2 = a
xa -> a -> a
forall a. Num a => a -> a -> a
*a
x; !x_4 :: a
x_4 = a
x_2a -> a -> a
forall a. Num a => a -> a -> a
*a
x_2
                        v3 :: a
v3 = a
va -> a -> a
forall a. Num a => a -> a -> a
*a
va -> a -> a
forall a. Num a => a -> a -> a
*a
v
                        dv :: a
dv = a
d a -> a -> a
forall a. Num a => a -> a -> a
* a
v3
                    if      a
u a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
0.0331a -> a -> a
forall a. Num a => a -> a -> a
*a
x_4
                     Bool -> Bool -> Bool
|| a -> a
forall a. Floating a => a -> a
log a
u a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0.5 a -> a -> a
forall a. Num a => a -> a -> a
* a
x_2 a -> a -> a
forall a. Num a => a -> a -> a
+ a
d a -> a -> a
forall a. Num a => a -> a -> a
- a
dv a -> a -> a
forall a. Num a => a -> a -> a
+ a
da -> a -> a
forall a. Num a => a -> a -> a
*a -> a
forall a. Floating a => a -> a
log a
v3
                        then a -> RVarT m a
forall (m :: * -> *) a. Monad m => a -> m a
return (a
ba -> a -> a
forall a. Num a => a -> a -> a
*a
dv)
                        else RVarT m a
go

{-# SPECIALIZE gamma :: Float  -> Float  -> RVar Float  #-}
{-# SPECIALIZE gamma :: Double -> Double -> RVar Double #-}
gamma :: (Distribution Gamma a) => a -> a -> RVar a
gamma :: a -> a -> RVar a
gamma a
a a
b = Gamma a -> RVar a
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (a -> a -> Gamma a
forall a. a -> a -> Gamma a
Gamma a
a a
b)

gammaT :: (Distribution Gamma a) => a -> a -> RVarT m a
gammaT :: a -> a -> RVarT m a
gammaT a
a a
b = Gamma a -> RVarT m a
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (a -> a -> Gamma a
forall a. a -> a -> Gamma a
Gamma a
a a
b)

erlang :: (Distribution (Erlang a) b) => a -> RVar b
erlang :: a -> RVar b
erlang a
a = Erlang a b -> RVar b
forall (d :: * -> *) t. Distribution d t => d t -> RVar t
rvar (a -> Erlang a b
forall a b. a -> Erlang a b
Erlang a
a)

erlangT :: (Distribution (Erlang a) b) => a -> RVarT m b
erlangT :: a -> RVarT m b
erlangT a
a = Erlang a b -> RVarT m b
forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (a -> Erlang a b
forall a b. a -> Erlang a b
Erlang a
a)

data    Gamma a    = Gamma a a
newtype Erlang a b = Erlang a

instance (Floating a, Ord a, Distribution Normal a, Distribution StdUniform a) => Distribution Gamma a where
    {-# SPECIALIZE instance Distribution Gamma Double #-}
    {-# SPECIALIZE instance Distribution Gamma Float #-}
    rvarT :: Gamma a -> RVarT n a
rvarT (Gamma a
a a
b) = a -> a -> RVarT n a
forall a (m :: * -> *).
(Floating a, Ord a, Distribution StdUniform a,
 Distribution Normal a) =>
a -> a -> RVarT m a
mtGamma a
a a
b

instance (Real a, Distribution Gamma a) => CDF Gamma a where
    cdf :: Gamma a -> a -> Double
cdf (Gamma a
a a
b) a
x = Double -> Double -> Double
incompleteGamma (a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
a) (a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
x Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ a -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
b)

instance (Integral a, Floating b, Ord b, Distribution Normal b, Distribution StdUniform b) => Distribution (Erlang a) b where
    rvarT :: Erlang a b -> RVarT n b
rvarT (Erlang a
a) = b -> b -> RVarT n b
forall a (m :: * -> *).
(Floating a, Ord a, Distribution StdUniform a,
 Distribution Normal a) =>
a -> a -> RVarT m a
mtGamma (a -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a) b
1

instance (Integral a, Real b, Distribution (Erlang a) b) => CDF (Erlang a) b where
    cdf :: Erlang a b -> b -> Double
cdf (Erlang a
a) b
x = Double -> Double -> Double
incompleteGamma (a -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
a) (b -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac b
x)