{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeOperators #-}
module Data.Massiv.Core.Index
( module Data.Massiv.Core.Index.Ix
, Stride
, pattern Stride
, unStride
, toLinearIndexStride
, strideStart
, strideSize
, oneStride
, Border(..)
, handleBorderIndex
, module Data.Massiv.Core.Index.Class
, zeroIndex
, isSafeSize
, isNonEmpty
, headDim
, tailDim
, lastDim
, initDim
, getIndex'
, setIndex'
, getDim'
, setDim'
, dropDim'
, pullOutDim'
, insertDim'
, fromDimension
, getDimension
, setDimension
, dropDimension
, pullOutDimension
, insertDimension
, iterLinearM
, iterLinearM_
, module Data.Massiv.Core.Iterator
) where
import Control.DeepSeq
import Data.Massiv.Core.Index.Class
import Data.Massiv.Core.Index.Ix
import Data.Massiv.Core.Index.Stride
import Data.Massiv.Core.Iterator
import GHC.TypeLits
data Border e =
Fill e
| Wrap
| Edge
| Reflect
| Continue
deriving (Eq, Show)
instance NFData e => NFData (Border e) where
rnf b = case b of
Fill e -> rnf e
Wrap -> ()
Edge -> ()
Reflect -> ()
Continue -> ()
handleBorderIndex ::
Index ix
=> Border e
-> Sz ix
-> (ix -> e)
-> ix
-> e
handleBorderIndex border !sz getVal !ix =
case border of
Fill val -> if isSafeIndex sz ix then getVal ix else val
Wrap -> getVal (repairIndex sz ix (flip mod) (flip mod))
Edge -> getVal (repairIndex sz ix (const (const 0)) (\ !k _ -> k - 1))
Reflect -> getVal (repairIndex sz ix (\ !k !i -> (abs i - 1) `mod` k)
(\ !k !i -> (-i - 1) `mod` k))
Continue -> getVal (repairIndex sz ix (\ !k !i -> abs i `mod` k)
(\ !k !i -> (-i - 2) `mod` k))
{-# INLINE [1] handleBorderIndex #-}
zeroIndex :: Index ix => ix
zeroIndex = pureIndex 0
{-# INLINE [1] zeroIndex #-}
isSafeSize :: Index ix => Sz ix -> Bool
isSafeSize = (zeroIndex >=)
{-# INLINE [1] isSafeSize #-}
isNonEmpty :: Index ix => Sz ix -> Bool
isNonEmpty !sz = isSafeIndex sz zeroIndex
{-# INLINE [1] isNonEmpty #-}
headDim :: Index ix => ix -> Int
headDim = fst . unconsDim
{-# INLINE [1] headDim #-}
tailDim :: Index ix => ix -> Lower ix
tailDim = snd . unconsDim
{-# INLINE [1] tailDim #-}
lastDim :: Index ix => ix -> Int
lastDim = snd . unsnocDim
{-# INLINE [1] lastDim #-}
initDim :: Index ix => ix -> Lower ix
initDim = fst . unsnocDim
{-# INLINE [1] initDim #-}
setDim' :: Index ix => ix -> Dim -> Int -> ix
setDim' ix dim i =
case setDim ix dim i of
Just ix' -> ix'
Nothing -> errorDim "setDim'" dim
{-# INLINE [1] setDim' #-}
getDim' :: Index ix => ix -> Dim -> Int
getDim' ix dim =
case getDim ix dim of
Just ix' -> ix'
Nothing -> errorDim "getDim'" dim
{-# INLINE [1] getDim' #-}
setIndex' :: Index ix => ix -> Dim -> Int -> ix
setIndex' ix dim i =
case setDim ix dim i of
Just ix' -> ix'
Nothing -> errorDim "setIndex'" dim
{-# INLINE [1] setIndex' #-}
{-# DEPRECATED setIndex' "In favor of `setDim'`" #-}
getIndex' :: Index ix => ix -> Dim -> Int
getIndex' ix dim =
case getDim ix dim of
Just ix' -> ix'
Nothing -> errorDim "getIndex'" dim
{-# INLINE [1] getIndex' #-}
{-# DEPRECATED getIndex' "In favor of `getDim'`" #-}
dropDim' :: Index ix => ix -> Dim -> Lower ix
dropDim' ix dim =
case dropDim ix dim of
Just ixl -> ixl
Nothing -> errorDim "dropDim'" dim
{-# INLINE [1] dropDim' #-}
pullOutDim' :: Index ix => ix -> Dim -> (Int, Lower ix)
pullOutDim' ix dim =
case pullOutDim ix dim of
Just i_ixl -> i_ixl
Nothing -> errorDim "pullOutDim'" dim
{-# INLINE [1] pullOutDim' #-}
insertDim' :: Index ix => Lower ix -> Dim -> Int -> ix
insertDim' ix dim i =
case insertDim ix dim i of
Just ix' -> ix'
Nothing -> errorDim "insertDim'" dim
{-# INLINE [1] insertDim' #-}
errorDim :: String -> Dim -> a
errorDim funName dim = error $ funName ++ ": Dimension is out of reach: " ++ show dim
{-# NOINLINE errorDim #-}
fromDimension :: KnownNat n => Dimension n -> Dim
fromDimension = fromIntegral . natVal
{-# INLINE [1] fromDimension #-}
setDimension :: IsIndexDimension ix n => ix -> Dimension n -> Int -> ix
setDimension ix d = setDim' ix (fromDimension d)
{-# INLINE [1] setDimension #-}
getDimension :: IsIndexDimension ix n => ix -> Dimension n -> Int
getDimension ix d = getDim' ix (fromDimension d)
{-# INLINE [1] getDimension #-}
dropDimension :: IsIndexDimension ix n => ix -> Dimension n -> Lower ix
dropDimension ix d = dropDim' ix (fromDimension d)
{-# INLINE [1] dropDimension #-}
pullOutDimension :: IsIndexDimension ix n => ix -> Dimension n -> (Int, Lower ix)
pullOutDimension ix d = pullOutDim' ix (fromDimension d)
{-# INLINE [1] pullOutDimension #-}
insertDimension :: IsIndexDimension ix n => Lower ix -> Dimension n -> Int -> ix
insertDimension ix d = insertDim' ix (fromDimension d)
{-# INLINE [1] insertDimension #-}
iterLinearM :: (Index ix, Monad m)
=> Sz ix
-> Int
-> Int
-> Int
-> (Int -> Int -> Bool)
-> a
-> (Int -> ix -> a -> m a)
-> m a
iterLinearM !sz !k0 !k1 !inc cond !acc f =
loopM k0 (`cond` k1) (+ inc) acc $ \ !i !acc0 -> f i (fromLinearIndex sz i) acc0
{-# INLINE iterLinearM #-}
iterLinearM_ :: (Index ix, Monad m) =>
Sz ix
-> Int
-> Int
-> Int
-> (Int -> Int -> Bool)
-> (Int -> ix -> m ())
-> m ()
iterLinearM_ !sz !k0 !k1 !inc cond f =
loopM_ k0 (`cond` k1) (+ inc) $ \ !i -> f i (fromLinearIndex sz i)
{-# INLINE iterLinearM_ #-}