{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE UndecidableInstances #-}
#if __GLASGOW_HASKELL__ >= 702
{-# LANGUAGE Safe #-}
#endif
module Control.Monad.Logic (
module Control.Monad.Logic.Class,
Logic,
logic,
runLogic,
observe,
observeMany,
observeAll,
LogicT(..),
runLogicT,
observeT,
observeManyT,
observeAllT,
module Control.Monad,
module Control.Monad.Trans
) where
import Control.Applicative
import Control.Monad
import qualified Control.Monad.Fail as Fail
import Control.Monad.Identity
import Control.Monad.Trans
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Error.Class
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid (Monoid (..))
#endif
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup (..))
#endif
import qualified Data.Foldable as F
import qualified Data.Traversable as T
import Control.Monad.Logic.Class
newtype LogicT m a =
LogicT { unLogicT :: forall r. (a -> m r -> m r) -> m r -> m r }
#if !MIN_VERSION_base(4,13,0)
observeT :: Monad m => LogicT m a -> m a
#else
observeT :: MonadFail m => LogicT m a -> m a
#endif
observeT lt = unLogicT lt (const . return) (fail "No answer.")
observeAllT :: Monad m => LogicT m a -> m [a]
observeAllT m = unLogicT m (liftM . (:)) (return [])
observeManyT :: Monad m => Int -> LogicT m a -> m [a]
observeManyT n m
| n <= 0 = return []
| n == 1 = unLogicT m (\a _ -> return [a]) (return [])
| otherwise = unLogicT (msplit m) sk (return [])
where
sk Nothing _ = return []
sk (Just (a, m')) _ = (a:) `liftM` observeManyT (n-1) m'
runLogicT :: LogicT m a -> (a -> m r -> m r) -> m r -> m r
runLogicT (LogicT r) = r
type Logic = LogicT Identity
logic :: (forall r. (a -> r -> r) -> r -> r) -> Logic a
logic f = LogicT $ \k -> Identity .
f (\a -> runIdentity . k a . Identity) .
runIdentity
observe :: Logic a -> a
observe lt = runIdentity $ unLogicT lt (const . return) (error "No answer.")
observeAll :: Logic a -> [a]
observeAll = runIdentity . observeAllT
observeMany :: Int -> Logic a -> [a]
observeMany i = take i . observeAll
runLogic :: Logic a -> (a -> r -> r) -> r -> r
runLogic l s f = runIdentity $ unLogicT l si fi
where
si = fmap . s
fi = Identity f
instance Functor (LogicT f) where
fmap f lt = LogicT $ \sk fk -> unLogicT lt (sk . f) fk
instance Applicative (LogicT f) where
pure a = LogicT $ \sk fk -> sk a fk
f <*> a = LogicT $ \sk fk -> unLogicT f (\g fk' -> unLogicT a (sk . g) fk') fk
instance Alternative (LogicT f) where
empty = LogicT $ \_ fk -> fk
f1 <|> f2 = LogicT $ \sk fk -> unLogicT f1 sk (unLogicT f2 sk fk)
instance Monad (LogicT m) where
return a = LogicT $ \sk fk -> sk a fk
m >>= f = LogicT $ \sk fk -> unLogicT m (\a fk' -> unLogicT (f a) sk fk') fk
#if !MIN_VERSION_base(4,13,0)
fail = Fail.fail
#endif
instance Fail.MonadFail (LogicT m) where
fail _ = LogicT $ \_ fk -> fk
instance MonadPlus (LogicT m) where
mzero = LogicT $ \_ fk -> fk
m1 `mplus` m2 = LogicT $ \sk fk -> unLogicT m1 sk (unLogicT m2 sk fk)
#if MIN_VERSION_base(4,9,0)
instance Semigroup (LogicT m a) where
(<>) = mplus
sconcat = foldr1 mplus
#endif
instance Monoid (LogicT m a) where
mempty = mzero
mappend = mplus
mconcat = foldr mplus mzero
instance MonadTrans LogicT where
lift m = LogicT $ \sk fk -> m >>= \a -> sk a fk
instance (MonadIO m) => MonadIO (LogicT m) where
liftIO = lift . liftIO
instance (Monad m) => MonadLogic (LogicT m) where
msplit m = lift $ unLogicT m ssk (return Nothing)
where
ssk a fk = return $ Just (a, (lift fk >>= reflect))
once m = LogicT $ \sk fk -> unLogicT m (\a _ -> sk a fk) fk
lnot m = LogicT $ \sk fk -> unLogicT m (\_ _ -> fk) (sk () fk)
#if MIN_VERSION_base(4,8,0)
instance {-# OVERLAPPABLE #-} (Monad m, F.Foldable m) => F.Foldable (LogicT m) where
foldMap f m = F.fold $ unLogicT m (liftM . mappend . f) (return mempty)
instance {-# OVERLAPPING #-} F.Foldable (LogicT Identity) where
foldr f z m = runLogic m f z
#else
instance (Monad m, F.Foldable m) => F.Foldable (LogicT m) where
foldMap f m = F.fold $ unLogicT m (liftM . mappend . f) (return mempty)
#endif
instance T.Traversable (LogicT Identity) where
traverse g l = runLogic l (\a ft -> cons <$> g a <*> ft) (pure mzero)
where cons a l' = return a `mplus` l'
instance MonadReader r m => MonadReader r (LogicT m) where
ask = lift ask
local f (LogicT m) = LogicT $ \sk fk -> do
env <- ask
local f $ m ((local (const env) .) . sk) (local (const env) fk)
instance MonadState s m => MonadState s (LogicT m) where
get = lift get
put = lift . put
instance MonadError e m => MonadError e (LogicT m) where
throwError = lift . throwError
catchError m h = LogicT $ \sk fk -> let
handle r = r `catchError` \e -> unLogicT (h e) sk fk
in handle $ unLogicT m (\a -> sk a . handle) fk