{-# LANGUAGE TypeOperators, ExplicitForAll, FlexibleContexts #-}
module Data.Array.Repa.Operators.IndexSpace
( reshape
, append, (++)
, transpose
, extract
, backpermute, unsafeBackpermute
, backpermuteDft, unsafeBackpermuteDft
, extend, unsafeExtend
, slice, unsafeSlice)
where
import Data.Array.Repa.Index
import Data.Array.Repa.Slice
import Data.Array.Repa.Base
import Data.Array.Repa.Repr.Delayed
import Data.Array.Repa.Operators.Traversal
import Data.Array.Repa.Shape as S
import Prelude hiding ((++), traverse)
import qualified Prelude as P
stage = "Data.Array.Repa.Operators.IndexSpace"
reshape :: ( Shape sh1, Shape sh2
, Source r1 e)
=> sh2
-> Array r1 sh1 e
-> Array D sh2 e
reshape sh2 arr
| not $ S.size sh2 == S.size (extent arr)
= error
$ stage P.++ ".reshape: reshaped array will not match size of the original"
reshape sh2 arr
= fromFunction sh2
$ unsafeIndex arr . fromIndex (extent arr) . toIndex sh2
{-# INLINE [2] reshape #-}
append, (++)
:: ( Shape sh
, Source r1 e, Source r2 e)
=> Array r1 (sh :. Int) e
-> Array r2 (sh :. Int) e
-> Array D (sh :. Int) e
append arr1 arr2
= unsafeTraverse2 arr1 arr2 fnExtent fnElem
where
(_ :. n) = extent arr1
fnExtent (sh1 :. i) (sh2 :. j)
= intersectDim sh1 sh2 :. (i + j)
fnElem f1 f2 (sh :. i)
| i < n = f1 (sh :. i)
| otherwise = f2 (sh :. (i - n))
{-# INLINE [2] append #-}
(++) arr1 arr2 = append arr1 arr2
{-# INLINE (++) #-}
transpose
:: (Shape sh, Source r e)
=> Array r (sh :. Int :. Int) e
-> Array D (sh :. Int :. Int) e
transpose arr
= unsafeTraverse arr
(\(sh :. m :. n) -> (sh :. n :.m))
(\f -> \(sh :. i :. j) -> f (sh :. j :. i))
{-# INLINE [2] transpose #-}
extract :: (Shape sh, Source r e)
=> sh
-> sh
-> Array r sh e
-> Array D sh e
extract start sz arr
= fromFunction sz (\ix -> arr `unsafeIndex` (addDim start ix))
{-# INLINE [2] extract #-}
backpermute, unsafeBackpermute
:: forall r sh1 sh2 e
. ( Shape sh1
, Source r e)
=> sh2
-> (sh2 -> sh1)
-> Array r sh1 e
-> Array D sh2 e
backpermute newExtent perm arr
= traverse arr (const newExtent) (. perm)
{-# INLINE [2] backpermute #-}
unsafeBackpermute newExtent perm arr
= unsafeTraverse arr (const newExtent) (. perm)
{-# INLINE [2] unsafeBackpermute #-}
backpermuteDft, unsafeBackpermuteDft
:: forall r1 r2 sh1 sh2 e
. ( Shape sh1, Shape sh2
, Source r1 e, Source r2 e)
=> Array r2 sh2 e
-> (sh2 -> Maybe sh1)
-> Array r1 sh1 e
-> Array D sh2 e
backpermuteDft arrDft fnIndex arrSrc
= fromFunction (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc `index` ix'
Nothing -> arrDft `index` ix
{-# INLINE [2] backpermuteDft #-}
unsafeBackpermuteDft arrDft fnIndex arrSrc
= fromFunction (extent arrDft) fnElem
where fnElem ix
= case fnIndex ix of
Just ix' -> arrSrc `unsafeIndex` ix'
Nothing -> arrDft `unsafeIndex` ix
{-# INLINE [2] unsafeBackpermuteDft #-}
extend, unsafeExtend
:: ( Slice sl
, Shape (SliceShape sl)
, Source r e)
=> sl
-> Array r (SliceShape sl) e
-> Array D (FullShape sl) e
extend sl arr
= backpermute
(fullOfSlice sl (extent arr))
(sliceOfFull sl)
arr
{-# INLINE [2] extend #-}
unsafeExtend sl arr
= unsafeBackpermute
(fullOfSlice sl (extent arr))
(sliceOfFull sl)
arr
{-# INLINE [2] unsafeExtend #-}
slice, unsafeSlice
:: ( Slice sl
, Shape (FullShape sl)
, Source r e)
=> Array r (FullShape sl) e
-> sl
-> Array D (SliceShape sl) e
slice arr sl
= backpermute
(sliceOfFull sl (extent arr))
(fullOfSlice sl)
arr
{-# INLINE [2] slice #-}
unsafeSlice arr sl
= unsafeBackpermute
(sliceOfFull sl (extent arr))
(fullOfSlice sl)
arr
{-# INLINE [2] unsafeSlice #-}