{-# LANGUAGE CPP #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}

-------------------------------------------------------------------------
-- |
-- Module      : Control.Monad.Logic
-- Copyright   : (c) Dan Doel
-- License     : BSD3
--
-- Maintainer  : dan.doel@gmail.com
-- Stability   : experimental
-- Portability : non-portable (multi-parameter type classes)
--
-- A backtracking, logic programming monad.
--
--    Adapted from the paper
--    /Backtracking, Interleaving, and Terminating
--        Monad Transformers/, by
--    Oleg Kiselyov, Chung-chieh Shan, Daniel P. Friedman, Amr Sabry
--    (<http://www.cs.rutgers.edu/~ccshan/logicprog/ListT-icfp2005.pdf>).
-------------------------------------------------------------------------

module ListT (
    ListT(..),
    runListT,
    select,
    fold
  ) where

import GhcPrelude

import Control.Applicative

import Control.Monad
import Control.Monad.Fail as MonadFail

-------------------------------------------------------------------------
-- | A monad transformer for performing backtracking computations
-- layered over another monad 'm'
newtype ListT m a =
    ListT { ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT :: forall r. (a -> m r -> m r) -> m r -> m r }

select :: Monad m => [a] -> ListT m a
select :: [a] -> ListT m a
select xs :: [a]
xs = (ListT m a -> ListT m a -> ListT m a)
-> ListT m a -> [ListT m a] -> ListT m a
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ListT m a -> ListT m a -> ListT m a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>) ListT m a
forall (m :: * -> *) a. MonadPlus m => m a
mzero ((a -> ListT m a) -> [a] -> [ListT m a]
forall a b. (a -> b) -> [a] -> [b]
map a -> ListT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [a]
xs)

fold :: ListT m a -> (a -> m r -> m r) -> m r -> m r
fold :: ListT m a -> (a -> m r -> m r) -> m r -> m r
fold = ListT m a -> (a -> m r -> m r) -> m r -> m r
forall (m :: * -> *) a r.
ListT m a -> (a -> m r -> m r) -> m r -> m r
runListT

-------------------------------------------------------------------------
-- | Runs a ListT computation with the specified initial success and
-- failure continuations.
runListT :: ListT m a -> (a -> m r -> m r) -> m r -> m r
runListT :: ListT m a -> (a -> m r -> m r) -> m r -> m r
runListT = ListT m a -> (a -> m r -> m r) -> m r -> m r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT

instance Functor (ListT f) where
    fmap :: (a -> b) -> ListT f a -> ListT f b
fmap f :: a -> b
f lt :: ListT f a
lt = (forall r. (b -> f r -> f r) -> f r -> f r) -> ListT f b
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (b -> f r -> f r) -> f r -> f r) -> ListT f b)
-> (forall r. (b -> f r -> f r) -> f r -> f r) -> ListT f b
forall a b. (a -> b) -> a -> b
$ \sk :: b -> f r -> f r
sk fk :: f r
fk -> ListT f a -> (a -> f r -> f r) -> f r -> f r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT f a
lt (b -> f r -> f r
sk (b -> f r -> f r) -> (a -> b) -> a -> f r -> f r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f) f r
fk

instance Applicative (ListT f) where
    pure :: a -> ListT f a
pure a :: a
a = (forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a)
-> (forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a
forall a b. (a -> b) -> a -> b
$ \sk :: a -> f r -> f r
sk fk :: f r
fk -> a -> f r -> f r
sk a
a f r
fk
    f :: ListT f (a -> b)
f <*> :: ListT f (a -> b) -> ListT f a -> ListT f b
<*> a :: ListT f a
a = (forall r. (b -> f r -> f r) -> f r -> f r) -> ListT f b
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (b -> f r -> f r) -> f r -> f r) -> ListT f b)
-> (forall r. (b -> f r -> f r) -> f r -> f r) -> ListT f b
forall a b. (a -> b) -> a -> b
$ \sk :: b -> f r -> f r
sk fk :: f r
fk -> ListT f (a -> b) -> ((a -> b) -> f r -> f r) -> f r -> f r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT f (a -> b)
f (\g :: a -> b
g fk' :: f r
fk' -> ListT f a -> (a -> f r -> f r) -> f r -> f r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT f a
a (b -> f r -> f r
sk (b -> f r -> f r) -> (a -> b) -> a -> f r -> f r
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
g) f r
fk') f r
fk

instance Alternative (ListT f) where
    empty :: ListT f a
empty = (forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a)
-> (forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a
forall a b. (a -> b) -> a -> b
$ \_ fk :: f r
fk -> f r
fk
    f1 :: ListT f a
f1 <|> :: ListT f a -> ListT f a -> ListT f a
<|> f2 :: ListT f a
f2 = (forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a)
-> (forall r. (a -> f r -> f r) -> f r -> f r) -> ListT f a
forall a b. (a -> b) -> a -> b
$ \sk :: a -> f r -> f r
sk fk :: f r
fk -> ListT f a -> (a -> f r -> f r) -> f r -> f r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT f a
f1 a -> f r -> f r
sk (ListT f a -> (a -> f r -> f r) -> f r -> f r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT f a
f2 a -> f r -> f r
sk f r
fk)

instance Monad (ListT m) where
    m :: ListT m a
m >>= :: ListT m a -> (a -> ListT m b) -> ListT m b
>>= f :: a -> ListT m b
f = (forall r. (b -> m r -> m r) -> m r -> m r) -> ListT m b
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (b -> m r -> m r) -> m r -> m r) -> ListT m b)
-> (forall r. (b -> m r -> m r) -> m r -> m r) -> ListT m b
forall a b. (a -> b) -> a -> b
$ \sk :: b -> m r -> m r
sk fk :: m r
fk -> ListT m a -> (a -> m r -> m r) -> m r -> m r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT m a
m (\a :: a
a fk' :: m r
fk' -> ListT m b -> (b -> m r -> m r) -> m r -> m r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT (a -> ListT m b
f a
a) b -> m r -> m r
sk m r
fk') m r
fk
#if !MIN_VERSION_base(4,13,0)
    fail = MonadFail.fail
#endif

instance MonadFail.MonadFail (ListT m) where
    fail :: String -> ListT m a
fail _ = (forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a)
-> (forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
forall a b. (a -> b) -> a -> b
$ \_ fk :: m r
fk -> m r
fk

instance MonadPlus (ListT m) where
    mzero :: ListT m a
mzero = (forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a)
-> (forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
forall a b. (a -> b) -> a -> b
$ \_ fk :: m r
fk -> m r
fk
    m1 :: ListT m a
m1 mplus :: ListT m a -> ListT m a -> ListT m a
`mplus` m2 :: ListT m a
m2 = (forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
forall (m :: * -> *) a.
(forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
ListT ((forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a)
-> (forall r. (a -> m r -> m r) -> m r -> m r) -> ListT m a
forall a b. (a -> b) -> a -> b
$ \sk :: a -> m r -> m r
sk fk :: m r
fk -> ListT m a -> (a -> m r -> m r) -> m r -> m r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT m a
m1 a -> m r -> m r
sk (ListT m a -> (a -> m r -> m r) -> m r -> m r
forall (m :: * -> *) a.
ListT m a -> forall r. (a -> m r -> m r) -> m r -> m r
unListT ListT m a
m2 a -> m r -> m r
sk m r
fk)