{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE PatternSynonyms #-}

-- |
-- Module      : Data.Massiv.Core.Index.Stride
-- Copyright   : (c) Alexey Kuleshevich 2018-2022
-- License     : BSD3
-- Maintainer  : Alexey Kuleshevich <lehins@yandex.ru>
-- Stability   : experimental
-- Portability : non-portable
module Data.Massiv.Core.Index.Stride (
  Stride (SafeStride),
  pattern Stride,
  unStride,
  oneStride,
  toLinearIndexStride,
  strideStart,
  strideSize,
) where

import Control.DeepSeq (NFData)
import Data.Massiv.Core.Index.Internal
import System.Random.Stateful (Random, Uniform (..), UniformRange (..))

-- | Stride provides a way to ignore elements of an array if an index is divisible by a
-- corresponding value in a stride. So, for a @Stride (i :. j)@ only elements with indices will be
-- kept around:
--
-- @
-- ( 0 :. 0) ( 0 :. j) ( 0 :. 2j) ( 0 :. 3j) ...
-- ( i :. 0) ( i :. j) ( i :. 2j) ( i :. 3j) ...
-- (2i :. 0) (2i :. j) (2i :. 2j) (2i :. 3j) ...
-- ...
-- @
--
-- Only positive strides make sense, so `Stride` pattern synonym constructor will prevent a user
-- from creating a stride with negative or zero values, thus promoting safety of the library.
--
-- ====__Examples:__
--
-- * Default and minimal stride of @`Stride` (`pureIndex` 1)@ will have no affect and all elements
--   will kept.
--
-- * If stride is @`Stride` 2@, then every 2nd element (i.e. with index 1, 3, 5, ..) will be skipped
--   and only elemnts with indices divisible by 2 will be kept around.
--
-- * In case of two dimensions, if what you want is to keep all rows divisible by 5, but keep every
--   column intact then you'd use @Stride (5 :. 1)@.
--
-- @since 0.2.1
newtype Stride ix = SafeStride ix deriving (Stride ix -> Stride ix -> Bool
forall ix. Eq ix => Stride ix -> Stride ix -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Stride ix -> Stride ix -> Bool
$c/= :: forall ix. Eq ix => Stride ix -> Stride ix -> Bool
== :: Stride ix -> Stride ix -> Bool
$c== :: forall ix. Eq ix => Stride ix -> Stride ix -> Bool
Eq, Stride ix -> Stride ix -> Bool
Stride ix -> Stride ix -> Ordering
Stride ix -> Stride ix -> Stride ix
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {ix}. Ord ix => Eq (Stride ix)
forall ix. Ord ix => Stride ix -> Stride ix -> Bool
forall ix. Ord ix => Stride ix -> Stride ix -> Ordering
forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
min :: Stride ix -> Stride ix -> Stride ix
$cmin :: forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
max :: Stride ix -> Stride ix -> Stride ix
$cmax :: forall ix. Ord ix => Stride ix -> Stride ix -> Stride ix
>= :: Stride ix -> Stride ix -> Bool
$c>= :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
> :: Stride ix -> Stride ix -> Bool
$c> :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
<= :: Stride ix -> Stride ix -> Bool
$c<= :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
< :: Stride ix -> Stride ix -> Bool
$c< :: forall ix. Ord ix => Stride ix -> Stride ix -> Bool
compare :: Stride ix -> Stride ix -> Ordering
$ccompare :: forall ix. Ord ix => Stride ix -> Stride ix -> Ordering
Ord, Stride ix -> ()
forall ix. NFData ix => Stride ix -> ()
forall a. (a -> ()) -> NFData a
rnf :: Stride ix -> ()
$crnf :: forall ix. NFData ix => Stride ix -> ()
NFData)

-- | A safe bidirectional pattern synonym for `Stride` construction that will make sure stride
-- elements are always positive.
--
-- @since 0.2.1
pattern Stride :: Index ix => ix -> Stride ix
pattern $bStride :: forall ix. Index ix => ix -> Stride ix
$mStride :: forall {r} {ix}.
Index ix =>
Stride ix -> (ix -> r) -> ((# #) -> r) -> r
Stride ix <- SafeStride ix
  where
    Stride ix
ix = forall ix. ix -> Stride ix
SafeStride (forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (forall a. Ord a => a -> a -> a
max Int
1) ix
ix)

{-# COMPLETE Stride #-}

instance Index ix => Show (Stride ix) where
  showsPrec :: Int -> Stride ix -> ShowS
showsPrec Int
n (SafeStride ix
ix) = Int -> ShowS -> ShowS
showsPrecWrapped Int
n ((String
"Stride " forall a. [a] -> [a] -> [a]
++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
1 ix
ix)

instance (UniformRange ix, Index ix) => Uniform (Stride ix) where
  uniformM :: forall g (m :: * -> *). StatefulGen g m => g -> m (Stride ix)
uniformM g
g = forall ix. ix -> Stride ix
SafeStride forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (forall ix. Index ix => Int -> ix
pureIndex Int
1, forall ix. Index ix => Int -> ix
pureIndex forall a. Bounded a => a
maxBound) g
g
  {-# INLINE uniformM #-}

instance UniformRange ix => UniformRange (Stride ix) where
  uniformRM :: forall g (m :: * -> *).
StatefulGen g m =>
(Stride ix, Stride ix) -> g -> m (Stride ix)
uniformRM (SafeStride ix
l, SafeStride ix
u) g
g = forall ix. ix -> Stride ix
SafeStride forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (ix
l, ix
u) g
g
  {-# INLINE uniformRM #-}

instance (UniformRange ix, Index ix) => Random (Stride ix)

-- | Just a helper function for unwrapping `Stride`.
--
-- @since 0.2.1
unStride :: Stride ix -> ix
unStride :: forall ix. Stride ix -> ix
unStride (SafeStride ix
ix) = ix
ix
{-# INLINE unStride #-}

-- | Adjust starting index according to the stride
--
-- @since 0.2.1
strideStart :: Index ix => Stride ix -> ix -> ix
strideStart :: forall ix. Index ix => Stride ix -> ix -> ix
strideStart (SafeStride ix
stride) ix
ix =
  forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2
    forall a. Num a => a -> a -> a
(+)
    ix
ix
    (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
mod (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Num a => a -> a -> a
subtract (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
mod ix
ix ix
stride) ix
stride) ix
stride)
{-# INLINE strideStart #-}

-- | Adjust size according to the stride.
--
-- @since 0.2.1
strideSize :: Index ix => Stride ix -> Sz ix -> Sz ix
strideSize :: forall ix. Index ix => Stride ix -> Sz ix -> Sz ix
strideSize (SafeStride ix
stride) (SafeSz ix
sz) =
  forall ix. ix -> Sz ix
SafeSz (forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (forall a. Num a => a -> a -> a
+ Int
1) forall a b. (a -> b) -> a -> b
$ forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
div (forall ix. Index ix => (Int -> Int) -> ix -> ix
liftIndex (forall a. Num a => a -> a -> a
subtract Int
1) ix
sz) ix
stride)
{-# INLINE strideSize #-}

-- | Compute linear index with stride using the original size and index
--
-- @since 0.2.1
toLinearIndexStride
  :: Index ix
  => Stride ix
  -- ^ Stride
  -> Sz ix
  -- ^ Size
  -> ix
  -- ^ Index
  -> Int
toLinearIndexStride :: forall ix. Index ix => Stride ix -> Sz ix -> ix -> Int
toLinearIndexStride (SafeStride ix
stride) Sz ix
sz ix
ix = forall ix. Index ix => Sz ix -> ix -> Int
toLinearIndex Sz ix
sz (forall ix. Index ix => (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex2 forall a. Integral a => a -> a -> a
div ix
ix ix
stride)
{-# INLINE toLinearIndexStride #-}

-- | A default stride of @1@, where all elements are kept
--
-- @since 0.2.1
oneStride :: Index ix => Stride ix
oneStride :: forall ix. Index ix => Stride ix
oneStride = forall ix. ix -> Stride ix
SafeStride (forall ix. Index ix => Int -> ix
pureIndex Int
1)
{-# INLINE oneStride #-}