{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Representation.Array
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Representation.Array
  where

import Data.Array.Accelerate.Array.Data
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Representation.Elt
import Data.Array.Accelerate.Representation.Shape                   hiding ( zip )
import Data.Array.Accelerate.Representation.Type

import Language.Haskell.TH
import Language.Haskell.TH.Syntax
import System.IO.Unsafe
import Text.Show                                                    ( showListWith )
import Prelude                                                      hiding ( (!!) )
import qualified Data.Vector.Unboxed                                as U


-- | Array data type, where the type arguments regard the representation
-- types of the shape and elements.
--
data Array sh e where
  Array :: sh                         -- extent of dimensions = shape
        -> ArrayData e                -- array payload
        -> Array sh e

-- | Segment descriptor (vector of segment lengths).
--
-- To represent nested one-dimensional arrays, we use a flat array of data
-- values in conjunction with a /segment descriptor/, which stores the lengths
-- of the subarrays.
--
type Segments = Vector

type Scalar = Array DIM0    -- ^ A singleton array with one element
type Vector = Array DIM1    -- ^ A one-dimensional array
type Matrix = Array DIM2    -- ^ A two-dimensional array

-- | Type witnesses shape and data layout of an array
--
data ArrayR a where
  ArrayR :: { ArrayR (Array sh e) -> ShapeR sh
arrayRshape :: ShapeR sh
            , ArrayR (Array sh e) -> TypeR e
arrayRtype  :: TypeR e
            }
         -> ArrayR (Array sh e)

instance Show (ArrayR a) where
  show :: ArrayR a -> String
show (ArrayR ShapeR sh
shR TypeR e
eR) = String
"Array DIM" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (ShapeR sh -> Int
forall sh. ShapeR sh -> Int
rank ShapeR sh
shR) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ TypeR e -> String
forall a. Show a => a -> String
show TypeR e
eR

type ArraysR = TupR ArrayR

instance Show (TupR ArrayR e) where
  show :: TupR ArrayR e -> String
show TupR ArrayR e
TupRunit           = String
"()"
  show (TupRsingle ArrayR e
aR)    = ArrayR e -> String
forall a. Show a => a -> String
show ArrayR e
aR
  show (TupRpair TupR ArrayR a
aR1 TupR ArrayR b
aR2) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ TupR ArrayR a -> String
forall a. Show a => a -> String
show TupR ArrayR a
aR1 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"," String -> ShowS
forall a. [a] -> [a] -> [a]
++ TupR ArrayR b -> String
forall a. Show a => a -> String
show TupR ArrayR b
aR2 String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

showArraysR :: ArraysR a -> ShowS
showArraysR :: ArraysR a -> ShowS
showArraysR = ArraysR a -> ShowS
forall a. Show a => a -> ShowS
shows

arraysRarray :: ShapeR sh -> TypeR e -> ArraysR (Array sh e)
arraysRarray :: ShapeR sh -> TypeR e -> ArraysR (Array sh e)
arraysRarray ShapeR sh
shR TypeR e
eR = ArrayR (Array sh e) -> ArraysR (Array sh e)
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (ShapeR sh -> TypeR e -> ArrayR (Array sh e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shR TypeR e
eR)

arraysRpair :: ArrayR a -> ArrayR b -> ArraysR (((), a), b)
arraysRpair :: ArrayR a -> ArrayR b -> ArraysR (((), a), b)
arraysRpair ArrayR a
a ArrayR b
b = TupR ArrayR ()
forall (s :: * -> *). TupR s ()
TupRunit TupR ArrayR () -> TupR ArrayR a -> TupR ArrayR ((), a)
forall (s :: * -> *) a b. TupR s a -> TupR s b -> TupR s (a, b)
`TupRpair` ArrayR a -> TupR ArrayR a
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ArrayR a
a TupR ArrayR ((), a) -> TupR ArrayR b -> ArraysR (((), a), b)
forall (s :: * -> *) a b. TupR s a -> TupR s b -> TupR s (a, b)
`TupRpair` ArrayR b -> TupR ArrayR b
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ArrayR b
b

-- | Creates a new, uninitialized Accelerate array.
--
allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e)
allocateArray :: ArrayR (Array sh e) -> sh -> IO (Array sh e)
allocateArray (ArrayR ShapeR sh
shR TypeR e
eR) sh
sh = do
  GArrayDataR UniqueArray e
adata  <- TypeR e -> Int -> IO (MutableArrayData e)
forall e.
HasCallStack =>
TupR ScalarType e -> Int -> IO (MutableArrayData e)
newArrayData TypeR e
eR (ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shR sh
sh
sh)
  Array sh e -> IO (Array sh e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Array sh e -> IO (Array sh e)) -> Array sh e -> IO (Array sh e)
forall a b. (a -> b) -> a -> b
$! sh -> GArrayDataR UniqueArray e -> Array sh e
forall sh e. sh -> ArrayData e -> Array sh e
Array sh
sh GArrayDataR UniqueArray e
adata

-- | Create an array from its representation function, applied at each
-- index of the array.
--
fromFunction :: ArrayR (Array sh e) -> sh -> (sh -> e) -> Array sh e
fromFunction :: ArrayR (Array sh e) -> sh -> (sh -> e) -> Array sh e
fromFunction ArrayR (Array sh e)
repr sh
sh sh -> e
f = IO (Array sh e) -> Array sh e
forall a. IO a -> a
unsafePerformIO (IO (Array sh e) -> Array sh e) -> IO (Array sh e) -> Array sh e
forall a b. (a -> b) -> a -> b
$! ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e)
forall sh e.
ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e)
fromFunctionM ArrayR (Array sh e)
repr sh
sh (e -> IO e
forall (m :: * -> *) a. Monad m => a -> m a
return (e -> IO e) -> (sh -> e) -> sh -> IO e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. sh -> e
f)

