{-# LANGUAGE BangPatterns #-}
module Data.Massiv.Core.Iterator
( loop
, loopM
, loopM_
, loopDeepM
, splitLinearly
, splitLinearlyWith_
, splitLinearlyWithM_
) where
loop :: Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> a) -> a
loop !init' condition increment !initAcc f = go init' initAcc
where
go !step !acc =
case condition step of
False -> acc
True -> go (increment step) (f step acc)
{-# INLINE loop #-}
loopM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
loopM !init' condition increment !initAcc f = go init' initAcc
where
go !step !acc =
case condition step of
False -> return acc
True -> f step acc >>= go (increment step)
{-# INLINE loopM #-}
loopM_ :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> (Int -> m a) -> m ()
loopM_ !init' condition increment f = go init'
where
go !step =
case condition step of
False -> return ()
True -> f step >> go (increment step)
{-# INLINE loopM_ #-}
loopDeepM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
loopDeepM !init' condition increment !initAcc f = go init' initAcc
where
go !step !acc =
case condition step of
False -> return acc
True -> go (increment step) acc >>= f step
{-# INLINE loopDeepM #-}
splitLinearly :: Int -> Int -> (Int -> Int -> a) -> a
splitLinearly numChunks totalLength action = action chunkLength slackStart
where
!chunkLength = totalLength `quot` numChunks
!slackStart = chunkLength * numChunks
{-# INLINE splitLinearly #-}
splitLinearlyWith_ :: Monad m => Int -> (m () -> m a) -> Int -> (Int -> b) -> (Int -> b -> m ()) -> m a
splitLinearlyWith_ numChunks with totalLength index =
splitLinearlyWithM_ numChunks with totalLength (pure . index)
{-# INLINE splitLinearlyWith_ #-}
splitLinearlyWithM_ ::
Monad m => Int -> (m () -> m a) -> Int -> (Int -> m b) -> (Int -> b -> m c) -> m a
splitLinearlyWithM_ numChunks with totalLength make write =
splitLinearly numChunks totalLength $ \chunkLength slackStart -> do
loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
with $ loopM_ start (< (start + chunkLength)) (+ 1) $ \ !k -> make k >>= write k
with $ loopM_ slackStart (< totalLength) (+ 1) $ \ !k -> make k >>= write k
{-# INLINE splitLinearlyWithM_ #-}