-- |
-- Module      :  Mcmc.Internal.Gamma
-- Description :  Generalized gamma function for automatic differentiation
-- Copyright   :  (c) 2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Tue Jul 13 12:53:09 2021.
--
-- The code is taken from "Numeric.SpecFunctions".
module Mcmc.Internal.Gamma
  ( logGammaG,
  )
where

import Data.Typeable
import qualified Data.Vector as VB
import Numeric.Polynomial
import Numeric.SpecFunctions
import Unsafe.Coerce

mSqrtEpsG :: RealFloat a => a
mSqrtEpsG :: a
mSqrtEpsG = a
1.4901161193847656e-8

mEulerMascheroniG :: RealFloat a => a
mEulerMascheroniG :: a
mEulerMascheroniG = a
0.5772156649015328606065121

-- | Generalized version of the log gamma distribution. See
-- 'Numeric.SpecFunctions.logGamma'.
logGammaG :: (Typeable a, RealFloat a) => a -> a
logGammaG :: a -> a
logGammaG a
z
  | a -> TypeRep
forall a. Typeable a => a -> TypeRep
typeOf a
z TypeRep -> TypeRep -> Bool
forall a. Eq a => a -> a -> Bool
== Double -> TypeRep
forall a. Typeable a => a -> TypeRep
typeOf (Double
0 :: Double) = (Double -> Double) -> a -> a
forall a b. a -> b
unsafeCoerce Double -> Double
logGamma a
z
  | Bool
