{-# LANGUAGE CPP                    #-}
{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE MultiParamTypeClasses  #-}
{-# LANGUAGE Safe                   #-}
{-# LANGUAGE UndecidableInstances   #-}
module Control.Monad.Random.Class (
    
    MonadRandom(..),
    
    MonadSplit(..),
    
    MonadInterleave(..),
    
    fromList,
    fromListMay,
    uniform,
    uniformMay,
    weighted,
    weightedMay
    ) where
import           Control.Monad
import           Control.Monad.Trans.Class
import           Control.Monad.Trans.Cont
import           Control.Monad.Trans.Error
import           Control.Monad.Trans.Except
import           Control.Monad.Trans.Identity
import           Control.Monad.Trans.List
import           Control.Monad.Trans.Maybe
import           Control.Monad.Trans.Reader
import qualified Control.Monad.Trans.RWS.Lazy      as LazyRWS
import qualified Control.Monad.Trans.RWS.Strict    as StrictRWS
import qualified Control.Monad.Trans.State.Lazy    as LazyState
import qualified Control.Monad.Trans.State.Strict  as StrictState
import qualified Control.Monad.Trans.Writer.Lazy   as LazyWriter
import qualified Control.Monad.Trans.Writer.Strict as StrictWriter
import           System.Random
import qualified Data.Foldable                     as F
#if MIN_VERSION_base(4,8,0)
#else
import           Data.Monoid                       (Monoid)
#endif
class (Monad m) => MonadRandom m where
  
  
  
  
  
  
  
  
  
  getRandomR :: (Random a) => (a, a) -> m a
  
  
  
  
  
  
  
  
  
  
  
  getRandom :: (Random a) => m a
  
  
  
  
  getRandomRs :: (Random a) => (a, a) -> m [a]
  
  
  
  
  getRandoms :: (Random a) => m [a]
instance MonadRandom IO where
  getRandomR       = randomRIO
  getRandom        = randomIO
  getRandomRs lohi = liftM (randomRs lohi) newStdGen
  getRandoms       = liftM randoms newStdGen
instance (MonadRandom m) => MonadRandom (ContT r m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (Error e, MonadRandom m) => MonadRandom (ErrorT e m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m) => MonadRandom (ExceptT e m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m) => MonadRandom (IdentityT m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m) => MonadRandom (ListT m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m) => MonadRandom (MaybeT m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (Monoid w, MonadRandom m) => MonadRandom (LazyRWS.RWST r w s m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (Monoid w, MonadRandom m) => MonadRandom (StrictRWS.RWST r w s m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m) => MonadRandom (ReaderT r m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m) => MonadRandom (LazyState.StateT s m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m) => MonadRandom (StrictState.StateT s m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m, Monoid w) => MonadRandom (LazyWriter.WriterT w m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
instance (MonadRandom m, Monoid w) => MonadRandom (StrictWriter.WriterT w m) where
  getRandomR  = lift . getRandomR
  getRandom   = lift getRandom
  getRandomRs = lift . getRandomRs
  getRandoms  = lift getRandoms
class (Monad m) => MonadSplit g m | m -> g where
  
  
  
  
  getSplit :: m g
instance MonadSplit StdGen IO where
  getSplit = newStdGen
instance (MonadSplit g m) => MonadSplit g (ContT r m) where
  getSplit = lift getSplit
instance (Error e, MonadSplit g m) => MonadSplit g (ErrorT e m) where
  getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ExceptT e m) where
  getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (IdentityT m) where
  getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ListT m) where
  getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (MaybeT m) where
  getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (LazyRWS.RWST r w s m) where
  getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (StrictRWS.RWST r w s m) where
  getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (ReaderT r m) where
  getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (LazyState.StateT s m) where
  getSplit = lift getSplit
instance (MonadSplit g m) => MonadSplit g (StrictState.StateT s m) where
  getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (LazyWriter.WriterT w m) where
  getSplit = lift getSplit
instance (Monoid w, MonadSplit g m) => MonadSplit g (StrictWriter.WriterT w m) where
  getSplit = lift getSplit
class MonadRandom m => MonadInterleave m where
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  
  interleave :: m a -> m a
instance (MonadInterleave m) => MonadInterleave (ContT r m) where
  interleave = mapContT interleave
instance (Error e, MonadInterleave m) => MonadInterleave (ErrorT e m) where
  interleave = mapErrorT interleave
instance (MonadInterleave m) => MonadInterleave (ExceptT e m) where
  interleave = mapExceptT interleave
instance (MonadInterleave m) => MonadInterleave (IdentityT m) where
  interleave = mapIdentityT interleave
instance (MonadInterleave m) => MonadInterleave (ListT m) where
  interleave = mapListT interleave
instance (MonadInterleave m) => MonadInterleave (MaybeT m) where
  interleave = mapMaybeT interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (LazyRWS.RWST r w s m) where
  interleave = LazyRWS.mapRWST interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (StrictRWS.RWST r w s m) where
  interleave = StrictRWS.mapRWST interleave
instance (MonadInterleave m) => MonadInterleave (ReaderT r m) where
  interleave = mapReaderT interleave
instance (MonadInterleave m) => MonadInterleave (LazyState.StateT s m) where
  interleave = LazyState.mapStateT interleave
instance (MonadInterleave m) => MonadInterleave (StrictState.StateT s m) where
  interleave = StrictState.mapStateT interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (LazyWriter.WriterT w m) where
  interleave = LazyWriter.mapWriterT interleave
instance (Monoid w, MonadInterleave m) => MonadInterleave (StrictWriter.WriterT w m) where
  interleave = StrictWriter.mapWriterT interleave
weighted :: (F.Foldable t, MonadRandom m) => t (a, Rational) -> m a
weighted t = do
  ma <- weightedMay t
  case ma of
    Nothing -> error "Control.Monad.Random.Class.weighted: empty collection, or total weight = 0"
    Just a  -> return a
weightedMay :: (F.Foldable t, MonadRandom m) => t (a, Rational) -> m (Maybe a)
weightedMay = fromListMay . F.toList
fromList :: (MonadRandom m) => [(a, Rational)] -> m a
fromList ws = do
  ma <- fromListMay ws
  case ma of
    Nothing -> error "Control.Monad.Random.Class.fromList: empty list, or total weight = 0"
    Just a  -> return a
fromListMay :: (MonadRandom m) => [(a, Rational)] -> m (Maybe a)
fromListMay xs = do
  let s    = fromRational (sum (map snd xs)) :: Double
      cums = scanl1 (\ ~(_,q) ~(y,s') -> (y, s'+q)) xs
  case s of
    0 -> return Nothing
    _ -> do
      p <- liftM toRational $ getRandomR (0, s)
      return . Just . fst . head . dropWhile ((< p) . snd) $ cums
uniform :: (F.Foldable t, MonadRandom m) => t a -> m a
uniform t = do
  ma <- uniformMay t
  case ma of
    Nothing -> error "Control.Monad.Random.Class.uniform: empty collection"
    Just a  -> return a
uniformMay :: (F.Foldable t, MonadRandom m) => t a -> m (Maybe a)
uniformMay = fromListMay . map (flip (,) 1) . F.toList