{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -fwarn-incomplete-patterns #-}
module Knit.Effect.RandomFu
(
Random
, sampleRVar
, sampleDist
, runRandomIOSimple
, runRandomIOPureMT
, runRandomFromSource
)
where
import qualified Polysemy as P
import Polysemy.Internal ( send )
import Data.IORef ( newIORef )
import qualified Data.Random as R
import qualified Data.Random.Source as R
import qualified Data.Random.Internal.Source as R
import qualified Data.Random.Source.PureMT as R
import Control.Monad.IO.Class ( MonadIO(..) )
data Random m r where
SampleRVar :: R.RVar t -> Random m t
GetRandomPrim :: R.Prim t -> Random m t
sampleRVar :: (P.Member Random effs) => R.RVar t -> P.Sem effs t
sampleRVar = send . SampleRVar
sampleDist :: (P.Member Random effs, R.Distribution d t) => d t -> P.Sem effs t
sampleDist = sampleRVar . R.rvar
getRandomPrim :: P.Member Random effs => R.Prim t -> P.Sem effs t
getRandomPrim = send . GetRandomPrim
runRandomIOSimple
:: forall effs a
. MonadIO (P.Sem effs)
=> P.Sem (Random ': effs) a
-> P.Sem effs a
runRandomIOSimple = P.interpret f
where
f :: forall m x . (Random m x -> P.Sem effs x)
f r = case r of
SampleRVar rv -> liftIO $ R.sample rv
GetRandomPrim pt -> liftIO $ R.getRandomPrim pt
runRandomFromSource
:: forall s effs a
. R.RandomSource (P.Sem effs) s
=> s
-> P.Sem (Random ': effs) a
-> P.Sem effs a
runRandomFromSource source = P.interpret f
where
f :: forall m x . (Random m x -> P.Sem effs x)
f r = case r of
SampleRVar rv -> R.runRVar (R.sample rv) source
GetRandomPrim pt -> R.runRVar (R.getRandomPrim pt) source
runRandomIOPureMT
:: MonadIO (P.Sem effs)
=> R.PureMT
-> P.Sem (Random ': effs) a
-> P.Sem effs a
runRandomIOPureMT source re =
liftIO (newIORef source) >>= flip runRandomFromSource re
$(R.monadRandom [d|
instance P.Member Random effs => R.MonadRandom (P.Sem effs) where
getRandomPrim = getRandomPrim
|])