-- | Create an array using a monadic function applied at each index.
--
-- @since 1.2.0.0
--
fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e)
fromFunctionM :: ArrayR (Array sh e) -> sh -> (sh -> IO e) -> IO (Array sh e)
fromFunctionM (ArrayR ShapeR sh
shR TypeR e
eR) sh
sh sh -> IO e
f = do
  let !n :: Int
n = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shR sh
sh
sh
  GArrayDataR UniqueArray e
arr <- TypeR e -> Int -> IO (MutableArrayData e)
forall e.
HasCallStack =>
TupR ScalarType e -> Int -> IO (MutableArrayData e)
newArrayData TypeR e
eR Int
n
  --
  let write :: Int -> IO ()
write !Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n    = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Bool
otherwise = do
            e
v <- sh -> IO e
f (ShapeR sh -> sh -> Int -> sh
forall sh. HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex ShapeR sh
shR sh
sh
sh Int
i)
            TypeR e -> MutableArrayData e -> Int -> e -> IO ()
forall e.
TupR ScalarType e -> MutableArrayData e -> Int -> e -> IO ()
writeArrayData TypeR e
eR GArrayDataR UniqueArray e
MutableArrayData e
arr Int
i e
e
v
            Int -> IO ()
write (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
  --
  Int -> IO ()
write Int
0
  Array sh e -> IO (Array sh e)
forall (m :: * -> *) a. Monad m => a -> m a
return (Array sh e -> IO (Array sh e)) -> Array sh e -> IO (Array sh e)
forall a b. (a -> b) -> a -> b
$! GArrayDataR UniqueArray e
arr GArrayDataR UniqueArray e -> Array sh e -> Array sh e
`seq` sh -> GArrayDataR UniqueArray e -> Array sh e
forall sh e. sh -> ArrayData e -> Array sh e
Array sh
sh GArrayDataR UniqueArray e
arr


-- | Convert a list into an Accelerate 'Array' in dense row-major order.
--
fromList :: forall sh e. ArrayR (Array sh e) -> sh -> [e] -> Array sh e
fromList :: ArrayR (Array sh e) -> sh -> [e] -> Array sh e
fromList (ArrayR ShapeR sh
shR TypeR e
eR) sh
sh [e]
xs = GArrayDataR UniqueArray e
adata GArrayDataR UniqueArray e -> Array sh e -> Array sh e
`seq` sh -> GArrayDataR UniqueArray e -> Array sh e
forall sh e. sh -> ArrayData e -> Array sh e
Array sh
sh GArrayDataR UniqueArray e
adata
  where
    -- Assume the array is in dense row-major order. This is safe because
    -- otherwise backends would not be able to directly memcpy.
    --
    !n :: Int
n    = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shR sh
sh
sh
    (GArrayDataR UniqueArray e
adata, e
_) = IO (GArrayDataR UniqueArray e, e) -> (GArrayDataR UniqueArray e, e)
forall e. IO (MutableArrayData e, e) -> (MutableArrayData e, e)
runArrayData @e (IO (GArrayDataR UniqueArray e, e)
 -> (GArrayDataR UniqueArray e, e))
-> IO (GArrayDataR UniqueArray e, e)
-> (GArrayDataR UniqueArray e, e)
forall a b. (a -> b) -> a -> b
$ do
                  GArrayDataR UniqueArray e
arr <- TypeR e -> Int -> IO (MutableArrayData e)
forall e.
HasCallStack =>
TupR ScalarType e -> Int -> IO (MutableArrayData e)
newArrayData TypeR e
eR Int
n
                  let go :: Int -> [e] -> IO ()
go !Int
i [e]
_ | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                      go !Int
i (e
v:[e]
vs)     = TypeR e -> MutableArrayData e -> Int -> e -> IO ()
forall e.
TupR ScalarType e -> MutableArrayData e -> Int -> e -> IO ()
writeArrayData TypeR e
eR GArrayDataR UniqueArray e
MutableArrayData e
arr Int
i e
v IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> [e] -> IO ()
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [e]
vs
                      go Int
_  []         = String -> IO ()
forall a. HasCallStack => String -> a
error String
"Data.Array.Accelerate.fromList: not enough input data"
                  --
                  Int -> [e] -> IO ()
go Int
0 [e]
[e]
xs
                  (GArrayDataR UniqueArray e, e) -> IO (GArrayDataR UniqueArray e, e)
forall (m :: * -> *) a. Monad m => a -> m a
return (GArrayDataR UniqueArray e
arr, e
forall a. HasCallStack => a
undefined)


-- | Convert an accelerated 'Array' to a list in row-major order.
--
toList :: ArrayR (Array sh e) -> Array sh e -> [e]
toList :: ArrayR (Array sh e) -> Array sh e -> [e]
toList (ArrayR ShapeR sh
shR TypeR e
eR) (Array sh
sh ArrayData e
adata) = Int -> [e]
go Int
0
  where
    -- Assume underling array is in row-major order. This is safe because
    -- otherwise backends would not be able to directly memcpy.
    --
    !n :: Int
n                  = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shR sh
sh
sh
    go :: Int -> [e]
go !Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n      = []
          | Bool
otherwise   = TypeR e -> ArrayData e -> Int -> e
forall e. TupR ScalarType e -> ArrayData e -> Int -> e
indexArrayData TypeR e
eR ArrayData e
ArrayData e
adata Int
i e -> [e] -> [e]
forall a. a -> [a] -> [a]
: Int -> [e]
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)

