{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}

-- | The @Random@ effect provides access to uniformly distributed random values of
-- user-specified types or from well-known numerical distributions.
--
-- This is the “fancy” syntax that hides most details of randomness
-- behind a nice API.
module Control.Effect.Random
  ( Random (..),

    -- * Uniform distributions
    uniform,
    uniformR,

    -- * Continuous distributions
    normal,
    standard,
    exponential,
    truncatedExp,
    gamma,
    chiSquare,
    beta,

    -- * Discrete distributions
    categorical,
    logCategorical,
    geometric0,
    geometric1,
    bernoulli,
    dirichlet,

    -- * Permutations
    uniformPermutation,
    uniformShuffle,

    -- * Introspection
    save,
    Distrib (..),

    -- * Re-exports
    MWC.Variate,
    Has,
  )
where

import Control.Algebra
import Data.Kind
import Data.Vector.Generic (Vector)
import qualified System.Random.MWC as MWC

-- | GADT representing the functions provided by mwc-random.
data Distrib a where
  Uniform :: MWC.Variate a => Distrib a
  UniformR :: MWC.Variate a => (a, a) -> Distrib a
  Normal :: Double -> Double -> Distrib Double
  Standard :: Distrib Double
  Exponential :: Double -> Distrib Double
  TruncatedExp :: Double -> (Double, Double) -> Distrib Double
  Gamma :: Double -> Double -> Distrib Double
  ChiSquare :: Int -> Distrib Double
  Beta :: Double -> Double -> Distrib Double
  Categorical :: Vector v Double => v Double -> Distrib Int
  LogCategorical :: Vector v Double => v Double -> Distrib Int
  Geometric0 :: Double -> Distrib Int
  Geometric1 :: Double -> Distrib Int
  Bernoulli :: Double -> Distrib Bool
  Dirichlet :: Traversable t => t Double -> Distrib (t Double)
  Permutation :: Vector v Int => Int -> Distrib (v Int)
  Shuffle :: Vector v a => v a -> Distrib (v a)

data Random (m :: Type -> Type) k where
  Random :: Distrib a -> Random m a
  Save :: Random m MWC.Seed

-- | Generate a single uniformly distributed random variate.  The
-- range of values produced varies by type:
--
-- * For fixed-width integral types, the type's entire range is
--   used.
--
-- * For floating point numbers, the range (0,1] is used. Zero is
--   explicitly excluded, to allow variates to be used in
--   statistical calculations that require non-zero values
--   (e.g. uses of the 'log' function).
--
-- To generate a 'Float' variate with a range of [0,1), subtract
-- 2**(-33).  To do the same with 'Double' variates, subtract
-- 2**(-53).
uniform :: (MWC.Variate a, Has Random sig m) => m a
uniform :: m a
uniform = Random m a -> m a
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib a -> Random m a
forall a (m :: * -> *). Distrib a -> Random m a
Random Distrib a
forall a. Variate a => Distrib a
Uniform)
{-# INLINE uniform #-}

-- | Generate single uniformly distributed random variable in a
-- given range.
--
-- * For integral types inclusive range is used.
--
-- * For floating point numbers range (a,b] is used if one ignores
--   rounding errors.
uniformR :: (MWC.Variate a, Has Random sig m) => (a, a) -> m a
uniformR :: (a, a) -> m a
uniformR (a, a)
r = Random m a -> m a
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib a -> Random m a
forall a (m :: * -> *). Distrib a -> Random m a
Random ((a, a) -> Distrib a
forall a. Variate a => (a, a) -> Distrib a
UniformR (a, a)
r))
{-# INLINE uniformR #-}

-- | Generate a normally distributed random variate with given mean and standard deviation.
normal ::
  Has Random sig m =>
  -- | Mean
  Double ->
  -- | Standard deviation
  Double ->
  m Double
normal :: Double -> Double -> m Double
normal Double
m Double
d = Random m Double -> m Double
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Double -> Random m Double
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> Double -> Distrib Double
Normal Double
m Double
d))

-- | Generate a normally distributed random variate with zero mean and unit variance.
standard :: Has Random sig m => m Double
standard :: m Double
standard = Random m Double -> m Double
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Double -> Random m Double
forall a (m :: * -> *). Distrib a -> Random m a
Random Distrib Double
Standard)

-- | Generate an exponentially distributed random variate.
exponential ::
  Has Random sig m =>
  -- | Scale parameter
  Double ->
  m Double
exponential :: Double -> m Double
exponential Double
s = Random m Double -> m Double
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Double -> Random m Double
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> Distrib Double
Exponential Double
s))

-- | Generate truncated exponentially distributed random variate.
truncatedExp ::
  Has Random sig m =>
  -- | Scale parameter
  Double ->
  -- | Range to which distribution is
  --   truncated. Values may be negative.
  (Double, Double) ->
  m (Double)
truncatedExp :: Double -> (Double, Double) -> m Double
truncatedExp Double
s (Double, Double)
r = Random m Double -> m Double
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Double -> Random m Double
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> (Double, Double) -> Distrib Double
TruncatedExp Double
s (Double, Double)
r))

-- | Random variate generator for gamma distribution.
gamma ::
  Has Random sig m =>
  -- | Shape parameter
  Double ->
  -- | Scale parameter
  Double ->
  m Double
gamma :: Double -> Double -> m Double
gamma Double
s Double
d = Random m Double -> m Double
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Double -> Random m Double
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> Double -> Distrib Double
Gamma Double
s Double
d))

-- | Random variate generator for the chi square distribution.
chiSquare ::
  Has Random sig m =>
  -- | Number of degrees of freedom
  Int ->
  m Double
