{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
module Control.Monad.Choice.Random
( UniformRandom
( UniformRandom
)
, lift
, colift
)
where
import Control.Applicative
( Alternative
( empty
, (<|>)
, some
, many
)
#if MIN_VERSION_base(4,10,0)
, Applicative
( liftA2
)
#endif
)
import Control.Monad
( MonadPlus
( mzero
, mplus
)
)
import Control.Monad.Class.Choice
( MonadChoice
( choose
)
)
import Control.Monad.Cont.Class
( MonadCont
( callCC
)
)
import Control.Monad.Error.Class
( MonadError
( throwError
, catchError
)
)
import qualified Control.Monad.Fail as Fail
( MonadFail
( fail
)
)
import Control.Monad.Fix
( MonadFix
( mfix
)
)
import Control.Monad.IO.Class
( MonadIO
( liftIO
)
)
import Control.Monad.Primitive
( PrimMonad
( PrimState
, primitive
)
)
import Control.Monad.Random.Class
( MonadRandom
( getRandomR
, getRandom
, getRandomRs
, getRandoms
)
, MonadSplit
( getSplit
)
, MonadInterleave
( interleave
)
, uniform
)
import Control.Monad.Reader.Class
( MonadReader
( ask
, local
, reader
)
)
import Control.Monad.RWS.Class
( MonadRWS
)
import Control.Monad.State.Class
( MonadState
( get
, put
, state
)
)
import Control.Monad.Writer.Class
( MonadWriter
( writer
, tell
, listen
, pass
)
)
import Control.Monad.Zip
( MonadZip
( mzip
, mzipWith
, munzip
)
)
import Data.Foldable
( Foldable
( fold
, foldr'
, foldl'
, toList
)
)
import Data.Functor.Classes
( Eq1
( liftEq
)
, Ord1
( liftCompare
)
)
#if !MIN_VERSION_base(4,13,0)
import Data.Semigroup
( Semigroup
( (<>)
)
)
#endif
newtype UniformRandom r a
= UniformRandom
{ runUniformRandom :: r a
}
lift :: r a -> UniformRandom r a
lift = UniformRandom
{-# INLINE lift #-}
colift :: UniformRandom r a -> r a
colift = runUniformRandom
{-# INLINE colift #-}
lift2 :: (r a -> s b -> t c) -> UniformRandom r a -> UniformRandom s b -> UniformRandom t c
lift2 = (((lift .) . (. colift)) .) . (. colift)
{-# INLINE lift2 #-}
instance Functor f => Functor (UniformRandom f) where
fmap = (lift .) . (. colift) . fmap
{-# INLINE fmap #-}
instance Applicative f => Applicative (UniformRandom f) where
pure = lift . pure
{-# INLINE pure #-}
(<*>) = lift2 (<*>)
{-# INLINE (<*>) #-}
#if MIN_VERSION_base(4,10,0)
liftA2 = lift2 . liftA2
{-# INLINE liftA2 #-}
#endif
(*>) = lift2 (*>)
{-# INLINE (*>) #-}
(<*) = lift2 (<*)
{-# INLINE (<*) #-}
instance Monad m => Monad (UniformRandom m) where
(>>=) = (lift .) . (. (colift .)) . (>>=) . colift
{-# INLINE (>>=) #-}
instance (Foldable f, MonadRandom m) => MonadChoice f (UniformRandom m) where
choose = lift . uniform
{-# INLINE choose #-}
instance MonadRandom m => MonadRandom (UniformRandom m) where
getRandomR = lift . getRandomR
{-# INLINE getRandomR #-}
getRandom = lift getRandom
{-# INLINE getRandom #-}
getRandomRs = lift . getRandomRs
{-# INLINE getRandomRs #-}
getRandoms = lift getRandoms
{-# INLINE getRandoms #-}
instance MonadFix m => MonadFix (UniformRandom m) where
mfix = lift . mfix . (colift .)
{-# INLINE mfix #-}
instance Fail.MonadFail m => Fail.MonadFail (UniformRandom m) where
fail = lift . Fail.fail
{-# INLINE fail #-}
instance Alternative f => Alternative (UniformRandom f) where
empty = lift empty
{-# INLINE empty #-}
(<|>) = lift2 (<|>)
{-# INLINE (<|>) #-}
some = lift . some . colift
{-# INLINE some #-}
many = lift . many . colift
{-# INLINE many #-}
instance MonadPlus m => MonadPlus (UniformRandom m) where
mzero = lift mzero
{-# INLINE mzero #-}
mplus = lift2 mplus
{-# INLINE mplus #-}
instance MonadIO m => MonadIO (UniformRandom m) where
liftIO = lift . liftIO
{-# INLINE liftIO #-}
instance Semigroup (r a) => Semigroup (UniformRandom r a) where
(<>) = lift2 (<>)
{-# INLINE (<>) #-}
instance
( Monoid (r a)
#if !MIN_VERSION_base(4,11,0)
, Semigroup (r a)
#endif
)
=> Monoid (UniformRandom r a)
where
mempty = lift mempty
{-# INLINE mempty #-}
mappend = (<>)
{-# INLINE mappend #-}
instance MonadError e m => MonadError e (UniformRandom m) where
throwError = lift . throwError
{-# INLINE throwError #-}
catchError = (. (colift .)) . (lift .) . catchError . colift
{-# INLINE catchError #-}
instance MonadReader r m => MonadReader r (UniformRandom m) where
ask = lift ask
{-# INLINE ask #-}
local = (lift .) . (. colift) . local
{-# INLINE local #-}
reader = lift . reader
{-# INLINE reader #-}
instance MonadState s m => MonadState s (UniformRandom m) where
get = lift get
{-# INLINE get #-}
put = lift . put
{-# INLINE put #-}
state = lift . state
{-# INLINE state #-}
instance Foldable f => Foldable (UniformRandom f) where
fold = fold . colift
{-# INLINE fold #-}
foldMap = (. colift) . foldMap
{-# INLINE foldMap #-}
foldr = ((. colift) .) . foldr
{-# INLINE foldr #-}
foldr' = ((. colift) .) . foldr'
{-# INLINE foldr' #-}
foldl = ((. colift) .) . foldl
{-# INLINE foldl #-}
foldl' = ((. colift) .) . foldl'
{-# INLINE foldl' #-}
foldr1 = (. colift) . foldr1
{-# INLINE foldr1 #-}
foldl1 = (. colift) . foldl1
{-# INLINE foldl1 #-}
toList = toList . colift
{-# INLINE toList #-}
null = null . colift
{-# INLINE null #-}
length = length . colift
{-# INLINE length #-}
elem = (. colift) . elem
{-# INLINE elem #-}
maximum = maximum . colift
{-# INLINE maximum #-}
minimum = minimum . colift
{-# INLINE minimum #-}
sum = sum . colift
{-# INLINE sum #-}
product = product . colift
{-# INLINE product #-}
instance Traversable t => Traversable (UniformRandom t) where
traverse = (fmap lift .) . (. colift) . traverse
{-# INLINE traverse #-}
sequenceA = fmap lift . sequenceA . colift
{-# INLINE sequenceA #-}
mapM = (fmap lift .) . (. colift) . mapM
{-# INLINE mapM #-}
sequence = fmap lift . sequence . colift
{-# INLINE sequence #-}
instance Eq1 f => Eq1 (UniformRandom f) where
liftEq = ((. colift) .) . (. colift) . liftEq
{-# INLINE liftEq #-}
instance Ord1 f => Ord1 (UniformRandom f) where
liftCompare = ((. colift) .) . (. colift) . liftCompare
{-# INLINE liftCompare #-}
instance MonadZip m => MonadZip (UniformRandom m) where
mzip = lift2 mzip
{-# INLINE mzip #-}
mzipWith = lift2 . mzipWith
{-# INLINE mzipWith #-}
munzip = (\(m1,m2) -> (lift m1, lift m2)) . munzip . colift
{-# INLINE munzip #-}
instance MonadCont m => MonadCont (UniformRandom m) where
callCC = lift . callCC . (colift .) . (. (lift .))
{-# INLINE callCC #-}
instance Eq (r a) => Eq (UniformRandom r a) where
(==) = (. colift) . (==) . colift
{-# INLINE (==) #-}
(/=) = (. colift) . (/=) . colift
{-# INLINE (/=) #-}
instance Ord (r a) => Ord (UniformRandom r a) where
compare = (. colift) . compare . colift
{-# INLINE compare #-}
(<) = (. colift) . (<) . colift
{-# INLINE (<) #-}
(<=) = (. colift) . (<=) . colift
{-# INLINE (<=) #-}
(>) = (. colift) . (>) . colift
{-# INLINE (>) #-}
(>=) = (. colift) . (>=) . colift
{-# INLINE (>=) #-}
max = lift2 max
{-# INLINE max #-}
min = lift2 min
{-# INLINE min #-}
instance MonadRWS r w s m => MonadRWS r w s (UniformRandom m)
instance MonadWriter w m => MonadWriter w (UniformRandom m) where
writer = lift . writer
{-# INLINE writer #-}
tell = lift . tell
{-# INLINE tell #-}
listen = lift . listen . colift
{-# INLINE listen #-}
pass = lift . pass . colift
{-# INLINE pass #-}
instance MonadSplit g m => MonadSplit g (UniformRandom m) where
getSplit = lift getSplit
{-# INLINE getSplit #-}
instance PrimMonad m => PrimMonad (UniformRandom m) where
type PrimState (UniformRandom m) = PrimState m
primitive = lift . primitive
{-# INLINE primitive #-}
instance MonadInterleave m => MonadInterleave (UniformRandom m) where
interleave = lift . interleave . colift
{-# INLINE interleave #-}