{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE PatternSynonyms #-}
module Data.Massiv.Core.Index
( module Data.Massiv.Core.Index.Ix
, Stride
, pattern Stride
, unStride
, toLinearIndexStride
, strideStart
, strideSize
, oneStride
, Border(..)
, handleBorderIndex
, module Data.Massiv.Core.Index.Class
, zeroIndex
, isSafeSize
, isNonEmpty
, headDim
, tailDim
, lastDim
, initDim
, getIndex'
, setIndex'
, iterLinearM
, iterLinearM_
, module Data.Massiv.Core.Iterator
) where
import Control.DeepSeq
import Data.Massiv.Core.Index.Class
import Data.Massiv.Core.Index.Ix
import Data.Massiv.Core.Index.Stride
import Data.Massiv.Core.Iterator
data Border e =
Fill e
| Wrap
| Edge
| Reflect
| Continue
deriving (Eq, Show)
instance NFData e => NFData (Border e) where
rnf b = case b of
Fill e -> rnf e
Wrap -> ()
Edge -> ()
Reflect -> ()
Continue -> ()
handleBorderIndex ::
Index ix
=> Border e
-> ix
-> (ix -> e)
-> ix
-> e
handleBorderIndex border !sz getVal !ix =
case border of
Fill val -> if isSafeIndex sz ix then getVal ix else val
Wrap -> getVal (repairIndex sz ix (flip mod) (flip mod))
Edge -> getVal (repairIndex sz ix (const (const 0)) (\ !k _ -> k - 1))
Reflect -> getVal (repairIndex sz ix (\ !k !i -> (abs i - 1) `mod` k)
(\ !k !i -> (-i - 1) `mod` k))
Continue -> getVal (repairIndex sz ix (\ !k !i -> abs i `mod` k)
(\ !k !i -> (-i - 2) `mod` k))
{-# INLINE [1] handleBorderIndex #-}
zeroIndex :: Index ix => ix
zeroIndex = pureIndex 0
{-# INLINE [1] zeroIndex #-}
isSafeSize :: Index ix => ix -> Bool
isSafeSize = (zeroIndex >=)
{-# INLINE [1] isSafeSize #-}
isNonEmpty :: Index ix => ix -> Bool
isNonEmpty !sz = isSafeIndex sz zeroIndex
{-# INLINE [1] isNonEmpty #-}
headDim :: Index ix => ix -> Int
headDim = fst . unconsDim
{-# INLINE [1] headDim #-}
tailDim :: Index ix => ix -> Lower ix
tailDim = snd . unconsDim
{-# INLINE [1] tailDim #-}
lastDim :: Index ix => ix -> Int
lastDim = snd . unsnocDim
{-# INLINE [1] lastDim #-}
initDim :: Index ix => ix -> Lower ix
initDim = fst . unsnocDim
{-# INLINE [1] initDim #-}
setIndex' :: Index ix => ix -> Dim -> Int -> ix
setIndex' ix dim i =
case setIndex ix dim i of
Just ix' -> ix'
Nothing -> error $ "setIndex': Dimension is out of reach: " ++ show dim
getIndex' :: Index ix => ix -> Dim -> Int
getIndex' ix dim =
case getIndex ix dim of
Just ix' -> ix'
Nothing -> error $ "getIndex': Dimension is out of reach: " ++ show dim
iterLinearM :: (Index ix, Monad m)
=> ix
-> Int
-> Int
-> Int
-> (Int -> Int -> Bool)
-> a
-> (Int -> ix -> a -> m a)
-> m a
iterLinearM !sz !k0 !k1 !inc cond !acc f =
loopM k0 (`cond` k1) (+ inc) acc $ \ !i !acc0 -> f i (fromLinearIndex sz i) acc0
{-# INLINE iterLinearM #-}
iterLinearM_ :: (Index ix, Monad m) =>
ix
-> Int
-> Int
-> Int
-> (Int -> Int -> Bool)
-> (Int -> ix -> m ())
-> m ()
iterLinearM_ !sz !k0 !k1 !inc cond f =
loopM_ k0 (`cond` k1) (+ inc) $ \ !i -> f i (fromLinearIndex sz i)
{-# INLINE iterLinearM_ #-}