{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE DeriveFunctor              #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving         #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE UndecidableInstances       #-}
{-# LANGUAGE FunctionalDependencies     #-}

-- | The Search monad and SearchT monad transformer allow computations
-- to be associated with costs and cost estimates, and explore
-- possible solutions in order of overall cost.  The solution space is
-- explored using the A* algorithm, or Dijkstra's if estimates are
-- omitted.  The order of exploring computations with equal cost is
-- not defined.
--
-- Costs must be monotonic (i.e. positive) and underestimated.  If the
-- cost of a computation is overestimated or a negative cost is
-- applied, sub-optimal solutions may be produced first.
--
-- Note that while 'runSearchT' will produce a lazy list of results
-- and the computation space is only explored as far as the list is
-- forced, using 'runSearchT' with e.g. the 'IO' base monad will not.
-- You need to use 'collapse' or 'abandon' to prune the search space
-- within the monadic computation.
--
-- Example:
--
-- > import Control.Monad.Search
-- > import Data.Monoid (Sum(..))
-- >
-- > -- All naturals, weighted by the size of the number
-- > naturals :: Search (Sum Integer) Integer
-- > naturals = return 0 <|> (cost' (Sum 1) >> ((+ 1) <$> naturals))
-- >   -- [ 0, 1, 2, 3, 4, 5, ... ]
-- >
-- > -- All pairs of naturals
-- > pairs :: Search (Sum Integer) (Integer, Integer)
-- > pairs = (,) <$> naturals <*> naturals
-- >   --    [ (0, 0), (1, 0), (0, 1), (1, 1), (2, 0), ... ]
-- >   -- or [ (0, 0), (0, 1), (1, 0), (2, 0), (1, 1), ... ]
-- >   -- or ...
module Control.Monad.Search
    ( -- * The Search monad
      Search
    , runSearch
    , runSearchBest
      -- * The SearchT monad transformer
    , SearchT
    , runSearchT
    , runSearchBestT
      -- * MonadClass and search monad operations
    , MonadSearch
    , cost
    , cost'
    , junction
    , abandon
    , seal
    , collapse
    , winner
    ) where

import           Control.Applicative             ( Alternative(..) )
import           Control.Monad                   ( MonadPlus(..) )

import           Control.Monad.Cont              ( MonadCont )
import           Control.Monad.Except            ( ExceptT(..), MonadError
                                                 , runExceptT )
import           Control.Monad.IO.Class          ( MonadIO )

import qualified Control.Monad.RWS.Lazy          as Lazy ( MonadRWS, RWST(..)
                                                         , runRWST )
import qualified Control.Monad.RWS.Strict        as Strict ( RWST(..), runRWST )
import           Control.Monad.Reader            ( MonadReader, ReaderT(..)
                                                 , runReaderT )

import qualified Control.Monad.State.Lazy        as Lazy ( MonadState
                                                         , StateT(..)
                                                         , runStateT )
import qualified Control.Monad.State.Strict      as Strict ( StateT(..)
                                                           , runStateT )
import           Control.Monad.Trans.Class       ( MonadTrans, lift )
import           Control.Monad.Trans.Free        ( FreeF(Free, Pure), FreeT
                                                 , runFreeT, wrap )
import           Control.Monad.Trans.Free.Church ( FT, fromFT )
import           Control.Monad.Trans.State       ( evalStateT, gets, modify )

import qualified Control.Monad.Writer.Lazy       as Lazy ( MonadWriter
                                                         , WriterT(..)
                                                         , runWriterT )
import qualified Control.Monad.Writer.Strict     as Strict ( WriterT(..)
                                                           , runWriterT )

import           Data.Functor.Identity           ( Identity, runIdentity )
import qualified Data.IntPSQ                     as PSQ
import qualified Data.IntMap.Strict              as Map
import           Data.Maybe                      ( catMaybes, listToMaybe )
import qualified Data.IntSet                     as Set

newtype Scopemap = Scopemap { unScopemap :: Map.IntMap Set.IntSet }

singleton  :: Int -> Int -> Scopemap
singleton k = Scopemap . Map.singleton k . Set.singleton

insert :: Int -> Int -> Scopemap -> Scopemap
insert k v = Scopemap . Map.alter fn k . unScopemap
  where
    fn = Just . maybe (Set.singleton v) (Set.insert v)

delete :: Int -> Int -> Scopemap -> Scopemap
delete k v = Scopemap . Map.update fn k . unScopemap
  where
    fn = (\s -> if Set.null s then Nothing else Just s) . Set.delete v

list :: Int -> Scopemap -> [Int]
list k = maybe [] Set.toList . Map.lookup k . unScopemap

listAll :: Scopemap -> [Int]
listAll = Set.toList . foldr (Set.union . snd) Set.empty . Map.toList . unScopemap

-- | The Search monad
type Search c = SearchT c Identity

-- | Generate all solutions in order of increasing cost.
runSearch :: (Ord c, Monoid c) => Search c a -> [(c, a)]
runSearch = runIdentity . runSearchT

-- | Generate only the best solution.
runSearchBest :: (Ord c, Monoid c) => Search c a -> Maybe (c, a)
runSearchBest = runIdentity . runSearchBestT

-- | Functor for the Free monad SearchT
data SearchF c a = Cost c c a
                 | Alt a a
                 | Enter a
                 | Exit a
                 | Collapse a
                 | Abandon
    deriving Functor

-- | The SearchT monad transformer
newtype SearchT c m a = SearchT { unSearchT :: FT (SearchF c) m a }
    deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadReader r, Lazy.MonadWriter w, Lazy.MonadState s, MonadError e, MonadCont)

instance (Ord c, Monoid c, Monad m) => Alternative (SearchT c m) where
    empty = abandon
    (<|>) = junction

instance (Ord c, Monoid c, Monad m) => MonadPlus (SearchT c m)

deriving instance Lazy.MonadRWS r w s m => Lazy.MonadRWS r w s (SearchT c m)

-- | Value type for A*/Dijkstra priority queue
data Cand c m a = Cand { candCost  :: !c
                       , candScope :: ![Int]
                       , candPath  :: FreeT (SearchF c) m a
                       }

-- | State used during evaluation of SearchT
data St c m a = St { stNum    :: !Int
                   , stScope  :: !Int
                   , stActive :: !Scopemap
                   , stQueue  :: !(PSQ.IntPSQ c (Cand c m a))
                   }

-- | Generate all solutions in order of increasing cost.
runSearchT :: (Ord c, Monoid c, Monad m) => SearchT c m a -> m [(c, a)]
runSearchT m = catMaybes <$> evalStateT go state
  where
    go = do
        mmin <- gets (PSQ.minView . stQueue)
        case mmin of
            Nothing -> return []
            Just (num, prio, cand, q) -> do
                updateQueue $ const q
                (:) <$> step num prio cand <*> go

    step num prio cand@Cand{..} = do
        path' <- lift $ runFreeT candPath
        case path' of
            Pure a -> return $ Just (candCost, a)
            Free Abandon -> return Nothing
            Free (Cost c e p) ->
                let newCost = candCost `mappend` c
                    newPriority = max prio $ newCost `mappend` e
                in do
                    reschedule <- gets (maybe False
                                              (\(_, x, _) -> x <= newPriority) .
                                            PSQ.findMin . stQueue)
                    let cand' = cand { candCost = newCost, candPath = p }
                    if reschedule
                        then do
                            updateQueue $ PSQ.insert num newPriority cand'
                            return Nothing
                        else step num newPriority cand'
            Free (Alt lhs rhs) -> do
                num' <- nextNum
                addScopes candScope num'
                updateQueue $ PSQ.insert num' prio cand { candPath = rhs }
                step num prio cand { candPath = lhs }
            Free (Enter p) -> do
                scope <- nextScope
                addScope scope num
                step num
                     prio
                     cand { candScope = scope : candScope, candPath = p }
            Free (Exit p) -> do
                delScope (head candScope) num
                step num prio cand { candScope = tail candScope, candPath = p }
            Free (Collapse p) -> do
                cs <- listScope $ listToMaybe candScope
                updateQueue $ \q -> foldr PSQ.delete q cs
                step num prio cand { candPath = p }

    nextNum = do
        modify $ \s -> s { stNum = stNum s + 1 }
        gets stNum

    nextScope = do
        modify $ \s -> s { stScope = stScope s + 1 }
        gets stScope

    addScope scope c = updateActive $ insert scope c

    addScopes scopes c = updateActive $ \sm -> foldr (`insert` c) sm scopes

    delScope scope c = updateActive $ delete scope c

    listScope scope = gets $ maybe listAll list scope . stActive

    updateQueue f = modify $ \s -> s { stQueue = f (stQueue s) }

    updateActive f = modify $ \s -> s { stActive = f (stActive s) }

    state = St 0 0 (singleton 0 0) queue

    queue = PSQ.singleton 0 mempty (Cand mempty [ 0 ] (fromFT $ unSearchT m))

-- | Generate only the best solutions.
runSearchBestT :: (Ord c, Monoid c, Monad m) => SearchT c m a -> m (Maybe (c, a))
runSearchBestT m = listToMaybe <$> runSearchT (m <* collapse)

-- | Minimal definition is @cost@, @junction@, and @abandon@.
class (Ord c, Monoid c, Monad m) => MonadSearch c m | m -> c where
    -- | Mark a computation with a definitive cost and additional
    -- estimated cost.  Definitive costs are accumulated and reported,
    -- while the estimate is reset with every call to `cost` and will
    -- not be included in the final result.
    cost :: c -> c -> m ()

    -- | Introduce an alternative computational path to be evaluated
    -- concurrently.
    junction :: m a -> m a -> m a

    -- | Abandon a computation.
    abandon :: m a

    -- | Limit the effect of `collapse` to alternatives within the
    -- sealed scope.
    seal :: m a -> m a

    -- | Abandon all other computations within the current sealed
    -- scope.
    collapse :: m ()

instance (Ord c, Monoid c, Monad m) => MonadSearch c (SearchT c m) where
    cost c e = SearchT . wrap $ Cost c e (return ())
    junction lhs rhs = SearchT . wrap $ Alt (unSearchT lhs) (unSearchT rhs)
    abandon = SearchT . wrap $ Abandon
    seal m = SearchT . wrap $ Enter (unSearchT m >>= wrap . Exit . return)
    collapse = SearchT . wrap $ Collapse (return ())

instance MonadSearch c m => MonadSearch c (ReaderT r m) where
    cost c e = lift $ cost c e
    junction lhs rhs = ReaderT $
        \r -> junction (runReaderT lhs r) (runReaderT rhs r)
    abandon = lift abandon
    seal m = ReaderT $ \r -> seal (runReaderT m r)
    collapse = lift collapse

instance (Monoid w, MonadSearch c m) => MonadSearch c (Lazy.WriterT w m) where
    cost c e = lift $ cost c e
    junction lhs rhs = Lazy.WriterT $
        junction (Lazy.runWriterT lhs) (Lazy.runWriterT rhs)
    abandon = lift abandon
    seal m = Lazy.WriterT $ seal (Lazy.runWriterT m)
    collapse = lift collapse

instance (Monoid w, MonadSearch c m) => MonadSearch c (Strict.WriterT w m) where
    cost c e = lift $ cost c e
    junction lhs rhs = Strict.WriterT $
        junction (Strict.runWriterT lhs) (Strict.runWriterT rhs)
    abandon = lift abandon
    seal m = Strict.WriterT $ seal (Strict.runWriterT m)
    collapse = lift collapse

instance MonadSearch c m => MonadSearch c (Lazy.StateT s m) where
    cost c e = lift $ cost c e
    junction lhs rhs = Lazy.StateT $
        \s -> junction (Lazy.runStateT lhs s) (Lazy.runStateT rhs s)
    abandon = lift abandon
    seal m = Lazy.StateT $ \s -> seal (Lazy.runStateT m s)
    collapse = lift collapse

instance MonadSearch c m => MonadSearch c (Strict.StateT s m) where
    cost c e = lift $ cost c e
    junction lhs rhs = Strict.StateT $
        \s -> junction (Strict.runStateT lhs s) (Strict.runStateT rhs s)
    abandon = lift abandon
    seal m = Strict.StateT $ \s -> seal (Strict.runStateT m s)
    collapse = lift collapse

instance (Monoid w, MonadSearch c m) => MonadSearch c (Lazy.RWST r w s m) where
    cost c e = lift $ cost c e
    junction lhs rhs = Lazy.RWST $
        \r s -> junction (Lazy.runRWST lhs r s) (Lazy.runRWST rhs r s)
    abandon = lift abandon
    seal m = Lazy.RWST $ \r s -> seal (Lazy.runRWST m r s)
    collapse = lift collapse

instance (Monoid w, MonadSearch c m) => MonadSearch c (Strict.RWST r w s m) where
    cost c e = lift $ cost c e
    junction lhs rhs = Strict.RWST $
        \r s -> junction (Strict.runRWST lhs r s) (Strict.runRWST rhs r s)
    abandon = lift abandon
    seal m = Strict.RWST $ \r s -> seal (Strict.runRWST m r s)
    collapse = lift collapse

instance MonadSearch c m => MonadSearch c (ExceptT e m) where
    cost c e = lift $ cost c e
    junction lhs rhs = ExceptT $ junction (runExceptT lhs) (runExceptT rhs)
    abandon = lift abandon
    seal m = ExceptT $ seal (runExceptT m)
    collapse = lift collapse

-- | Mark an operation with a cost.
--
-- > cost' c = cost c mempty
cost' :: MonadSearch c m => c -> m ()
cost' c = cost c mempty

-- | Limit a given computation to the first successful return.
--
-- > winner m = seal (m <* collapse)
winner :: MonadSearch c m => m a -> m a
winner m = seal $ m <* collapse