concatVectors :: forall e. TypeR e -> [Vector e] -> Vector e
concatVectors :: TypeR e -> [Vector e] -> Vector e
concatVectors TypeR e
tR [Vector e]
vs = GArrayDataR UniqueArray e
adata GArrayDataR UniqueArray e -> Vector e -> Vector e
`seq` ((), Int) -> GArrayDataR UniqueArray e -> Vector e
forall sh e. sh -> ArrayData e -> Array sh e
Array ((), Int
len) GArrayDataR UniqueArray e
adata
  where
    offsets :: [Int]
offsets     = (Int -> Int -> Int) -> Int -> [Int] -> [Int]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0 ((Vector e -> Int) -> [Vector e] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ShapeR ((), Int) -> ((), Int) -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR ((), Int)
dim1 (((), Int) -> Int) -> (Vector e -> ((), Int)) -> Vector e -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector e -> ((), Int)
forall sh e. Array sh e -> sh
shape) [Vector e]
vs)
    len :: Int
len         = [Int] -> Int
forall a. [a] -> a
last [Int]
offsets
    (GArrayDataR UniqueArray e
adata, e
_)  = IO (GArrayDataR UniqueArray e, e) -> (GArrayDataR UniqueArray e, e)
forall e. IO (MutableArrayData e, e) -> (MutableArrayData e, e)
runArrayData @e (IO (GArrayDataR UniqueArray e, e)
 -> (GArrayDataR UniqueArray e, e))
