{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-unticked-promoted-constructors #-}
#if __GLASGOW_HASKELL__ < 820
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}
#endif
module Data.Massiv.Core.Index.Internal
( Sz(SafeSz)
, pattern Sz
, pattern Sz1
, type Sz1
, unSz
, zeroSz
, oneSz
, liftSz
, liftSz2
, consSz
, unconsSz
, snocSz
, unsnocSz
, setSzM
, insertSzM
, pullOutSzM
, Dim(..)
, Dimension(DimN)
, pattern Dim1
, pattern Dim2
, pattern Dim3
, pattern Dim4
, pattern Dim5
, IsIndexDimension
, IsDimValid
, ReportInvalidDim
, Lower
, Index(..)
, Ix0(..)
, type Ix1
, pattern Ix1
, IndexException(..)
, SizeException(..)
, ShapeException(..)
, showsPrecWrapped
) where
import Control.DeepSeq
import Control.Exception (Exception(..), throw)
import Control.Monad.Catch (MonadThrow(..))
import Data.Coerce
import Data.Massiv.Core.Iterator
import Data.Typeable
import GHC.TypeLits
newtype Sz ix =
SafeSz ix
deriving (Eq, Ord, NFData)
pattern Sz :: Index ix => ix -> Sz ix
pattern Sz ix <- SafeSz ix where
Sz ix = SafeSz (liftIndex (max 0) ix)
{-# COMPLETE Sz #-}
type Sz1 = Sz Ix1
pattern Sz1 :: Ix1 -> Sz1
pattern Sz1 ix <- SafeSz ix where
Sz1 ix = SafeSz (max 0 ix)
{-# COMPLETE Sz1 #-}
instance Index ix => Show (Sz ix) where
showsPrec n sz@(SafeSz usz) = showsPrecWrapped n (str ++)
where
str =
"Sz" ++
case unDim (dimensions sz) of
1 -> "1 " ++ show usz
_ -> " (" ++ shows usz ")"
instance (Num ix, Index ix) => Num (Sz ix) where
(+) x y = Sz (coerce x + coerce y)
{-# INLINE (+) #-}
(-) x y = Sz (coerce x - coerce y)
{-# INLINE (-) #-}
(*) x y = SafeSz (coerce x * coerce y)
{-# INLINE (*) #-}
abs !x = x
{-# INLINE abs #-}
negate !_x = 0
{-# INLINE negate #-}
signum x = SafeSz (signum (coerce x))
{-# INLINE signum #-}
fromInteger = Sz . fromInteger
{-# INLINE fromInteger #-}
unSz :: Sz ix -> ix
unSz (SafeSz ix) = ix
{-# INLINE unSz #-}
zeroSz :: Index ix => Sz ix
zeroSz = SafeSz (pureIndex 0)
{-# INLINE zeroSz #-}
oneSz :: Index ix => Sz ix
oneSz = SafeSz (pureIndex 1)
{-# INLINE oneSz #-}
liftSz :: Index ix => (Int -> Int) -> Sz ix -> Sz ix
liftSz f (SafeSz ix) = Sz (liftIndex f ix)
{-# INLINE liftSz #-}
liftSz2 :: Index ix => (Int -> Int -> Int) -> Sz ix -> Sz ix -> Sz ix
liftSz2 f sz1 sz2 = Sz (liftIndex2 f (coerce sz1) (coerce sz2))
{-# INLINE liftSz2 #-}
consSz :: Index ix => Sz1 -> Sz (Lower ix) -> Sz ix
consSz (SafeSz i) (SafeSz ix) = SafeSz (consDim i ix)
{-# INLINE consSz #-}
snocSz :: Index ix => Sz (Lower ix) -> Sz1 -> Sz ix
snocSz (SafeSz i) (SafeSz ix) = SafeSz (snocDim i ix)
{-# INLINE snocSz #-}
setSzM :: (MonadThrow m, Index ix) => Sz ix -> Dim -> Sz Int -> m (Sz ix)
setSzM (SafeSz sz) dim (SafeSz sz1) = SafeSz <$> setDimM sz dim sz1
{-# INLINE setSzM #-}
insertSzM :: (MonadThrow m, Index ix) => Sz (Lower ix) -> Dim -> Sz Int -> m (Sz ix)
insertSzM (SafeSz sz) dim (SafeSz sz1) = SafeSz <$> insertDimM sz dim sz1
{-# INLINE insertSzM #-}
unconsSz :: Index ix => Sz ix -> (Sz1, Sz (Lower ix))
unconsSz (SafeSz sz) = coerce (unconsDim sz)
{-# INLINE unconsSz #-}
unsnocSz :: Index ix => Sz ix -> (Sz (Lower ix), Sz1)
unsnocSz (SafeSz sz) = coerce (unsnocDim sz)
{-# INLINE unsnocSz #-}
pullOutSzM :: (MonadThrow m, Index ix) => Sz ix -> Dim -> m (Sz Ix1, Sz (Lower ix))
pullOutSzM (SafeSz sz) = fmap coerce . pullOutDimM sz
{-# INLINE pullOutSzM #-}
newtype Dim = Dim { unDim :: Int } deriving (Eq, Ord, Num, Real, Integral, Enum, NFData)
instance Show Dim where
show (Dim d) = "(Dim " ++ show d ++ ")"
data Dimension (n :: Nat) where
DimN :: (1 <= n, KnownNat n) => Dimension n
pattern Dim1 :: Dimension 1
pattern Dim1 = DimN
pattern Dim2 :: Dimension 2
pattern Dim2 = DimN
pattern Dim3 :: Dimension 3
pattern Dim3 = DimN
pattern Dim4 :: Dimension 4
pattern Dim4 = DimN
pattern Dim5 :: Dimension 5
pattern Dim5 = DimN
type IsIndexDimension ix n = (1 <= n, n <= Dimensions ix, Index ix, KnownNat n)
type family Lower ix :: *
type family ReportInvalidDim (dims :: Nat) (n :: Nat) isNotZero isLess :: Bool where
ReportInvalidDim dims n True True = True
ReportInvalidDim dims n True False =
TypeError (Text "Dimension " :<>: ShowType n :<>: Text " is higher than " :<>:
Text "the maximum expected " :<>: ShowType dims)
ReportInvalidDim dims n False isLess =
TypeError (Text "Zero dimensional indices are not supported")
type family IsDimValid ix n :: Bool where
IsDimValid ix n = ReportInvalidDim (Dimensions ix) n (1 <=? n) (n <=? Dimensions ix)
class ( Eq ix
, Ord ix
, Show ix
, NFData ix
, Eq (Lower ix)
, Ord (Lower ix)
, Show (Lower ix)
, NFData (Lower ix)
, KnownNat (Dimensions ix)
) =>
Index ix
where
type Dimensions ix :: Nat
dimensions :: proxy ix -> Dim
totalElem :: Sz ix -> Int
consDim :: Int -> Lower ix -> ix
unconsDim :: ix -> (Int, Lower ix)
snocDim :: Lower ix -> Int -> ix
unsnocDim :: ix -> (Lower ix, Int)
pullOutDimM :: MonadThrow m => ix -> Dim -> m (Int, Lower ix)
insertDimM :: MonadThrow m => Lower ix -> Dim -> Int -> m ix
getDimM :: MonadThrow m => ix -> Dim -> m Int
getDimM ix dim = fst <$> modifyDimM ix dim id
{-# INLINE [1] getDimM #-}
setDimM :: MonadThrow m => ix -> Dim -> Int -> m ix
setDimM ix dim i = snd <$> modifyDimM ix dim (const i)
{-# INLINE [1] setDimM #-}
modifyDimM :: MonadThrow m => ix -> Dim -> (Int -> Int) -> m (Int, ix)
modifyDimM ix dim f = do
i <- getDimM ix dim
ix' <- setDimM ix dim (f i)
pure (i, ix')
{-# INLINE [1] modifyDimM #-}
pureIndex :: Int -> ix
liftIndex2 :: (Int -> Int -> Int) -> ix -> ix -> ix
liftIndex :: (Int -> Int) -> ix -> ix
liftIndex f = liftIndex2 (\_ i -> f i) (pureIndex 0)
{-# INLINE [1] liftIndex #-}
foldlIndex :: (a -> Int -> a) -> a -> ix -> a
default foldlIndex :: Index (Lower ix) =>
(a -> Int -> a) -> a -> ix -> a
foldlIndex f !acc !ix = foldlIndex f (f acc i0) ixL
where
!(i0, ixL) = unconsDim ix
{-# INLINE [1] foldlIndex #-}
isSafeIndex ::
Sz ix
-> ix
-> Bool
default isSafeIndex :: Index (Lower ix) =>
Sz ix -> ix -> Bool
isSafeIndex sz !ix = isSafeIndex n0 i0 && isSafeIndex szL ixL
where
!(n0, szL) = unconsSz sz
!(i0, ixL) = unconsDim ix
{-# INLINE [1] isSafeIndex #-}
toLinearIndex ::
Sz ix
-> ix
-> Int
default toLinearIndex :: Index (Lower ix) =>
Sz ix -> ix -> Int
toLinearIndex (SafeSz sz) !ix = toLinearIndex (SafeSz szL) ixL * n + i
where
!(szL, n) = unsnocDim sz
!(ixL, i) = unsnocDim ix
{-# INLINE [1] toLinearIndex #-}
toLinearIndexAcc :: Int -> ix -> ix -> Int
default toLinearIndexAcc :: Index (Lower ix) =>
Int -> ix -> ix -> Int
toLinearIndexAcc !acc !sz !ix = toLinearIndexAcc (acc * n + i) szL ixL
where
!(n, szL) = unconsDim sz
!(i, ixL) = unconsDim ix
{-# INLINE [1] toLinearIndexAcc #-}
fromLinearIndex :: Sz ix -> Int -> ix
default fromLinearIndex :: Index (Lower ix) =>
Sz ix -> Int -> ix
fromLinearIndex (SafeSz sz) k = consDim q ixL
where
!(q, ixL) = fromLinearIndexAcc (snd (unconsDim sz)) k
{-# INLINE [1] fromLinearIndex #-}
fromLinearIndexAcc :: ix -> Int -> (Int, ix)
default fromLinearIndexAcc :: Index (Lower ix) =>
ix -> Int -> (Int, ix)
fromLinearIndexAcc ix' !k = (q, consDim r ixL)
where
!(m, ix) = unconsDim ix'
!(kL, ixL) = fromLinearIndexAcc ix k
!(q, r) = quotRem kL m
{-# INLINE [1] fromLinearIndexAcc #-}
repairIndex ::
Sz ix
-> ix
-> (Sz Int -> Int -> Int)
-> (Sz Int -> Int -> Int)
-> ix
default repairIndex :: Index (Lower ix) =>
Sz ix -> ix -> (Sz Int -> Int -> Int) -> (Sz Int -> Int -> Int) -> ix
repairIndex sz !ix rBelow rOver =
consDim (repairIndex n i rBelow rOver) (repairIndex szL ixL rBelow rOver)
where
!(n, szL) = unconsSz sz
!(i, ixL) = unconsDim ix
{-# INLINE [1] repairIndex #-}
iterM ::
Monad m
=> ix
-> ix
-> ix
-> (Int -> Int -> Bool)
-> a
-> (ix -> a -> m a)
-> m a
default iterM :: (Index (Lower ix), Monad m) =>
ix -> ix -> ix -> (Int -> Int -> Bool) -> a -> (ix -> a -> m a) -> m a
iterM !sIx eIx !incIx cond !acc f =
loopM s (`cond` e) (+ inc) acc $ \ !i !acc0 ->
iterM sIxL eIxL incIxL cond acc0 $ \ !ix -> f (consDim i ix)
where
!(s, sIxL) = unconsDim sIx
!(e, eIxL) = unconsDim eIx
!(inc, incIxL) = unconsDim incIx
{-# INLINE iterM #-}
iterM_ :: Monad m => ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
default iterM_ :: (Index (Lower ix), Monad m) =>
ix -> ix -> ix -> (Int -> Int -> Bool) -> (ix -> m a) -> m ()
iterM_ !sIx eIx !incIx cond f =
loopM_ s (`cond` e) (+ inc) $ \ !i -> iterM_ sIxL eIxL incIxL cond $ \ !ix -> f (consDim i ix)
where
!(s, sIxL) = unconsDim sIx
!(e, eIxL) = unconsDim eIx
!(inc, incIxL) = unconsDim incIx
{-# INLINE iterM_ #-}
data Ix0 = Ix0 deriving (Eq, Ord, Show)
instance NFData Ix0 where
rnf Ix0 = ()
type Ix1 = Int
pattern Ix1 :: Int -> Ix1
pattern Ix1 i = i
{-# COMPLETE Ix1 #-}
type instance Lower Int = Ix0
instance Index Ix1 where
type Dimensions Ix1 = 1
dimensions _ = 1
{-# INLINE [1] dimensions #-}
totalElem = unSz
{-# INLINE [1] totalElem #-}
isSafeIndex (SafeSz k) !i = 0 <= i && i < k
{-# INLINE [1] isSafeIndex #-}
toLinearIndex _ = id
{-# INLINE [1] toLinearIndex #-}
toLinearIndexAcc !acc m i = acc * m + i
{-# INLINE [1] toLinearIndexAcc #-}
fromLinearIndex _ = id
{-# INLINE [1] fromLinearIndex #-}
fromLinearIndexAcc n k = k `quotRem` n
{-# INLINE [1] fromLinearIndexAcc #-}
repairIndex k@(SafeSz ksz) !i rBelow rOver
| ksz <= 0 = throw $ IndexZeroException ksz
| i < 0 = rBelow k i
| i >= ksz = rOver k i
| otherwise = i
{-# INLINE [1] repairIndex #-}
consDim i _ = i
{-# INLINE [1] consDim #-}
unconsDim i = (i, Ix0)
{-# INLINE [1] unconsDim #-}
snocDim _ i = i
{-# INLINE [1] snocDim #-}
unsnocDim i = (Ix0, i)
{-# INLINE [1] unsnocDim #-}
getDimM ix 1 = pure ix
getDimM ix d = throwM $ IndexDimensionException ix d
{-# INLINE [1] getDimM #-}
setDimM _ 1 ix = pure ix
setDimM ix d _ = throwM $ IndexDimensionException ix d
{-# INLINE [1] setDimM #-}
modifyDimM ix 1 f = pure (ix, f ix)
modifyDimM ix d _ = throwM $ IndexDimensionException ix d
{-# INLINE [1] modifyDimM #-}
pullOutDimM ix 1 = pure (ix, Ix0)
pullOutDimM ix d = throwM $ IndexDimensionException ix d
{-# INLINE [1] pullOutDimM #-}
insertDimM Ix0 1 i = pure i
insertDimM ix d _ = throwM $ IndexDimensionException ix d
{-# INLINE [1] insertDimM #-}
pureIndex i = i
{-# INLINE [1] pureIndex #-}
liftIndex f = f
{-# INLINE [1] liftIndex #-}
liftIndex2 f = f
{-# INLINE [1] liftIndex2 #-}
foldlIndex f = f
{-# INLINE [1] foldlIndex #-}
iterM k0 k1 inc cond = loopM k0 (`cond` k1) (+inc)
{-# INLINE iterM #-}
iterM_ k0 k1 inc cond = loopM_ k0 (`cond` k1) (+inc)
{-# INLINE iterM_ #-}
data IndexException where
IndexZeroException :: Index ix => !ix -> IndexException
IndexDimensionException :: (NFData ix, Show ix, Typeable ix) => !ix -> !Dim -> IndexException
IndexOutOfBoundsException :: Index ix => !(Sz ix) -> !ix -> IndexException
instance Show IndexException where
show (IndexZeroException ix) = "IndexZeroException: " ++ showsPrec 1 ix ""
show (IndexDimensionException ix dim) =
"IndexDimensionException: " ++ showsPrec 1 dim " for " ++ showsPrec 1 ix ""
show (IndexOutOfBoundsException sz ix) =
"IndexOutOfBoundsException: " ++ showsPrec 1 ix " is not safe for " ++ showsPrec 1 sz ""
showsPrec n exc = showsPrecWrapped n (show exc ++)
instance Eq IndexException where
e1 == e2 =
case (e1, e2) of
(IndexZeroException i1, IndexZeroException i2) -> show i1 == show i2
(IndexDimensionException i1 d1, IndexDimensionException i2 d2) ->
show i1 == show i2 && d1 == d2
(IndexOutOfBoundsException sz1 i1, IndexOutOfBoundsException sz2 i2) ->
show sz1 == show sz2 && show i1 == show i2
_ -> False
instance NFData IndexException where
rnf =
\case
IndexZeroException i -> rnf i
IndexDimensionException i d -> i `deepseq` rnf d
IndexOutOfBoundsException sz i -> sz `deepseq` rnf i
instance Exception IndexException
data SizeException where
SizeMismatchException :: Index ix => !(Sz ix) -> !(Sz ix) -> SizeException
SizeElementsMismatchException :: (Index ix, Index ix') => !(Sz ix) -> !(Sz ix') -> SizeException
SizeSubregionException :: Index ix => !(Sz ix) -> !ix -> !(Sz ix) -> SizeException
SizeEmptyException :: Index ix => !(Sz ix) -> SizeException
instance Eq SizeException where
e1 == e2 =
case (e1, e2) of
(SizeMismatchException sz1 sz1', SizeMismatchException sz2 sz2') ->
show sz1 == show sz2 && show sz1' == show sz2'
(SizeElementsMismatchException sz1 sz1', SizeElementsMismatchException sz2 sz2') ->
show sz1 == show sz2 && show sz1' == show sz2'
(SizeSubregionException sz1 i1 sz1', SizeSubregionException sz2 i2 sz2') ->
show sz1 == show sz2 && show i1 == show i2 && show sz1' == show sz2'
(SizeEmptyException sz1, SizeEmptyException sz2) -> show sz1 == show sz2
_ -> False
instance NFData SizeException where
rnf =
\case
SizeMismatchException sz sz' -> sz `deepseq` rnf sz'
SizeElementsMismatchException sz sz' -> sz `deepseq` rnf sz'
SizeSubregionException sz i sz' -> sz `deepseq` i `deepseq` rnf sz'
SizeEmptyException sz -> rnf sz
instance Exception SizeException
instance Show SizeException where
show (SizeMismatchException sz sz') =
"SizeMismatchException: (" ++ show sz ++ ") vs (" ++ show sz' ++ ")"
show (SizeElementsMismatchException sz sz') =
"SizeElementsMismatchException: (" ++ show sz ++ ") vs (" ++ show sz' ++ ")"
show (SizeSubregionException sz' ix sz) =
"SizeSubregionException: (" ++
show sz' ++ ") is to small for " ++ show ix ++ " (" ++ show sz ++ ")"
show (SizeEmptyException sz) =
"SizeEmptyException: (" ++ show sz ++ ") corresponds to an empty array"
showsPrec n exc = showsPrecWrapped n (show exc ++)
data ShapeException
= DimTooShortException !Sz1 !Sz1
| DimTooLongException
deriving Eq
instance Show ShapeException where
showsPrec _ DimTooLongException = ("DimTooLongException" ++)
showsPrec n (DimTooShortException sz sz') =
showsPrecWrapped
n
(("DimTooShortException: expected (" ++) . shows sz . ("), got (" ++) . shows sz' . (")" ++))
instance Exception ShapeException
showsPrecWrapped :: Int -> ShowS -> ShowS
showsPrecWrapped n inner
| n < 1 = inner
| otherwise = ('(':) . inner . (")" ++)