-- Copyright 2020 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
-- | Arrays of dynamic size, but static rank.  The arrays are polymorphic in the underlying
-- linear data structure used to store the actual values.
module Data.Array.Internal.RankedG(
  Array(..), Vector, VecElem,
  size, shapeL, rank,
  toList, fromList, toVector, fromVector,
  normalize,
  scalar, unScalar, constant,
  reshape, stretch, stretchOuter, transpose,
  index, pad,
  mapA, zipWithA, zipWith3A,
  append, concatOuter,
  ravel, unravel,
  window, stride, rotate,
  slice, rerank, rerank2, rev,
  reduce, foldrA, traverseA,
  allSameA,
  sumA, productA, maximumA, minimumA,
  anyA, allA,
  broadcast,
  generate, iterateN, iota,
  ) where
import Control.Monad(replicateM)
import Control.DeepSeq
import Data.Data(Data)
import Data.List(sort)
import GHC.Generics(Generic)
import GHC.Stack
import GHC.TypeLits(Nat, type (+), KnownNat, type (<=))
import Test.QuickCheck hiding (generate)
import Text.PrettyPrint.HughesPJClass hiding ((<>))

import Data.Array.Internal

-- | Arrays stored in a /v/ with values of type /a/.
data Array (n :: Nat) v a = A !ShapeL !(T v a)
  deriving (forall (n :: Nat) (v :: * -> *) a x.
Rep (Array n v a) x -> Array n v a
forall (n :: Nat) (v :: * -> *) a x.
Array n v a -> Rep (Array n v a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall (n :: Nat) (v :: * -> *) a x.
Rep (Array n v a) x -> Array n v a
$cfrom :: forall (n :: Nat) (v :: * -> *) a x.
Array n v a -> Rep (Array n v a) x
Generic, Array n v a -> DataType
Array n v a -> Constr
forall {n :: Nat} {v :: * -> *} {a}.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Typeable (Array n v a)
forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> DataType
forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> Constr
forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> Array n v a -> Array n v a
forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> Array n v a -> u
forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> Array n v a -> [u]
forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
forall (n :: Nat) (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Array n v a))
forall (n :: Nat) (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a))
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
$cgmapMo :: forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
$cgmapMp :: forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
$cgmapM :: forall (n :: Nat) (v :: * -> *) a (m :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> Array n v a -> m (Array n v a)
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> Array n v a -> u
$cgmapQi :: forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> Array n v a -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> Array n v a -> [u]
$cgmapQ :: forall (n :: Nat) (v :: * -> *) a u.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> Array n v a -> [u]
gmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
$cgmapQr :: forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
gmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
$cgmapQl :: forall (n :: Nat) (v :: * -> *) a r r'.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> Array n v a -> r
gmapT :: (forall b. Data b => b -> b) -> Array n v a -> Array n v a
$cgmapT :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> Array n v a -> Array n v a
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a))
$cdataCast2 :: forall (n :: Nat) (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c (Array n v a))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (Array n v a))
$cdataCast1 :: forall (n :: Nat) (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (Array n v a))
dataTypeOf :: Array n v a -> DataType
$cdataTypeOf :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> DataType
toConstr :: Array n v a -> Constr
$ctoConstr :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
Array n v a -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
$cgunfold :: forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (Array n v a)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
$cgfoldl :: forall (n :: Nat) (v :: * -> *) a (c :: * -> *).
(KnownNat n, Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> Array n v a -> c (Array n v a)
Data)

instance (Vector v, Show a, VecElem v a) => Show (Array n v a) where
  showsPrec :: Int -> Array n v a -> ShowS
