{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unused-imports #-}

-- | This carrier lifts the internals of its random number generation into
-- a 'LiftC' constraint, assuming the parameter to that 'LiftC' implements
-- 'PrimMonad'. In practice, this means that your effect stack must terminate
-- with @LiftC IO@ or @LiftC (ST s)@.
module Control.Carrier.Random.Lifted
  ( RandomC (..),
    runRandomSystem,
    runRandomSeeded,

    -- * Random effect
    module Control.Effect.Random,
  )
where

import Control.Algebra
import Control.Carrier.Lift
import Control.Carrier.Reader
import Control.Effect.Random
import Control.Effect.Sum
import Control.Monad.Fail
import Control.Monad.IO.Class
import Control.Monad.Primitive
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as Dist

newtype RandomC prim m a = RandomC {RandomC prim m a -> ReaderC (Gen (PrimState prim)) m a
runRandomC :: ReaderC (MWC.Gen (PrimState prim)) m a}
  deriving (Functor (RandomC prim m)
a -> RandomC prim m a
Functor (RandomC prim m)
-> (forall a. a -> RandomC prim m a)
-> (forall a b.
    RandomC prim m (a -> b) -> RandomC prim m a -> RandomC prim m b)
-> (forall a b c.
    (a -> b -> c)
    -> RandomC prim m a -> RandomC prim m b -> RandomC prim m c)
-> (forall a b.
    RandomC prim m a -> RandomC prim m b -> RandomC prim m b)
-> (forall a b.
    RandomC prim m a -> RandomC prim m b -> RandomC prim m a)
-> Applicative (RandomC prim m)
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
RandomC prim m a -> RandomC prim m b -> RandomC prim m a
RandomC prim m (a -> b) -> RandomC prim m a -> RandomC prim m b
(a -> b -> c)
-> RandomC prim m a -> RandomC prim m b -> RandomC prim m c
forall a. a -> RandomC prim m a
forall a b.
RandomC prim m a -> RandomC prim m b -> RandomC prim m a
forall a b.
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
forall a b.
RandomC prim m (a -> b) -> RandomC prim m a -> RandomC prim m b
forall a b c.
(a -> b -> c)
-> RandomC prim m a -> RandomC prim m b -> RandomC prim m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
forall (prim :: * -> *) (m :: * -> *).
Applicative m =>
Functor (RandomC prim m)
forall (prim :: * -> *) (m :: * -> *) a.
Applicative m =>
a -> RandomC prim m a
forall (prim :: * -> *) (m :: * -> *) a b.
Applicative m =>
RandomC prim m a -> RandomC prim m b -> RandomC prim m a
forall (prim :: * -> *) (m :: * -> *) a b.
Applicative m =>
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
forall (prim :: * -> *) (m :: * -> *) a b.
Applicative m =>
RandomC prim m (a -> b) -> RandomC prim m a -> RandomC prim m b
forall (prim :: * -> *) (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> RandomC prim m a -> RandomC prim m b -> RandomC prim m c
<* :: RandomC prim m a -> RandomC prim m b -> RandomC prim m a
$c<* :: forall (prim :: * -> *) (m :: * -> *) a b.
Applicative m =>
RandomC prim m a -> RandomC prim m b -> RandomC prim m a
*> :: RandomC prim m a -> RandomC prim m b -> RandomC prim m b
$c*> :: forall (prim :: * -> *) (m :: * -> *) a b.
Applicative m =>
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
liftA2 :: (a -> b -> c)
-> RandomC prim m a -> RandomC prim m b -> RandomC prim m c
$cliftA2 :: forall (prim :: * -> *) (m :: * -> *) a b c.
Applicative m =>
(a -> b -> c)
-> RandomC prim m a -> RandomC prim m b -> RandomC prim m c
<*> :: RandomC prim m (a -> b) -> RandomC prim m a -> RandomC prim m b
$c<*> :: forall (prim :: * -> *) (m :: * -> *) a b.
Applicative m =>
RandomC prim m (a -> b) -> RandomC prim m a -> RandomC prim m b
pure :: a -> RandomC prim m a
$cpure :: forall (prim :: * -> *) (m :: * -> *) a.
Applicative m =>
a -> RandomC prim m a
$cp1Applicative :: forall (prim :: * -> *) (m :: * -> *).
Applicative m =>
Functor (RandomC prim m)
Applicative, a -> RandomC prim m b -> RandomC prim m a
(a -> b) -> RandomC prim m a -> RandomC prim m b
(forall a b. (a -> b) -> RandomC prim m a -> RandomC prim m b)
-> (forall a b. a -> RandomC prim m b -> RandomC prim m a)
-> Functor (RandomC prim m)
forall a b. a -> RandomC prim m b -> RandomC prim m a
forall a b. (a -> b) -> RandomC prim m a -> RandomC prim m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
forall (prim :: * -> *) (m :: * -> *) a b.
Functor m =>
a -> RandomC prim m b -> RandomC prim m a
forall (prim :: * -> *) (m :: * -> *) a b.
Functor m =>
(a -> b) -> RandomC prim m a -> RandomC prim m b
<$ :: a -> RandomC prim m b -> RandomC prim m a
$c<$ :: forall (prim :: * -> *) (m :: * -> *) a b.
Functor m =>
a -> RandomC prim m b -> RandomC prim m a
fmap :: (a -> b) -> RandomC prim m a -> RandomC prim m b
$cfmap :: forall (prim :: * -> *) (m :: * -> *) a b.
Functor m =>
(a -> b) -> RandomC prim m a -> RandomC prim m b
Functor, Applicative (RandomC prim m)
a -> RandomC prim m a
Applicative (RandomC prim m)
-> (forall a b.
    RandomC prim m a -> (a -> RandomC prim m b) -> RandomC prim m b)
-> (forall a b.
    RandomC prim m a -> RandomC prim m b -> RandomC prim m b)
-> (forall a. a -> RandomC prim m a)
-> Monad (RandomC prim m)
RandomC prim m a -> (a -> RandomC prim m b) -> RandomC prim m b
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
forall a. a -> RandomC prim m a
forall a b.
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
forall a b.
RandomC prim m a -> (a -> RandomC prim m b) -> RandomC prim m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
forall (prim :: * -> *) (m :: * -> *).
Monad m =>
Applicative (RandomC prim m)
forall (prim :: * -> *) (m :: * -> *) a.
Monad m =>
a -> RandomC prim m a
forall (prim :: * -> *) (m :: * -> *) a b.
Monad m =>
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
forall (prim :: * -> *) (m :: * -> *) a b.
Monad m =>
RandomC prim m a -> (a -> RandomC prim m b) -> RandomC prim m b
return :: a -> RandomC prim m a
$creturn :: forall (prim :: * -> *) (m :: * -> *) a.
Monad m =>
a -> RandomC prim m a
>> :: RandomC prim m a -> RandomC prim m b -> RandomC prim m b
$c>> :: forall (prim :: * -> *) (m :: * -> *) a b.
Monad m =>
RandomC prim m a -> RandomC prim m b -> RandomC prim m b
>>= :: RandomC prim m a -> (a -> RandomC prim m b) -> RandomC prim m b
$c>>= :: forall (prim :: * -> *) (m :: * -> *) a b.
Monad m =>
RandomC prim m a -> (a -> RandomC prim m b) -> RandomC prim m b
$cp1Monad :: forall (prim :: * -> *) (m :: * -> *).
Monad m =>
Applicative (RandomC prim m)
Monad, Monad (RandomC prim m)
Monad (RandomC prim m)
-> (forall a. String -> RandomC prim m a)
-> MonadFail (RandomC prim m)
String -> RandomC prim m a
forall a. String -> RandomC prim m a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
forall (prim :: * -> *) (m :: * -> *).
MonadFail m =>
Monad (RandomC prim m)
forall (prim :: * -> *) (m :: * -> *) a.
MonadFail m =>
String -> RandomC prim m a
fail :: String -> RandomC prim m a
$cfail :: forall (prim :: * -> *) (m :: * -> *) a.
MonadFail m =>
String -> RandomC prim m a
$cp1MonadFail :: forall (prim :: * -> *) (m :: * -> *).
MonadFail m =>
Monad (RandomC prim m)
MonadFail, Monad (RandomC prim m)
Monad (RandomC prim m)
-> (forall a. IO a -> RandomC prim m a) -> MonadIO (RandomC prim m)
IO a -> RandomC prim m a
forall a. IO a -> RandomC prim m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall (prim :: * -> *) (m :: * -> *).
MonadIO m =>
Monad (RandomC prim m)
forall (prim :: * -> *) (m :: * -> *) a.
MonadIO m =>
IO a -> RandomC prim m a
liftIO :: IO a -> RandomC prim m a
$cliftIO :: forall (prim :: * -> *) (m :: * -> *) a.
MonadIO m =>
IO a -> RandomC prim m a
$cp1MonadIO :: forall (prim :: * -> *) (m :: * -> *).
MonadIO m =>
Monad (RandomC prim m)
MonadIO)

instance (Algebra sig m, Member (Lift n) sig, PrimMonad n) => Algebra (Random :+: sig) (RandomC n m) where
  alg :: Handler ctx n (RandomC n m)
-> (:+:) Random sig n a -> ctx () -> RandomC n m (ctx a)
alg Handler ctx n (RandomC n m)
hdl (:+:) Random sig n a
sig ctx ()
ctx = do
    Gen (PrimState n)
gen <- ReaderC (Gen (PrimState n)) m (Gen (PrimState n))
-> RandomC n m (Gen (PrimState n))
forall (prim :: * -> *) (m :: * -> *) a.
ReaderC (Gen (PrimState prim)) m a -> RandomC prim m a
RandomC ReaderC (Gen (PrimState n)) m (Gen (PrimState n))
forall r (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader r) sig m =>
m r
ask
    case (:+:) Random sig n a
sig of
      L (Random Distrib a
dist) -> do
        let act :: Gen (PrimState n) -> n a
act = case Distrib a
dist of
              Distrib a
Uniform -> Gen (PrimState n) -> n a
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
Gen (PrimState m) -> m a
MWC.uniform
              UniformR (a, a)
r -> (a, a) -> Gen (PrimState n) -> n a
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
MWC.uniformR (a, a)
r
              Normal Double
m Double
d -> Double -> Double -> Gen (PrimState n) -> n Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Gen (PrimState m) -> m Double
Dist.normal Double
m Double
d
              Distrib a
Standard -> Gen (PrimState n) -> n a
forall (m :: * -> *). PrimMonad m => Gen (PrimState m) -> m Double
Dist.standard
              Exponential Double
s -> Double -> Gen (PrimState n) -> n Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Gen (PrimState m) -> m Double
Dist.exponential Double
s
              TruncatedExp Double
s (Double, Double)
r -> Double -> (Double, Double) -> Gen (PrimState n) -> n Double
forall (m :: * -> *).
PrimMonad m =>
Double -> (Double, Double) -> Gen (PrimState m) -> m Double
Dist.truncatedExp Double
s (Double, Double)
r
              Gamma Double
s Double
h -> Double -> Double -> Gen (PrimState n) -> n Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Gen (PrimState m) -> m Double
Dist.gamma Double
s Double
h
              ChiSquare Int
d -> Int -> Gen (PrimState n) -> n Double
forall (m :: * -> *).
PrimMonad m =>
Int -> Gen (PrimState m) -> m Double
Dist.chiSquare Int
d
              Beta Double
a Double
b -> Double -> Double -> Gen (PrimState n) -> n Double
forall (m :: * -> *).
PrimMonad m =>
Double -> Double -> Gen (PrimState m) -> m Double
Dist.beta Double
a Double
b
              Categorical v Double
w -> v Double -> Gen (PrimState n) -> n Int
forall (m :: * -> *) (v :: * -> *).
(PrimMonad m, Vector v Double) =>
v Double -> Gen (PrimState m) -> m Int
Dist.categorical v Double
w
              LogCategorical v Double
lw -> v Double -> Gen (PrimState n) -> n Int
forall (m :: * -> *) (v :: * -> *).
(PrimMonad m, Vector v Double) =>
v Double -> Gen (PrimState m) -> m Int
Dist.logCategorical v Double
lw
              Geometric0 Double
p -> Double -> Gen (PrimState n) -> n Int
forall (m :: * -> *).
PrimMonad m =>
Double -> Gen (PrimState m) -> m Int
Dist.geometric0 Double
p
              Geometric1 Double
p -> Double -> Gen (PrimState n) -> n Int
forall (m :: * -> *).
PrimMonad m =>
Double -> Gen (PrimState m) -> m Int
Dist.geometric1 Double
p
              Bernoulli Double
p -> Double -> Gen (PrimState n) -> n Bool
forall (m :: * -> *).
PrimMonad m =>
Double -> Gen (PrimState m) -> m Bool
Dist.bernoulli Double
p
              Dirichlet t Double
t -> t Double -> Gen (PrimState n) -> n (t Double)
forall (m :: * -> *) (t :: * -> *).
(PrimMonad m, Traversable t) =>
t Double -> Gen (PrimState m) -> m (t Double)
Dist.dirichlet t Double
t
              Permutation Int
n -> Int -> Gen (PrimState n) -> n (v Int)
forall (m :: * -> *) (v :: * -> *).
(PrimMonad m, Vector v Int) =>
Int -> Gen (PrimState m) -> m (v Int)
Dist.uniformPermutation Int
n
              Shuffle v a
v -> v a -> Gen (PrimState n) -> n (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> Gen (PrimState m) -> m (v a)
Dist.uniformShuffle v a
v
        a
res <- n a -> RandomC n m a
forall (n :: * -> *) (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
(Has (Lift n) sig m, Functor n) =>
n a -> m a
sendM @n (Gen (PrimState n) -> n a
act Gen (PrimState n)
gen)
        ctx a -> RandomC n m (ctx a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
res a -> ctx () -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx)
      L Random n a
Save -> do
        Seed
res <- n Seed -> RandomC n m Seed
forall (n :: * -> *) (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
(Has (Lift n) sig m, Functor n) =>
n a -> m a
sendM @n (Gen (PrimState n) -> n Seed
forall (m :: * -> *). PrimMonad m => Gen (PrimState m) -> m Seed
MWC.save Gen (PrimState n)
gen)
        ctx Seed -> RandomC n m (ctx Seed)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Seed
res Seed -> ctx () -> ctx Seed
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx)
      R sig n a
other -> ReaderC (Gen (PrimState n)) m (ctx a) -> RandomC n m (ctx a)
forall (prim :: * -> *) (m :: * -> *) a.
ReaderC (Gen (PrimState prim)) m a -> RandomC prim m a
RandomC (Handler ctx n (ReaderC (Gen (PrimState n)) m)
-> (:+:) (Reader (Gen (PrimState n))) sig n a
-> ctx ()
-> ReaderC (Gen (PrimState n)) m (ctx a)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) (ctx :: * -> *)
       (n :: * -> *) a.
(Algebra sig m, Functor ctx) =>
Handler ctx n m -> sig n a -> ctx () -> m (ctx a)
alg (RandomC n m (ctx x) -> ReaderC (Gen (PrimState n)) m (ctx x)
forall (prim :: * -> *) (m :: * -> *) a.
RandomC prim m a -> ReaderC (Gen (PrimState prim)) m a
runRandomC (RandomC n m (ctx x) -> ReaderC (Gen (PrimState n)) m (ctx x))
-> (ctx (n x) -> RandomC n m (ctx x))
-> ctx (n x)
-> ReaderC (Gen (PrimState n)) m (ctx x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ctx (n x) -> RandomC n m (ctx x)
Handler ctx n (RandomC n m)
hdl) (sig n a -> (:+:) (Reader (Gen (PrimState n))) sig n a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R sig n a
other) ctx ()
ctx)
  {-# INLINE alg #-}

-- | Run a computation, seeding its random values from the system random number generator.
--
-- This is the de facto standard way to use this carrier. Keep in mind that seeding the RNG
-- may be a computationally intensive process.
runRandomSystem :: MonadIO m => RandomC IO m a -> m a
runRandomSystem :: RandomC IO m a -> m a
runRandomSystem (RandomC ReaderC (Gen (PrimState IO)) m a
act) = do
  Gen RealWorld
rand <- IO (Gen RealWorld) -> m (Gen RealWorld)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO (Gen RealWorld)
IO (Gen (PrimState IO))
MWC.createSystemRandom
  Gen RealWorld -> ReaderC (Gen RealWorld) m a -> m a
forall r (m :: * -> *) a. r -> ReaderC r m a -> m a
runReader Gen RealWorld
rand ReaderC (Gen RealWorld) m a
ReaderC (Gen (PrimState IO)) m a
act

-- | Run a computation, seeding its random values from an existing 'MWC.Seed'.
runRandomSeeded :: forall m n sig a. (Has (Lift n) sig m, PrimMonad n) => MWC.Seed -> RandomC n m a -> m a
runRandomSeeded :: Seed -> RandomC n m a -> m a
runRandomSeeded Seed
s (RandomC ReaderC (Gen (PrimState n)) m a
act) = do
  Gen (PrimState n)
rand <- n (Gen (PrimState n)) -> m (Gen (PrimState n))
forall (n :: * -> *) (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
(Has (Lift n) sig m, Functor n) =>
n a -> m a
sendM @n (Seed -> n (Gen (PrimState n))
forall (m :: * -> *). PrimMonad m => Seed -> m (Gen (PrimState m))
MWC.restore Seed
s)
  Gen (PrimState n) -> ReaderC (Gen (PrimState n)) m a -> m a
forall r (m :: * -> *) a. r -> ReaderC r m a -> m a
runReader Gen (PrimState n)
rand ReaderC (Gen (PrimState n)) m a
act