{-# LANGUAGE UndecidableInstances, Rank2Types, FlexibleInstances, FlexibleContexts, GADTs, ScopedTypeVariables, FunctionalDependencies #-}

-------------------------------------------------------------------------
-- |
-- Module      : Control.Monad.LogicState
-- Copyright   : (c) Atze Dijkstra
-- License     : BSD3
--
-- Maintainer  : atzedijkstra@gmail.com
-- Stability   : experimental, (as of 20160218) under development
-- Portability : non-portable
--
-- A backtracking, logic programming monad partially derived and on top of logict, adding backtrackable state.
--
--    LogicT (and thus this library as well) is adapted from the paper
--    /Backtracking, Interleaving, and Terminating
--        Monad Transformers/, by
--    Oleg Kiselyov, Chung-chieh Shan, Daniel P. Friedman, Amr Sabry
--    (<http://www.cs.rutgers.edu/~ccshan/logicprog/LogicT-icfp2005.pdf>).
--
-- 
-------------------------------------------------------------------------

module Control.Monad.LogicState (
    module Control.Monad.Logic.Class,
    module Control.Monad,
    module Control.Monad.Trans,
    module Control.Monad.LogicState.Class,
    module Control.Monad.TransLogicState.Class,
    -- * The LogicState monad
    LogicState,
    LogicStateT(..),
  ) where

import Data.Maybe
import Data.Typeable

import Control.Applicative

import Control.Monad
import Control.Monad.Identity
import Control.Monad.Trans

import Control.Monad.State
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Error.Class

import Data.Monoid (Monoid(mappend, mempty))
import qualified Data.Foldable as F
import qualified Data.Traversable as T

import Control.Monad.Logic.Class

import Control.Monad.LogicState.Class
import Control.Monad.TransLogicState.Class

-------------------------------------------------------------------------
-- | A monad transformer for performing backtracking computations
-- layered over another monad 'm', with propagation of global and backtracking state, e.g. resp. for freshness/uniqueness and maintaining variable mappings.
newtype LogicStateT gs bs m a =
    LogicStateT { forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT ::
      forall r. LogicContS gs bs r m a
    }

-- | Convenience types
type LogicStateS gs bs r m = StateT (gs,bs) m r -- (gs,bs) -> m (r,(gs,bs))
type LogicContS gs bs r m a =
           (   a                                 --  result
            -> LogicStateS gs bs r m             --  failure continuation
            -> LogicStateS gs bs r m
           )                                     -- ^ success continuation
        -> LogicStateS gs bs r m                 -- ^ failure continuation
        -> LogicStateS gs bs r m                 -- ^ global + backtracking state

instance Functor (LogicStateT gs bs f) where
    fmap :: forall a b.
(a -> b) -> LogicStateT gs bs f a -> LogicStateT gs bs f b
fmap a -> b
f LogicStateT gs bs f a
lt = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \b -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk -> forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs f a
lt (b -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)

instance Applicative (LogicStateT gs bs f) where
    pure :: forall a. a -> LogicStateT gs bs f a
pure a
a = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \a -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk -> a -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk a
a
    LogicStateT gs bs f (a -> b)
f <*> :: forall a b.
LogicStateT gs bs f (a -> b)
-> LogicStateT gs bs f a -> LogicStateT gs bs f b
<*> LogicStateT gs bs f a
a = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \b -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk -> forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs f (a -> b)
f (\a -> b
g -> forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs f a
a (b -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
g))

instance Monad (LogicStateT gs bs m) where
    return :: forall a. a -> LogicStateT gs bs m a
return a
a = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT (forall a b. (a -> b) -> a -> b
$ a
a)
    LogicStateT gs bs m a
m >>= :: forall a b.
LogicStateT gs bs m a
-> (a -> LogicStateT gs bs m b) -> LogicStateT gs bs m b
>>= a -> LogicStateT gs bs m b
f = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \b -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk -> forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs m a
m (\a
a -> forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT (a -> LogicStateT gs bs m b
f a
a) b -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk)

instance MonadFail (LogicStateT gs bs m) where
    fail :: forall a. String -> LogicStateT gs bs m a
fail String
_ = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. a -> b -> a
const

instance Alternative (LogicStateT gs bs f) where
    empty :: forall a. LogicStateT gs bs f a
empty = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. a -> b -> a
const
    -- state backtracking variant, but in general interacts badly with other combinators using msplit. Backtracking separately available.
    -- f1 <|> f2 = LogicStateT $ \sk fk -> StateT $ \s@(_,bs) -> runStateT (unLogicStateT f1 sk (StateT $ \(gs',_) -> runStateT (unLogicStateT f2 sk fk) (gs',bs))) s
    LogicStateT gs bs f a
f1 <|> :: forall a.
LogicStateT gs bs f a
-> LogicStateT gs bs f a -> LogicStateT gs bs f a
<|> LogicStateT gs bs f a
f2 = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \a -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk LogicStateS gs bs r f
fk -> forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs f a
f1 a -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk (forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs f a
f2 a -> LogicStateS gs bs r f -> LogicStateS gs bs r f
sk LogicStateS gs bs r f
fk)

instance MonadPlus (LogicStateT gs bs m) where
  mzero :: forall a. LogicStateT gs bs m a
mzero = forall (f :: * -> *) a. Alternative f => f a
empty
  {-# INLINE mzero #-}
  mplus :: forall a.
LogicStateT gs bs m a
-> LogicStateT gs bs m a -> LogicStateT gs bs m a
mplus = forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>)
  {-# INLINE mplus #-}

instance MonadTrans (LogicStateT gs bs) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> LogicStateT gs bs m a
lift m a
m = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk LogicStateS gs bs r m
fk -> forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \(gs, bs)
s -> m a
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
a -> forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk a
a LogicStateS gs bs r m
fk) (gs, bs)
s

instance (MonadIO m) => MonadIO (LogicStateT gs bs m) where
    liftIO :: forall a. IO a -> LogicStateT gs bs m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

instance MonadReader r m => MonadReader r (LogicStateT gs bs m) where
    ask :: LogicStateT gs bs m r
ask = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall r (m :: * -> *). MonadReader r m => m r
ask
    local :: forall a.
(r -> r) -> LogicStateT gs bs m a -> LogicStateT gs bs m a
local r -> r
f LogicStateT gs bs m a
m = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk LogicStateS gs bs r m
fk -> forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT forall a b. (a -> b) -> a -> b
$ forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs m a
m (\a
a LogicStateS gs bs r m
fk -> forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local r -> r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk a
a LogicStateS gs bs r m
fk)) (forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local r -> r
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT LogicStateS gs bs r m
fk)

instance (Monad m) => MonadLogic (LogicStateT gs bs m) where
    msplit :: forall a.
LogicStateT gs bs m a
-> LogicStateT gs bs m (Maybe (a, LogicStateT gs bs m a))
msplit LogicStateT gs bs m a
m =
       forall s (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(TransLogicState s t, Monad m) =>
(s -> m (a, s)) -> t m a
liftWithState forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT forall a b. (a -> b) -> a -> b
$ forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs m a
m
         (\a
a LogicStateS gs bs (Maybe (a, LogicStateT gs bs m a)) m
fk -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (a
a, forall s (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(TransLogicState s t, Monad m) =>
(s -> m (a, s)) -> t m a
liftWithState (forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT LogicStateS gs bs (Maybe (a, LogicStateT gs bs m a)) m
fk) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) a. Alternative m => Maybe (a, m a) -> m a
reflect)))
         (forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing)

instance TransLogicState (gs,bs) (LogicStateT gs bs) where
  -- observe s lt = runIdentity $ evalStateT (unLogicStateT lt (\a _ -> return a) (error "No answer.")) s

  observeT :: forall (m :: * -> *) a.
MonadFail m =>
(gs, bs) -> LogicStateT gs bs m a -> m a
observeT (gs, bs)
s LogicStateT gs bs m a
lt = forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs m a
lt (\a
a LogicStateS gs bs a m
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return a
a) (forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No answer.")) (gs, bs)
s
    
  observeStateAllT :: forall (m :: * -> *) a.
Monad m =>
(gs, bs) -> LogicStateT gs bs m a -> m ([a], (gs, bs))
observeStateAllT (gs, bs)
s LogicStateT gs bs m a
m = forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs m a
m
    (\a
a LogicStateS gs bs [a] m
fk -> LogicStateS gs bs [a] m
fk forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \[a]
as -> forall (m :: * -> *) a. Monad m => a -> m a
return (a
aforall a. a -> [a] -> [a]
:[a]
as))
    (forall (m :: * -> *) a. Monad m => a -> m a
return []))
    (gs, bs)
s

  observeStateManyT :: forall (m :: * -> *) a.
Monad m =>
(gs, bs) -> Int -> LogicStateT gs bs m a -> m ([a], (gs, bs))
observeStateManyT (gs, bs)
s Int
n LogicStateT gs bs m a
m = forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (Int -> LogicStateT gs bs m a -> StateT (gs, bs) m [a]
obs Int
n LogicStateT gs bs m a
m) (gs, bs)
s
   where
     obs :: Int -> LogicStateT gs bs m a -> StateT (gs, bs) m [a]
obs Int
n LogicStateT gs bs m a
m
        | Int
n forall a. Ord a => a -> a -> Bool
<= Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return []
        | Int
n forall a. Eq a => a -> a -> Bool
== Int
1 = forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs m a
m (\a
a StateT (gs, bs) m [a]
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return [a
a]) (forall (m :: * -> *) a. Monad m => a -> m a
return [])
        | Bool
otherwise = forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT (forall (m :: * -> *) a. MonadLogic m => m a -> m (Maybe (a, m a))
msplit LogicStateT gs bs m a
m) Maybe (a, LogicStateT gs bs m a)
-> StateT (gs, bs) m [a] -> StateT (gs, bs) m [a]
sk (forall (m :: * -> *) a. Monad m => a -> m a
return [])
     
     sk :: Maybe (a, LogicStateT gs bs m a)
-> StateT (gs, bs) m [a] -> StateT (gs, bs) m [a]
sk Maybe (a, LogicStateT gs bs m a)
Nothing StateT (gs, bs) m [a]
_ = forall (m :: * -> *) a. Monad m => a -> m a
return []
     sk (Just (a
a, LogicStateT gs bs m a
m')) StateT (gs, bs) m [a]
_ = forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \(gs, bs)
s -> (\[a]
as -> (a
aforall a. a -> [a] -> [a]
:[a]
as,(gs, bs)
s)) forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` forall s (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(TransLogicState s t, Monad m) =>
s -> Int -> t m a -> m [a]
observeManyT (gs, bs)
s (Int
nforall a. Num a => a -> a -> a
-Int
1) LogicStateT gs bs m a
m'

  liftWithState :: forall (m :: * -> *) a.
Monad m =>
((gs, bs) -> m (a, (gs, bs))) -> LogicStateT gs bs m a
liftWithState (gs, bs) -> m (a, (gs, bs))
m = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk LogicStateS gs bs r m
fk -> forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \(gs, bs)
s -> (gs, bs) -> m (a, (gs, bs))
m (gs, bs)
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(a
a,(gs, bs)
s) -> forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk a
a LogicStateS gs bs r m
fk) (gs, bs)
s
  {-# INLINE liftWithState #-}

instance Monad m => MonadState (gs,bs) (LogicStateT gs bs m) where
    get :: LogicStateT gs bs m (gs, bs)
get   = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \(gs, bs) -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk LogicStateS gs bs r m
fk -> forall s (m :: * -> *). MonadState s m => m s
get forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(gs, bs)
s -> (gs, bs) -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk (gs, bs)
s LogicStateS gs bs r m
fk
    put :: (gs, bs) -> LogicStateT gs bs m ()
put (gs, bs)
s = forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \() -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk LogicStateS gs bs r m
fk -> forall s (m :: * -> *). MonadState s m => s -> m ()
put (gs, bs)
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \()
a -> () -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk ()
a LogicStateS gs bs r m
fk

instance (Monad m) => MonadLogicState (,) gs bs m (LogicStateT gs bs m) where
    backtrackWithRoll :: forall a.
(gs -> bs -> bs -> m bs)
-> LogicStateT gs bs m a
-> LogicStateT gs bs m (LogicStateT gs bs m a)
backtrackWithRoll gs -> bs -> bs -> m bs
roll LogicStateT gs bs m a
m = do
      (gs
_,bs
bs1) <- forall s (m :: * -> *). MonadState s m => m s
get
      forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall gs bs (m :: * -> *) a.
(forall r. LogicContS gs bs r m a) -> LogicStateT gs bs m a
LogicStateT forall a b. (a -> b) -> a -> b
$ \a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk LogicStateS gs bs r m
fk ->
        forall s (m :: * -> *) a. (s -> m (a, s)) -> StateT s m a
StateT forall a b. (a -> b) -> a -> b
$ \(gs
gs2,bs
bs2) -> do
          bs
bs <- gs -> bs -> bs -> m bs
roll gs
gs2 bs
bs2 bs
bs1
          forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (forall gs bs (m :: * -> *) a.
LogicStateT gs bs m a -> forall r. LogicContS gs bs r m a
unLogicStateT LogicStateT gs bs m a
m a -> LogicStateS gs bs r m -> LogicStateS gs bs r m
sk LogicStateS gs bs r m
fk) (gs
gs2,bs
bs)


-------------------------------------------------------------------------
-- | The basic LogicVar monad, for performing backtracking computations
-- returning values of type 'a'
type LogicState gs bs = LogicStateT gs bs Identity