{-# LANGUAGE CPP, Rank2Types, TypeFamilies #-}
module Control.Monad.ST.Logic.Internal
       ( LogicT
       , runLogicT
       , observeT
       , observeAllT
       , observeManyT
       , liftST
       , runLogicST
       , observeST
       , observeAllST
       , observeManyST
       , Ref
       , newRef
       , readRef
       , writeRef
       , modifyRef
       , modifyRef'
       ) where

import Control.Applicative
import Control.Monad
import Control.Monad.IO.Class
import qualified Control.Monad.Logic as Logic
import Control.Monad.Logic.Class
#ifdef MODULE_Control_Monad_ST_Safe
import Control.Monad.ST.Safe
#else
import Control.Monad.ST
#endif
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict (StateT, evalStateT)
import qualified Control.Monad.Trans.State.Strict as State

import qualified Data.STRef as ST

class Monad m => MonadST m where
  type World m
  liftST :: ST (World m) a -> m a

instance MonadST (ST s) where
  type World (ST s) = s
  liftST = id

instance MonadST IO where
  type World IO = RealWorld
  liftST = stToIO

newtype LogicT s m a =
  LogicT { unLogicT :: StateT (Switch m) (Logic.LogicT m) a
         }

runLogicT :: MonadST m => (forall s . LogicT s m a) -> (a -> m r -> m r) -> m r -> m r
runLogicT m = unsafeRunLogicT m
{-# INLINE runLogicT #-}

runLogicST :: (forall s . LogicT s (ST s) a) -> (a -> r -> r) -> r -> r
runLogicST m next zero = runST $ unsafeRunLogicT m (liftM . next) (return zero)
{-# INLINE runLogicST #-}

unsafeRunLogicT :: MonadST m => LogicT s m a -> (a -> m r -> m r) -> m r -> m r
unsafeRunLogicT m next zero = do
  s <- newSwitch
  Logic.runLogicT (evalStateT (unLogicT m) s) next zero
{-# SPECIALIZE unsafeRunLogicT :: LogicT s (ST s) a -> (a -> ST s r -> ST s r) -> ST s r -> ST s r #-}
{-# SPECIALIZE unsafeRunLogicT :: LogicT s IO a -> (a -> IO r -> IO r) -> IO r -> IO r #-}

observeT :: MonadST m => (forall s . LogicT s m a) -> m a
observeT m = unsafeObserveT m
{-# INLINE observeT #-}

observeST :: (forall s . LogicT s (ST s) a) -> a
observeST m = runST $ unsafeObserveT m
{-# INLINE observeST #-}

unsafeObserveT :: MonadST m => LogicT s m a -> m a
unsafeObserveT m = do
  s <- newSwitch
  Logic.observeT (evalStateT (unLogicT m) s)
{-# SPECIALIZE unsafeObserveT :: LogicT s (ST s) a -> ST s a #-}
{-# SPECIALIZE unsafeObserveT :: LogicT s IO a -> IO a #-}

observeAllT :: MonadST m => (forall s . LogicT s m a) -> m [a]
observeAllT m = unsafeObserveAllT m
{-# INLINE observeAllT #-}

observeAllST :: (forall s . LogicT s (ST s) a) -> [a]
observeAllST m = runST $ unsafeObserveAllT m
{-# INLINE observeAllST #-}

unsafeObserveAllT :: MonadST m => LogicT s m a -> m [a]
unsafeObserveAllT m = do
  s <- newSwitch
  Logic.observeAllT (evalStateT (unLogicT m) s)
{-# SPECIALIZE unsafeObserveAllT :: LogicT s (ST s) a -> ST s [a] #-}
{-# SPECIALIZE unsafeObserveAllT :: LogicT s IO a -> IO [a] #-}

observeManyT :: MonadST m => Int -> (forall s . LogicT s m a) -> m [a]
observeManyT n m = unsafeObserveManyT n m
{-# INLINE observeManyT #-}

observeManyST :: Int -> (forall s . LogicT s (ST s) a) -> [a]
observeManyST n m = runST $ unsafeObserveManyT n m
{-# INLINE observeManyST #-}

unsafeObserveManyT :: MonadST m => Int -> LogicT s m a -> m [a]
unsafeObserveManyT n m = do
  s <- newSwitch
  Logic.observeManyT n (evalStateT (unLogicT m) s)
{-# SPECIALIZE unsafeObserveManyT :: Int -> LogicT s (ST s) a -> ST s [a] #-}
{-# SPECIALIZE unsafeObserveManyT :: Int -> LogicT s IO a -> IO [a] #-}

instance Functor (LogicT s m) where
  fmap f = LogicT . fmap f . unLogicT
  {-# INLINE fmap #-}

instance Applicative (LogicT s m) where
  pure = LogicT . pure
  {-# INLINE pure #-}
  f <*> a = LogicT $ unLogicT f <*> unLogicT a
  {-# INLINE (<*>) #-}
#ifndef CLASS_OldApplicative
  a *> b = LogicT $ unLogicT a *> unLogicT b
  {-# INLINE (*>) #-}
  a <* b = LogicT $ unLogicT a <* unLogicT b
  {-# INLINE (<*) #-}
#endif

instance MonadST m => Alternative (LogicT s m) where
  empty = LogicT empty
  {-# INLINE empty #-}
  (<|>) = plusLogic
  {-# INLINE (<|>) #-}

instance Monad (LogicT s m) where
  return = LogicT . return
  {-# INLINE return #-}
  m >>= k = LogicT $ unLogicT m >>= unLogicT . k
  {-# INLINE (>>=) #-}
  m >> n = LogicT $ unLogicT m >> unLogicT n
  {-# INLINE (>>) #-}
  fail = LogicT . fail
  {-# INLINE fail #-}

instance MonadST m => MonadPlus (LogicT s m) where
  mzero = LogicT mzero
  {-# INLINE mzero #-}
  mplus = plusLogic
  {-# INLINE mplus #-}

plusLogic :: MonadST m => LogicT s m a -> LogicT s m a -> LogicT s m a
plusLogic m n = do
  s <- newSwitch
  LogicT $ unLogicT (put s *> m) <|> unLogicT (flipSwitch s *> n)
{-# SPECIALIZE plusLogic :: LogicT s (ST s) a -> LogicT s (ST s) a -> LogicT s (ST s) a #-}
{-# SPECIALIZE plusLogic :: LogicT s IO a -> LogicT s IO a -> LogicT s IO a #-}

instance MonadST m => MonadLogic (LogicT s m) where
  msplit = LogicT . fmap (fmap (fmap LogicT)) . msplit . unLogicT
  {-# INLINE msplit #-}

liftLogic :: Monad m => m a -> LogicT s m a
liftLogic = LogicT . lift . lift
{-# SPECIALIZE liftLogic :: ST s a -> LogicT s (ST s) a #-}
{-# SPECIALIZE liftLogic :: IO a -> LogicT s IO a #-}

instance MonadIO m => MonadIO (LogicT s m) where
  liftIO = liftLogic . liftIO

instance MonadST m => MonadST (LogicT s m) where
  type World (LogicT s m) = World m
  liftST = liftLogic . liftST
  {-# INLINE liftST #-}

get :: Monad m => LogicT s m (Switch m)
get = LogicT State.get
{-# SPECIALIZE get :: LogicT s (ST s) (Switch (ST s)) #-}
{-# SPECIALIZE get :: LogicT s IO (Switch IO) #-}

put :: Monad m => Switch m -> LogicT s m ()
put s = s `seq` LogicT (State.put s)
{-# SPECIALIZE put :: Switch (ST s) -> LogicT s (ST s) () #-}
{-# SPECIALIZE put :: Switch IO -> LogicT s IO () #-}

type Switch m = ST.STRef (World m) Bool

newSwitch :: MonadST m => m (Switch m)
newSwitch = liftST $ ST.newSTRef False
{-# INLINE newSwitch #-}

flipSwitch :: MonadST m => Switch m -> m ()
flipSwitch = liftST . flip ST.writeSTRef True
{-# INLINE flipSwitch #-}

ifFlipped :: Switch (ST s) -> ST s a -> ST s a -> ST s a
ifFlipped switch t f = do
  p <- ST.readSTRef switch
  if p then t else f

newtype Ref s m a = Ref (ST.STRef (World m) (Value m a))

data Value m a
  = New {-# UNPACK #-} !(Write m a)
  | {-# UNPACK #-} !(Write m a) :| !(Value m a)

data Write m a = Write {-# UNPACK #-} !(Switch m) a

newRef :: MonadST m => a -> LogicT s m (Ref s m a)
newRef a = get >>= liftST . fmap Ref . newSTRef a
{-# SPECIALIZE newRef :: a -> LogicT s (ST s) (Ref s (ST s) a) #-}
{-# SPECIALIZE newRef :: a -> LogicT s IO (Ref s IO a) #-}

newSTRef :: a -> Switch m -> ST (World m) (ST.STRef (World m) (Value m a))
newSTRef a = ST.newSTRef .! New . flip Write a

infixr 9 .!
(.!) :: (b -> c) -> (a -> b) -> a -> c
f .! g = \ a -> a `seq` f (g a)

readRef :: MonadST m => Ref s m a -> LogicT s m a
readRef (Ref ref) = liftST $ readSTRef ref
{-# SPECIALIZE readRef :: Ref s (ST s) a -> LogicT s (ST s) a #-}
{-# SPECIALIZE readRef :: Ref s IO a -> LogicT s IO a #-}

readSTRef :: ST.STRef (World m) (Value m a) -> ST (World m) a
readSTRef ref = ST.readSTRef ref >>= \ value -> case value of
  Write switch a :| xs -> ifFlipped switch (backtrack xs) $ return a
  New (Write _ a) -> return a
  where
    backtrack xs@(Write switch a :| ys) =
      ifFlipped switch (backtrack ys) $
      ST.writeSTRef ref xs >> return a
    backtrack xs@(New (Write _ a)) =
      ST.writeSTRef ref xs >> return a

writeRef :: MonadST m => Ref s m a -> a -> LogicT s m ()
writeRef ref a = modifyRef'' ref $ \ switch _ -> Write switch a
{-# SPECIALIZE writeRef :: Ref s (ST s) a -> a -> LogicT s (ST s) () #-}
{-# SPECIALIZE writeRef :: Ref s IO a -> a -> LogicT s IO () #-}

modifyRef :: MonadST m => Ref s m a -> (a -> a) -> LogicT s m ()
modifyRef ref f = modifyRef'' ref $ \ switch a -> Write switch $ f a
{-# SPECIALIZE modifyRef :: Ref s (ST s) a -> (a -> a) -> LogicT s (ST s) () #-}
{-# SPECIALIZE modifyRef :: Ref s IO a -> (a -> a) -> LogicT s IO () #-}

modifyRef' :: MonadST m => Ref s m a -> (a -> a) -> LogicT s m ()
modifyRef' ref f = modifyRef'' ref $ \ switch a -> Write switch $! f a
{-# SPECIALIZE modifyRef' :: Ref s (ST s) a -> (a -> a) -> LogicT s (ST s) () #-}
{-# SPECIALIZE modifyRef' :: Ref s IO a -> (a -> a) -> LogicT s IO () #-}

modifyRef'' :: MonadST m => Ref s m a -> (Switch m -> a -> Write m a) -> LogicT s m ()
modifyRef'' (Ref ref) f = get >>= \ r -> liftST $ modifySTRef ref f r
{-# INLINE modifyRef'' #-}

modifySTRef :: ST.STRef (World m) (Value m a) ->
               (Switch m -> a -> Write m a) ->
               Switch m ->
               ST (World m) ()
modifySTRef ref f = \ r -> ST.readSTRef ref >>= \ value -> backtrack value r
  where
    backtrack xs@(Write switch a :| ys) r =
      ifFlipped switch
      (backtrack ys r)
      (ST.writeSTRef ref $! f r a :| if switch == r then ys else xs)
    backtrack xs@(New (Write switch a)) r =
      ST.writeSTRef ref $!
      if switch == r then New (f r a) else f r a :| xs
{-# INLINE modifySTRef #-}