module Control.Cond
( CondT, Cond
, runCondT, runCond, execCondT, evalCondT, test
, MonadQuery(..), guardM, guard_, guardM_, apply, consider
, accept, ignore, norecurse, prune
, matches, ifM, whenM, unlessM
, if_, when_, unless_, or_, and_, not_
, recurse
)
where
import Control.Applicative
import Control.Arrow (second)
import Control.Monad hiding (mapM_, sequence_)
import Control.Monad.Base
import Control.Monad.Catch
import Control.Monad.Cont.Class as C
import Control.Monad.Error.Class as E
import Control.Monad.Fix
import Control.Monad.Morph as M
import Control.Monad.Reader.Class as R
import Control.Monad.State.Class as S
import Control.Monad.Trans
import Control.Monad.Trans.Cont (ContT(..))
import Control.Monad.Trans.Control
import Control.Monad.Trans.Error (ErrorT(..))
import Control.Monad.Trans.Except (ExceptT(..))
import Control.Monad.Trans.Identity (IdentityT(..))
import Control.Monad.Trans.List (ListT(..))
import Control.Monad.Trans.Maybe (MaybeT(..))
import qualified Control.Monad.Trans.RWS.Lazy as LazyRWS
import qualified Control.Monad.Trans.RWS.Strict as StrictRWS
import Control.Monad.Trans.Reader (ReaderT(..))
import Control.Monad.Trans.State (StateT(..))
import qualified Control.Monad.Trans.State.Lazy as Lazy
import qualified Control.Monad.Trans.State.Strict as Strict
import qualified Control.Monad.Trans.Writer.Lazy as Lazy
import qualified Control.Monad.Trans.Writer.Strict as Strict
import Control.Monad.Writer.Class
import Control.Monad.Zip
import Data.Foldable
import Data.Functor.Identity
import Data.Monoid hiding ((<>))
import Data.Semigroup
import Prelude hiding (mapM_, foldr1, sequence_)
data Recursor a m r = Stop | Recurse (CondT a m r) | Continue
deriving Functor
instance Semigroup (Recursor a m r) where
Stop <> _ = Stop
_ <> Stop = Stop
Recurse n <> _ = Recurse n
_ <> Recurse n = Recurse n
_ <> _ = Continue
instance Monoid (Recursor a m r) where
mempty = Continue
mappend = (<>)
instance MFunctor (Recursor a) where
hoist _ Stop = Stop
hoist nat (Recurse n) = Recurse (hoist nat n)
hoist _ Continue = Continue
type CondR a m r = (Maybe r, Recursor a m r)
accept' :: r -> CondR a m r
accept' x = (Just x, Continue)
recurse' :: CondR a m r
recurse' = (Nothing, Continue)
newtype CondT a m r = CondT { getCondT :: StateT a m (CondR a m r) }
deriving Functor
type Cond a = CondT a Identity
instance (Monad m, Semigroup r) => Semigroup (CondT a m r) where
(<>) = liftM2 (<>)
instance (Monad m, Monoid r) => Monoid (CondT a m r) where
mempty = CondT $ return mempty
mappend = liftM2 mappend
instance Monad m => Applicative (CondT a m) where
pure = return
(<*>) = ap
instance Monad m => Monad (CondT a m) where
return = CondT . return . accept'
fail _ = mzero
CondT m >>= k = CondT $ m >>= \case
(Nothing, Stop) -> return (Nothing, Stop)
(Nothing, Continue) -> return (Nothing, Continue)
(Nothing, Recurse n) -> return (Nothing, Recurse (n >>= k))
(Just r, Stop) -> fmap (const Stop) `liftM` getCondT (k r)
(Just r, Continue) -> getCondT (k r)
(Just r, Recurse n) -> getCondT (k r) >>= \case
(v, Continue) -> return (v, Recurse (n >>= k))
x -> return x
instance MonadReader r m => MonadReader r (CondT a m) where
ask = lift R.ask
local f (CondT m) = CondT $ R.local f m
reader = lift . R.reader
instance MonadWriter w m => MonadWriter w (CondT a m) where
writer = lift . writer
tell = lift . tell
listen m = m >>= lift . listen . return
pass m = m >>= lift . pass . return
instance MonadState s m => MonadState s (CondT a m) where
get = lift S.get
put = lift . S.put
state = lift . S.state
instance Monad m => Alternative (CondT a m) where
empty = CondT $ return recurse'
CondT f <|> CondT g = CondT $ do
r <- f
case r of
x@(Just _, _) -> return x
_ -> g
instance Monad m => MonadPlus (CondT a m) where
mzero = CondT $ return recurse'
mplus (CondT f) (CondT g) = CondT $ do
r <- f
case r of
x@(Just _, _) -> return x
_ -> g
instance MonadError e m => MonadError e (CondT a m) where
throwError = CondT . throwError
catchError (CondT m) h = CondT $ m `catchError` \e -> getCondT (h e)
instance MonadThrow m => MonadThrow (CondT a m) where
throwM = CondT . throwM
instance MonadCatch m => MonadCatch (CondT a m) where
catch (CondT m) c = CondT $ m `catch` \e -> getCondT (c e)
#if MIN_VERSION_exceptions(0,6,0)
instance MonadMask m => MonadMask (CondT a m) where
#endif
mask a = CondT $ mask $ \u -> getCondT (a $ q u)
where q u = CondT . u . getCondT
uninterruptibleMask a =
CondT $ uninterruptibleMask $ \u -> getCondT (a $ q u)
where q u = CondT . u . getCondT
instance MonadBase b m => MonadBase b (CondT a m) where
liftBase m = CondT $ liftM accept' $ liftBase m
instance MonadIO m => MonadIO (CondT a m) where
liftIO m = CondT $ liftM accept' $ liftIO m
instance MonadTrans (CondT a) where
lift m = CondT $ liftM accept' $ lift m
#if MIN_VERSION_monad_control(1,0,0)
instance MonadBaseControl b m => MonadBaseControl b (CondT r m) where
type StM (CondT r m) a = StM m (CondR r m a, r)
liftBaseWith f = CondT $ StateT $ \s ->
liftM (\x -> (accept' x, s)) $ liftBaseWith $ \runInBase ->
f $ \k -> runInBase $ runStateT (getCondT k) s
restoreM = CondT . StateT . const . restoreM
#else
instance MonadBaseControl b m => MonadBaseControl b (CondT r m) where
newtype StM (CondT r m) a =
CondTStM { unCondTStM :: StM m (Result r m a, r) }
liftBaseWith f = CondT $ StateT $ \s ->
liftM (\x -> (accept' x, s)) $ liftBaseWith $ \runInBase -> f $ \k ->
liftM CondTStM $ runInBase $ runStateT (getCondT k) s
restoreM = CondT . StateT . const . restoreM . unCondTStM
#endif
instance MFunctor (CondT a) where
hoist nat (CondT m) = CondT $ hoist nat (fmap (hoist nat) `liftM` m)
instance MonadCont m => MonadCont (CondT a m) where
callCC f = CondT $ StateT $ \a ->
callCC $ \k -> flip runStateT a $ getCondT $ f $ \r ->
CondT $ StateT $ \a' -> k ((Just r, Continue), a')
instance Monad m => MonadZip (CondT a m) where
mzipWith = liftM2
instance MonadFix m => MonadFix (CondT a m) where
mfix f = CondT $ StateT $ \a -> mdo
((mb, n), a') <- case mb of
Nothing -> return ((mb, n), a')
Just b -> runStateT (getCondT (f b)) a
return ((mb, n), a')
runCondT :: Monad m => a -> CondT a m r -> m ((Maybe r, Maybe (CondT a m r)), a)
runCondT a c@(CondT (StateT s)) = go `liftM` s a
where
go (p, a') = (second (recursorToMaybe c) p, a')
recursorToMaybe _ Stop = Nothing
recursorToMaybe p Continue = Just p
recursorToMaybe _ (Recurse n) = Just n
runCond :: a -> Cond a r -> Maybe r
runCond = ((fst . fst . runIdentity) .) . runCondT
execCondT :: Monad m => a -> CondT a m r -> m (Maybe a, Maybe (CondT a m r))
execCondT a c = go `liftM` runCondT a c
where
go ((mr, mnext), a') = (const a' <$> mr, mnext)
evalCondT :: Monad m => a -> CondT a m r -> m (Maybe r)
evalCondT a c = go `liftM` runCondT a c
where
go ((mr, _), _) = mr
test :: Monad m => a -> CondT a m r -> m Bool
test a c = go `liftM` runCondT a c
where
go ((Nothing, _), _) = False
go ((Just _, _), _) = True
class Monad m => MonadQuery a m | m -> a where
query :: m a
queries :: (a -> b) -> m b
update :: a -> m ()
updates :: (a -> a) -> m ()
instance Monad m => MonadQuery a (CondT a m) where
query = CondT $ gets accept'
queries f = CondT $ state (\a -> (accept' (f a), a))
update a = CondT $ liftM accept' $ put a
updates f = CondT $ liftM accept' $ modify f
instance MonadQuery r m => MonadQuery r (ReaderT r m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance (MonadQuery r m, Monoid w) => MonadQuery r (LazyRWS.RWST r w s m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance (MonadQuery r m, Monoid w)
=> MonadQuery r (StrictRWS.RWST r w s m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance MonadQuery r' m => MonadQuery r' (ContT r m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance (Error e, MonadQuery r m) => MonadQuery r (ErrorT e m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance MonadQuery r m => MonadQuery r (ExceptT e m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance MonadQuery r m => MonadQuery r (IdentityT m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance MonadQuery r m => MonadQuery r (ListT m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance MonadQuery r m => MonadQuery r (MaybeT m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance MonadQuery r m => MonadQuery r (Lazy.StateT s m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance MonadQuery r m => MonadQuery r (Strict.StateT s m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance (Monoid w, MonadQuery r m) => MonadQuery r (Lazy.WriterT w m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
instance (Monoid w, MonadQuery r m) => MonadQuery r (Strict.WriterT w m) where
query = lift query
queries = lift . queries
update = lift . update
updates = lift . updates
guardM :: MonadPlus m => m Bool -> m ()
guardM = (>>= guard)
guard_ :: (MonadPlus m, MonadQuery a m) => (a -> Bool) -> m ()
guard_ f = query >>= guard . f
guardM_ :: (MonadPlus m, MonadQuery a m) => (a -> m Bool) -> m ()
guardM_ f = query >>= guardM . f
apply :: (MonadPlus m, MonadQuery a m) => (a -> m (Maybe r)) -> m r
apply = queries >=> (>>= maybe mzero return)
consider :: (MonadPlus m, MonadQuery a m) => (a -> m (Maybe (r, a))) -> m r
consider = queries >=> (>>= maybe mzero (\(r, a') -> const r `liftM` update a'))
accept :: MonadPlus m => m ()
accept = return ()
ignore :: MonadPlus m => m r
ignore = mzero
norecurse :: Monad m => CondT a m ()
norecurse = CondT $ return (Just (), Stop)
prune :: Monad m => CondT a m r
prune = CondT $ return (Nothing, Stop)
matches :: MonadPlus m => m r -> m Bool
matches m = (const True `liftM` m) `mplus` return False
ifM :: Monad m => m Bool -> m s -> m s -> m s
ifM c x y = c >>= \b -> if b then x else y
if_ :: MonadPlus m => m r -> m s -> m s -> m s
if_ c x y = matches c >>= \b -> if b then x else y
whenM :: Monad m => m Bool -> m s -> m ()
whenM c x = ifM c (x >> return ()) (return ())
when_ :: MonadPlus m => m r -> m s -> m ()
when_ c x = if_ c (x >> return ()) (return ())
unlessM :: Monad m => m Bool -> m s -> m ()
unlessM c x = ifM c (return ()) (x >> return ())
unless_ :: MonadPlus m => m r -> m s -> m ()
unless_ c x = if_ c (return ()) (x >> return ())
or_ :: MonadPlus m => [m r] -> m r
or_ = Data.Foldable.msum
and_ :: MonadPlus m => [m r] -> m ()
and_ = sequence_
not_ :: MonadPlus m => m r -> m ()
not_ c = if_ c ignore accept
recurse :: Monad m => CondT a m r -> CondT a m r
recurse c = CondT $ fmap (const (Recurse c)) `liftM` getCondT c