-- |
-- Module      : Streamly.Internal.Data.Stream.StreamD.Lift
-- Copyright   : (c) 2018 Composewell Technologies
-- License     : BSD-3-Clause
-- Maintainer  : streamly@composewell.com
-- Stability   : experimental
-- Portability : GHC
--
-- Transform the underlying monad of a stream.

module Streamly.Internal.Data.Stream.StreamD.Lift
    (
    -- * Generalize Inner Monad
      hoist
    , generally -- XXX generalize

    -- * Transform Inner Monad
    , liftInner
    , runReaderT
    , evalStateT
    , runStateT
    )
where

#include "inline.hs"

import Control.Monad.Trans.Class (MonadTrans(lift))
import Control.Monad.Trans.Reader (ReaderT)
import Control.Monad.Trans.State.Strict (StateT)
import Data.Functor.Identity (Identity(..))
import Streamly.Internal.Data.SVar (adaptState)

import qualified Control.Monad.Trans.Reader as Reader
import qualified Control.Monad.Trans.State.Strict as State

import Streamly.Internal.Data.Stream.StreamD.Type

-------------------------------------------------------------------------------
-- Generalize Inner Monad
-------------------------------------------------------------------------------

{-# INLINE_NORMAL hoist #-}
hoist :: Monad n => (forall x. m x -> n x) -> Stream m a -> Stream n a
hoist :: (forall x. m x -> n x) -> Stream m a -> Stream n a
hoist forall x. m x -> n x
f (Stream State Stream m a -> s -> m (Step s a)
step s
state) = ((State Stream n a -> s -> n (Step s a)) -> s -> Stream n a
forall (m :: * -> *) a s.
(State Stream m a -> s -> m (Step s a)) -> s -> Stream m a
Stream State Stream n a -> s -> n (Step s a)
forall (m :: * -> *) a. State Stream m a -> s -> n (Step s a)
step' s
state)
    where
    {-# INLINE_LATE step' #-}
    step' :: State Stream m a -> s -> n (Step s a)
step' State Stream m a
gst s
st = do
        Step s a
r <- m (Step s a) -> n (Step s a)
forall x. m x -> n x
f (m (Step s a) -> n (Step s a)) -> m (Step s a) -> n (Step s a)
forall a b. (a -> b) -> a -> b
$ State Stream m a -> s -> m (Step s a)
step (State Stream m a -> State Stream m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a (n :: * -> *) b.
State t m a -> State t n b
adaptState State Stream m a
gst) s
st
        Step s a -> n (Step s a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Step s a -> n (Step s a)) -> Step s a -> n (Step s a)
forall a b. (a -> b) -> a -> b
$ case Step s a
r of
            Yield a
x s
s -> a -> s -> Step s a
forall s a. a -> s -> Step s a
Yield a
x s
s
            Skip  s
s   -> s -> Step s a
forall s a. s -> Step s a
Skip s
s
            Step s a
Stop      -> Step s a
forall s a. Step s a
Stop

{-# INLINE generally #-}
generally :: Monad m => Stream Identity a -> Stream m a
generally :: Stream Identity a -> Stream m a
generally = (forall x. Identity x -> m x) -> Stream Identity a -> Stream m a
forall (n :: * -> *) (m :: * -> *) a.
Monad n =>
(forall x. m x -> n x) -> Stream m a -> Stream n a
hoist (x -> m x
forall (m :: * -> *) a. Monad m => a -> m a
return (x -> m x) -> (Identity x -> x) -> Identity x -> m x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity x -> x
forall a. Identity a -> a
runIdentity)

-------------------------------------------------------------------------------
-- Transform Inner Monad
-------------------------------------------------------------------------------

{-# INLINE_NORMAL liftInner #-}
liftInner :: (Monad m, MonadTrans t, Monad (t m))
    => Stream m a -> Stream (t m) a
liftInner :: Stream m a -> Stream (t m) a
liftInner (Stream State Stream m a -> s -> m (Step s a)
step s
state) = (State Stream (t m) a -> s -> t m (Step s a))
-> s -> Stream (t m) a
forall (m :: * -> *) a s.
(State Stream m a -> s -> m (Step s a)) -> s -> Stream m a
Stream State Stream (t m) a -> s -> t m (Step s a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad (t m)) =>
State Stream m a -> s -> t m (Step s a)
step' s
state
    where
    {-# INLINE_LATE step' #-}
    step' :: State Stream m a -> s -> t m (Step s a)
step' State Stream m a
gst s
st = do
        Step s a
r <- m (Step s a) -> t m (Step s a)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Step s a) -> t m (Step s a)) -> m (Step s a) -> t m (Step s a)
forall a b. (a -> b) -> a -> b
$ State Stream m a -> s -> m (Step s a)
step (State Stream m a -> State Stream m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a (n :: * -> *) b.
State t m a -> State t n b
adaptState State Stream m a
gst) s
st
        Step s a -> t m (Step s a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Step s a -> t m (Step s a)) -> Step s a -> t m (Step s a)
forall a b. (a -> b) -> a -> b
$ case Step s a
r of
            Yield a
x s
s -> a -> s -> Step s a
forall s a. a -> s -> Step s a
Yield a
x s
s
            Skip s
s    -> s -> Step s a
forall s a. s -> Step s a
Skip s
s
            Step s a
Stop      -> Step s a
forall s a. Step s a
Stop

{-# INLINE_NORMAL runReaderT #-}
runReaderT :: Monad m => m s -> Stream (ReaderT s m) a -> Stream m a
runReaderT :: m s -> Stream (ReaderT s m) a -> Stream m a
runReaderT m s
env (Stream State Stream (ReaderT s m) a -> s -> ReaderT s m (Step s a)
step s
state) = (State Stream m a -> (s, m s) -> m (Step (s, m s) a))
-> (s, m s) -> Stream m a
forall (m :: * -> *) a s.
(State Stream m a -> s -> m (Step s a)) -> s -> Stream m a
Stream State Stream m a -> (s, m s) -> m (Step (s, m s) a)
forall (m :: * -> *) (m :: * -> *) a.
Monad m =>
State Stream m a -> (s, m s) -> m (Step (s, m s) a)
step' (s
state, m s
env)
    where
    {-# INLINE_LATE step' #-}
    step' :: State Stream m a -> (s, m s) -> m (Step (s, m s) a)
step' State Stream m a
gst (s
st, m s
action) = do
        s
sv <- m s
action
        Step s a
r <- ReaderT s m (Step s a) -> s -> m (Step s a)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
Reader.runReaderT (State Stream (ReaderT s m) a -> s -> ReaderT s m (Step s a)
step (State Stream m a -> State Stream (ReaderT s m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a (n :: * -> *) b.
State t m a -> State t n b
adaptState State Stream m a
gst) s
st) s
sv
        Step (s, m s) a -> m (Step (s, m s) a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Step (s, m s) a -> m (Step (s, m s) a))
-> Step (s, m s) a -> m (Step (s, m s) a)
forall a b. (a -> b) -> a -> b
$ case Step s a
r of
            Yield a
x s
s -> a -> (s, m s) -> Step (s, m s) a
forall s a. a -> s -> Step s a
Yield a
x (s
s, s -> m s
forall (m :: * -> *) a. Monad m => a -> m a
return s
sv)
            Skip  s
s   -> (s, m s) -> Step (s, m s) a
forall s a. s -> Step s a
Skip (s
s, s -> m s
forall (m :: * -> *) a. Monad m => a -> m a
return s
sv)
            Step s a
Stop      -> Step (s, m s) a
forall s a. Step s a
Stop

{-# INLINE_NORMAL evalStateT #-}
evalStateT :: Monad m => m s -> Stream (StateT s m) a -> Stream m a
evalStateT :: m s -> Stream (StateT s m) a -> Stream m a
evalStateT m s
initial (Stream State Stream (StateT s m) a -> s -> StateT s m (Step s a)
step s
state) = (State Stream m a -> (s, m s) -> m (Step (s, m s) a))
-> (s, m s) -> Stream m a
forall (m :: * -> *) a s.
(State Stream m a -> s -> m (Step s a)) -> s -> Stream m a
Stream State Stream m a -> (s, m s) -> m (Step (s, m s) a)
forall (m :: * -> *) (m :: * -> *) a.
Monad m =>
State Stream m a -> (s, m s) -> m (Step (s, m s) a)
step' (s
state, m s
initial)
    where
    {-# INLINE_LATE step' #-}
    step' :: State Stream m a -> (s, m s) -> m (Step (s, m s) a)
step' State Stream m a
gst (s
st, m s
action) = do
        s
sv <- m s
action
        (Step s a
r, !s
sv') <- StateT s m (Step s a) -> s -> m (Step s a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
State.runStateT (State Stream (StateT s m) a -> s -> StateT s m (Step s a)
step (State Stream m a -> State Stream (StateT s m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a (n :: * -> *) b.
State t m a -> State t n b
adaptState State Stream m a
gst) s
st) s
sv
        Step (s, m s) a -> m (Step (s, m s) a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Step (s, m s) a -> m (Step (s, m s) a))
-> Step (s, m s) a -> m (Step (s, m s) a)
forall a b. (a -> b) -> a -> b
$ case Step s a
r of
            Yield a
x s
s -> a -> (s, m s) -> Step (s, m s) a
forall s a. a -> s -> Step s a
Yield a
x (s
s, s -> m s
forall (m :: * -> *) a. Monad m => a -> m a
return s
sv')
            Skip  s
s   -> (s, m s) -> Step (s, m s) a
forall s a. s -> Step s a
Skip (s
s, s -> m s
forall (m :: * -> *) a. Monad m => a -> m a
return s
sv')
            Step s a
Stop      -> Step (s, m s) a
forall s a. Step s a
Stop

{-# INLINE_NORMAL runStateT #-}
runStateT :: Monad m => m s -> Stream (StateT s m) a -> Stream m (s, a)
runStateT :: m s -> Stream (StateT s m) a -> Stream m (s, a)
runStateT m s
initial (Stream State Stream (StateT s m) a -> s -> StateT s m (Step s a)
step s
state) = (State Stream m (s, a) -> (s, m s) -> m (Step (s, m s) (s, a)))
-> (s, m s) -> Stream m (s, a)
forall (m :: * -> *) a s.
(State Stream m a -> s -> m (Step s a)) -> s -> Stream m a
Stream State Stream m (s, a) -> (s, m s) -> m (Step (s, m s) (s, a))
forall (m :: * -> *) (m :: * -> *) a.
Monad m =>
State Stream m a -> (s, m s) -> m (Step (s, m s) (s, a))
step' (s
state, m s
initial)
    where
    {-# INLINE_LATE step' #-}
    step' :: State Stream m a -> (s, m s) -> m (Step (s, m s) (s, a))
step' State Stream m a
gst (s
st, m s
action) = do
        s
sv <- m s
action
        (Step s a
r, !s
sv') <- StateT s m (Step s a) -> s -> m (Step s a, s)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
State.runStateT (State Stream (StateT s m) a -> s -> StateT s m (Step s a)
step (State Stream m a -> State Stream (StateT s m) a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a (n :: * -> *) b.
State t m a -> State t n b
adaptState State Stream m a
gst) s
st) s
sv
        Step (s, m s) (s, a) -> m (Step (s, m s) (s, a))
forall (m :: * -> *) a. Monad m => a -> m a
return (Step (s, m s) (s, a) -> m (Step (s, m s) (s, a)))
-> Step (s, m s) (s, a) -> m (Step (s, m s) (s, a))
forall a b. (a -> b) -> a -> b
$ case Step s a
r of
            Yield a
x s
s -> (s, a) -> (s, m s) -> Step (s, m s) (s, a)
forall s a. a -> s -> Step s a
Yield (s
sv', a
x) (s
s, s -> m s
forall (m :: * -> *) a. Monad m => a -> m a
return s
sv')
            Skip  s
s   -> (s, m s) -> Step (s, m s) (s, a)
forall s a. s -> Step s a
Skip (s
s, s -> m s
forall (m :: * -> *) a. Monad m => a -> m a
return s
sv')
            Step s a
Stop      -> Step (s, m s) (s, a)
forall s a. Step s a
Stop