{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}
module Data.Massiv.Core.Index.Stride (
Stride (SafeStride),
pattern Stride,
unStride,
oneStride,
toLinearIndexStride,
strideStart,
strideSize,
) where
import Control.DeepSeq (NFData)
import Data.Massiv.Core.Index.Internal
import System.Random.Stateful (Random, Uniform (..), UniformRange (..))
newtype Stride ix = SafeStride ix deriving (Stride ix -> Stride ix -> Bool
forall ix. Eq ix => Stride ix -> Stride ix -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Stride ix -> Stride ix -> Bool
$c/= :: forall ix. Eq ix => Stride ix -> Stride ix -> Bool
== :: Stride ix -> Stride ix -> Bool
$c== :: forall ix. Eq ix => Stride ix -> Stride ix -> Bool
Eq, Stride ix -> Stride ix -> Bool
Stride ix -> Stride ix -> Ordering
Stride ix -> Stride ix -> Stride ix
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {ix}. Ord ix => Eq (Stride ix)
forall ix. Ord ix => Stride ix -> Stride ix -> Bool
forall ix. Ord ix => Stride ix -> Stride ix -> Ordering
forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
min :: Stride ix -> Stride ix -> Stride ix
$cmin :: forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
max :: Stride ix -> Stride ix -> Stride ix
$cmax :: forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
>= :: Stride ix -> Stride ix -> Bool
$c>= :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
> :: Stride ix -> Stride ix -> Bool
$c> :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
<= :: Stride ix -> Stride ix -> Bool
$c<= :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
< :: Stride ix -> Stride ix -> Bool
$c< :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
compare :: Stride ix -> Stride ix -> Ordering
$ccompare :: forall ix. Ord ix => Stride ix -> Stride ix -> Ordering
Ord, Stride ix -> ()
forall ix. NFData ix => Stride ix -> ()
forall a. (a -> ()) -> NFData a
rnf :: Stride ix -> ()
$crnf :: forall ix. NFData ix => Stride ix -> ()
NFData)
pattern Stride :: Index ix => ix -> Stride ix
pattern $bStride :: forall ix. Index ix => ix -> Stride ix
$mStride :: forall {r} {ix}.
Index ix =>
Stride ix -> (ix -> r) -> ((# #) -> r) -> r
Stride ix <- SafeStride ix
where
Stride ix
ix = forall ix. ix -> Stride ix
SafeStride (forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (forall a. Ord a => a -> a -> a
max Int
1) ix
ix)
{-# COMPLETE Stride #-}
instance Index ix => Show (Stride ix) where
showsPrec :: Int -> Stride ix -> ShowS
showsPrec Int
n (SafeStride ix
ix) = Int -> ShowS -> ShowS
showsPrecWrapped Int
n ((String
"Stride " forall a. [a] -> [a] -> [a]
++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
1 ix
ix)
instance (UniformRange ix, Index ix) => Uniform (Stride ix) where
uniformM :: forall g (m :: * -> *). StatefulGen g m => g -> m (Stride ix)
uniformM g
g = forall ix. ix -> Stride ix
SafeStride forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (forall ix. Index ix => Int -> ix
pureIndex Int
1, forall ix. Index ix => Int -> ix
pureIndex forall a. Bounded a => a
maxBound) g
g
{-# INLINE uniformM #-}
instance UniformRange ix => UniformRange (Stride ix) where
uniformRM :: forall g (m :: * -> *).
StatefulGen g m =>
(Stride ix, Stride ix) -> g -> m (Stride ix)
uniformRM (SafeStride ix
l, SafeStride ix
u) g
g = forall ix. ix -> Stride ix
SafeStride forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (ix
l, ix
u) g
g
{-# INLINE uniformRM #-}
instance (UniformRange ix, Index ix) => Random (Stride ix)
unStride :: Stride ix -> ix
unStride :: forall ix. Stride ix -> ix
unStride (SafeStride ix
ix) = ix
ix
{-# INLINE unStride #-}
strideStart :: Index ix => Stride ix -> ix -> ix
strideStart :: forall ix. Index ix => Stride ix -> ix -> ix
strideStart (SafeStride ix
stride) ix
ix =
forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2
forall a. Num a => a -> a -> a
(+)
ix
ix
(forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
mod (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Num a => a -> a -> a
subtract (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
mod ix
ix ix
stride) ix
stride) ix
stride)
{-# INLINE strideStart #-}
strideSize :: Index ix => Stride ix -> Sz ix -> Sz ix
strideSize :: forall ix. Index ix => Stride ix -> Sz ix -> Sz ix
strideSize (SafeStride ix
stride) (SafeSz ix
sz) =
forall ix. ix -> Sz ix
SafeSz (forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (forall a. Num a => a -> a -> a
+ Int
1) forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
div (forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (forall a. Num a => a -> a -> a
subtract Int
1) ix
sz) ix
stride)
{-# INLINE strideSize #-}
toLinearIndexStride
:: Index ix
=> Stride ix
-> Sz ix
-> ix
-> Int
toLinearIndexStride :: forall ix. Index ix => Stride ix -> Sz ix -> ix -> Int
toLinearIndexStride (SafeStride ix
stride) Sz ix
sz ix
ix = forall ix. Index ix => Sz ix -> ix -> Int
toLinearIndex Sz ix
sz (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
div ix
ix ix
stride)
{-# INLINE toLinearIndexStride #-}
oneStride :: Index ix => Stride ix
oneStride :: forall ix. Index ix => Stride ix
oneStride = forall ix. ix -> Stride ix
SafeStride (forall ix. Index ix => Int -> ix
pureIndex Int
1)
{-# INLINE oneStride #-}