module Control.Monad.ST.Trans.Internal where
import GHC.Base
import GHC.ST hiding (liftST)
import Control.Monad.Fix
import Control.Monad.Trans
import Control.Monad.Error.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class
import Control.Applicative
import Data.Array.ST
import Data.Array.Base
import GHC.Int (Int8, Int16, Int32, Int64)
import GHC.Word (Word8, Word16, Word32, Word64)
import GHC.Ptr (Ptr, FunPtr)
import GHC.Stable (StablePtr)
newtype STT s m a = STT (State# s -> m (STTRet s a))
unSTT :: STT s m a -> (State# s -> m (STTRet s a))
unSTT (STT f) = f
data STTRet s a = STTRet (State# s) a
liftST :: Applicative m => ST s a -> STT s m a
liftST (ST f) = STT (\s -> let (# s', a #) = f s in pure (STTRet s' a))
instance Monad m => Monad (STT s m) where
return a = STT $ \st -> return (STTRet st a)
STT m >>= k = STT $ \st ->
do ret <- m st
case ret of
STTRet new_st a ->
unSTT (k a) new_st
fail msg = lift (fail msg)
instance MonadTrans (STT s) where
lift m = STT $ \st ->
do a <- m
return (STTRet st a)
liftSTT :: STT s m a -> State# s -> m (STTRet s a)
liftSTT (STT m) s = m s
instance (MonadFix m) => MonadFix (STT s m) where
mfix k = STT $ \ s -> mdo
ans@(STTRet _ r) <- liftSTT (k r) s
return ans
instance Functor (STTRet s) where
fmap f (STTRet s a) = STTRet s (f a)
instance Functor m => Functor (STT s m) where
fmap f (STT g) = STT $ \s# -> (fmap . fmap) f (g s#)
instance (Monad m, Functor m) => Applicative (STT s m) where
pure a = STT $ \s# -> return (STTRet s# a)
(STT m) <*> (STT n) = STT $ \s1 ->
do (STTRet s2 f) <- m s1
(STTRet s3 x) <- n s2
return (STTRet s3 (f x))
instance MonadError e m => MonadError e (STT s m) where
throwError e = lift (throwError e)
catchError (STT m) f = STT $ \st -> catchError (m st)
(\e -> unSTT (f e) st)
instance MonadReader r m => MonadReader r (STT s m) where
ask = lift ask
local f (STT m) = STT $ \st -> local f (m st)
instance MonadState s m => MonadState s (STT s' m) where
get = lift get
put s = lift (put s)
instance MonadWriter w m => MonadWriter w (STT s m) where
tell w = lift (tell w)
listen (STT m)= STT $ \st1 -> do (STTRet st2 a, w) <- listen (m st1)
return (STTRet st2 (a,w))
pass (STT m) = STT $ \st1 -> pass (do (STTRet st2 (a,f)) <- m st1
return (STTRet st2 a, f))
instance Monad m => MArray (STArray s) e (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Bool (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Char (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Int (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Word (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) (Ptr a) (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) (FunPtr a) (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Float (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Double (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) (StablePtr a) (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Int8 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Int16 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Int32 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Int64 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Word8 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Word16 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Word32 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)
instance Monad m => MArray (STUArray s) Word64 (STT s m) where
getBounds arr = liftST (getBounds arr)
getNumElements arr = liftST (getNumElements arr)
newArray bounds e = liftST (newArray bounds e)
unsafeRead arr i = liftST (unsafeRead arr i)
unsafeWrite arr i e = liftST (unsafeWrite arr i e)