{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE RankNTypes #-}
#include "inline.hs"
module Streamly.Internal.Data.Stream.StreamD.Type
(
Step (..)
#if __GLASGOW_HASKELL__ >= 800
, Stream (Stream, UnStream)
#else
, Stream (UnStream)
, pattern Stream
#endif
, fromStreamK
, toStreamK
, fromStreamD
, map
, mapM
, yield
, yieldM
, concatMap
, concatMapM
, foldrT
, foldrM
, foldrMx
, foldr
, foldrS
, foldl'
, foldlM'
, foldlx'
, foldlMx'
, toList
, fromList
, eqBy
, cmpBy
, take
, GroupState (..)
, groupsOf
, groupsOf2
)
where
import Control.Applicative (liftA2)
import Control.Monad (when)
import Control.Monad.Catch (MonadThrow, throwM)
import Control.Monad.Trans (lift, MonadTrans)
import Data.Functor.Identity (Identity(..))
import GHC.Base (build)
import GHC.Types (SPEC(..))
import Prelude hiding (map, mapM, foldr, take, concatMap)
import Fusion.Plugin.Types (Fuse(..))
import Streamly.Internal.Data.SVar (State(..), adaptState, defState)
import Streamly.Internal.Data.Fold.Types (Fold(..), Fold2(..))
import qualified Streamly.Internal.Data.Stream.StreamK as K
{-# ANN type Step Fuse #-}
data Step s a = Yield a s | Skip s | Stop
instance Functor (Step s) where
{-# INLINE fmap #-}
fmap f (Yield x s) = Yield (f x) s
fmap _ (Skip s) = Skip s
fmap _ Stop = Stop
data Stream m a =
forall s. UnStream (State K.Stream m a -> s -> m (Step s a)) s
unShare :: Stream m a -> Stream m a
unShare (UnStream step state) = UnStream step' state
where step' gst = step (adaptState gst)
pattern Stream :: (State K.Stream m a -> s -> m (Step s a)) -> s -> Stream m a
pattern Stream step state <- (unShare -> UnStream step state)
where Stream = UnStream
#if __GLASGOW_HASKELL__ >= 802
{-# COMPLETE Stream #-}
#endif
{-# INLINE_LATE fromStreamK #-}
fromStreamK :: Monad m => K.Stream m a -> Stream m a
fromStreamK = Stream step
where
step gst m1 =
let stop = return Stop
single a = return $ Yield a K.nil
yieldk a r = return $ Yield a r
in K.foldStreamShared gst yieldk single stop m1
{-# INLINE_LATE toStreamK #-}
toStreamK :: Monad m => Stream m a -> K.Stream m a
toStreamK (Stream step state) = go state
where
go st = K.mkStream $ \gst yld _ stp ->
let go' ss = do
r <- step gst ss
case r of
Yield x s -> yld x (go s)
Skip s -> go' s
Stop -> stp
in go' st
#ifndef DISABLE_FUSION
{-# RULES "fromStreamK/toStreamK fusion"
forall s. toStreamK (fromStreamK s) = s #-}
{-# RULES "toStreamK/fromStreamK fusion"
forall s. fromStreamK (toStreamK s) = s #-}
#endif
{-# INLINE fromStreamD #-}
fromStreamD :: (K.IsStream t, Monad m) => Stream m a -> t m a
fromStreamD = K.fromStream . toStreamK
{-# INLINE_NORMAL mapM #-}
mapM :: Monad m => (a -> m b) -> Stream m a -> Stream m b
mapM f (Stream step state) = Stream step' state
where
{-# INLINE_LATE step' #-}
step' gst st = do
r <- step (adaptState gst) st
case r of
Yield x s -> f x >>= \a -> return $ Yield a s
Skip s -> return $ Skip s
Stop -> return Stop
{-# INLINE map #-}
map :: Monad m => (a -> b) -> Stream m a -> Stream m b
map f = mapM (return . f)
instance Functor m => Functor (Stream m) where
{-# INLINE fmap #-}
fmap f (Stream step state) = Stream step' state
where
{-# INLINE_LATE step' #-}
step' gst st = fmap (fmap f) (step (adaptState gst) st)
{-# INLINE_NORMAL concatMapM #-}
concatMapM :: Monad m => (a -> m (Stream m b)) -> Stream m a -> Stream m b
concatMapM f (Stream step state) = Stream step' (Left state)
where
{-# INLINE_LATE step' #-}
step' gst (Left st) = do
r <- step (adaptState gst) st
case r of
Yield a s -> do
b_stream <- f a
return $ Skip (Right (b_stream, s))
Skip s -> return $ Skip (Left s)
Stop -> return Stop
step' gst (Right (UnStream inner_step inner_st, st)) = do
r <- inner_step (adaptState gst) inner_st
case r of
Yield b inner_s ->
return $ Yield b (Right (Stream inner_step inner_s, st))
Skip inner_s ->
return $ Skip (Right (Stream inner_step inner_s, st))
Stop -> return $ Skip (Left st)
{-# INLINE concatMap #-}
concatMap :: Monad m => (a -> Stream m b) -> Stream m a -> Stream m b
concatMap f = concatMapM (return . f)
{-# INLINE_NORMAL yield #-}
yield :: Applicative m => a -> Stream m a
yield x = Stream (\_ s -> pure $ step undefined s) True
where
{-# INLINE_LATE step #-}
step _ True = Yield x False
step _ False = Stop
{-# INLINE_NORMAL concatAp #-}
concatAp :: Functor f => Stream f (a -> b) -> Stream f a -> Stream f b
concatAp (Stream stepa statea) (Stream stepb stateb) = Stream step' (Left statea)
where
{-# INLINE_LATE step' #-}
step' gst (Left st) = fmap
(\r -> case r of
Yield f s -> Skip (Right (f, s, stateb))
Skip s -> Skip (Left s)
Stop -> Stop)
(stepa (adaptState gst) st)
step' gst (Right (f, os, st)) = fmap
(\r -> case r of
Yield a s -> Yield (f a) (Right (f, os, s))
Skip s -> Skip (Right (f,os, s))
Stop -> Skip (Left os))
(stepb (adaptState gst) st)
{-# INLINE_NORMAL apSequence #-}
apSequence :: Functor f => Stream f a -> Stream f b -> Stream f b
apSequence (Stream stepa statea) (Stream stepb stateb) = Stream step (Left statea)
where
{-# INLINE_LATE step #-}
step gst (Left st) =
fmap
(\r ->
case r of
Yield _ s -> Skip (Right (s, stateb))
Skip s -> Skip (Left s)
Stop -> Stop)
(stepa (adaptState gst) st)
step gst (Right (ostate, st)) =
fmap
(\r ->
case r of
Yield b s -> Yield b (Right (ostate, s))
Skip s -> Skip (Right (ostate, s))
Stop -> Skip (Left ostate))
(stepb gst st)
instance Applicative f => Applicative (Stream f) where
{-# INLINE pure #-}
pure = yield
{-# INLINE (<*>) #-}
(<*>) = concatAp
{-# INLINE (*>) #-}
(*>) = apSequence
instance Monad m => Monad (Stream m) where
{-# INLINE return #-}
return = pure
{-# INLINE (>>=) #-}
(>>=) = flip concatMap
{-# INLINE (>>) #-}
(>>) = (*>)
instance MonadTrans Stream where
lift = yieldM
instance (MonadThrow m) => MonadThrow (Stream m) where
throwM = lift . throwM
{-# INLINE_NORMAL foldrM #-}
foldrM :: Monad m => (a -> m b -> m b) -> m b -> Stream m a -> m b
foldrM f z (Stream step state) = go SPEC state
where
{-# INLINE_LATE go #-}
go !_ st = do
r <- step defState st
case r of
Yield x s -> f x (go SPEC s)
Skip s -> go SPEC s
Stop -> z
{-# INLINE_NORMAL foldrMx #-}
foldrMx :: Monad m
=> (a -> m x -> m x) -> m x -> (m x -> m b) -> Stream m a -> m b
foldrMx fstep final convert (Stream step state) = convert $ go SPEC state
where
{-# INLINE_LATE go #-}
go !_ st = do
r <- step defState st
case r of
Yield x s -> fstep x (go SPEC s)
Skip s -> go SPEC s
Stop -> final
{-# INLINE_NORMAL foldr #-}
foldr :: Monad m => (a -> b -> b) -> b -> Stream m a -> m b
foldr f z = foldrM (\a b -> liftA2 f (return a) b) (return z)
{-# INLINE_NORMAL yieldM #-}
yieldM :: Monad m => m a -> Stream m a
yieldM m = Stream step True
where
{-# INLINE_LATE step #-}
step _ True = m >>= \x -> return $ Yield x False
step _ False = return Stop
{-# INLINE_NORMAL foldrS #-}
foldrS
:: Monad m
=> (a -> Stream m b -> Stream m b)
-> Stream m b
-> Stream m a
-> Stream m b
foldrS f final (Stream step state) = go SPEC state
where
{-# INLINE_LATE go #-}
go !_ st = do
r <- yieldM $ step defState st
case r of
Yield x s -> f x (go SPEC s)
Skip s -> go SPEC s
Stop -> final
{-# INLINE_NORMAL foldrT #-}
foldrT :: (Monad m, Monad (t m), MonadTrans t)
=> (a -> t m b -> t m b) -> t m b -> Stream m a -> t m b
foldrT f final (Stream step state) = go SPEC state
where
{-# INLINE_LATE go #-}
go !_ st = do
r <- lift $ step defState st
case r of
Yield x s -> f x (go SPEC s)
Skip s -> go SPEC s
Stop -> final
{-# INLINE_NORMAL toList #-}
toList :: Monad m => Stream m a -> m [a]
toList = foldr (:) []
{-# INLINE_LATE toListFB #-}
toListFB :: (a -> b -> b) -> b -> Stream Identity a -> b
toListFB c n (Stream step state) = go state
where
go st = case runIdentity (step defState st) of
Yield x s -> x `c` go s
Skip s -> go s
Stop -> n
{-# RULES "toList Identity" toList = toListId #-}
{-# INLINE_EARLY toListId #-}
toListId :: Stream Identity a -> Identity [a]
toListId s = Identity $ build (\c n -> toListFB c n s)
{-# INLINE_NORMAL foldlMx' #-}
foldlMx' :: Monad m => (x -> a -> m x) -> m x -> (x -> m b) -> Stream m a -> m b
foldlMx' fstep begin done (Stream step state) =
begin >>= \x -> go SPEC x state
where
{-# INLINE_LATE go #-}
go !_ acc st = acc `seq` do
r <- step defState st
case r of
Yield x s -> do
acc' <- fstep acc x
go SPEC acc' s
Skip s -> go SPEC acc s
Stop -> done acc
{-# INLINE foldlx' #-}
foldlx' :: Monad m => (x -> a -> x) -> x -> (x -> b) -> Stream m a -> m b
foldlx' fstep begin done m =
foldlMx' (\b a -> return (fstep b a)) (return begin) (return . done) m
{-# INLINE_NORMAL foldlM' #-}
foldlM' :: Monad m => (b -> a -> m b) -> b -> Stream m a -> m b
foldlM' fstep begin (Stream step state) = go SPEC begin state
where
{-# INLINE_LATE go #-}
go !_ acc st = acc `seq` do
r <- step defState st
case r of
Yield x s -> do
acc' <- fstep acc x
go SPEC acc' s
Skip s -> go SPEC acc s
Stop -> return acc
{-# INLINE foldl' #-}
foldl' :: Monad m => (b -> a -> b) -> b -> Stream m a -> m b
foldl' fstep = foldlM' (\b a -> return (fstep b a))
{-# INLINE_LATE fromList #-}
fromList :: Applicative m => [a] -> Stream m a
fromList = Stream step
where
{-# INLINE_LATE step #-}
step _ (x:xs) = pure $ Yield x xs
step _ [] = pure Stop
{-# INLINE_NORMAL eqBy #-}
eqBy :: Monad m => (a -> b -> Bool) -> Stream m a -> Stream m b -> m Bool
eqBy eq (Stream step1 t1) (Stream step2 t2) = eq_loop0 SPEC t1 t2
where
eq_loop0 !_ s1 s2 = do
r <- step1 defState s1
case r of
Yield x s1' -> eq_loop1 SPEC x s1' s2
Skip s1' -> eq_loop0 SPEC s1' s2
Stop -> eq_null s2
eq_loop1 !_ x s1 s2 = do
r <- step2 defState s2
case r of
Yield y s2'
| eq x y -> eq_loop0 SPEC s1 s2'
| otherwise -> return False
Skip s2' -> eq_loop1 SPEC x s1 s2'
Stop -> return False
eq_null s2 = do
r <- step2 defState s2
case r of
Yield _ _ -> return False
Skip s2' -> eq_null s2'
Stop -> return True
{-# INLINE_NORMAL cmpBy #-}
cmpBy
:: Monad m
=> (a -> b -> Ordering) -> Stream m a -> Stream m b -> m Ordering
cmpBy cmp (Stream step1 t1) (Stream step2 t2) = cmp_loop0 SPEC t1 t2
where
cmp_loop0 !_ s1 s2 = do
r <- step1 defState s1
case r of
Yield x s1' -> cmp_loop1 SPEC x s1' s2
Skip s1' -> cmp_loop0 SPEC s1' s2
Stop -> cmp_null s2
cmp_loop1 !_ x s1 s2 = do
r <- step2 defState s2
case r of
Yield y s2' -> case x `cmp` y of
EQ -> cmp_loop0 SPEC s1 s2'
c -> return c
Skip s2' -> cmp_loop1 SPEC x s1 s2'
Stop -> return GT
cmp_null s2 = do
r <- step2 defState s2
case r of
Yield _ _ -> return LT
Skip s2' -> cmp_null s2'
Stop -> return EQ
{-# INLINE_NORMAL take #-}
take :: Monad m => Int -> Stream m a -> Stream m a
take n (Stream step state) = n `seq` Stream step' (state, 0)
where
{-# INLINE_LATE step' #-}
step' gst (st, i) | i < n = do
r <- step gst st
return $ case r of
Yield x s -> Yield x (s, i + 1)
Skip s -> Skip (s, i)
Stop -> Stop
step' _ (_, _) = return Stop
data GroupState s fs
= GroupStart s
| GroupBuffer s fs Int
| GroupYield fs (GroupState s fs)
| GroupFinish
{-# INLINE_NORMAL groupsOf #-}
groupsOf
:: Monad m
=> Int
-> Fold m a b
-> Stream m a
-> Stream m b
groupsOf n (Fold fstep initial extract) (Stream step state) =
n `seq` Stream step' (GroupStart state)
where
{-# INLINE_LATE step' #-}
step' _ (GroupStart st) = do
when (n <= 0) $
error $ "Streamly.Internal.Data.Stream.StreamD.Type.groupsOf: the size of "
++ "groups [" ++ show n ++ "] must be a natural number"
fs <- initial
return $ Skip (GroupBuffer st fs 0)
step' gst (GroupBuffer st fs i) = do
r <- step (adaptState gst) st
case r of
Yield x s -> do
!fs' <- fstep fs x
let i' = i + 1
return $
if i' >= n
then Skip (GroupYield fs' (GroupStart s))
else Skip (GroupBuffer s fs' i')
Skip s -> return $ Skip (GroupBuffer s fs i)
Stop -> return $ Skip (GroupYield fs GroupFinish)
step' _ (GroupYield fs next) = do
r <- extract fs
return $ Yield r next
step' _ GroupFinish = return Stop
{-# INLINE_NORMAL groupsOf2 #-}
groupsOf2
:: Monad m
=> Int
-> m c
-> Fold2 m c a b
-> Stream m a
-> Stream m b
groupsOf2 n input (Fold2 fstep inject extract) (Stream step state) =
n `seq` Stream step' (GroupStart state)
where
{-# INLINE_LATE step' #-}
step' _ (GroupStart st) = do
when (n <= 0) $
error $ "Streamly.Internal.Data.Stream.StreamD.Type.groupsOf: the size of "
++ "groups [" ++ show n ++ "] must be a natural number"
fs <- input >>= inject
return $ Skip (GroupBuffer st fs 0)
step' gst (GroupBuffer st fs i) = do
r <- step (adaptState gst) st
case r of
Yield x s -> do
!fs' <- fstep fs x
let i' = i + 1
return $
if i' >= n
then Skip (GroupYield fs' (GroupStart s))
else Skip (GroupBuffer s fs' i')
Skip s -> return $ Skip (GroupBuffer s fs i)
Stop -> return $ Skip (GroupYield fs GroupFinish)
step' _ (GroupYield fs next) = do
r <- extract fs
return $ Yield r next
step' _ GroupFinish = return Stop