-> IO (GArrayDataR UniqueArray e, e)
-> (GArrayDataR UniqueArray e, e)
forall a b. (a -> b) -> a -> b
$ do
      GArrayDataR UniqueArray e
arr <- TypeR e -> Int -> IO (GArrayDataR UniqueArray e)
forall e.
HasCallStack =>
TupR ScalarType e -> Int -> IO (MutableArrayData e)
newArrayData TypeR e
tR Int
len
      [IO ()] -> IO ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ TypeR e -> GArrayDataR UniqueArray e -> Int -> e -> IO ()
forall e.
TupR ScalarType e -> MutableArrayData e -> Int -> e -> IO ()
writeArrayData TypeR e
tR GArrayDataR UniqueArray e
arr (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) (TypeR e -> GArrayDataR UniqueArray e -> Int -> e
forall e. TupR ScalarType e -> ArrayData e -> Int -> e
indexArrayData TypeR e
tR GArrayDataR UniqueArray e
ad Int
i)
                | (Array ((), Int
n) GArrayDataR UniqueArray e
ad, Int
k) <- [Vector e]
vs [Vector e] -> [Int] -> [(Vector e, Int)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` [Int]
offsets
                , Int
i <- [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ]
      (GArrayDataR UniqueArray e, e) -> IO (GArrayDataR UniqueArray e, e)
forall (m :: * -> *) a. Monad m => a -> m a
return (GArrayDataR UniqueArray e
arr, e
forall a. HasCallStack => a
undefined)

shape :: Array sh e -> sh
shape :: Array sh e -> sh
shape (Array sh
sh ArrayData e
_) = sh
sh

reshape :: HasCallStack => ShapeR sh -> sh -> ShapeR sh' -> Array sh' e -> Array sh e
reshape :: ShapeR sh -> sh -> ShapeR sh' -> Array sh' e -> Array sh e
reshape ShapeR sh
shR sh
sh ShapeR sh'
shR' (Array sh'
sh' ArrayData e
adata)
  = String -> Bool -> Array sh e -> Array sh e
forall a. HasCallStack => String -> Bool -> a -> a
boundsCheck String
"shape mismatch" (ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shR sh
sh Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeR sh' -> sh' -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh'
shR' sh'
sh')
  (Array sh e -> Array sh e) -> Array sh e -> Array sh e
forall a b. (a -> b) -> a -> b
$ sh -> ArrayData e -> Array sh e
forall sh e. sh -> ArrayData e -> Array sh e
Array sh
sh ArrayData e
adata

(!) :: (ArrayR (Array sh e), Array sh e) -> sh -> e
(!) = (ArrayR (Array sh e) -> Array sh e -> sh -> e)
-> (ArrayR (Array sh e), Array sh e) -> sh -> e
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ArrayR (Array sh e) -> Array sh e -> sh -> e
forall sh e. ArrayR (Array sh e) -> Array sh e -> sh -> e
indexArray

(!!) :: (TypeR e, Array sh e) -> Int -> e
!! :: (TypeR e, Array sh e) -> Int -> e
(!!) = (TypeR e -> Array sh e -> Int -> e)
-> (TypeR e, Array sh e) -> Int -> e
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry TypeR e -> Array sh e -> Int -> e
forall e sh. TypeR e -> Array sh e -> Int -> e
linearIndexArray

indexArray :: ArrayR (Array sh e) -> Array sh e -> sh -> e
indexArray :: ArrayR (Array sh e) -> Array sh e -> sh -> e
indexArray (ArrayR ShapeR sh
shR TypeR e
adR) (Array sh
sh ArrayData e
adata) sh
ix = TypeR e -> ArrayData e -> Int -> e
forall e. TupR ScalarType e -> ArrayData e -> Int -> e
indexArrayData TypeR e
adR ArrayData e
ArrayData e
adata (ShapeR sh -> sh -> sh -> Int
forall sh. HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex ShapeR sh
shR sh
sh
sh sh
sh
ix)

linearIndexArray :: TypeR e -> Array sh e -> Int -> e
linearIndexArray :: TypeR e -> Array sh e -> Int -> e
linearIndexArray TypeR e
adR (Array sh
_ ArrayData e
adata) = TypeR e -> ArrayData e -> Int -> e
forall e. TupR ScalarType e -> ArrayData e -> Int -> e
indexArrayData TypeR e
adR ArrayData e
adata

showArray :: (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String
showArray :: (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String
showArray e -> ShowS
f arrR :: ArrayR (Array sh e)
arrR@(ArrayR ShapeR sh
shR TypeR e
_) arr :: Array sh e
arr@(Array sh
sh ArrayData e
_) = case ShapeR sh
shR of
  ShapeR sh
ShapeRz                         -> String
"Scalar Z "                       String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
list
  ShapeRsnoc ShapeR sh
ShapeRz              -> String
"Vector (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
shapeString String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
list
  ShapeRsnoc (ShapeRsnoc ShapeR sh
ShapeRz) -> String
"Matrix (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
shapeString String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") " String -> ShowS
forall a. [a] -> [a] -> [a]
++ (e -> ShowS) -> ArrayR (Array DIM2 e) -> Array DIM2 e -> String
forall e.
(e -> ShowS) -> ArrayR (Array DIM2 e) -> Array DIM2 e -> String
showMatrix e -> ShowS
f ArrayR (Array sh e)
ArrayR (Array DIM2 e)
arrR Array sh e
Array DIM2 e
arr
  ShapeR sh
_                               -> String
"Array ("  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
shapeString String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
list
  where
    shapeString :: String
shapeString = ShapeR sh -> sh -> String
forall sh. ShapeR sh -> sh -> String
showShape ShapeR sh
shR sh
sh
sh
    list :: String
list        = (e -> ShowS) -> [e] -> ShowS
forall a. (a -> ShowS) -> [a] -> ShowS
showListWith e -> ShowS
f (ArrayR (Array sh e) -> Array sh e -> [e]
forall sh e. ArrayR (Array sh e) -> Array sh e -> [e]
toList ArrayR (Array sh e)
arrR Array sh e
arr) String
""

showArrayShort :: Int -> (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String
showArrayShort :: Int -> (e -> ShowS) -> ArrayR (Array sh e) -> Array sh e -> String
showArrayShort Int
n e -> ShowS
f ArrayR (Array sh e)
arrR Array sh e
arr = Char
'[' Char -> ShowS
forall a. a -> [a] -> [a]
: Int -> [e] -> String
go Int
0 (ArrayR (Array sh e) -> Array sh e -> [e]
forall sh e. ArrayR (Array sh e) -> Array sh e -> [e]
toList ArrayR (Array sh e)
arrR Array sh e
arr)
  where
    go :: Int -> [e] -> String
go Int
_ []       = String
"]"
    go Int
i (e
x:[e]
xs)
      | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n    = String
" ..]"
      | Bool
otherwise = Char
',' Char -> ShowS
forall a. a -> [a] -> [a]
: e -> ShowS
f e
x (Int -> [e] -> String
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) [e]
xs)

-- TODO: Make special formatting optional? It is more difficult to
-- copy/paste the result, for example. Also it does not look good if the
-- matrix row does not fit on a single line.
--
showMatrix :: (e -> ShowS) -> ArrayR (Array DIM2 e) -> Array DIM2 e -> String
showMatrix :: (e -> ShowS) -> ArrayR (Array DIM2 e) -> Array DIM2 e -> String
showMatrix e -> ShowS
f (ArrayR ShapeR sh
_ TypeR e
arrR) arr :: Array DIM2 e
arr@(Array DIM2
sh ArrayData e
_)
  | Int
rows Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
cols Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = String
"[]"
  | Bool
otherwise        = String
"\n  [" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Int -> String
ppMat Int
0 Int
0
    where
      (((), Int
rows), Int
cols) = DIM2
sh
      lengths :: Vector Int
lengths            = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate (Int
rowsInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
cols) (\Int
i -> String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (e -> ShowS
f (TypeR e -> Array DIM2 e -> Int -> e
forall e sh. TypeR e -> Array sh e -> Int -> e
linearIndexArray TypeR e
arrR Array DIM2 e
Array DIM2 e
arr Int
i) String
""))
      widths :: Vector Int
widths             = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate Int
cols (\Int
c -> Vector Int -> Int
forall a. (Unbox a, Ord a) => Vector a -> a
U.maximum (Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
U.generate Int
rows (\Int
r -> Vector Int
lengths Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.! (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
colsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c))))
      --
      ppMat :: Int -> Int -> String
      ppMat :: Int -> Int -> String
ppMat !Int
r !Int
c | Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
cols = Int -> Int -> String
ppMat (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
0
      ppMat !Int
r !Int
c             =
        let
            !i :: Int
i    = Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
colsInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c
            !l :: Int
l    = Vector Int
lengths Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.! Int
i
            !w :: Int
w    = Vector Int
widths  Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.! Int
c
            !pad :: Int
pad  = Int
1
            cell :: String
cell  = Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
wInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
lInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
pad) Char
' ' String -> ShowS
forall a. [a] -> [a] -> [a]
++ e -> ShowS
f (TypeR e -> Array DIM2 e -> Int -> e
forall e sh. TypeR e -> Array sh e -> Int -> e
linearIndexArray TypeR e
arrR Array DIM2 e
Array DIM2 e
arr Int
i) String
""
            --
            before :: String
before
              | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = String
"\n   "
              | Bool
otherwise       = String
""
            --
            after :: String
after
              | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
rowsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 Bool -> Bool -> Bool
&& Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
colsInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 = String
"]"
              | Bool
otherwise                  = Char
',' Char -> ShowS
forall a. a -> [a] -> [a]
: Int -> Int -> String
ppMat Int
r (Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
        in
        String
before String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
cell String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
after

reduceRank :: ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank :: ArrayR (Array (sh, Int) e) -> ArrayR (Array sh e)
reduceRank (ArrayR (ShapeRsnoc ShapeR sh
shR) TypeR e
aeR) = ShapeR sh -> TypeR e -> ArrayR (Array sh e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shR TypeR e
aeR

rnfArray :: ArrayR a -> a -> ()
rnfArray :: ArrayR a -> a -> ()
rnfArray (ArrayR ShapeR sh
shR TypeR e
adR) (Array sh ad) = ShapeR sh -> sh -> ()
forall sh. ShapeR sh -> sh -> ()
rnfShape ShapeR sh
shR sh
sh () -> () -> ()
`seq` TypeR e -> ArrayData e -> ()
forall e. TupR ScalarType e -> ArrayData e -> ()
rnfArrayData TypeR e
adR ArrayData e
ad

rnfArrayR :: ArrayR arr -> ()
rnfArrayR :: ArrayR arr -> ()
rnfArrayR (ArrayR ShapeR sh
shR TypeR e
tR) = ShapeR sh -> ()
forall sh. ShapeR sh -> ()
rnfShapeR ShapeR sh
shR () -> () -> ()
`seq` (forall b. ScalarType b -> ()) -> TypeR e -> ()
forall (s :: * -> *) a. (forall b. s b -> ()) -> TupR s a -> ()
rnfTupR forall b. ScalarType b -> ()
rnfScalarType TypeR e
tR

rnfArraysR :: ArraysR arrs -> arrs -> ()
rnfArraysR :: ArraysR arrs -> arrs -> ()
rnfArraysR ArraysR arrs
TupRunit           ()      = ()
rnfArraysR (TupRsingle ArrayR arrs
arrR)  arrs
arr     = ArrayR arrs -> arrs -> ()
forall a. ArrayR a -> a -> ()
rnfArray ArrayR arrs
arrR arrs
arr
rnfArraysR (TupRpair TupR ArrayR a
aR1 TupR ArrayR b
aR2) (a1,a2) = TupR ArrayR a -> a -> ()
forall arrs. ArraysR arrs -> arrs -> ()
rnfArraysR TupR ArrayR a
aR1 a
a1 () -> () -> ()
`seq` TupR ArrayR b -> b -> ()
forall arrs. ArraysR arrs -> arrs -> ()
rnfArraysR TupR ArrayR b
aR2 b
a2

liftArrayR :: ArrayR a -> Q (TExp (ArrayR a))
liftArrayR :: ArrayR a -> Q (TExp (ArrayR a))
liftArrayR (ArrayR ShapeR sh
shR TypeR e
tR) = [|| ArrayR $$(liftShapeR shR) $$(liftTypeR tR) ||]

liftArraysR :: ArraysR arrs -> Q (TExp (ArraysR arrs))
liftArraysR :: ArraysR arrs -> Q (TExp (ArraysR arrs))
liftArraysR ArraysR arrs
TupRunit          = [|| TupRunit ||]
liftArraysR (TupRsingle ArrayR arrs
repr) = [|| TupRsingle $$(liftArrayR repr) ||]
liftArraysR (TupRpair TupR ArrayR a
a TupR ArrayR b
b)    = [|| TupRpair $$(liftArraysR a) $$(liftArraysR b) ||]

liftArray :: forall sh e. ArrayR (Array sh e) -> Array sh e -> Q (TExp (Array sh e))
liftArray :: ArrayR (Array sh e) -> Array sh e -> Q (TExp (Array sh e))
liftArray (ArrayR ShapeR sh
shR TypeR e
adR) (Array sh
sh ArrayData e
adata) =
  [|| Array $$(liftElt (shapeType shR) sh) $$(liftArrayData sz adR adata) ||] Q (TExp (Array sh e)) -> Q Type -> Q (TExp (Array sh e))
forall t. Q (TExp t) -> Q Type -> Q (TExp t)
`at` [t| Array $(liftTypeQ (shapeType shR)) $(liftTypeQ adR) |]
  where
    sz :: Int
    sz :: Int
sz = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shR sh
sh
sh

    at :: Q (TExp t) -> Q Type -> Q (TExp t)
    at :: Q (TExp t) -> Q Type -> Q (TExp t)
at Q (TExp t)
e Q Type
t = Q Exp -> Q (TExp t)
forall a. Q Exp -> Q (TExp a)
unsafeTExpCoerce (Q Exp -> Q (TExp t)) -> Q Exp -> Q (TExp t)
forall a b. (a -> b) -> a -> b
$ Q Exp -> Q Type -> Q Exp
sigE (Q (TExp t) -> Q Exp
forall a. Q (TExp a) -> Q Exp
unTypeQ Q (TExp t)
e) Q Type
t