{-# LANGUAGE DeriveAnyClass, DeriveFunctor, FlexibleInstances, GeneralizedNewtypeDeriving, KindSignatures, MultiParamTypeClasses, RankNTypes, TypeOperators, UndecidableInstances #-}
module Control.Effect.NonDet
( NonDet(..)
, Alternative(..)
, runNonDet
, NonDetC(..)
) where
import Control.Applicative (Alternative(..))
import Control.Effect.Carrier
import Control.Effect.Sum
import Control.Monad (MonadPlus(..))
import Control.Monad.Fail
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Prelude hiding (fail)
data NonDet (m :: * -> *) k
= Empty
| Choose (Bool -> k)
deriving (Functor, HFunctor, Effect)
runNonDet :: (Alternative f, Applicative m) => NonDetC m a -> m (f a)
runNonDet (NonDetC m) = m (fmap . (<|>) . pure) (pure empty)
newtype NonDetC m a = NonDetC
{
runNonDetC :: forall b . (a -> m b -> m b) -> m b -> m b
}
deriving (Functor)
instance Applicative (NonDetC m) where
pure a = NonDetC (\ cons -> cons a)
{-# INLINE pure #-}
NonDetC f <*> NonDetC a = NonDetC $ \ cons ->
f (\ f' -> a (cons . f'))
{-# INLINE (<*>) #-}
instance Alternative (NonDetC m) where
empty = NonDetC (\ _ nil -> nil)
{-# INLINE empty #-}
NonDetC l <|> NonDetC r = NonDetC $ \ cons -> l cons . r cons
{-# INLINE (<|>) #-}
instance Monad (NonDetC m) where
NonDetC a >>= f = NonDetC $ \ cons ->
a (\ a' -> runNonDetC (f a') cons)
{-# INLINE (>>=) #-}
instance MonadFail m => MonadFail (NonDetC m) where
fail s = NonDetC (\ _ _ -> fail s)
{-# INLINE fail #-}
instance MonadIO m => MonadIO (NonDetC m) where
liftIO io = NonDetC (\ cons nil -> liftIO io >>= flip cons nil)
{-# INLINE liftIO #-}
instance MonadPlus (NonDetC m)
instance MonadTrans NonDetC where
lift m = NonDetC (\ cons nil -> m >>= flip cons nil)
{-# INLINE lift #-}
instance (Carrier sig m, Effect sig) => Carrier (NonDet :+: sig) (NonDetC m) where
eff (L Empty) = empty
eff (L (Choose k)) = k True <|> k False
eff (R other) = NonDetC $ \ cons nil -> eff (handle [()] (fmap concat . traverse runNonDet) other) >>= foldr cons nil
{-# INLINE eff #-}