module Data.Array.Accelerate.Utility.Arrange (
   mapWithIndex,
   gather,
   scatter,
   ) where

import qualified Data.Array.Accelerate as A
import Data.Array.Accelerate (Acc, Array, Exp)


mapWithIndex ::
   (A.Shape sh, A.Elt a, A.Elt b) =>
   (Exp sh -> Exp a -> Exp b) ->
   Acc (Array sh a) -> Acc (Array sh b)
mapWithIndex f xs =
   A.zipWith f (A.generate (A.shape xs) id) xs


gather ::
   (A.Shape ix, A.Shape ix', A.Elt ix', A.Elt a) =>
   Acc (Array ix ix') -> Acc (Array ix' a) -> Acc (Array ix a)
gather indices xs =
   A.map (xs A.!) indices

scatter ::
   (A.Shape ix, A.Shape ix', A.Elt ix', A.Elt a) =>
   (Exp a -> Exp a -> Exp a) ->
   Acc (Array ix ix') -> Acc (Array ix' a) ->
   Acc (Array ix a) -> Acc (Array ix' a)
scatter f indices deflt xs =
   A.permute f deflt (indices A.!) xs