{-# LANGUAGE BangPatterns #-}
module Data.Massiv.Core.Iterator
( loop
, loopA_
, loopM
, loopM_
, loopDeepM
, splitLinearly
, splitLinearlyWith_
, splitLinearlyWithM_
, splitLinearlyWithStartAtM_
, splitLinearlyWithStatefulM_
) where
import Control.Scheduler
loop :: Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> a) -> a
loop !init' condition increment !initAcc f = go init' initAcc
where
go !step !acc
| condition step = go (increment step) (f step acc)
| otherwise = acc
{-# INLINE loop #-}
loopM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
loopM !init' condition increment !initAcc f = go init' initAcc
where
go !step !acc
| condition step = f step acc >>= go (increment step)
| otherwise = return acc
{-# INLINE loopM #-}
loopM_ :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> (Int -> m a) -> m ()
loopM_ !init' condition increment f = go init'
where
go !step
| condition step = f step >> go (increment step)
| otherwise = pure ()
{-# INLINE loopM_ #-}
loopA_ :: Applicative f => Int -> (Int -> Bool) -> (Int -> Int) -> (Int -> f a) -> f ()
loopA_ !init' condition increment f = go init'
where
go !step
| condition step = f step *> go (increment step)
| otherwise = pure ()
{-# INLINE loopA_ #-}
loopDeepM :: Monad m => Int -> (Int -> Bool) -> (Int -> Int) -> a -> (Int -> a -> m a) -> m a
loopDeepM !init' condition increment !initAcc f = go init' initAcc
where
go !step !acc
| condition step = go (increment step) acc >>= f step
| otherwise = return acc
{-# INLINE loopDeepM #-}
splitLinearly :: Int
-> Int
-> (Int -> Int -> a)
-> a
splitLinearly numChunks totalLength action = action chunkLength slackStart
where
!chunkLength = totalLength `quot` numChunks
!slackStart = chunkLength * numChunks
{-# INLINE splitLinearly #-}
splitLinearlyWith_ ::
Monad m => Scheduler m () -> Int -> (Int -> b) -> (Int -> b -> m ()) -> m ()
splitLinearlyWith_ scheduler totalLength index =
splitLinearlyWithM_ scheduler totalLength (pure . index)
{-# INLINE splitLinearlyWith_ #-}
splitLinearlyWithM_ ::
Monad m => Scheduler m () -> Int -> (Int -> m b) -> (Int -> b -> m c) -> m ()
splitLinearlyWithM_ scheduler totalLength make write =
splitLinearly (numWorkers scheduler) totalLength $ \chunkLength slackStart -> do
loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
scheduleWork_ scheduler $
loopM_ start (< (start + chunkLength)) (+ 1) $ \ !k -> make k >>= write k
scheduleWork_ scheduler $ loopM_ slackStart (< totalLength) (+ 1) $ \ !k -> make k >>= write k
{-# INLINE splitLinearlyWithM_ #-}
splitLinearlyWithStartAtM_ ::
Monad m => Scheduler m () -> Int -> Int -> (Int -> m b) -> (Int -> b -> m c) -> m ()
splitLinearlyWithStartAtM_ scheduler startAt totalLength make write =
splitLinearly (numWorkers scheduler) totalLength $ \chunkLength slackStart -> do
loopM_ startAt (< (slackStart + startAt)) (+ chunkLength) $ \ !start ->
scheduleWork_ scheduler $
loopM_ start (< (start + chunkLength)) (+ 1) $ \ !k -> make k >>= write k
scheduleWork_ scheduler $
loopM_ (slackStart + startAt) (< (totalLength + startAt)) (+ 1) $ \ !k -> make k >>= write k
{-# INLINE splitLinearlyWithStartAtM_ #-}
splitLinearlyWithStatefulM_ ::
Monad m
=> SchedulerWS s m ()
-> Int
-> (Int -> s -> m b)
-> (Int -> b -> m c)
-> m ()
splitLinearlyWithStatefulM_ schedulerWS totalLength make store =
let nWorkers = numWorkers (unwrapSchedulerWS schedulerWS)
in splitLinearly nWorkers totalLength $ \chunkLength slackStart -> do
loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
scheduleWorkState_ schedulerWS $ \s ->
loopM_ start (< (start + chunkLength)) (+ 1) $ \ !k ->
make k s >>= store k
scheduleWorkState_ schedulerWS $ \s ->
loopM_ slackStart (< totalLength) (+ 1) $ \ !k ->
make k s >>= store k
{-# INLINE splitLinearlyWithStatefulM_ #-}