{-# LANGUAGE Safe #-}
module Test.SmallCheck.SeriesMonad where
import Control.Applicative (Applicative(..), Alternative(..), (<$>))
import Control.Monad (MonadPlus(..))
import Control.Monad.Logic (MonadLogic(..), LogicT)
import Control.Monad.Reader (MonadTrans(..), ReaderT, runReaderT)
import Control.Arrow (second)
type Depth = Int
newtype Series m a = Series (ReaderT Depth (LogicT m) a)
instance Functor (Series m) where
fmap f (Series x) = Series (fmap f x)
instance Monad (Series m) where
Series x >>= f = Series (x >>= unSeries . f)
where
unSeries (Series y) = y
return = pure
instance Applicative (Series m) where
pure = Series . pure
Series x <*> Series y = Series (x <*> y)
instance MonadPlus (Series m) where
mzero = empty
mplus = (<|>)
instance Alternative (Series m) where
empty = Series empty
Series x <|> Series y = Series (x <|> y)
instance Monad m => MonadLogic (Series m) where
msplit (Series a) = Series (fmap (second Series) <$> msplit a)
instance MonadTrans Series where
lift a = Series $ lift . lift $ a
runSeries :: Depth -> Series m a -> LogicT m a
runSeries d (Series a) = runReaderT a d