showsPrec Int
p a :: Array n v a
a@(A ShapeL
s T v a
_) = Bool -> ShowS -> ShowS
showParen (Int
p forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$
    String -> ShowS
showString String
"fromList " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 ShapeL
s forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ShowS
showString String
" " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 (forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> [a]
toList Array n v a
a)

instance (KnownNat n, Vector v, Read a, VecElem v a) => Read (Array n v a) where
  readsPrec :: Int -> ReadS (Array n v a)
readsPrec Int
p = forall a. Bool -> ReadS a -> ReadS a
readParen (Int
p forall a. Ord a => a -> a -> Bool
> Int
10) forall a b. (a -> b) -> a -> b
$ \ String
r1 ->
    [(forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> [a] -> Array n v a
fromList ShapeL
s [a]
xs, String
r4)
    | (String
"fromList", String
r2) <- ReadS String
lex String
r1, (ShapeL
s, String
r3) <- forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
r2
    , ([a]
xs, String
r4) <- forall a. Read a => Int -> ReadS a
readsPrec Int
11 String
r3, forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
s forall a. Eq a => a -> a -> Bool
== forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n, forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
s forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
xs]

instance (Vector v, Eq a, VecElem v a, Eq (v a)) => Eq (Array n v a) where
  (A ShapeL
s T v a
v) == :: Array n v a -> Array n v a -> Bool
== (A ShapeL
s' T v a
v') = ShapeL
s forall a. Eq a => a -> a -> Bool
== ShapeL
s' Bool -> Bool -> Bool
&& forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a, Eq (v a)) =>
ShapeL -> T v a -> T v a -> Bool
equalT ShapeL
s T v a
v T v a
v'
  {-# INLINE (==) #-}

instance (Vector v, Ord a, Ord (v a), VecElem v a) => Ord (Array n v a) where
  (A ShapeL
s T v a
v) compare :: Array n v a -> Array n v a -> Ordering
`compare` (A ShapeL
s' T v a
v') = forall a. Ord a => a -> a -> Ordering
compare ShapeL
s ShapeL
s' forall a. Semigroup a => a -> a -> a
<> forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a, Ord (v a)) =>
ShapeL -> T v a -> T v a -> Ordering
compareT ShapeL
s T v a
v T v a
v'
  {-# INLINE compare #-}

instance (Vector v, Pretty a, VecElem v a) => Pretty (Array n v a) where
  pPrintPrec :: PrettyLevel -> Rational -> Array n v a -> Doc
pPrintPrec PrettyLevel
l Rational
p (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Pretty a) =>
PrettyLevel -> Rational -> ShapeL -> T v a -> Doc
ppT PrettyLevel
l Rational
p ShapeL
sh T v a
t

instance (NFData (v a)) => NFData (Array n v a) where
  rnf :: Array n v a -> ()
rnf (A ShapeL
sh T v a
v) = forall a. NFData a => a -> ()
rnf ShapeL
sh seq :: forall a b. a -> b -> b
`seq` forall a. NFData a => a -> ()
rnf T v a
v

-- | The number of elements in the array.
-- O(1) time.
{-# INLINE size #-}
size :: Array n v a -> Int
size :: forall (n :: Nat) (v :: * -> *) a. Array n v a -> Int
size = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL

-- | The shape of an array, i.e., a list of the sizes of its dimensions.
-- In the linearization of the array the outermost (i.e. first list element)
-- varies most slowly.
-- O(1) time.
{-# INLINE shapeL #-}
shapeL :: Array n v a -> ShapeL
shapeL :: forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL (A ShapeL
s T v a
_) = ShapeL
s

-- | The rank of an array, i.e., the number of dimensions it has.
-- O(1) time.
{-# INLINE rank #-}
rank :: forall n v a . (KnownNat n) => Array n v a -> Int
rank :: forall (n :: Nat) (v :: * -> *) a. KnownNat n => Array n v a -> Int
rank (A ShapeL
_ T v a
_) = forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n

-- | Index into an array.  Fails if the array has rank 0 or if the index is out of bounds.
-- O(1) time.
{-# INLINE index #-}
index :: (Vector v, HasCallStack) => Array (1+n) v a -> Int -> Array n v a
index :: forall (v :: * -> *) (n :: Nat) a.
(Vector v, HasCallStack) =>
Array (1 + n) v a -> Int -> Array n v a
index (A (Int
s:ShapeL
ss) T v a
t) Int
i | Int
i forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i forall a. Ord a => a -> a -> Bool
>= Int
s = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"index: out of bounds " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
i, Int
s)
                     | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ss forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T v a
t Int
i
index (A [] T v a
_) Int
_ = forall a. HasCallStack => String -> a
error String
"index: scalar"

-- | Convert to a list with the elements in the linearization order.
-- O(n) time.
{-# INLINE toList #-}
toList :: (Vector v, VecElem v a) => Array n v a -> [a]
toList :: forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> [a]
toList (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> [a]
toListT ShapeL
sh T v a
t

-- | Convert to a vector with the elements in the linearization order.
-- O(n) or O(1) time (the latter if the vector is already in the linearization order).
{-# INLINE toVector #-}
toVector :: (Vector v, VecElem v a) => Array n v a -> v a
toVector :: forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> v a
toVectorT ShapeL
sh T v a
t

-- | Convert from a list with the elements given in the linearization order.
-- Fails if the given shape does not have the same number of elements as the list.
-- O(n) time.
{-# INLINE fromList #-}
fromList :: forall n v a . (HasCallStack, Vector v, VecElem v a, KnownNat n) =>
            ShapeL -> [a] -> Array n v a
fromList :: forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> [a] -> Array n v a
fromList ShapeL
ss [a]
vs | Int
n forall a. Eq a => a -> a -> Bool
/= Int
l = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"fromList: size mismatch " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
n, Int
l)
               | forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss forall a. Eq a => a -> a -> Bool
/= forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"fromList: rank mismatch " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss, forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
               | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ss forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList [a]
vs
  where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
ss
        l :: Int
l = forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
vs

-- | Convert from a vector with the elements given in the linearization order.
-- Fails if the given shape does not have the same number of elements as the list.
-- O(1) time.
{-# INLINE fromVector #-}
fromVector :: forall n v a . (HasCallStack, Vector v, VecElem v a, KnownNat n) =>
              ShapeL -> v a -> Array n v a
fromVector :: forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector ShapeL
ss v a
v | Int
n forall a. Eq a => a -> a -> Bool
/= Int
l = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"fromVector: size mismatch" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
n, Int
l)
                | forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss forall a. Eq a => a -> a -> Bool
/= forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"fromVector: rank mismatch " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ss, forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
                | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ss forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 v a
v
  where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
ss
        l :: Int
l = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v

-- | Make sure the underlying vector is in the linearization order.
-- This is semantically an identity function, but can have big performance
-- implications.
-- O(n) or O(1) time.
{-# INLINE normalize #-}
normalize :: (Vector v, VecElem v a, KnownNat n) => Array n v a -> Array n v a
normalize :: forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a, KnownNat n) =>
Array n v a -> Array n v a
normalize Array n v a
a = forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector (forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL Array n v a
a) forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector Array n v a
a

-- | Change the shape of an array.  Fails if the arrays have different number of elements.
-- O(n) or O(1) time.
{-# INLINE reshape #-}
reshape :: forall n n' v a . (HasCallStack,Vector v, VecElem v a, KnownNat n, KnownNat n') =>
           ShapeL -> Array n v a -> Array n' v a
reshape :: forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape ShapeL
sh (A ShapeL
sh' t :: T v a
t@(T ShapeL
ost Int
oo v a
v))
  | Int
n forall a. Eq a => a -> a -> Bool
/= Int
n' = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"reshape: size mismatch " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (ShapeL
sh, ShapeL
sh')
  | forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh forall a. Eq a => a -> a -> Bool
/= forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n' = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"reshape: rank mismatch " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh, forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
  | forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v forall a. Eq a => a -> a -> Bool
== Int
1 = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Int
0) ShapeL
sh) Int
0 v a
v  -- Fast special case for singleton vector
  | Just ShapeL
nst <- ShapeL -> ShapeL -> ShapeL -> Maybe ShapeL
simpleReshape ShapeL
ost ShapeL
sh' ShapeL
sh = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
nst Int
oo v a
v
  | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
st Int
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> T v a -> v a
toVectorT ShapeL
sh' T v a
t
  where Int
n : ShapeL
st = ShapeL -> ShapeL
getStridesT ShapeL
sh
        n' :: Int
n' = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
sh'

-- | Change the size of dimensions with size 1.  These dimension can be changed to any size.
-- All other dimensions must remain the same.
-- O(1) time.
{-# INLINE stretch #-}
stretch :: (HasCallStack) => ShapeL -> Array n v a -> Array n v a
stretch :: forall (n :: Nat) (v :: * -> *) a.
HasCallStack =>
ShapeL -> Array n v a -> Array n v a
stretch ShapeL
sh (A ShapeL
sh' T v a
vs) | Just [Bool]
bs <- forall {a}. (Eq a, Num a) => [a] -> [a] -> Maybe [Bool]
str ShapeL
sh ShapeL
sh' = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT [Bool]
bs T v a
vs
                      | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"stretch: incompatible " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (ShapeL
sh, ShapeL
sh')
  where str :: [a] -> [a] -> Maybe [Bool]
str [] [] = forall a. a -> Maybe a
Just []
        str (a
x:[a]
xs) (a
y:[a]
ys) | a
x forall a. Eq a => a -> a -> Bool
== a
y = (Bool
False forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [a] -> Maybe [Bool]
str [a]
xs [a]
ys
                          | a
y forall a. Eq a => a -> a -> Bool
== a
1 = (Bool
True  forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [a] -> [a] -> Maybe [Bool]
str [a]
xs [a]
ys
        str [a]
_ [a]
_ = forall a. Maybe a
Nothing

-- | Change the size of the outermost dimension by replication.
{-# INLINE stretchOuter #-}
stretchOuter :: (HasCallStack, 1 <= n) =>
                Int -> Array n v a -> Array n v a
stretchOuter :: forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, 1 <= n) =>
Int -> Array n v a -> Array n v a
stretchOuter Int
s (A (Int
1:ShapeL
sh) T v a
vs) =
  forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (Int
sforall a. a -> [a] -> [a]
:ShapeL
sh) forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT (Bool
True forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Bool
False) (forall (v :: * -> *) a. T v a -> ShapeL
strides T v a
vs)) T v a
vs
stretchOuter Int
_ Array n v a
_ = forall a. HasCallStack => String -> a
error String
"stretchOuter: needs outermost dimension of size 1"

-- | Convert a value to a scalar (rank 0) array.
-- O(1) time.
{-# INLINE scalar #-}
scalar :: (Vector v, VecElem v a) => a -> Array 0 v a
scalar :: forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> Array 0 v a
scalar = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT

-- | Convert a scalar (rank 0) array to a value.
-- O(1) time.
{-# INLINE unScalar #-}
unScalar :: (Vector v, VecElem v a) => Array 0 v a -> a
unScalar :: forall (v :: * -> *) a. (Vector v, VecElem v a) => Array 0 v a -> a
unScalar (A ShapeL
_ T v a
t) = forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT T v a
t

-- | Make an array with all elements having the same value.
-- O(1) time
{-# INLINE constant #-}
constant :: forall n v a . (Vector v, VecElem v a, KnownNat n) =>
            ShapeL -> a -> Array n v a
constant :: forall (n :: Nat) (v :: * -> *) a.
(Vector v, VecElem v a, KnownNat n) =>
ShapeL -> a -> Array n v a
constant ShapeL
sh | ShapeL -> Bool
badShape ShapeL
sh = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"constant: bad shape: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShapeL
sh
            | forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh forall a. Eq a => a -> a -> Bool
/= forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = forall a. HasCallStack => String -> a
error String
"constant: rank mismatch"
            | Bool
otherwise   = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> a -> T v a
constantT ShapeL
sh

-- | Map over the array elements.
-- O(n) time.
{-# INLINE mapA #-}
mapA :: (Vector v, VecElem v a, VecElem v b) =>
        (a -> b) -> Array n v a -> Array n v b
mapA :: forall (v :: * -> *) a b (n :: Nat).
(Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> Array n v a -> Array n v b
mapA a -> b
f (A ShapeL
s T v a
t) = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
s (forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
ShapeL -> (a -> b) -> T v a -> T v b
mapT ShapeL
s a -> b
f T v a
t)

-- | Map over the array elements.
-- O(n) time.
{-# INLINE zipWithA #-}
zipWithA :: (Vector v, VecElem v a, VecElem v b, VecElem v c) =>
            (a -> b -> c) -> Array n v a -> Array n v b -> Array n v c
zipWithA :: forall (v :: * -> *) a b c (n :: Nat).
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
(a -> b -> c) -> Array n v a -> Array n v b -> Array n v c
zipWithA a -> b -> c
f (A ShapeL
s T v a
t) (A ShapeL
s' T v b
t') | ShapeL
s forall a. Eq a => a -> a -> Bool
== ShapeL
s' = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
s (forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
ShapeL -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT ShapeL
s a -> b -> c
f T v a
t T v b
t')
                             | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"zipWithA: shape mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (ShapeL
s, ShapeL
s')

-- | Map over the array elements.
-- O(n) time.
{-# INLINE zipWith3A #-}
zipWith3A :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
             (a -> b -> c -> d) -> Array n v a -> Array n v b -> Array n v c -> Array n v d
zipWith3A :: forall (v :: * -> *) a b c d (n :: Nat).
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
(a -> b -> c -> d)
-> Array n v a -> Array n v b -> Array n v c -> Array n v d
zipWith3A a -> b -> c -> d
f (A ShapeL
s T v a
t) (A ShapeL
s' T v b
t') (A ShapeL
s'' T v c
t'') | ShapeL
s forall a. Eq a => a -> a -> Bool
== ShapeL
s' Bool -> Bool -> Bool
&& ShapeL
s forall a. Eq a => a -> a -> Bool
== ShapeL
s'' = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
s (forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
ShapeL -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T ShapeL
s a -> b -> c -> d
f T v a
t T v b
t' T v c
t'')
                                          | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"zipWith3A: shape mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (ShapeL
s, ShapeL
s', ShapeL
s'')

-- | Pad each dimension on the low and high side with the given value.
-- O(n) time.
{-# INLINE pad #-}
pad :: forall n a v . (Vector v, VecElem v a) =>
       [(Int, Int)] -> a -> Array n v a -> Array n v a
pad :: forall (n :: Nat) a (v :: * -> *).
(Vector v, VecElem v a) =>
[(Int, Int)] -> a -> Array n v a -> Array n v a
pad [(Int, Int)]
aps a
v (A ShapeL
ash T v a
at) = forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
a -> [(Int, Int)] -> ShapeL -> T v a -> (ShapeL, T v a)
padT a
v [(Int, Int)]
aps ShapeL
ash T v a
at

-- | Do an arbitrary array transposition.
-- Fails if the transposition argument is not a permutation of the numbers
-- [0..r-1], where r is the rank of the array.
-- O(1) time.
{-# INLINE transpose #-}
transpose :: forall n v a . (KnownNat n) =>
            [Int] -> Array n v a -> Array n v a
transpose :: forall (n :: Nat) (v :: * -> *) a.
KnownNat n =>
ShapeL -> Array n v a -> Array n v a
transpose ShapeL
is (A ShapeL
sh T v a
t) | Int
l forall a. Ord a => a -> a -> Bool
> Int
n = forall a. HasCallStack => String -> a
error String
"transpose: rank exceeded"
                      | forall a. Ord a => [a] -> [a]
sort ShapeL
is forall a. Eq a => a -> a -> Bool
/= [Int
0 .. Int
lforall a. Num a => a -> a -> a
-Int
1] =
                          forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"transpose: not a permutation: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ShapeL
is
                      | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (forall a. ShapeL -> [a] -> [a]
permute ShapeL
is' ShapeL
sh) (forall (v :: * -> *) a. ShapeL -> T v a -> T v a
transposeT ShapeL
is' T v a
t)
  where l :: Int
l = forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
is
        n :: Int
n = forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n
        is' :: ShapeL
is' = ShapeL
is forall a. [a] -> [a] -> [a]
++ [Int
l .. Int
nforall a. Num a => a -> a -> a
-Int
1]

-- | Append two arrays along the outermost dimension.
-- All dimensions, except the outermost, must be the same.
-- O(n) time.
{-# INLINE append #-}
append :: (Vector v, VecElem v a, KnownNat n) =>
          Array n v a -> Array n v a -> Array n v a
append :: forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a, KnownNat n) =>
Array n v a -> Array n v a -> Array n v a
append a :: Array n v a
a@(A (Int
sa:ShapeL
sh) T v a
_) b :: Array n v a
b@(A (Int
sb:ShapeL
sh') T v a
_) | ShapeL
sh forall a. Eq a => a -> a -> Bool
== ShapeL
sh' =
  forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector (Int
saforall a. Num a => a -> a -> a
+Int
sb forall a. a -> [a] -> [a]
: ShapeL
sh) (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
v a -> v a -> v a
vAppend (forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector Array n v a
a) (forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector Array n v a
b))
append Array n v a
_ Array n v a
_ = forall a. HasCallStack => String -> a
error String
"append: bad shape"

-- | Concatenate a number of arrays into a single array.
-- Fails if any, but the outer, dimensions differ.
-- O(n) time.
{-# INLINE concatOuter #-}
concatOuter :: (Vector v, VecElem v a, KnownNat n) => [Array n v a] -> Array n v a
concatOuter :: forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a, KnownNat n) =>
[Array n v a] -> Array n v a
concatOuter [] = forall a. HasCallStack => String -> a
error String
"concatOuter: empty list"
concatOuter [Array n v a]
as | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> Bool
allSame forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. [a] -> [a]
tail [ShapeL]
shs =
                 forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"concatOuter: non-conforming inner dimensions: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [ShapeL]
shs
               | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector ShapeL
sh' forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector [Array n v a]
as
  where shs :: [ShapeL]
shs@(ShapeL
sh:[ShapeL]
_) = forall a b. (a -> b) -> [a] -> [b]
map forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL [Array n v a]
as
        sh' :: ShapeL
sh' = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (forall a b. (a -> b) -> [a] -> [b]
map forall a. [a] -> a
head [ShapeL]
shs) forall a. a -> [a] -> [a]
: forall a. [a] -> [a]
tail ShapeL
sh

-- | Turn a rank-1 array of arrays into a single array by making the outer array into the outermost
-- dimension of the result array.  All the arrays must have the same shape.
-- O(n) time.
{-# INLINE ravel #-}
ravel :: (Vector v, Vector v', VecElem v a, VecElem v' (Array n v a), KnownNat (1+n)) =>
         Array 1 v' (Array n v a) -> Array (1+n) v a
ravel :: forall (v :: * -> *) (v' :: * -> *) a (n :: Nat).
(Vector v, Vector v', VecElem v a, VecElem v' (Array n v a),
 KnownNat (1 + n)) =>
Array 1 v' (Array n v a) -> Array (1 + n) v a
ravel Array 1 v' (Array n v a)
aa =
  case forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> [a]
toList Array 1 v' (Array n v a)
aa of
    [] -> forall a. HasCallStack => String -> a
error String
"ravel: empty array"
    [Array n v a]
as | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> Bool
allSame [ShapeL]
shs -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"ravel: non-conforming inner dimensions: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [ShapeL]
shs
       | Bool
otherwise -> forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector ShapeL
sh' forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector [Array n v a]
as
      where shs :: [ShapeL]
shs@(ShapeL
sh:[ShapeL]
_) = forall a b. (a -> b) -> [a] -> [b]
map forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL [Array n v a]
as
            sh' :: ShapeL
sh' = forall (t :: * -> *) a. Foldable t => t a -> Int
length [Array n v a]
as forall a. a -> [a] -> [a]
: ShapeL
sh

-- | Turn an array into a nested array, this is the inverse of 'ravel'.
-- I.e., @ravel . unravel == id@.
{-# INLINE unravel #-}
unravel :: (Vector v, Vector v', VecElem v a, VecElem v' (Array n v a)) =>
           Array (1+n) v a -> Array 1 v' (Array n v a)
unravel :: forall (v :: * -> *) (v' :: * -> *) a (n :: Nat).
(Vector v, Vector v', VecElem v a, VecElem v' (Array n v a)) =>
Array (1 + n) v a -> Array 1 v' (Array n v a)
unravel = forall (n :: Nat) (i :: Nat) (o :: Nat) (v :: * -> *)
       (v' :: * -> *) a b.
(Vector v, Vector v', VecElem v a, VecElem v' b, KnownNat n,
 KnownNat o, KnownNat (n + o), KnownNat (1 + o)) =>
(Array i v a -> Array o v' b)
-> Array (n + i) v a -> Array (n + o) v' b
rerank @1 forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> Array 0 v a
scalar

-- | Make a window of the outermost dimensions.
-- The rank increases with the length of the window list.
-- E.g., if the shape of the array is @[10,12,8]@ and
-- the window size is @[3,3]@ then the resulting array will have shape
-- @[8,10,3,3,8]@.
--
-- E.g., @window [2] (fromList [4] [1,2,3,4]) == fromList [3,2] [1,2, 2,3, 3,4]@
-- O(1) time.
--
-- If the window parameter @ws = [w1,...,wk]@ and @wa = window ws a@ then
-- @wa `index` i1 ... `index` ik == slice [(i1,w1),...,(ik,wk)] a@.
{-# INLINE window #-}
window :: forall n n' v a . (Vector v, KnownNat n, KnownNat n') =>
          [Int] -> Array n v a -> Array n' v a
window :: forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(Vector v, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
window ShapeL
aws Array n v a
_ | forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n' forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
aws forall a. Num a => a -> a -> a
+ forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"window: rank mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int, forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
aws, forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n' :: Int)
window ShapeL
aws (A ShapeL
ash (T ShapeL
ss Int
o v a
v)) = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (ShapeL -> ShapeL -> ShapeL
win ShapeL
aws ShapeL
ash) (forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T (ShapeL
ss' forall a. [a] -> [a] -> [a]
++ ShapeL
ss) Int
o v a
v)
  where ss' :: ShapeL
ss' = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a b. a -> b -> a
const ShapeL
ss ShapeL
aws
        win :: ShapeL -> ShapeL -> ShapeL
win (Int
w:ShapeL
ws) (Int
s:ShapeL
sh) | Int
w forall a. Ord a => a -> a -> Bool
<= Int
s = Int
s forall a. Num a => a -> a -> a
- Int
w forall a. Num a => a -> a -> a
+ Int
1 forall a. a -> [a] -> [a]
: ShapeL -> ShapeL -> ShapeL
win ShapeL
ws ShapeL
sh
                          | Bool
otherwise = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"window: bad window size : " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (Int
w, Int
s)
        win [] ShapeL
sh = ShapeL
aws forall a. [a] -> [a] -> [a]
++ ShapeL
sh
        win ShapeL
_ ShapeL
_ = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"window: rank mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (ShapeL
aws, ShapeL
ash)

-- | Stride the outermost dimensions.
-- E.g., if the array shape is @[10,12,8]@ and the strides are
-- @[2,2]@ then the resulting shape will be @[5,6,8]@.
-- O(1) time.
{-# INLINE stride #-}
stride :: (Vector v) => [Int] -> Array n v a -> Array n v a
stride :: forall (v :: * -> *) (n :: Nat) a.
Vector v =>
ShapeL -> Array n v a -> Array n v a
stride ShapeL
ats (A ShapeL
ash (T ShapeL
ss Int
o v a
v)) = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A (ShapeL -> ShapeL -> ShapeL
str ShapeL
ats ShapeL
ash) (forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) (ShapeL
ats forall a. [a] -> [a] -> [a]
++ forall a. a -> [a]
repeat Int
1) ShapeL
ss) Int
o v a
v)
  where str :: ShapeL -> ShapeL -> ShapeL
str (Int
t:ShapeL
ts) (Int
s:ShapeL
sh) = (Int
sforall a. Num a => a -> a -> a
+Int
tforall a. Num a => a -> a -> a
-Int
1) forall a. Integral a => a -> a -> a
`quot` Int
t forall a. a -> [a] -> [a]
: ShapeL -> ShapeL -> ShapeL
str ShapeL
ts ShapeL
sh
        str [] ShapeL
sh = ShapeL
sh
        str ShapeL
_ ShapeL
_ = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"stride: rank mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (ShapeL
ats, ShapeL
ash)

-- | Rotate the array k times along the d'th dimension.
-- E.g., if the array shape is @[2, 3, 2]@, d is 1, and k is 4,
-- the resulting shape will be @[2, 4, 3, 2]@.
rotate :: forall d p v a.
          (KnownNat p, KnownNat d,
          Vector v, VecElem v a,
          -- Nonsense
          (d + (p + 1)) ~ ((p + d) + 1),
          (d + p) ~ (p + d),
          1 <= p + 1,
          KnownNat ((p + d) + 1),
          KnownNat (p + 1),
          KnownNat (1 + (p + 1))
          ) =>
          Int -> Array (p + d) v a -> Array (p + d + 1) v a
rotate :: forall (d :: Nat) (p :: Nat) (v :: * -> *) a.
(KnownNat p, KnownNat d, Vector v, VecElem v a,
 (d + (p + 1)) ~ ((p + d) + 1), (d + p) ~ (p + d), 1 <= (p + 1),
 KnownNat ((p + d) + 1), KnownNat (p + 1),
 KnownNat (1 + (p + 1))) =>
Int -> Array (p + d) v a -> Array ((p + d) + 1) v a
rotate Int
k Array (p + d) v a
a = forall (n :: Nat) (i :: Nat) (o :: Nat) (v :: * -> *)
       (v' :: * -> *) a b.
(Vector v, Vector v', VecElem v a, VecElem v' b, KnownNat n,
 KnownNat o, KnownNat (n + o), KnownNat (1 + o)) =>
(Array i v a -> Array o v' b)
-> Array (n + i) v a -> Array (n + o) v' b
rerank @d @p @(p + 1) Array p v a -> Array (p + 1) v a
f Array (p + d) v a
a
 where
  f :: Array p v a -> Array (p + 1) v a
  f :: Array p v a -> Array (p + 1) v a
f Array p v a
arr = let Int
h:ShapeL
t = forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL Array p v a
arr
              m :: Int
m = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
t
              n :: Int
n = Int
h forall a. Num a => a -> a -> a
* Int
m
              arr' :: Array (p + 1) v a
arr' = forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape @p @(p + 1) (Int
1forall a. a -> [a] -> [a]
:Int
hforall a. a -> [a] -> [a]
:ShapeL
t) Array p v a
arr
              repeated :: Array (p + 1) v a
repeated = forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, 1 <= n) =>
Int -> Array n v a -> Array n v a
stretchOuter (Int
k forall a. Num a => a -> a -> a
+ Int
1) Array (p + 1) v a
arr'
              flattened :: Array 1 v a
flattened = forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape @(p + 1) @1 [(Int
k forall a. Num a => a -> a -> a
+ Int
1) forall a. Num a => a -> a -> a
* Int
n] Array (p + 1) v a
repeated
              batched :: Array 2 v a
batched = forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(Vector v, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
window @1 @2 [Int
n] Array 1 v a
flattened
              strided :: Array 2 v a
strided = forall (v :: * -> *) (n :: Nat) a.
Vector v =>
ShapeL -> Array n v a -> Array n v a
stride [Int
n forall a. Num a => a -> a -> a
+ Int
m] Array 2 v a
batched
          in forall (n :: Nat) (v :: * -> *) a.
ShapeL -> Array n v a -> Array n v a
rev [Int
0] (forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape (Int
kforall a. a -> [a] -> [a]
:Int
hforall a. a -> [a] -> [a]
:ShapeL
t) Array 2 v a
strided)

-- | Extract a slice of an array.
-- The first argument is a list of (offset, length) pairs.
-- The length of the slicing argument must not exceed the rank of the array.
-- The extracted slice must fall within the array dimensions.
-- E.g. @slice [1,2] (fromList [4] [1,2,3,4]) == [2,3]@.
-- O(1) time.
{-# INLINE slice #-}
slice :: [(Int, Int)] -> Array n v a -> Array n v a
slice :: forall (n :: Nat) (v :: * -> *) a.
[(Int, Int)] -> Array n v a -> Array n v a
slice [(Int, Int)]
asl (A ShapeL
ash (T ShapeL
ats Int
ao v a
v)) = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
rsh (forall (v :: * -> *) a. ShapeL -> Int -> v a -> T v a
T ShapeL
ats Int
o v a
v)
  where (Int
o, ShapeL
rsh) = [(Int, Int)] -> ShapeL -> ShapeL -> (Int, ShapeL)
slc [(Int, Int)]
asl ShapeL
ash ShapeL
ats
        slc :: [(Int, Int)] -> ShapeL -> ShapeL -> (Int, ShapeL)
slc ((Int
k,Int
n):[(Int, Int)]
sl) (Int
s:ShapeL
sh) (Int
t:ShapeL
ts) | Int
k forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
k forall a. Ord a => a -> a -> Bool
> Int
s Bool -> Bool -> Bool
|| Int
kforall a. Num a => a -> a -> a
+Int
n forall a. Ord a => a -> a -> Bool
> Int
s = forall a. HasCallStack => String -> a
error String
"slice: out of bounds"
                                     | Bool
otherwise = (Int
i forall a. Num a => a -> a -> a
+ Int
kforall a. Num a => a -> a -> a
*Int
t, Int
nforall a. a -> [a] -> [a]
:ShapeL
ns) where (Int
i, ShapeL
ns) = [(Int, Int)] -> ShapeL -> ShapeL -> (Int, ShapeL)
slc [(Int, Int)]
sl ShapeL
sh ShapeL
ts
        slc [] ShapeL
sh ShapeL
_ = (Int
ao, ShapeL
sh)
        slc [(Int, Int)]
_ ShapeL
_ ShapeL
_ = forall a. HasCallStack => String -> a
error String
"impossible"

-- | Apply a function to the subarrays /n/ levels down and make
-- the results into an array with the same /n/ outermost dimensions.
-- The /n/ must not exceed the rank of the array.
-- O(n) time.
{-# INLINE rerank #-}
rerank :: forall n i o v v' a b .
          (Vector v, Vector v', VecElem v a, VecElem v' b
          , KnownNat n, KnownNat o, KnownNat (n+o), KnownNat (1+o)) =>
          (Array i v a -> Array o v' b) -> Array (n+i) v a -> Array (n+o) v' b
rerank :: forall (n :: Nat) (i :: Nat) (o :: Nat) (v :: * -> *)
       (v' :: * -> *) a b.
(Vector v, Vector v', VecElem v a, VecElem v' b, KnownNat n,
 KnownNat o, KnownNat (n + o), KnownNat (1 + o)) =>
(Array i v a -> Array o v' b)
-> Array (n + i) v a -> Array (n + o) v' b
rerank Array i v a -> Array o v' b
f (A ShapeL
sh T v a
t) =
  forall (v :: * -> *) a (m :: Nat) (n :: Nat).
(Vector v, VecElem v a, KnownNat m) =>
ShapeL -> [Array n v a] -> Array m v a
ravelOuter ShapeL
osh forall a b. (a -> b) -> a -> b
$
  forall a b. (a -> b) -> [a] -> [b]
map (Array i v a -> Array o v' b
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ish) forall a b. (a -> b) -> a -> b
$
  forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v a
t
  where (ShapeL
osh, ShapeL
ish) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n) ShapeL
sh

{-# INLINABLE ravelOuter #-}
ravelOuter :: (Vector v, VecElem v a, KnownNat m) => ShapeL -> [Array n v a] -> Array m v a
ravelOuter :: forall (v :: * -> *) a (m :: Nat) (n :: Nat).
(Vector v, VecElem v a, KnownNat m) =>
ShapeL -> [Array n v a] -> Array m v a
ravelOuter ShapeL
_ [] = forall a. HasCallStack => String -> a
error String
"ravelOuter: empty list"
ravelOuter ShapeL
osh [Array n v a]
as | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. Eq a => [a] -> Bool
allSame [ShapeL]
shs = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"ravelOuter: non-conforming inner dimensions: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show [ShapeL]
shs
                  | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> v a -> Array n v a
fromVector ShapeL
sh' forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
Array n v a -> v a
toVector [Array n v a]
as
  where shs :: [ShapeL]
shs@(ShapeL
sh:[ShapeL]
_) = forall a b. (a -> b) -> [a] -> [b]
map forall (n :: Nat) (v :: * -> *) a. Array n v a -> ShapeL
shapeL [Array n v a]
as
        sh' :: ShapeL
sh' = ShapeL
osh forall a. [a] -> [a] -> [a]
++ ShapeL
sh

-- | Apply a two-argument function to the subarrays /n/ levels down and make
-- the results into an array with the same /n/ outermost dimensions.
-- The /n/ must not exceed the rank of the array.
-- O(n) time.
{-# INLINE rerank2 #-}
rerank2 :: forall n i o a b c v .
           (Vector v, VecElem v a, VecElem v b, VecElem v c,
            KnownNat n, KnownNat o, KnownNat (n+o), KnownNat (1+o)) =>
           (Array i v a -> Array i v b -> Array o v c) -> Array (n+i) v a -> Array (n+i) v b -> Array (n+o) v c
rerank2 :: forall (n :: Nat) (i :: Nat) (o :: Nat) a b c (v :: * -> *).
(Vector v, VecElem v a, VecElem v b, VecElem v c, KnownNat n,
 KnownNat o, KnownNat (n + o), KnownNat (1 + o)) =>
(Array i v a -> Array i v b -> Array o v c)
-> Array (n + i) v a -> Array (n + i) v b -> Array (n + o) v c
rerank2 Array i v a -> Array i v b -> Array o v c
f (A ShapeL
sha T v a
ta) (A ShapeL
shb T v b
tb) | forall a. Int -> [a] -> [a]
take Int
n ShapeL
sha forall a. Eq a => a -> a -> Bool
/= forall a. Int -> [a] -> [a]
take Int
n ShapeL
shb = forall a. HasCallStack => String -> a
error String
"rerank2: shape mismatch"
                                | Bool
otherwise =
  forall (v :: * -> *) a (m :: Nat) (n :: Nat).
(Vector v, VecElem v a, KnownNat m) =>
ShapeL -> [Array n v a] -> Array m v a
ravelOuter ShapeL
osh forall a b. (a -> b) -> a -> b
$
  forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\ T v a
a T v b
b -> Array i v a -> Array i v b -> Array o v c
f (forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
isha T v a
a) (forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
ishb T v b
b))
          (forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v a
ta)
          (forall (v :: * -> *) a. ShapeL -> T v a -> [T v a]
subArraysT ShapeL
osh T v b
tb)
  where (ShapeL
osh, ShapeL
isha) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
n ShapeL
sha
        ishb :: ShapeL
ishb = forall a. Int -> [a] -> [a]
drop Int
n ShapeL
shb
        n :: Int
n = forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n

-- | Reverse the given dimensions, with the outermost being dimension 0.
-- O(1) time.
{-# INLINE rev #-}
rev :: [Int] -> Array n v a -> Array n v a
rev :: forall (n :: Nat) (v :: * -> *) a.
ShapeL -> Array n v a -> Array n v a
rev ShapeL
rs (A ShapeL
sh T v a
t) | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\ Int
r -> Int
r forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
&& Int
r forall a. Ord a => a -> a -> Bool
< Int
n) ShapeL
rs = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh (forall (v :: * -> *) a. ShapeL -> ShapeL -> T v a -> T v a
reverseT ShapeL
rs ShapeL
sh T v a
t)
                | Bool
otherwise = forall a. HasCallStack => String -> a
error String
"reverse: bad reverse dimension"
  where n :: Int
n = forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh

-- | Reduce all elements of an array into a rank 0 array.
-- To reduce parts use 'rerank' and 'transpose' together with 'reduce'.
-- O(n) time.
{-# INLINE reduce #-}
reduce :: (Vector v, VecElem v a) =>
          (a -> a -> a) -> a -> Array n v a -> Array 0 v a
reduce :: forall (v :: * -> *) a (n :: Nat).
(Vector v, VecElem v a) =>
(a -> a -> a) -> a -> Array n v a -> Array 0 v a
reduce a -> a -> a
f a
z (A ShapeL
sh T v a
t) = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [] forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
reduceT ShapeL
sh a -> a -> a
f a
z T v a
t

-- | Right fold across all elements of an array.
{-# INLINE foldrA #-}
foldrA :: (Vector v, VecElem v a) => (a -> b -> b) -> b -> Array n v a -> b
foldrA :: forall (v :: * -> *) a b (n :: Nat).
(Vector v, VecElem v a) =>
(a -> b -> b) -> b -> Array n v a -> b
foldrA a -> b -> b
f b
z (A ShapeL
sh T v a
t) = forall (v :: * -> *) a b.
(Vector v, VecElem v a) =>
ShapeL -> (a -> b -> b) -> b -> T v a -> b
foldrT ShapeL
sh a -> b -> b
f b
z T v a
t

-- | Constrained version of 'traverse' for 'Array's.
{-# INLINE traverseA #-}
traverseA
  :: (Vector v, VecElem v a, VecElem v b, Applicative f)
  => (a -> f b) -> Array n v a -> f (Array n v b)
traverseA :: forall (v :: * -> *) a b (f :: * -> *) (n :: Nat).
(Vector v, VecElem v a, VecElem v b, Applicative f) =>
(a -> f b) -> Array n v a -> f (Array n v b)
traverseA a -> f b
f (A ShapeL
sh T v a
t) = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (v :: * -> *) a b (f :: * -> *).
(Vector v, VecElem v a, VecElem v b, Applicative f) =>
ShapeL -> (a -> f b) -> T v a -> f (T v b)
traverseT ShapeL
sh a -> f b
f T v a
t

-- | Check if all elements of the array are equal.
allSameA :: (Vector v, VecElem v a, Eq a) => Array r v a -> Bool
allSameA :: forall (v :: * -> *) a (r :: Nat).
(Vector v, VecElem v a, Eq a) =>
Array r v a -> Bool
allSameA (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a) =>
ShapeL -> T v a -> Bool
allSameT ShapeL
sh T v a
t

instance (KnownNat r, Vector v, VecElem v a, Arbitrary a) => Arbitrary (Array r v a) where
  arbitrary :: Gen (Array r v a)
arbitrary = do
    -- Don't generate huge number of elements
    ShapeL
ss <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @r) (forall a. Small a -> a
getSmall forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Positive a -> a
getPositive forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => Gen a
arbitrary) forall a. Gen a -> (a -> Bool) -> Gen a
`suchThat` ((forall a. Ord a => a -> a -> Bool
< Int
10000) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product)
    forall (n :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n) =>
ShapeL -> [a] -> Array n v a
fromList ShapeL
ss forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => Int -> Gen [a]
vector (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ShapeL
ss)

-- | Sum of all elements.
{-# INLINE sumA #-}
sumA :: (Vector v, VecElem v a, Num a) => Array r v a -> a
sumA :: forall (v :: * -> *) a (r :: Nat).
(Vector v, VecElem v a, Num a) =>
Array r v a -> a
sumA (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
ShapeL -> T v a -> a
sumT ShapeL
sh T v a
t

-- | Product of all elements.
{-# INLINE productA #-}
productA :: (Vector v, VecElem v a, Num a) => Array r v a -> a
productA :: forall (v :: * -> *) a (r :: Nat).
(Vector v, VecElem v a, Num a) =>
Array r v a -> a
productA (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
ShapeL -> T v a -> a
productT ShapeL
sh T v a
t

-- | Maximum of all elements.
{-# INLINE maximumA #-}
maximumA :: (HasCallStack, Vector v, VecElem v a, Ord a) => Array r v a -> a
maximumA :: forall (v :: * -> *) a (r :: Nat).
(HasCallStack, Vector v, VecElem v a, Ord a) =>
Array r v a -> a
maximumA a :: Array r v a
a@(A ShapeL
sh T v a
t) | forall (n :: Nat) (v :: * -> *) a. Array n v a -> Int
size Array r v a
a forall a. Ord a => a -> a -> Bool
> Int
0 = forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
ShapeL -> T v a -> a
maximumT ShapeL
sh T v a
t
                    | Bool
otherwise  = forall a. HasCallStack => String -> a
error String
"maximumA called with empty array"

-- | Minimum of all elements.
{-# INLINE minimumA #-}
minimumA :: (HasCallStack, Vector v, VecElem v a, Ord a) => Array r v a -> a
minimumA :: forall (v :: * -> *) a (r :: Nat).
(HasCallStack, Vector v, VecElem v a, Ord a) =>
Array r v a -> a
minimumA a :: Array r v a
a@(A ShapeL
sh T v a
t) | forall (n :: Nat) (v :: * -> *) a. Array n v a -> Int
size Array r v a
a forall a. Ord a => a -> a -> Bool
> Int
0 = forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
ShapeL -> T v a -> a
minimumT ShapeL
sh T v a
t
                    | Bool
otherwise  = forall a. HasCallStack => String -> a
error String
"minimumA called with empty array"

-- | Test if the predicate holds for any element.
{-# INLINE anyA #-}
anyA :: (Vector v, VecElem v a) => (a -> Bool) -> Array r v a -> Bool
anyA :: forall (v :: * -> *) a (r :: Nat).
(Vector v, VecElem v a) =>
(a -> Bool) -> Array r v a -> Bool
anyA a -> Bool
p (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> Bool) -> T v a -> Bool
anyT ShapeL
sh a -> Bool
p T v a
t

-- | Test if the predicate holds for all elements.
{-# INLINE allA #-}
allA :: (Vector v, VecElem v a) => (a -> Bool) -> Array r v a -> Bool
allA :: forall (v :: * -> *) a (r :: Nat).
(Vector v, VecElem v a) =>
(a -> Bool) -> Array r v a -> Bool
allA a -> Bool
p (A ShapeL
sh T v a
t) = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (a -> Bool) -> T v a -> Bool
allT ShapeL
sh a -> Bool
p T v a
t

-- | Put the dimensions of the argument into the specified dimensions,
-- and just replicate the data along all other dimensions.
-- The list of dimensions indicies must have the same rank as the argument array
-- and it must be strictly ascending.
broadcast :: forall r' r v a .
             (HasCallStack, Vector v, VecElem v a, KnownNat r, KnownNat r') =>
             [Int] -> ShapeL -> Array r v a -> Array r' v a
broadcast :: forall (r' :: Nat) (r :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat r, KnownNat r') =>
ShapeL -> ShapeL -> Array r v a -> Array r' v a
broadcast ShapeL
ds ShapeL
sh Array r v a
a | forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
ds forall a. Eq a => a -> a -> Bool
/= forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @r = forall a. HasCallStack => String -> a
error String
"broadcast: wrong number of broadcasts"
                  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\ Int
d -> Int
d forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
d forall a. Ord a => a -> a -> Bool
>= Int
r) ShapeL
ds = forall a. HasCallStack => String -> a
error String
"broadcast: bad dimension"
                  | Bool -> Bool
not (forall {a}. Ord a => [a] -> Bool
ascending ShapeL
ds) = forall a. HasCallStack => String -> a
error String
"broadcast: unordered dimensions"
                  | forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh forall a. Eq a => a -> a -> Bool
/= Int
r = forall a. HasCallStack => String -> a
error String
"broadcast: wrong rank"
                  | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a.
HasCallStack =>
ShapeL -> Array n v a -> Array n v a
stretch ShapeL
sh forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (n' :: Nat) (v :: * -> *) a.
(HasCallStack, Vector v, VecElem v a, KnownNat n, KnownNat n') =>
ShapeL -> Array n v a -> Array n' v a
reshape ShapeL
rsh Array r v a
a
  where r :: Int
r = forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @r'
        rsh :: ShapeL
rsh = [ if Int
i forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ShapeL
ds then Int
s else Int
1 | (Int
i, Int
s) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ShapeL
sh ]
        ascending :: [a] -> Bool
ascending (a
x:a
y:[a]
ys) = a
x forall a. Ord a => a -> a -> Bool
< a
y Bool -> Bool -> Bool
&& [a] -> Bool
ascending (a
yforall a. a -> [a] -> [a]
:[a]
ys)
        ascending [a]
_ = Bool
True

-- | Generate an array with a function that computes the value for each index.
{-# INLINE generate #-}
generate :: forall n v a .
            (KnownNat n, Vector v, VecElem v a) =>
            ShapeL -> ([Int] -> a) -> Array n v a
generate :: forall (n :: Nat) (v :: * -> *) a.
(KnownNat n, Vector v, VecElem v a) =>
ShapeL -> (ShapeL -> a) -> Array n v a
generate ShapeL
sh | forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh forall a. Eq a => a -> a -> Bool
/= forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"generate: rank mismatch " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeL
sh, forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf @n :: Int)
            | Bool
otherwise = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A ShapeL
sh forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
ShapeL -> (ShapeL -> a) -> T v a
generateT ShapeL
sh

-- | Iterate a function n times.
{-# INLINE iterateN #-}
iterateN :: forall v a .
            (Vector v, VecElem v a) =>
            Int -> (a -> a) -> a -> Array 1 v a
iterateN :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> (a -> a) -> a -> Array 1 v a
iterateN Int
n a -> a
f = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [Int
n] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> (a -> a) -> a -> T v a
iterateNT Int
n a -> a
f

-- | Generate a vector from 0 to n-1.
{-# INLINE iota #-}
iota :: forall v a .
        (Vector v, VecElem v a, Enum a, Num a) =>
        Int -> Array 1 v a
iota :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Enum a, Num a) =>
Int -> Array 1 v a
iota Int
n = forall (n :: Nat) (v :: * -> *) a. ShapeL -> T v a -> Array n v a
A [Int
n] forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a, Enum a, Num a) =>
Int -> T v a
iotaT Int
n