{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.Massiv.Array.Ops.Map
( map
, imap
, traverseA
, itraverseA
, traverseAR
, itraverseAR
, mapM
, mapMR
, forM
, forMR
, imapM
, imapMR
, iforM
, iforMR
, mapM_
, forM_
, imapM_
, iforM_
, mapIO
, mapIO_
, imapIO
, imapIO_
, forIO
, forIO_
, iforIO
, iforIO_
, mapP_
, imapP_
, zip
, zip3
, unzip
, unzip3
, zipWith
, zipWith3
, izipWith
, izipWith3
, liftArray2
) where
import Control.Monad (void, when)
import Control.Monad.ST (runST)
import Data.Foldable (foldlM)
import Data.Massiv.Array.Delayed.Internal
import Data.Massiv.Array.Mutable
import Data.Massiv.Array.Ops.Fold.Internal (foldrFB)
import Data.Massiv.Core.Common
import Data.Massiv.Core.Scheduler
import Data.Monoid ((<>))
import GHC.Base (build)
import Prelude hiding (map, mapM, mapM_,
traverse, unzip, unzip3,
zip, zip3, zipWith,
zipWith3)
import qualified Prelude as Prelude (traverse)
map :: Source r ix e' => (e' -> e) -> Array r ix e' -> Array D ix e
map f = imap (const f)
{-# INLINE map #-}
imap :: Source r ix e' => (ix -> e' -> e) -> Array r ix e' -> Array D ix e
imap f !arr = DArray (getComp arr) (size arr) (\ !ix -> f ix (unsafeIndex arr ix))
{-# INLINE imap #-}
zip :: (Source r1 ix e1, Source r2 ix e2)
=> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix (e1, e2)
zip = zipWith (,)
{-# INLINE zip #-}
zip3 :: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> Array r1 ix e1 -> Array r2 ix e2 -> Array r3 ix e3 -> Array D ix (e1, e2, e3)
zip3 = zipWith3 (,,)
{-# INLINE zip3 #-}
unzip :: Source r ix (e1, e2) => Array r ix (e1, e2) -> (Array D ix e1, Array D ix e2)
unzip arr = (map fst arr, map snd arr)
{-# INLINE unzip #-}
unzip3 :: Source r ix (e1, e2, e3)
=> Array r ix (e1, e2, e3) -> (Array D ix e1, Array D ix e2, Array D ix e3)
unzip3 arr = (map (\ (e, _, _) -> e) arr, map (\ (_, e, _) -> e) arr, map (\ (_, _, e) -> e) arr)
{-# INLINE unzip3 #-}
zipWith :: (Source r1 ix e1, Source r2 ix e2)
=> (e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
zipWith f = izipWith (\ _ e1 e2 -> f e1 e2)
{-# INLINE zipWith #-}
izipWith :: (Source r1 ix e1, Source r2 ix e2)
=> (ix -> e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
izipWith f arr1 arr2 =
DArray (getComp arr1 <> getComp arr2) (liftIndex2 min (size arr1) (size arr2)) $ \ !ix ->
f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix)
{-# INLINE izipWith #-}
zipWith3 :: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> (e1 -> e2 -> e3 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array r3 ix e3 -> Array D ix e
zipWith3 f = izipWith3 (\ _ e1 e2 e3 -> f e1 e2 e3)
{-# INLINE zipWith3 #-}
izipWith3
:: (Source r1 ix e1, Source r2 ix e2, Source r3 ix e3)
=> (ix -> e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
izipWith3 f arr1 arr2 arr3 =
DArray
(getComp arr1 <> getComp arr2 <> getComp arr3)
(liftIndex2 min (liftIndex2 min (size arr1) (size arr2)) (size arr3)) $ \ !ix ->
f ix (unsafeIndex arr1 ix) (unsafeIndex arr2 ix) (unsafeIndex arr3 ix)
{-# INLINE izipWith3 #-}
traverseA ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> (a -> f b)
-> Array r' ix a
-> f (Array r ix b)
traverseA f arr = loadList <$> Prelude.traverse f (build (\c n -> foldrFB c n arr))
where
loadList xs =
runST $ do
marr <- unsafeNew (size arr)
_ <- foldlM (\i e -> unsafeLinearWrite marr i e >> return (i + 1)) 0 xs
unsafeFreeze (getComp arr) marr
{-# INLINE loadList #-}
{-# INLINE traverseA #-}
itraverseA ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> (ix -> a -> f b)
-> Array r' ix a
-> f (Array r ix b)
itraverseA f arr =
fmap loadList $ Prelude.traverse (uncurry f) $ build (\c n -> foldrFB c n (zipWithIndex arr))
where
loadList xs =
runST $ do
marr <- unsafeNew (size arr)
_ <- foldlM (\i e -> unsafeLinearWrite marr i e >> return (i + 1)) 0 xs
unsafeFreeze (getComp arr) marr
{-# INLINE loadList #-}
{-# INLINE itraverseA #-}
traverseAR ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> r
-> (a -> f b)
-> Array r' ix a
-> f (Array r ix b)
traverseAR _ = traverseA
{-# INLINE traverseAR #-}
itraverseAR ::
(Source r' ix a, Mutable r ix b, Applicative f)
=> r
-> (ix -> a -> f b)
-> Array r' ix a
-> f (Array r ix b)
itraverseAR _ = itraverseA
{-# INLINE itraverseAR #-}
zipWithIndex :: forall r ix e . Source r ix e => Array r ix e -> Array D ix (ix, e)
zipWithIndex arr = zip (makeArray mempty (size arr) id :: Array D ix ix) arr
{-# INLINE zipWithIndex #-}
mapM ::
(Source r' ix a, Mutable r ix b, Monad m)
=> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapM = traverseA
{-# INLINE mapM #-}
mapMR ::
(Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (a -> m b)
-> Array r' ix a
-> m (Array r ix b)
mapMR _ = traverseA
{-# INLINE mapMR #-}
forM ::
(Source r' ix a, Mutable r ix b, Monad m)
=> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forM = flip traverseA
{-# INLINE forM #-}
forMR ::
(Source r' ix a, Mutable r ix b, Monad m)
=> r
-> Array r' ix a
-> (a -> m b)
-> m (Array r ix b)
forMR _ = flip traverseA
{-# INLINE forMR #-}
imapM ::
(Source r' ix a, Mutable r ix b, Monad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapM = itraverseA
{-# INLINE imapM #-}
imapMR ::
(Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
imapMR _ = itraverseA
{-# INLINE imapMR #-}
iforM ::
(Source r' ix a, Mutable r ix b, Monad m)
=> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
iforM = itraverseA
{-# INLINE iforM #-}
iforMR ::
(Source r' ix a, Mutable r ix b, Monad m)
=> r
-> (ix -> a -> m b)
-> Array r' ix a
-> m (Array r ix b)
iforMR _ = itraverseA
{-# INLINE iforMR #-}
mapM_ :: (Source r ix a, Monad m) => (a -> m b) -> Array r ix a -> m ()
mapM_ f !arr = iterM_ zeroIndex (size arr) (pureIndex 1) (<) (f . unsafeIndex arr)
{-# INLINE mapM_ #-}
forM_ :: (Source r ix a, Monad m) => Array r ix a -> (a -> m b) -> m ()
forM_ = flip mapM_
{-# INLINE forM_ #-}
iforM_ :: (Source r ix a, Monad m) => Array r ix a -> (ix -> a -> m b) -> m ()
iforM_ = flip imapM_
{-# INLINE iforM_ #-}
mapIO ::
(Source r' ix a, Mutable r ix b) => (a -> IO b) -> Array r' ix a -> IO (Array r ix b)
mapIO action = imapIO (const action)
{-# INLINE mapIO #-}
mapIO_ :: Source r b e => (e -> IO a) -> Array r b e -> IO ()
mapIO_ action = imapIO_ (const action)
{-# INLINE mapIO_ #-}
imapIO_ :: Source r ix e => (ix -> e -> IO a) -> Array r ix e -> IO ()
imapIO_ action arr =
case getComp arr of
Seq -> imapM_ action arr
ParOn wids -> do
let sz = size arr
withScheduler_ wids $ \scheduler ->
splitLinearlyWith_
(numWorkers scheduler)
(scheduleWork scheduler)
(totalElem sz)
(unsafeLinearIndex arr)
(\i -> void . action (fromLinearIndex sz i))
{-# INLINE imapIO_ #-}
imapIO ::
(Source r' ix a, Mutable r ix b) => (ix -> a -> IO b) -> Array r' ix a -> IO (Array r ix b)
imapIO action arr = generateArrayIO (getComp arr) (size arr) $ \ix -> action ix (unsafeIndex arr ix)
{-# INLINE imapIO #-}
forIO ::
(Source r' ix a, Mutable r ix b) => Array r' ix a -> (a -> IO b) -> IO (Array r ix b)
forIO = flip mapIO
{-# INLINE forIO #-}
forIO_ :: Source r ix e => Array r ix e -> (e -> IO a) -> IO ()
forIO_ = flip mapIO_
{-# INLINE forIO_ #-}
iforIO ::
(Source r' ix a, Mutable r ix b) => Array r' ix a -> (ix -> a -> IO b) -> IO (Array r ix b)
iforIO = flip imapIO
{-# INLINE iforIO #-}
iforIO_ :: Source r ix a => Array r ix a -> (ix -> a -> IO b) -> IO ()
iforIO_ = flip imapIO_
{-# INLINE iforIO_ #-}
mapP_ :: Source r ix a => (a -> IO b) -> Array r ix a -> IO ()
mapP_ f = imapP_ (const f)
{-# INLINE mapP_ #-}
{-# DEPRECATED mapP_ "In favor of 'mapIO_'" #-}
imapP_ :: Source r ix a => (ix -> a -> IO b) -> Array r ix a -> IO ()
imapP_ f arr = do
let sz = size arr
wIds =
case getComp arr of
ParOn ids -> ids
_ -> []
divideWork_ wIds sz $ \ !scheduler !chunkLength !totalLength !slackStart -> do
loopM_ 0 (< slackStart) (+ chunkLength) $ \ !start ->
scheduleWork scheduler $
iterLinearM_ sz start (start + chunkLength) 1 (<) $ \ !i ix -> do
void $ f ix (unsafeLinearIndex arr i)
when (slackStart < totalLength) $
scheduleWork scheduler $
iterLinearM_ sz slackStart totalLength 1 (<) $ \ !i ix -> do
void $ f ix (unsafeLinearIndex arr i)
{-# INLINE imapP_ #-}
{-# DEPRECATED imapP_ "In favor of 'imapIO_'" #-}