{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Representation.Array
where
import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape hiding ( zip )
import Data.Array.Accelerate.Representation.Type
import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import System.IO.Unsafe
import Text.Show ( showListWith )
import Prelude hiding ( (!!) )
import qualified Data.Vector.Unboxed as U
data Array sh e where
Array :: sh
-> ArrayData e
-> Array sh e
type Segments = Vector
type Scalar = Array DIM0
type Vector = Array DIM1
type Matrix = Array DIM2
data ArrayR a where
ArrayR :: { arrayRshape :: ShapeR sh
, arrayRtype :: TypeR e
}
-> ArrayR (Array sh e)
instance Show (ArrayR a) where
show (ArrayR shR eR) = "Array DIM" ++ show (rank shR) ++ " " ++ show eR
type ArraysR = TupR ArrayR
instance Show (TupR ArrayR e) where
show TupRunit = "()"
show (TupRsingle aR) = show aR
show (TupRpair aR1 aR2) = "(" ++ show aR1 ++ "," ++ show aR2 ++ ")"
showArraysR :: ArraysR a -> ShowS
showArraysR = shows
arraysRarray :: ShapeR sh -> TypeR e -> ArraysR (Array sh e)
arraysRarray shR eR = TupRsingle (ArrayR shR eR)
arraysRpair :: ArrayR a -> ArrayR b -> ArraysR (((), a), b)
arraysRpair a b = TupRunit `TupRpair` TupRsingle a `TupRpair` TupRsingle b
allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e)
allocateArray (ArrayR shR eR) sh = do
adata <- newArrayData eR (size shR sh)
return $! Array sh adata
fromFunction :: ArrayR (Array sh e) -> sh -> (sh -> e) -> Array sh e
fromFunction repr sh f = unsafePerformIO $! fromFunctionM repr sh (return . f)
fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e)
fromFunctionM (ArrayR shR eR) sh f = do
let !n = size shR sh
arr <- newArrayData eR n
let write !i
| i >= n = return ()
| otherwise = do
v <- f (fromIndex shR sh i)
writeArrayData eR arr i v
write (i+1)
write 0
return $! arr `seq` Array sh arr
fromList :: forall sh e. ArrayR (Array sh e) -> sh -> [e] -> Array sh e
fromList (ArrayR shR eR) sh xs = adata `seq` Array sh adata
where
!n = size shR sh
(adata, _) = runArrayData @e $ do
arr <- newArrayData eR n
let go !i _ | i >= n = return ()
go !i (v:vs) = writeArrayData eR arr i v >> go (i+1) vs
go _ [] = error "Data.Array.Accelerate.fromList: not enough input data"
go 0 xs
return (arr, undefined)
toList :: ArrayR (Array sh e) -> Array sh e -> [e]
toList (ArrayR shR eR) (Array sh adata) = go 0
where
!n = size shR sh
go !i | i >= n = []
| otherwise = indexArrayData eR adata i : go (i+1)
concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e
concatVectors tR vs = adata `seq` Array ((), len) adata
where
offsets = scanl (+) 0 (map (size dim1 . shape) vs)
len = last offsets
(adata, _) = runArrayData @e $ do
arr <- newArrayData tR len
sequence_ [ writeArrayData tR arr (i + k) (indexArrayData tR ad i)
| (Array ((), n) ad, k) <- vs `zip` offsets
, i <- [0 .. n - 1] ]
return (arr, undefined)
shape :: Array sh e -> sh
shape (Array sh _) = sh
reshape :: HasCallStack => ShapeR sh -> sh -> ShapeR sh' -> Array sh' e -> Array sh e
reshape shR sh shR' (Array sh' adata)
= boundsCheck "shape mismatch" (size shR sh == size shR' sh')
$ Array sh adata
(!) :: (ArrayR (Array sh e), Array sh e) -> sh -> e
(!) = uncurry indexArray
(!!) :: (TypeR e, Array sh e) -> Int -> e
(!!) = uncurry linearIndexArray
indexArray :: ArrayR (Array sh e) -> Array sh e -> sh -> e
indexArray (ArrayR shR adR) (Array sh adata) ix = indexArrayData adR adata (toIndex shR sh ix)
linearIndexArray :: TypeR e -> Array sh e -> Int -> e
linearIndexArray adR (Array _ adata) = indexArrayData adR adata
showArray :: (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String
showArray f arrR@(ArrayR shR _) arr@(Array sh _) = case shR of
ShapeRz -> "Scalar Z " ++ list
ShapeRsnoc ShapeRz -> "Vector (" ++ shapeString ++ ") " ++ list
ShapeRsnoc (ShapeRsnoc ShapeRz) -> "Matrix (" ++ shapeString ++ ") " ++ showMatrix f arrR arr
_ -> "Array (" ++ shapeString ++ ") " ++ list
where
shapeString = showShape shR sh
list = showListWith f (toList arrR arr) ""
showArrayShort :: Int -> (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String
showArrayShort n f arrR arr = '[' : go 0 (toList arrR arr)
where
go _ [] = "]"
go i (x:xs)
| i >= n = " ..]"
| otherwise = ',' : f x (go (i+1) xs)
showMatrix :: (e -> ShowS) -> ArrayR (Array DIM2 e) -> Array DIM2 e -> String
showMatrix f (ArrayR _ arrR) arr@(Array sh _)
| rows * cols == 0 = "[]"
| otherwise = "\n [" ++ ppMat 0 0
where
(((), rows), cols) = sh
lengths = U.generate (rows*cols) (\i -> length (f (linearIndexArray arrR arr i) ""))
widths = U.generate cols (\c -> U.maximum (U.generate rows (\r -> lengths U.! (r*cols+c))))
ppMat :: Int -> Int -> String
ppMat !r !c | c >= cols = ppMat (r+1) 0
ppMat !r !c =
let
!i = r*cols+c
!l = lengths U.! i
!w = widths U.! c
!pad = 1
cell = replicate (w-l+pad) ' ' ++ f (linearIndexArray arrR arr i) ""
before
| r > 0 && c == 0 = "\n "
| otherwise = ""
after
| r >= rows-1 && c >= cols-1 = "]"
| otherwise = ',' : ppMat r (c+1)
in
before ++ cell ++ after
reduceRank :: ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank (ArrayR (ShapeRsnoc shR) aeR) = ArrayR shR aeR
rnfArray :: ArrayR a -> a -> ()
rnfArray (ArrayR shR adR) (Array sh ad) = rnfShape shR sh `seq` rnfArrayData adR ad
rnfArrayR :: ArrayR arr -> ()
rnfArrayR (ArrayR shR tR) = rnfShapeR shR `seq` rnfTupR rnfScalarType tR
rnfArraysR :: ArraysR arrs -> arrs -> ()
rnfArraysR TupRunit () = ()
rnfArraysR (TupRsingle arrR) arr = rnfArray arrR arr
rnfArraysR (TupRpair aR1 aR2) (a1,a2) = rnfArraysR aR1 a1 `seq` rnfArraysR aR2 a2
liftArrayR :: ArrayR a -> Q (TExp (ArrayR a))
liftArrayR (ArrayR shR tR) = [|| ArrayR $$(liftShapeR shR) $$(liftTypeR tR) ||]
liftArraysR :: ArraysR arrs -> Q (TExp (ArraysR arrs))
liftArraysR TupRunit = [|| TupRunit ||]
liftArraysR (TupRsingle repr) = [|| TupRsingle $$(liftArrayR repr) ||]
liftArraysR (TupRpair a b) = [|| TupRpair $$(liftArraysR a) $$(liftArraysR b) ||]
liftArray :: forall sh e. ArrayR (Array sh e) -> Array sh e -> Q (TExp (Array sh e))
liftArray (ArrayR shR adR) (Array sh adata) =
[|| Array $$(liftElt (shapeType shR) sh) $$(liftArrayData sz adR adata) ||] `at` [t| Array $(liftTypeQ (shapeType shR)) $(liftTypeQ adR) |]
where
sz :: Int
sz = size shR sh
at :: Q (TExp t) -> Q Type -> Q (TExp t)
at e t = unsafeTExpCoerce $ sigE (unTypeQ e) t