{-| Module : Control.Monad.AStar Description : Provides a monad transformer for A* search. See README for details and examples. Copyright : (c) Chris Penner, 2019 License : BSD3 See the for usage info and examples. -} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE DerivingStrategies #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} module Control.Monad.AStar ( -- * Types AStar , AStarT , BranchState (..) -- * Methods , MonadAStar(..) , branch , failure -- * Executing Search , runAStar , runAStarT , evalAStar , evalAStarT , execAStar , execAStarT -- , tryWhile -- , tryWhileT ) where import Control.Monad.AStar.Class import Control.Monad.Except import Control.Monad.Fail import Control.Monad.Logic import Control.Applicative import Control.Monad.State import Data.Functor.Identity import Data.Maybe data Step r a = Pure a | Checkpoint | Solved r deriving (Show, Functor, Eq) data BranchState s c = BranchState { branchState :: s , cumulativeCost :: c , estimateTillDone :: c } deriving (Show, Eq) totalCost :: Semigroup c => BranchState s c -> c totalCost bs = cumulativeCost bs <> estimateTillDone bs -- | Non-transformer version of 'AStarT' type AStar s c r a = AStarT s c r Identity a -- | The 'AStar' search monad transformer -- -- Lots of type variables here: -- -- @s@: State; keep anything you want in here, it will stay coherent across -- branch switches. -- -- @c@: Cost measure: The type you'll use for both your heuristic estimate and for your -- accumulated total cost. -- Usually requires 'Ord' (for comparing branch costs) and 'Monoid' (for appending and -- instantiating cumulative costs). -- -- @r@: Result type, this is often redundant to State but is provided for convenience. -- This is the type you pass to 'done' when you've found a solution. -- -- @m@: An arbitrary monad which will be threaded through. -- -- Be wary that effects will be run in seemingly non-deterministic ordering as we switch -- chaotically between branches. newtype AStarT s c r m a = AStarT { unAStarT :: StateT (BranchState s c) (LogicT m) (Step r a) } deriving stock Functor instance MonadTrans (AStarT s c r) where lift m = AStarT . lift . lift $ (Pure <$> m) instance (MonadIO m, Semigroup c, Ord c) => MonadIO (AStarT s c r m) where liftIO io = lift $ liftIO io instance (Monad m, Semigroup c, Ord c) => Applicative (AStarT s c r m) where pure = return (<*>) = ap instance (Ord c, Semigroup c, Monad m) => MonadPlus (AStarT s c r m) where mzero = empty mplus = (<|>) instance (Ord c, Semigroup c, Monad m) => MonadFail (AStarT s c r m) where fail _ = empty instance (Ord c, Semigroup c, Monad m) => MonadState s (AStarT s c r m) where get = AStarT $ Pure . branchState <$> get put s = AStarT $ Pure <$> modify (\bs -> bs{branchState=s}) instance (Monad m, Semigroup c, Ord c) => Monad (AStarT s c r m) where return = AStarT . return . Pure AStarT m >>= f = AStarT $ do next <- msplit m case next of Nothing -> empty (Just (Solved r, _)) -> pure (Solved r) -- Should I interleave these instead? (Just (Pure a, continue)) -> (unAStarT $ f a) <|> unAStarT (AStarT continue >>= f) (Just (Checkpoint, continue)) -> pure Checkpoint <|> unAStarT (AStarT continue >>= f) instance (Ord c, Monad m, Semigroup c) => Alternative (AStarT s c r m) where empty = AStarT empty (<|>) = weightedInterleave weightedInterleave :: (Ord c, Semigroup c, Monad m) => AStarT s c r m a -> AStarT s c r m a -> AStarT s c r m a weightedInterleave (AStarT a) (AStarT b) = AStarT $ weightedInterleave' a b weightedInterleave' :: (Ord c, Semigroup c, MonadLogic m, MonadState (BranchState s c) m) => m (Step r a) -> m (Step r a) -> m (Step r a) weightedInterleave' ma mb = do beforeBoth <- get (rA, lState) <- liftA2 (,) (msplit ma) get put beforeBoth (rB, rState) <- liftA2 (,) (msplit mb) get case (rA, rB) of (m, Nothing) -> put lState >> reflect m (Nothing, m) -> put rState >> reflect m (Just (Solved r, _), _) -> put lState >> pure (Solved r) (_, Just (Solved r, _)) -> put rState >> pure (Solved r) (l@(Just (Checkpoint, lm)), r@(Just (Checkpoint, rm))) | totalCost lState <= totalCost rState ->do (put lState >> pure Checkpoint) <|> ((put lState >> lm) `weightedInterleave'` (put rState >> reflect r)) | otherwise -> (put rState >> pure Checkpoint) <|> ((put rState >> rm) `weightedInterleave'` (put lState >> reflect l)) ((Just (Pure la, lm)), r) -> (put lState >> pure (Pure la)) `interleave` ((put lState >> lm) `weightedInterleave'` (put rState >> reflect r)) (l, (Just (Pure ra, rm))) -> (put rState >> pure (Pure ra)) `interleave` ((put rState >> rm) `weightedInterleave'` (put lState >> reflect l)) -- | Run a pure A* computation returning the solution and branch state if one was found. runAStar :: (Monoid c) => AStar s c r a -> s -> Maybe (r, s) runAStar m s = runIdentity $ runAStarT m s -- | Run an A* computation effect returning the solution and branch state if one was found. runAStarT :: (Monad m, Monoid c) => AStarT s c r m a -> s -> m (Maybe (r, s)) runAStarT (AStarT m) s = fmap listToMaybe . observeManyT 1 $ do runStateT m (BranchState s mempty mempty) >>= \case (Solved r, s) -> return (r, branchState s) _ -> empty evalAStar :: (Monoid c) => AStar s c r a -> s -> Maybe r evalAStar m s = fst <$> runAStar m s evalAStarT :: (Monad m, Monoid c) => AStarT s c r m a -> s -> m (Maybe r) evalAStarT m s = fmap fst <$> runAStarT m s execAStar :: (Monoid c) => AStar s c r a -> s -> Maybe s execAStar m s = fmap snd $ runAStar m s execAStarT :: (Monad m, Monoid c) => AStarT s c r m a -> s -> m (Maybe s) execAStarT m s = fmap snd <$> runAStarT m s -- | Run a pure A* search but short-circuit when the lowest cost fails a predicate. -- -- This is useful for detecting if your search is diverging, or is likely to fail. -- tryWhile :: Monoid c => (s -> c -> Bool) -> AStar s c r a -> s -> (Maybe (r, s)) -- tryWhile p m s = runIdentity $ tryWhileT p m s -- | Effectful version of 'tryWhile' -- tryWhileT :: (Monoid c, Monad m) => (s -> c -> Bool) -> AStarT s c r m a -> s -> m (Maybe (r, s)) -- tryWhileT p m s = fmap (second branchState) <$> tryWhileT' p m (BranchState s mempty mempty) -- tryWhileT' :: (Monoid c, Monad m) => (s -> c -> Bool) -> AStarT s c r m a -> BranchState s c -> m (Maybe (r, BranchState s c)) -- tryWhileT' p m s = do -- stepAStar m s >>= \case -- Nothing -> return Nothing -- Just ((Pure _, s), continue) -> tryWhileT' p continue s -- Just ((Checkpoint, bs), continue) -> do -- if p (branchState bs) (totalCost bs) then tryWhileT' p continue s -- else return Nothing -- Just ((Solved r, s), _) -> return (Just (r, s)) -- NOTE probably doesn't handle state properly in returned continuation... -- stepAStar :: (Monoid c, Monad m) => AStarT s c r m a -> BranchState s c -> m (Maybe ((Step r a, BranchState s c), AStarT s c r m a)) -- stepAStar (AStarT m) s = fmap (fmap go) . observeT . _ $ msplit (evalStateT m s) -- where -- go (v, x) = (v, AStarT (lift x)) instance (Ord c, Monoid c, Monad m) => MonadAStar c r (AStarT s c r m) where estimate c = AStarT $ modify (\bs -> bs{estimateTillDone=c}) >> (pure Checkpoint <|> pure (Pure ())) spend c = AStarT $ modify (\bs -> bs{cumulativeCost=cumulativeCost bs <> c}) >> (pure Checkpoint <|> pure (Pure ())) done = AStarT . pure . Solved