chiSquare :: Int -> m Double
chiSquare Int
d = Random m Double -> m Double
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Double -> Random m Double
forall a (m :: * -> *). Distrib a -> Random m a
Random (Int -> Distrib Double
ChiSquare Int
d))

-- | Random variate generator for the geometric distribution,
-- computing the number of failures before success. Distribution's
-- support is [0..].
geometric0 ::
  Has Random sig m =>
  -- | /p/ success probability lies in (0,1]
  Double ->
  m Int
geometric0 :: Double -> m Int
geometric0 Double
p = Random m Int -> m Int
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Int -> Random m Int
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> Distrib Int
Geometric0 Double
p))

-- | Random variate generator for geometric distribution for number of
-- trials. Distribution's support is [1..] (i.e. just 'geometric0'
-- shifted by 1).
geometric1 ::
  Has Random sig m =>
  -- | /p/ success probability lies in (0,1]
  Double ->
  m Int
geometric1 :: Double -> m Int
geometric1 Double
p = Random m Int -> m Int
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Int -> Random m Int
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> Distrib Int
Geometric1 Double
p))

-- | Random variate generator for Beta distribution
beta ::
  Has Random sig m =>
  -- | alpha (>0)
  Double ->
  -- | beta  (>0)
  Double ->
  m Double
beta :: Double -> Double -> m Double
beta Double
a Double
b = Random m Double -> m Double
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Double -> Random m Double
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> Double -> Distrib Double
Beta Double
a Double
b))
{-# INLINE beta #-}

-- | Random variate generator for Dirichlet distribution
dirichlet ::
  (Has Random sig m, Traversable t) =>
  -- | container of parameters
  t Double ->
  m (t Double)
{-# INLINE dirichlet #-}
dirichlet :: t Double -> m (t Double)
dirichlet t Double
t = Random m (t Double) -> m (t Double)
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib (t Double) -> Random m (t Double)
forall a (m :: * -> *). Distrib a -> Random m a
Random (t Double -> Distrib (t Double)
forall (t :: * -> *).
Traversable t =>
t Double -> Distrib (t Double)
Dirichlet t Double
t))

-- | Random variate generator for Bernoulli distribution
bernoulli ::
  Has Random sig m =>
  -- | Probability of success (returning True)
  Double ->
  m Bool
{-# INLINE bernoulli #-}
bernoulli :: Double -> m Bool
bernoulli Double
p = Random m Bool -> m Bool
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Bool -> Random m Bool
forall a (m :: * -> *). Distrib a -> Random m a
Random (Double -> Distrib Bool
Bernoulli Double
p))

-- | Random variate generator for categorical distribution.
categorical ::
  (Has Random sig m, Vector v Double) =>
  -- | List of weights [>0]
  v Double ->
  m Int
{-# INLINE categorical #-}
categorical :: v Double -> m Int
categorical v Double
v = Random m Int -> m Int
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Int -> Random m Int
forall a (m :: * -> *). Distrib a -> Random m a
Random (v Double -> Distrib Int
forall (v :: * -> *). Vector v Double => v Double -> Distrib Int
Categorical v Double
v))

-- | Random variate generator for categorical distribution where the
--   weights are in the log domain. It's implemented in terms of
--   'categorical'.
logCategorical ::
  (Has Random sig m, Vector v Double) =>
  -- | List of logarithms of weights
  v Double ->
  m Int
logCategorical :: v Double -> m Int
logCategorical v Double
v = Random m Int -> m Int
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib Int -> Random m Int
forall a (m :: * -> *). Distrib a -> Random m a
Random (v Double -> Distrib Int
forall (v :: * -> *). Vector v Double => v Double -> Distrib Int
LogCategorical v Double
v))

-- | Save the state of the random number generator to be used by subsequent
-- carrier invocations.
save :: Has Random sig m => m MWC.Seed
save :: m Seed
save = Random m Seed -> m Seed
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send Random m Seed
forall (m :: * -> *). Random m Seed
Save

-- | Random variate generator for uniformly distributed permutations. It returns random permutation of vector [0 .. n-1]. This is the Fisher-Yates shuffle.
uniformPermutation ::
  (Has Random sig m, Vector v Int) =>
  Int ->
  m (v Int)
uniformPermutation :: Int -> m (v Int)
uniformPermutation Int
n = Random m (v Int) -> m (v Int)
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib (v Int) -> Random m (v Int)
forall a (m :: * -> *). Distrib a -> Random m a
Random (Int -> Distrib (v Int)
forall (v :: * -> *). Vector v Int => Int -> Distrib (v Int)
Permutation Int
n))

-- | Random variate generator for a uniformly distributed shuffle (all
--   shuffles are equiprobable) of a vector. It uses Fisher-Yates
--   shuffle algorithm.
--
-- Implementation details prevent a native implementation of the 'MWC.uniformShuffleM'
-- function. Use the native API if this is required.
uniformShuffle ::
  (Has Random sig m, Vector v a) =>
  v a ->
  m (v a)
uniformShuffle :: v a -> m (v a)
uniformShuffle v a
n = Random m (v a) -> m (v a)
forall (eff :: (* -> *) -> * -> *) (sig :: (* -> *) -> * -> *)
       (m :: * -> *) a.
(Member eff sig, Algebra sig m) =>
eff m a -> m a
send (Distrib (v a) -> Random m (v a)
forall a (m :: * -> *). Distrib a -> Random m a
Random (v a -> Distrib (v a)
forall (v :: * -> *) a. Vector v a => v a -> Distrib (v a)
Shuffle v a
n))