{-# LANGUAGE MagicHash, UnboxedTuples, MultiParamTypeClasses, FlexibleInstances, UndecidableInstances #-}
-------------------------------------------------------------------------------
-- |
-- Module      : Control.Monad.Trans.IO
-- Copyright   : (c) mniip 2016
-- License     : MIT
-- Maintainer  : mniip@mniip.com
-- Stability   : none
-- Portability : non-portable
--
-- IO transformer capable of adding IO capabilities to any monad.
--
-- The resulting computations are lazy in the sense of being lazy IO.
--
-------------------------------------------------------------------------------

module Control.Monad.Trans.IO
    (
        IORet#(..),
		-- * The IOT monad transformer
        IOT(..),
        runIOT, hoistIOT,
        fromIO, sequenceIO
    )
    where

import Control.Applicative
import Control.Monad
import Control.Monad.Cont.Class
import Control.Monad.Error.Class
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Reader.Class
import Control.Monad.State.Class
import Control.Monad.Writer.Class

import GHC.Prim
import GHC.IO

-- | Datatype used to fit unlifted tuples in lifted containers.
data IORet# a = IORet# { getIORet# :: (# State# RealWorld, a #) }

instance Functor IORet# where
    fmap f (IORet# (# s#, x #)) = IORet# (# s#, f x #)

-- | An IO Transformer parameterized by @m@ - the inner monad.
data IOT m a = IOT# { getIOT# :: State# RealWorld -> m (IORet# a) }

instance Functor m => Functor (IOT m) where
    fmap f (IOT# g) = IOT# $ \s# -> fmap (fmap f) (g s#)

instance (Functor m, Monad m) => Applicative (IOT m) where
    pure x = IOT# $ \s# -> return $ IORet# (# s#, x #)
    IOT# ff <*> IOT# fx = IOT# $ \s# -> do
        IORet# (# s#, f #) <- ff s#
        IORet# (# s#, x #) <- fx s#
        return $ IORet# (# s#, f x #)

instance (Functor m, MonadPlus m) => Alternative (IOT m) where
    empty = mzero
    (<|>) = mplus

instance Monad m => Monad (IOT m) where
    return x = IOT# $ \s# -> return $ IORet# (# s#, x #)
    IOT# fx >>= ff = IOT# $ \s# -> do
        IORet# (# s#, x #) <- fx s#
        getIOT# (ff x) s#

instance MonadPlus m => MonadPlus (IOT m) where
    mzero = IOT# $ \s# -> mzero
    mplus (IOT# fa) (IOT# fb) = IOT# $ \s# -> mplus (fa s#) (fb s#)

-- Run an IO computation. Note that this might return an interleaved result.
runIOT :: Functor m => IOT m a -> IO (m a)
runIOT (IOT# f) = IO $ \s# -> (# s#, fmap (\(IORet# (# s#, x #)) -> x) $ f s# #)

-- Change the underlying monad in an IO computation. If @MagicHash@ is not in use, a natural transformation could be supplied, or the 'Functor' instance of 'IORet#' could be made use of.
hoistIOT :: (m (IORet# a) -> n (IORet# b)) -> IOT m a -> IOT n b
hoistIOT f m = IOT# $ \s# -> f (getIOT# m s#)

-- Create an IO computation from an IO of a monadic computation.
fromIO :: Functor m => IO (m a) -> IOT m a
fromIO (IO i) = IOT# $ \s# -> case i s# of (# s#, mx #) -> fmap (\x -> IORet# (# s#, x #)) mx

-- Create an IO computation from a monadic computation of IO.
sequenceIO :: Functor m => m (IO a) -> IOT m a
sequenceIO mx = IOT# $ \s# -> fmap (\(IO i) -> IORet# (i s#)) mx

fromIO' :: Monad m => IO (m a) -> IOT m a
fromIO' (IO i) = IOT# $ \s# -> case i s# of (# s#, mx #) -> liftM (\x -> IORet# (# s#, x #)) mx

instance MonadTrans IOT where
    lift i = fromIO' (return i)

instance Monad m => MonadIO (IOT m) where
    liftIO i = fromIO' (fmap return i) 

instance MonadCont m => MonadCont (IOT m) where
    callCC f = IOT# $ \s# -> callCC $ \c -> getIOT# (f $ \x -> IOT# $ \s# -> c $ IORet# (# s#, x #)) s#

instance MonadError e m => MonadError e (IOT m) where
    throwError = lift . throwError
    catchError f h = IOT# $ \s# -> getIOT# f s# `catchError` \e -> getIOT# (h e) s#

instance MonadReader r m => MonadReader r (IOT m) where
    ask = lift ask
    local = hoistIOT . local
    reader = lift . reader

instance MonadState s m => MonadState s (IOT m) where
    get = lift get
    put = lift . put
    state = lift . state

instance MonadWriter w m => MonadWriter w (IOT m) where
    writer = lift . writer
    tell = lift . tell
    listen f = IOT# $ \s# -> do
        (IORet# (# s#, x #), w) <- listen (getIOT# f s#)
        return $ IORet# (# s#, (x, w) #)
    pass f = IOT# $ \s# -> pass $ do
        IORet# (# s#, ~(x, g) #) <- getIOT# f s#
        return (IORet# (# s#, x #), g)