otherwise = a -> a
forall a. RealFloat a => a -> a
logGammaNonDouble a
z
{-# SPECIALIZE logGammaG :: Double -> Double #-}

-- | See 'Numeric.SpecFunctions.logGamma'.
logGammaNonDouble :: RealFloat a => a -> a
logGammaNonDouble :: a -> a
logGammaNonDouble a
z
  | a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 = a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
0
  | a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
forall a. RealFloat a => a
mSqrtEpsG = a -> a
forall a. Floating a => a -> a
log (a
1 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
forall a. RealFloat a => a
mEulerMascheroniG)
  | a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0.5 = a -> a -> a
forall a. RealFloat a => a -> a -> a
lgamma1_15G a
z (a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a -> a -> a
forall a. Num a => a -> a -> a
- a -> a
forall a. Floating a => a -> a
log a
z
  | a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1 = a -> a -> a
forall a. RealFloat a => a -> a -> a
lgamma15_2G a
z (a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a -> a -> a
forall a. Num a => a -> a -> a
- a -> a
forall a. Floating a => a -> a
log a
z
  | a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
1.5 = a -> a -> a
forall a. RealFloat a => a -> a -> a
lgamma1_15G (a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
1) (a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
2)
  | a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
2 = a -> a -> a
forall a. RealFloat a => a -> a -> a
lgamma15_2G (a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
1) (a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
2)
  | a
z a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
15 = a -> a
forall a. RealFloat a => a -> a
lgammaSmallG a
z
  | Bool
otherwise = a -> a
forall a. RealFloat a => a -> a
lanczosApproxG a
z

lgamma1_15G :: RealFloat a => a -> a -> a
lgamma1_15G :: a -> a -> a
lgamma1_15G a
zm1 a
zm2 =
  a
r a -> a -> a
forall a. Num a => a -> a -> a
* a
y a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
    a -> a -> a
forall a. Num a => a -> a -> a
* ( a -> Vector a -> a
forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm1 Vector a
forall a. RealFloat a => Vector a
tableLogGamma_1_15PG
          a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> Vector a -> a
forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm1 Vector a
forall a. RealFloat a => Vector a
tableLogGamma_1_15QG
      )
  where
    r :: a
r = a
zm1 a -> a -> a
forall a. Num a => a -> a -> a
* a
zm2
    y :: a
y = a
0.52815341949462890625

tableLogGamma_1_15PG :: RealFloat a => VB.Vector a
tableLogGamma_1_15PG :: Vector a
tableLogGamma_1_15PG =
  [a] -> Vector a
forall a. [a] -> Vector a
VB.fromList
    [ a
0.490622454069039543534e-1,
      -a
0.969117530159521214579e-1,
      -a
0.414983358359495381969e0,
      -a
0.406567124211938417342e0,
      -a
0.158413586390692192217e0,
      -a
0.240149820648571559892e-1,
      -a
0.100346687696279557415e-2
    ]
{-# NOINLINE tableLogGamma_1_15PG #-}

tableLogGamma_1_15QG :: RealFloat a => VB.Vector a
tableLogGamma_1_15QG :: Vector a
tableLogGamma_1_15QG =
  [a] -> Vector a
forall a. [a] -> Vector a
VB.fromList
    [ a
1,
      a
0.302349829846463038743e1,
      a
0.348739585360723852576e1,
      a
0.191415588274426679201e1,
      a
0.507137738614363510846e0,
      a
0.577039722690451849648e-1,
      a
0.195768102601107189171e-2
    ]
{-# NOINLINE tableLogGamma_1_15QG #-}

lgamma15_2G :: RealFloat a => a -> a -> a
lgamma15_2G :: a -> a -> a
lgamma15_2G a
zm1 a
zm2 =
  a
r a -> a -> a
forall a. Num a => a -> a -> a
* a
y a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
    a -> a -> a
forall a. Num a => a -> a -> a
* ( a -> Vector a -> a
forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial (-a
zm2) Vector a
forall a. RealFloat a => Vector a
tableLogGamma_15_2PG
          a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> Vector a -> a
forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial (-a
zm2) Vector a
forall a. RealFloat a => Vector a
tableLogGamma_15_2QG
      )
  where
    r :: a
r = a
zm1 a -> a -> a
forall a. Num a => a -> a -> a
* a
zm2
    y :: a
y = a
0.452017307281494140625

tableLogGamma_15_2PG :: RealFloat a => VB.Vector a
tableLogGamma_15_2PG :: Vector a
tableLogGamma_15_2PG =
  [a] -> Vector a
forall a. [a] -> Vector a
VB.fromList
    [ -a
0.292329721830270012337e-1,
      a
0.144216267757192309184e0,
      -a
0.142440390738631274135e0,
      a
0.542809694055053558157e-1,
      -a
0.850535976868336437746e-2,
      a
0.431171342679297331241e-3
    ]
{-# NOINLINE tableLogGamma_15_2PG #-}

tableLogGamma_15_2QG :: RealFloat a => VB.Vector a
tableLogGamma_15_2QG :: Vector a
tableLogGamma_15_2QG =
  [a] -> Vector a
forall a. [a] -> Vector a
VB.fromList
    [ a
1,
      -a
0.150169356054485044494e1,
      a
0.846973248876495016101e0,
      -a
0.220095151814995745555e0,
      a
0.25582797155975869989e-1,
      -a
0.100666795539143372762e-2,
      -a
0.827193521891290553639e-6
    ]
{-# NOINLINE tableLogGamma_15_2QG #-}

lgammaSmallG :: RealFloat a => a -> a
lgammaSmallG :: a -> a
lgammaSmallG = a -> a -> a
forall a. RealFloat a => a -> a -> a
go a
0
  where
    go :: t -> t -> t
go t
acc t
z
      | t
z t -> t -> Bool
forall a. Ord a => a -> a -> Bool
< t
3 = t
acc t -> t -> t
forall a. Num a => a -> a -> a
+ t -> t
forall a. RealFloat a => a -> a
lgamma2_3G t
z
      | Bool
otherwise = t -> t -> t
go (t
acc t -> t -> t
forall a. Num a => a -> a -> a
+ t -> t
forall a. Floating a => a -> a
log t
zm1) t
zm1
      where
        zm1 :: t
zm1 = t
z t -> t -> t
forall a. Num a => a -> a -> a
- t
1

lgamma2_3G :: RealFloat a => a -> a
lgamma2_3G :: a -> a
lgamma2_3G a
z =
  a
r a -> a -> a
forall a. Num a => a -> a -> a
* a
y a -> a -> a
forall a. Num a => a -> a -> a
+ a
r
    a -> a -> a
forall a. Num a => a -> a -> a
* ( a -> Vector a -> a
forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm2 Vector a
forall a. RealFloat a => Vector a
tableLogGamma_2_3PG
          a -> a -> a
forall a. Fractional a => a -> a -> a
/ a -> Vector a -> a
forall (v :: * -> *) a. (Vector v a, Num a) => a -> v a -> a
evaluatePolynomial a
zm2 Vector a
forall a. RealFloat a => Vector a
tableLogGamma_2_3QG
      )
  where
    r :: a
r = a
zm2 a -> a -> a
forall a. Num a => a -> a -> a
* (a
z a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
    zm2 :: a
zm2 = a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
2
    y :: a
y = a
0.158963680267333984375e0

tableLogGamma_2_3PG :: RealFloat a => VB.Vector a
tableLogGamma_2_3PG :: Vector a
tableLogGamma_2_3PG =
  [a] -> Vector a
forall a. [a] -> Vector a
VB.fromList
    [ -a
0.180355685678449379109e-1,
      a
0.25126649619989678683e-1,
      a
0.494103151567532234274e-1,
      a
0.172491608709613993966e-1,
      -a
0.259453563205438108893e-3,
      -a
0.541009869215204396339e-3,
      -a
0.324588649825948492091e-4
    ]
{-# NOINLINE tableLogGamma_2_3PG #-}

tableLogGamma_2_3QG :: RealFloat a => VB.Vector a
tableLogGamma_2_3QG :: Vector a
tableLogGamma_2_3QG =
  [a] -> Vector a
forall a. [a] -> Vector a
VB.fromList
    [ a
1,
      a
0.196202987197795200688e1,
      a
0.148019669424231326694e1,
      a
0.541391432071720958364e0,
      a
0.988504251128010129477e-1,
      a
0.82130967464889339326e-2,
      a
0.224936291922115757597e-3,
      -a
0.223352763208617092964e-6
    ]
{-# NOINLINE tableLogGamma_2_3QG #-}

lanczosApproxG :: RealFloat a => a -> a
lanczosApproxG :: a -> a
lanczosApproxG a
z =
  (a -> a
forall a. Floating a => a -> a
log (a
z a -> a -> a
forall a. Num a => a -> a -> a
+ a
g a -> a -> a
forall a. Num a => a -> a -> a
- a
0.5) a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a -> a -> a
forall a. Num a => a -> a -> a
* (a
z a -> a -> a
forall a. Num a => a -> a -> a
- a
0.5)
    a -> a -> a
forall a. Num a => a -> a -> a
+ a -> a
forall a. Floating a => a -> a
log (Vector (a, a) -> a -> a
forall a. RealFloat a => Vector (a, a) -> a -> a
evalRatioG Vector (a, a)
forall a. RealFloat a => Vector (a, a)
tableLanczosG a
z)
  where
    g :: a
g = a
6.024680040776729583740234375

tableLanczosG :: RealFloat a => VB.Vector (a, a)
tableLanczosG :: Vector (a, a)
tableLanczosG =
  [(a, a)] -> Vector (a, a)
forall a. [a] -> Vector a
VB.fromList
    [ (a
56906521.91347156388090791033559122686859, a
0),
      (a
103794043.1163445451906271053616070238554, a
39916800),
      (a
86363131.28813859145546927288977868422342, a
120543840),
      (a
43338889.32467613834773723740590533316085, a
150917976),
      (a
14605578.08768506808414169982791359218571, a
105258076),
      (a
3481712.15498064590882071018964774556468, a
45995730),
      (a
601859.6171681098786670226533699352302507, a
13339535),
      (a
75999.29304014542649875303443598909137092, a
2637558),
      (a
6955.999602515376140356310115515198987526, a
357423),
      (a
449.9445569063168119446858607650988409623, a
32670),
      (a
19.51992788247617482847860966235652136208, a
1925),
      (a
0.5098416655656676188125178644804694509993, a
66),
      (a
0.006061842346248906525783753964555936883222, a
1)
    ]
{-# NOINLINE tableLanczosG #-}

data LG a = LG !a !a

evalRatioG :: RealFloat a => VB.Vector (a, a) -> a -> a
evalRatioG :: Vector (a, a) -> a -> a
evalRatioG Vector (a, a)
coef a
x
  | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
1 = LG a -> a
forall a. Fractional a => LG a -> a
fini (LG a -> a) -> LG a -> a
forall a b. (a -> b) -> a -> b
$ (LG a -> (a, a) -> LG a) -> LG a -> Vector (a, a) -> LG a
forall a b. (a -> b -> a) -> a -> Vector b -> a
VB.foldl' LG a -> (a, a) -> LG a
stepL (a -> a -> LG a
forall a. a -> a -> LG a
LG a
0 a
0) Vector (a, a)
coef
  | Bool
otherwise = LG a -> a
forall a. Fractional a => LG a -> a
fini (LG a -> a) -> LG a -> a
forall a b. (a -> b) -> a -> b
$ ((a, a) -> LG a -> LG a) -> LG a -> Vector (a, a) -> LG a
forall a b. (a -> b -> b) -> b -> Vector a -> b
VB.foldr' (a, a) -> LG a -> LG a
stepR (a -> a -> LG a
forall a. a -> a -> LG a
LG a
0 a
0) Vector (a, a)
coef
  where
    fini :: LG a -> a
fini (LG a
num a
den) = a
num a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
den
    stepR :: (a, a) -> LG a -> LG a
stepR (a
a, a
b) (LG a
num a
den) = a -> a -> LG a
forall a. a -> a -> LG a
LG (a
num a -> a -> a
forall a. Num a => a -> a -> a
* a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
a) (a
den a -> a -> a
forall a. Num a => a -> a -> a
* a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
b)
    stepL :: LG a -> (a, a) -> LG a
stepL (LG a
num a
den) (a
a, a
b) = a -> a -> LG a
forall a. a -> a -> LG a
LG (a
num a -> a -> a
forall a. Num a => a -> a -> a
* a
rx a -> a -> a
forall a. Num a => a -> a -> a
+ a
a) (a
den a -> a -> a
forall a. Num a => a -> a -> a
* a
rx a -> a -> a
forall a. Num a => a -> a -> a
+ a
b)
    rx :: a
rx = a -> a
forall a. Fractional a => a -> a
recip a
x