module Control.Monad.Trans.IO
    (
        IORet#(..),
		
        IOT(..),
        runIOT, hoistIOT,
        fromIO, sequenceIO
    )
    where
import Control.Applicative
import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class
import GHC.Prim
import GHC.IO
data IORet# a = IORet# { getIORet# :: (# State# RealWorld, a #) }
instance Functor IORet# where
    fmap f (IORet# (# s#, x #)) = IORet# (# s#, f x #)
data IOT m a = IOT# { getIOT# :: State# RealWorld -> m (IORet# a) }
instance Functor m => Functor (IOT m) where
    fmap f (IOT# g) = IOT# $ \s# -> fmap (fmap f) (g s#)
instance (Functor m, Monad m) => Applicative (IOT m) where
    pure x = IOT# $ \s# -> return $ IORet# (# s#, x #)
    IOT# ff <*> IOT# fx = IOT# $ \s# -> do
        IORet# (# s#, f #) <- ff s#
        IORet# (# s#, x #) <- fx s#
        return $ IORet# (# s#, f x #)
instance (Functor m, MonadPlus m) => Alternative (IOT m) where
    empty = mzero
    (<|>) = mplus
instance Monad m => Monad (IOT m) where
    return x = IOT# $ \s# -> return $ IORet# (# s#, x #)
    IOT# fx >>= ff = IOT# $ \s# -> do
        IORet# (# s#, x #) <- fx s#
        getIOT# (ff x) s#
instance MonadPlus m => MonadPlus (IOT m) where
    mzero = IOT# $ \s# -> mzero
    mplus (IOT# fa) (IOT# fb) = IOT# $ \s# -> mplus (fa s#) (fb s#)
runIOT :: Functor m => IOT m a -> IO (m a)
runIOT (IOT# f) = IO $ \s# -> (# s#, fmap (\(IORet# (# s#, x #)) -> x) $ f s# #)
hoistIOT :: (m (IORet# a) -> n (IORet# b)) -> IOT m a -> IOT n b
hoistIOT f m = IOT# $ \s# -> f (getIOT# m s#)
fromIO :: Functor m => IO (m a) -> IOT m a
fromIO (IO i) = IOT# $ \s# -> case i s# of (# s#, mx #) -> fmap (\x -> IORet# (# s#, x #)) mx
sequenceIO :: Functor m => m (IO a) -> IOT m a
sequenceIO mx = IOT# $ \s# -> fmap (\(IO i) -> IORet# (i s#)) mx
fromIO' :: Monad m => IO (m a) -> IOT m a
fromIO' (IO i) = IOT# $ \s# -> case i s# of (# s#, mx #) -> liftM (\x -> IORet# (# s#, x #)) mx
instance MonadTrans IOT where
    lift i = fromIO' (return i)
instance Monad m => MonadIO (IOT m) where
    liftIO i = fromIO' (fmap return i) 
instance MonadCont m => MonadCont (IOT m) where
    callCC f = IOT# $ \s# -> callCC $ \c -> getIOT# (f $ \x -> IOT# $ \s# -> c $ IORet# (# s#, x #)) s#
instance MonadError e m => MonadError e (IOT m) where
    throwError = lift . throwError
    catchError f h = IOT# $ \s# -> getIOT# f s# `catchError` \e -> getIOT# (h e) s#
instance MonadReader r m => MonadReader r (IOT m) where
    ask = lift ask
    local = hoistIOT . local
    reader = lift . reader
instance MonadState s m => MonadState s (IOT m) where
    get = lift get
    put = lift . put
    state = lift . state
instance MonadWriter w m => MonadWriter w (IOT m) where
    writer = lift . writer
    tell = lift . tell
    listen f = IOT# $ \s# -> do
        (IORet# (# s#, x #), w) <- listen (getIOT# f s#)
        return $ IORet# (# s#, (x, w) #)
    pass f = IOT# $ \s# -> pass $ do
        IORet# (# s#, ~(x, g) #) <- getIOT# f s#
        return (IORet# (# s#, x #), g)