{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
module Data.Massiv.Array.Ops.Transform
(
transpose
, transposeInner
, transposeOuter
, backpermute
, resize
, resize'
, extract
, extract'
, extractFromTo
, extractFromTo'
, append
, append'
, splitAt
, splitAt'
, traverse
, traverse2
) where
import Control.Monad (guard)
import Data.Massiv.Array.Delayed.Internal
import Data.Massiv.Array.Ops.Construct
import Data.Massiv.Core.Common
import Data.Maybe (fromMaybe)
import Prelude hiding (splitAt, traverse)
extract :: Size r ix e
=> ix
-> ix
-> Array r ix e
-> Maybe (Array (EltRepr r ix) ix e)
extract !sIx !newSz !arr
| isSafeIndex sz1 sIx && isSafeIndex eIx1 sIx && isSafeIndex sz1 eIx =
Just $ unsafeExtract sIx newSz arr
| otherwise = Nothing
where
sz1 = liftIndex (+1) (size arr)
eIx1 = liftIndex (+1) eIx
eIx = liftIndex2 (+) sIx newSz
{-# INLINE extract #-}
extract' :: Size r ix e
=> ix
-> ix
-> Array r ix e
-> Array (EltRepr r ix) ix e
extract' !sIx !newSz !arr =
case extract sIx newSz arr of
Just arr' -> arr'
Nothing ->
error $
"Data.Massiv.Array.extract': Cannot extract an array of size " ++
show newSz ++
" starting at " ++ show sIx ++ " from within an array of size: " ++ show (size arr)
{-# INLINE extract' #-}
extractFromTo :: Size r ix e =>
ix
-> ix
-> Array r ix e
-> Maybe (Array (EltRepr r ix) ix e)
extractFromTo sIx eIx = extract sIx $ liftIndex2 (-) eIx sIx
{-# INLINE extractFromTo #-}
extractFromTo' :: Size r ix e =>
ix
-> ix
-> Array r ix e
-> Array (EltRepr r ix) ix e
extractFromTo' sIx eIx = extract' sIx $ liftIndex2 (-) eIx sIx
{-# INLINE extractFromTo' #-}
resize :: (Index ix', Size r ix e) => ix' -> Array r ix e -> Maybe (Array r ix' e)
resize !sz !arr
| totalElem sz == totalElem (size arr) = Just $ unsafeResize sz arr
| otherwise = Nothing
{-# INLINE resize #-}
resize' :: (Index ix', Size r ix e) => ix' -> Array r ix e -> Array r ix' e
resize' !sz !arr =
maybe
(error $
"Total number of elements do not match: " ++
show sz ++ " vs " ++ show (size arr))
id $
resize sz arr
{-# INLINE resize' #-}
transpose :: Source r Ix2 e => Array r Ix2 e -> Array D Ix2 e
transpose = transposeInner
{-# INLINE [1] transpose #-}
{-# RULES
"transpose . transpose" [~1] forall arr . transpose (transpose arr) = delay arr
"transposeInner . transposeInner" [~1] forall arr . transposeInner (transposeInner arr) = delay arr
"transposeOuter . transposeOuter" [~1] forall arr . transposeOuter (transposeOuter arr) = delay arr
#-}
transposeInner :: (Index (Lower ix), Source r' ix e)
=> Array r' ix e -> Array D ix e
transposeInner !arr = unsafeMakeArray (getComp arr) (transInner (size arr)) newVal
where
transInner !ix =
fromMaybe (errorImpossible "transposeInner" ix) $ do
n <- getDim ix (dimensions ix)
m <- getDim ix (dimensions ix - 1)
ix' <- setDim ix (dimensions ix) m
setDim ix' (dimensions ix - 1) n
{-# INLINE transInner #-}
newVal = unsafeIndex arr . transInner
{-# INLINE newVal #-}
{-# INLINE [1] transposeInner #-}
transposeOuter :: (Index (Lower ix), Source r' ix e)
=> Array r' ix e -> Array D ix e
transposeOuter !arr = unsafeMakeArray (getComp arr) (transOuter (size arr)) newVal
where
transOuter !ix =
fromMaybe (errorImpossible "transposeOuter" ix) $ do
n <- getDim ix 1
m <- getDim ix 2
ix' <- setDim ix 1 m
setDim ix' 2 n
{-# INLINE transOuter #-}
newVal = unsafeIndex arr . transOuter
{-# INLINE newVal #-}
{-# INLINE [1] transposeOuter #-}
backpermute :: (Source r' ix' e, Index ix) =>
ix
-> (ix -> ix')
-> Array r' ix' e
-> Array D ix e
backpermute sz ixF !arr = makeArray (getComp arr) sz (evaluateAt arr . ixF)
{-# INLINE backpermute #-}
append :: (Source r1 ix e, Source r2 ix e) =>
Dim -> Array r1 ix e -> Array r2 ix e -> Maybe (Array D ix e)
append n !arr1 !arr2 = do
let sz1 = size arr1
sz2 = size arr2
k1 <- getDim sz1 n
k2 <- getDim sz2 n
sz1' <- setDim sz2 n k1
guard $ sz1 == sz1'
newSz <- setDim sz1 n (k1 + k2)
return $
unsafeMakeArray (getComp arr1) newSz $ \ !ix ->
fromMaybe (errorImpossible "append" ix) $ do
k' <- getDim ix n
if k' < k1
then Just (unsafeIndex arr1 ix)
else do
i <- getDim ix n
ix' <- setDim ix n (i - k1)
return $ unsafeIndex arr2 ix'
{-# INLINE append #-}
append' :: (Source r1 ix e, Source r2 ix e) =>
Dim -> Array r1 ix e -> Array r2 ix e -> Array D ix e
append' dim arr1 arr2 =
case append dim arr1 arr2 of
Just arr -> arr
Nothing ->
error $
if 0 < dim && dim <= dimensions (size arr1)
then "append': Dimension mismatch: " ++ show (size arr1) ++ " and " ++ show (size arr2)
else "append': Invalid dimension: " ++ show dim
{-# INLINE append' #-}
splitAt ::
(Size r ix e, r' ~ EltRepr r ix)
=> Dim
-> Int
-> Array r ix e
-> Maybe (Array r' ix e, Array r' ix e)
splitAt dim i arr = do
let sz = size arr
eIx <- setDim sz dim i
sIx <- setDim zeroIndex dim i
arr1 <- extractFromTo zeroIndex eIx arr
arr2 <- extractFromTo sIx sz arr
return (arr1, arr2)
{-# INLINE splitAt #-}
splitAt' :: (Size r ix e, r' ~ EltRepr r ix) =>
Dim -> Int -> Array r ix e -> (Array r' ix e, Array r' ix e)
splitAt' dim i arr =
case splitAt dim i arr of
Just res -> res
Nothing ->
error $
"Data.Massiv.Array.splitAt': " ++
if 0 < dim && dim <= dimensions (size arr)
then "Index out of bounds: " ++
show i ++ " for dimension: " ++ show dim ++ " and array with size: " ++ show (size arr)
else "Invalid dimension: " ++ show dim ++ " for array with size: " ++ show (size arr)
{-# INLINE splitAt' #-}
traverse
:: (Source r1 ix1 e1, Index ix)
=> ix
-> ((ix1 -> e1) -> ix -> e)
-> Array r1 ix1 e1
-> Array D ix e
traverse sz f arr1 = makeArray (getComp arr1) sz (f (evaluateAt arr1))
{-# INLINE traverse #-}
traverse2
:: (Source r1 ix1 e1, Source r2 ix2 e2, Index ix)
=> ix
-> ((ix1 -> e1) -> (ix2 -> e2) -> ix -> e)
-> Array r1 ix1 e1
-> Array r2 ix2 e2
-> Array D ix e
traverse2 sz f arr1 arr2 = makeArray (getComp arr1) sz (f (evaluateAt arr1) (evaluateAt arr2))
{-# INLINE traverse2 #-}
errorImpossible :: Show c => String -> c -> a
errorImpossible fName cause =
error $ "Data.Massiv.Array." ++ fName ++ ": Impossible happened " ++ show cause
{-# NOINLINE errorImpossible #-}