{-# LANGUAGE
MultiParamTypeClasses,
RankNTypes,
FlexibleInstances, FlexibleContexts,
RecordWildCards, BangPatterns
#-}
module Data.Random.Distribution.Ziggurat
( Ziggurat(..)
, mkZigguratRec
, mkZiggurat
, mkZiggurat_
, findBin0
, runZiggurat
) where
import Data.Random.Internal.Find
import Data.Random.Distribution.Uniform
import Data.Random.Distribution
import Data.Random.RVar
import Data.Vector.Generic as Vec
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as UV
data Ziggurat v t = Ziggurat {
forall (v :: * -> *) t. Ziggurat v t -> v t
zTable_xs :: !(v t),
forall (v :: * -> *) t. Ziggurat v t -> v t
zTable_y_ratios :: !(v t),
forall (v :: * -> *) t. Ziggurat v t -> v t
zTable_ys :: !(v t),
forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). RVarT m (Int, t)
zGetIU :: !(forall m. RVarT m (Int, t)),
forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). RVarT m t
zTailDist :: (forall m. RVarT m t),
forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). t -> t -> RVarT m t
zUniform :: !(forall m. t -> t -> RVarT m t),
forall (v :: * -> *) t. Ziggurat v t -> t -> t
zFunc :: !(t -> t),
forall (v :: * -> *) t. Ziggurat v t -> Bool
zMirror :: !Bool
}
{-# INLINE runZiggurat #-}
{-# SPECIALIZE runZiggurat :: Ziggurat UV.Vector Float -> RVarT m Float #-}
{-# SPECIALIZE runZiggurat :: Ziggurat UV.Vector Double -> RVarT m Double #-}
{-# SPECIALIZE runZiggurat :: Ziggurat V.Vector Float -> RVarT m Float #-}
{-# SPECIALIZE runZiggurat :: Ziggurat V.Vector Double -> RVarT m Double #-}
runZiggurat :: (Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat :: forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat !Ziggurat{v a
Bool
a -> a
forall (m :: * -> *). RVarT m a
forall (m :: * -> *). RVarT m (Int, a)
forall (m :: * -> *). a -> a -> RVarT m a
zMirror :: Bool
zFunc :: a -> a
zUniform :: forall (m :: * -> *). a -> a -> RVarT m a
zTailDist :: forall (m :: * -> *). RVarT m a
zGetIU :: forall (m :: * -> *). RVarT m (Int, a)
zTable_ys :: v a
zTable_y_ratios :: v a
zTable_xs :: v a
zMirror :: forall (v :: * -> *) t. Ziggurat v t -> Bool
zFunc :: forall (v :: * -> *) t. Ziggurat v t -> t -> t
zUniform :: forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). t -> t -> RVarT m t
zTailDist :: forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). RVarT m t
zGetIU :: forall (v :: * -> *) t.
Ziggurat v t -> forall (m :: * -> *). RVarT m (Int, t)
zTable_ys :: forall (v :: * -> *) t. Ziggurat v t -> v t
zTable_y_ratios :: forall (v :: * -> *) t. Ziggurat v t -> v t
zTable_xs :: forall (v :: * -> *) t. Ziggurat v t -> v t
..} = forall (m :: * -> *). RVarT m a
go
where
{-# NOINLINE go #-}
go :: RVarT m a
go = do
(!Int
i,!a
u) <- forall (m :: * -> *). RVarT m (Int, a)
zGetIU
if forall a. Num a => a -> a
abs a
u forall a. Ord a => a -> a -> Bool
< v a
zTable_y_ratios forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i
then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! (a
u forall a. Num a => a -> a -> a
* v a
zTable_xs forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i)
else if Int
i forall a. Eq a => a -> a -> Bool
== Int
0
then forall {a} {m :: * -> *}. (Ord a, Num a) => a -> RVarT m a
sampleTail a
u
else Int -> a -> RVarT m a
sampleGreyArea Int
i forall a b. (a -> b) -> a -> b
$! (a
u forall a. Num a => a -> a -> a
* v a
zTable_xs forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i)
{-# INLINE sampleGreyArea #-}
sampleGreyArea :: Int -> a -> RVarT m a
sampleGreyArea Int
i a
x = do
!a
v <- forall (m :: * -> *). a -> a -> RVarT m a
zUniform (v a
zTable_ys forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! (Int
iforall a. Num a => a -> a -> a
+Int
1)) (v a
zTable_ys forall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
! Int
i)
if a
v forall a. Ord a => a -> a -> Bool
< a -> a
zFunc (forall a. Num a => a -> a
abs a
x)
then forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! a
x
else RVarT m a
go
{-# INLINE sampleTail #-}
sampleTail :: a -> RVarT m a
sampleTail a
x
| Bool
zMirror Bool -> Bool -> Bool
&& a
x forall a. Ord a => a -> a -> Bool
< a
0 = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Num a => a -> a
negate forall (m :: * -> *). RVarT m a
zTailDist
| Bool
otherwise = forall (m :: * -> *). RVarT m a
zTailDist
{-# INLINE mkZiggurat_ #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Float -> Float) -> (Float -> Float) -> Int -> Float -> Float -> (forall m. RVarT m (Int, Float)) -> (forall m. RVarT m Float ) -> Ziggurat UV.Vector Float #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Double -> Double) -> (Double -> Double) -> Int -> Double -> Double -> (forall m. RVarT m (Int, Double)) -> (forall m. RVarT m Double) -> Ziggurat UV.Vector Double #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Float -> Float) -> (Float -> Float) -> Int -> Float -> Float -> (forall m. RVarT m (Int, Float)) -> (forall m. RVarT m Float ) -> Ziggurat V.Vector Float #-}
{-# SPECIALIZE mkZiggurat_ :: Bool -> (Double -> Double) -> (Double -> Double) -> Int -> Double -> Double -> (forall m. RVarT m (Int, Double)) -> (forall m. RVarT m Double) -> Ziggurat V.Vector Double #-}
mkZiggurat_ :: (RealFloat t, Vector v t,
Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall m. RVarT m (Int, t))
-> (forall m. RVarT m t)
-> Ziggurat v t
mkZiggurat_ :: forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
m t -> t
f t -> t
fInv Int
c t
r t
v forall (m :: * -> *). RVarT m (Int, t)
getIU forall (m :: * -> *). RVarT m t
tailDist = Ziggurat
{ zTable_xs :: v t
zTable_xs = v t
xs
, zTable_y_ratios :: v t
zTable_y_ratios = forall (v :: * -> *) a. (Vector v a, Fractional a) => v a -> v a
precomputeRatios v t
xs
, zTable_ys :: v t
zTable_ys = forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
Vec.map t -> t
f v t
xs
, zGetIU :: forall (m :: * -> *). RVarT m (Int, t)
zGetIU = forall (m :: * -> *). RVarT m (Int, t)
getIU
, zUniform :: forall (m :: * -> *). t -> t -> RVarT m t
zUniform = forall a (m :: * -> *).
Distribution Uniform a =>
a -> a -> RVarT m a
uniformT
, zFunc :: t -> t
zFunc = t -> t
f
, zTailDist :: forall (m :: * -> *). RVarT m t
zTailDist = forall (m :: * -> *). RVarT m t
tailDist
, zMirror :: Bool
zMirror = Bool
m
}
where
xs :: v t
xs = forall a (v :: * -> *).
(Fractional a, Vector v a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> v a
zigguratTable t -> t
f t -> t
fInv Int
c t
r t
v
mkZiggurat :: (RealFloat t, Vector v t,
Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall m. RVarT m (Int, t))
-> (forall m. t -> RVarT m t)
-> Ziggurat v t
mkZiggurat :: forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). t -> RVarT m t)
-> Ziggurat v t
mkZiggurat Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU forall (m :: * -> *). t -> RVarT m t
tailDist =
forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> Int
-> t
-> t
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). RVarT m t)
-> Ziggurat v t
mkZiggurat_ Bool
m t -> t
f t -> t
fInv Int
c t
r t
v forall (m :: * -> *). RVarT m (Int, t)
getIU (forall (m :: * -> *). t -> RVarT m t
tailDist t
r)
where
(t
r,t
v) = forall b.
RealFloat b =>
Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b)
findBin0 Int
c t -> t
f t -> t
fInv t -> t
fInt t
fVol
mkZigguratRec ::
(RealFloat t, Vector v t,
Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall m. RVarT m (Int, t))
-> Ziggurat v t
mkZigguratRec :: forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> Ziggurat v t
mkZigguratRec Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU = Ziggurat v t
z
where
fix :: ((forall m. a -> RVarT m a) -> (forall m. a -> RVarT m a)) -> (forall m. a -> RVarT m a)
fix :: forall a.
((forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
fix (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
g = (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
g (forall a.
((forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
fix (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
g)
z :: Ziggurat v t
z = forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). t -> RVarT m t)
-> Ziggurat v t
mkZiggurat Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU (forall a.
((forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
fix (forall a (v :: * -> *).
(RealFloat a, Vector v a, Distribution Uniform a) =>
Bool
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> a
-> Int
-> (forall (m :: * -> *). RVarT m (Int, a))
-> Ziggurat v a
-> (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
mkTail Bool
m t -> t
f t -> t
fInv t -> t
fInt t
fVol Int
c forall (m :: * -> *). RVarT m (Int, t)
getIU Ziggurat v t
z))
mkTail ::
(RealFloat a, Vector v a, Distribution Uniform a) =>
Bool
-> (a -> a) -> (a -> a) -> (a -> a)
-> a
-> Int
-> (forall m. RVarT m (Int, a))
-> Ziggurat v a
-> (forall m. a -> RVarT m a)
-> (forall m. a -> RVarT m a)
mkTail :: forall a (v :: * -> *).
(RealFloat a, Vector v a, Distribution Uniform a) =>
Bool
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> a
-> Int
-> (forall (m :: * -> *). RVarT m (Int, a))
-> Ziggurat v a
-> (forall (m :: * -> *). a -> RVarT m a)
-> forall (m :: * -> *). a -> RVarT m a
mkTail Bool
m a -> a
f a -> a
fInv a -> a
fInt a
fVol Int
c forall (m :: * -> *). RVarT m (Int, a)
getIU Ziggurat v a
typeRep forall (m :: * -> *). a -> RVarT m a
nextTail a
r = do
a
x <- forall (d :: * -> *) t (n :: * -> *).
Distribution d t =>
d t -> RVarT n t
rvarT (forall t (v :: * -> *).
(RealFloat t, Vector v t, Distribution Uniform t) =>
Bool
-> (t -> t)
-> (t -> t)
-> (t -> t)
-> t
-> Int
-> (forall (m :: * -> *). RVarT m (Int, t))
-> (forall (m :: * -> *). t -> RVarT m t)
-> Ziggurat v t
mkZiggurat Bool
m a -> a
f' a -> a
fInv' a -> a
fInt' a
fVol' Int
c forall (m :: * -> *). RVarT m (Int, a)
getIU forall (m :: * -> *). a -> RVarT m a
nextTail forall a. a -> a -> a
`asTypeOf` Ziggurat v a
typeRep)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x forall a. Num a => a -> a -> a
+ a
r forall a. Num a => a -> a -> a
* forall a. Num a => a -> a
signum a
x)
where
fIntR :: a
fIntR = a -> a
fInt a
r
f' :: a -> a
f' a
x | a
x forall a. Ord a => a -> a -> Bool
< a
0 = a -> a
f a
r
| Bool
otherwise = a -> a
f (a
xforall a. Num a => a -> a -> a
+a
r)
fInv' :: a -> a
fInv' = forall a. Num a => a -> a -> a
subtract a
r forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
fInv
fInt' :: a -> a
fInt' a
x | a
x forall a. Ord a => a -> a -> Bool
< a
0 = a
0
| Bool
otherwise = a -> a
fInt (a
xforall a. Num a => a -> a -> a
+a
r) forall a. Num a => a -> a -> a
- a
fIntR
fVol' :: a
fVol' = a
fVol forall a. Num a => a -> a -> a
- a
fIntR
zigguratTable :: (Fractional a, Vector v a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> v a
zigguratTable :: forall a (v :: * -> *).
(Fractional a, Vector v a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> v a
zigguratTable a -> a
f a -> a
fInv Int
c a
r a
v = case forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs a -> a
f a -> a
fInv Int
c a
r a
v of
([a]
xs, a
_excess) -> forall (v :: * -> *) a. Vector v a => [a] -> v a
fromList [a]
xs
zigguratExcess :: (Fractional a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> a
zigguratExcess :: forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> a
zigguratExcess a -> a
f a -> a
fInv Int
c a
r a
v = forall a b. (a, b) -> b
snd (forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs a -> a
f a -> a
fInv Int
c a
r a
v)
zigguratXs :: (Fractional a, Ord a) => (a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs :: forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> ([a], a)
zigguratXs a -> a
f a -> a
fInv Int
c a
r a
v = ([a]
xs, a
excess)
where
xs :: [a]
xs = forall a b. (a -> b) -> [a] -> [b]
Prelude.map Int -> a
x [Int
0..Int
c]
ys :: [a]
ys = forall a b. (a -> b) -> [a] -> [b]
Prelude.map a -> a
f [a]
xs
x :: Int -> a
x Int
0 = a
v forall a. Fractional a => a -> a -> a
/ a -> a
f a
r
x Int
1 = a
r
x Int
i | Int
i forall a. Eq a => a -> a -> Bool
== Int
c = a
0
x Int
i | Int
i forall a. Ord a => a -> a -> Bool
> Int
1 = Int -> a
next (Int
iforall a. Num a => a -> a -> a
-Int
1)
x Int
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"zigguratXs: programming error! this case should be impossible!"
next :: Int -> a
next Int
i = let x_i :: a
x_i = [a]
xsforall a. [a] -> Int -> a
!!Int
i
in if a
x_i forall a. Ord a => a -> a -> Bool
<= a
0 then -a
1 else a -> a
fInv ([a]
ysforall a. [a] -> Int -> a
!!Int
i forall a. Num a => a -> a -> a
+ (a
v forall a. Fractional a => a -> a -> a
/ a
x_i))
excess :: a
excess = [a]
xsforall a. [a] -> Int -> a
!!(Int
cforall a. Num a => a -> a -> a
-Int
1) forall a. Num a => a -> a -> a
* (a -> a
f a
0 forall a. Num a => a -> a -> a
- [a]
ys forall a. [a] -> Int -> a
!! (Int
cforall a. Num a => a -> a -> a
-Int
1)) forall a. Num a => a -> a -> a
- a
v
precomputeRatios :: (Vector v a, Fractional a) => v a -> v a
precomputeRatios :: forall (v :: * -> *) a. (Vector v a, Fractional a) => v a -> v a
precomputeRatios v a
zTable_xs = forall (v :: * -> *) a. Vector v a => Int -> (Int -> a) -> v a
generate (Int
cforall a. Num a => a -> a -> a
-Int
1) forall a b. (a -> b) -> a -> b
$ \Int
i -> v a
zTable_xsforall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
!(Int
iforall a. Num a => a -> a -> a
+Int
1) forall a. Fractional a => a -> a -> a
/ v a
zTable_xsforall (v :: * -> *) a.
(HasCallStack, Vector v a) =>
v a -> Int -> a
!Int
i
where
c :: Int
c = forall (v :: * -> *) a. Vector v a => v a -> Int
Vec.length v a
zTable_xs
findBin0 :: (RealFloat b) =>
Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b)
findBin0 :: forall b.
RealFloat b =>
Int -> (b -> b) -> (b -> b) -> (b -> b) -> b -> (b, b)
findBin0 Int
cInt b -> b
f b -> b
fInv b -> b
fInt b
fVol = (b
rMin,b -> b
v b
rMin)
where
c :: b
c = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
cInt
v :: b -> b
v b
r = b
r forall a. Num a => a -> a -> a
* b -> b
f b
r forall a. Num a => a -> a -> a
+ b
fVol forall a. Num a => a -> a -> a
- b -> b
fInt b
r
r0 :: b
r0 = forall a. (Fractional a, Ord a) => (a -> Bool) -> a
findMin (\b
r -> b -> b
v b
r forall a. Ord a => a -> a -> Bool
<= b
fVol forall a. Fractional a => a -> a -> a
/ b
c)
rMin :: b
rMin = forall a. (Fractional a, Ord a) => a -> a -> (a -> Bool) -> a
findMinFrom b
r0 b
1 forall a b. (a -> b) -> a -> b
$ \b
r ->
let e :: b
e = b -> b
exc b
r
in b
e forall a. Ord a => a -> a -> Bool
>= b
0 Bool -> Bool -> Bool
&& Bool -> Bool
not (forall a. RealFloat a => a -> Bool
isNaN b
e)
exc :: b -> b
exc b
x = forall a.
(Fractional a, Ord a) =>
(a -> a) -> (a -> a) -> Int -> a -> a -> a
zigguratExcess b -> b
f b -> b
fInv Int
cInt b
x (b -> b
v b
x)
instance (Num t, Ord t, Vector v t) => Distribution (Ziggurat v) t where
rvar :: Ziggurat v t -> RVar t
rvar = forall a (v :: * -> *) (m :: * -> *).
(Num a, Ord a, Vector v a) =>
Ziggurat v a -> RVarT m a
runZiggurat