-- Copyright 2020 Google LLC
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--      http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

{-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UndecidableSuperClasses #-}
module Data.Array.Internal(module Data.Array.Internal) where
import Control.DeepSeq
import Data.Data(Data)
import qualified Data.DList as DL
import Data.Kind (Type)
import Data.List(foldl', zipWith4, zipWith5, sortBy, sortOn, foldl1')
import Data.Proxy
import GHC.Exts(Constraint, build)
import GHC.Generics(Generic)
import GHC.TypeLits(KnownNat, natVal)
import Text.PrettyPrint
import Text.PrettyPrint.HughesPJClass

{- HLINT ignore "Reduce duplication" -}

-- The underlying storage of values must be an instance of Vector.
-- For some types, like unboxed vectors, we require an extra
-- constraint on the elements, which VecElem allows you to express.
-- For vector types that don't need the constraint it can be set
-- to some dummy class.
-- | The 'Vector' class is the interface to the underlying storage for the arrays.
-- The operations map straight to operations for 'Vector'.
class Vector v where
  type VecElem v :: Type -> Constraint
  vIndex    :: (VecElem v a) => v a -> Int -> a
  vLength   :: (VecElem v a) => v a -> Int
  vToList   :: (VecElem v a) => v a -> [a]
  vFromList :: (VecElem v a) => [a] -> v a
  vSingleton:: (VecElem v a) => a -> v a
  vReplicate:: (VecElem v a) => Int -> a -> v a
  vMap      :: (VecElem v a, VecElem v b) => (a -> b) -> v a -> v b
  vZipWith  :: (VecElem v a, VecElem v b, VecElem v c) => (a -> b -> c) -> v a -> v b -> v c
  vZipWith3 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d) => (a -> b -> c -> d) -> v a -> v b -> v c -> v d
  vZipWith4 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e) => (a -> b -> c -> d -> e) -> v a -> v b -> v c -> v d -> v e
  vZipWith5 :: (VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e, VecElem v f) => (a -> b -> c -> d -> e -> f) -> v a -> v b -> v c -> v d -> v e -> v f
  vAppend   :: (VecElem v a) => v a -> v a -> v a
  vConcat   :: (VecElem v a) => [v a] -> v a
  vFold     :: (VecElem v a) => (a -> a -> a) -> a -> v a -> a
  vSlice    :: (VecElem v a) => Int -> Int -> v a -> v a
  vSum      :: (VecElem v a, Num a) => v a -> a
  vProduct  :: (VecElem v a, Num a) => v a -> a
  vMaximum  :: (VecElem v a, Ord a) => v a -> a
  vMinimum  :: (VecElem v a, Ord a) => v a -> a
  vUpdate   :: (VecElem v a) => v a -> [(Int, a)] -> v a
  vGenerate :: (VecElem v a) => Int -> (Int -> a) -> v a
  vAll      :: (VecElem v a) => (a -> Bool) -> v a -> Bool
  vAny      :: (VecElem v a) => (a -> Bool) -> v a -> Bool

class None a
instance None a

-- This instance is not used anywheer.  It serves more as a reference semantics.
instance Vector [] where
  type VecElem [] = None
  vIndex :: forall a. VecElem [] a => [a] -> Int -> a
vIndex = forall a. [a] -> Int -> a
(!!)
  vLength :: forall a. VecElem [] a => [a] -> Int
vLength = forall (t :: * -> *) a. Foldable t => t a -> Int
length
  vToList :: forall a. VecElem [] a => [a] -> [a]
vToList = forall a. a -> a
id
  vFromList :: forall a. VecElem [] a => [a] -> [a]
vFromList = forall a. a -> a
id
  vSingleton :: forall a. VecElem [] a => a -> [a]
vSingleton = forall (f :: * -> *) a. Applicative f => a -> f a
pure
  vReplicate :: forall a. VecElem [] a => Int -> a -> [a]
vReplicate = forall a. Int -> a -> [a]
replicate
  vMap :: forall a b. (VecElem [] a, VecElem [] b) => (a -> b) -> [a] -> [b]
vMap = forall a b. (a -> b) -> [a] -> [b]
map
  vZipWith :: forall a b c.
(VecElem [] a, VecElem [] b, VecElem [] c) =>
(a -> b -> c) -> [a] -> [b] -> [c]
vZipWith = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
  vZipWith3 :: forall a b c d.
(VecElem [] a, VecElem [] b, VecElem [] c, VecElem [] d) =>
(a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
vZipWith3 = forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
  vZipWith4 :: forall a b c d e.
(VecElem [] a, VecElem [] b, VecElem [] c, VecElem [] d,
 VecElem [] e) =>
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
vZipWith4 = forall a b c d e.
(a -> b -> c -> d -> e) -> [a] -> [b] -> [c] -> [d] -> [e]
zipWith4
  vZipWith5 :: forall a b c d e f.
(VecElem [] a, VecElem [] b, VecElem [] c, VecElem [] d,
 VecElem [] e, VecElem [] f) =>
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
vZipWith5 = forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5
  vAppend :: forall a. VecElem [] a => [a] -> [a] -> [a]
vAppend = forall a. [a] -> [a] -> [a]
(++)
  vConcat :: forall a. VecElem [] a => [[a]] -> [a]
vConcat = forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
  vFold :: forall a. VecElem [] a => (a -> a -> a) -> a -> [a] -> a
vFold = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl'
  vSlice :: forall a. VecElem [] a => Int -> Int -> [a] -> [a]
vSlice Int
o Int
n = forall a. Int -> [a] -> [a]
take Int
n forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
drop Int
o
  vSum :: forall a. (VecElem [] a, Num a) => [a] -> a
vSum = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
  vProduct :: forall a. (VecElem [] a, Num a) => [a] -> a
vProduct = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product
  vMaximum :: forall a. (VecElem [] a, Ord a) => [a] -> a
vMaximum = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum
  vMinimum :: forall a. (VecElem [] a, Ord a) => [a] -> a
vMinimum = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum
  vUpdate :: forall a. VecElem [] a => [a] -> [(Int, a)] -> [a]
vUpdate [a]
xs [(Int, a)]
us = forall {t} {a}. (Ord t, Num t) => [a] -> [(t, a)] -> t -> [a]
loop [a]
xs (forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn forall a b. (a, b) -> a
fst [(Int, a)]
us) Int
0
    where
      loop :: [a] -> [(t, a)] -> t -> [a]
loop [] [] t
_ = []
      loop [] ((t, a)
_:[(t, a)]
_) t
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"vUpdate: out of bounds"
      loop [a]
as [] t
_ = [a]
as
      loop (a
a:[a]
as) ias :: [(t, a)]
ias@((t
i,a
a'):[(t, a)]
ias') t
n =
        case forall a. Ord a => a -> a -> Ordering
compare t
i t
n of
          Ordering
LT -> forall a. HasCallStack => [Char] -> a
error [Char]
"vUpdate: bad index"
          Ordering
EQ -> a
a' forall a. a -> [a] -> [a]
: [a] -> [(t, a)] -> t -> [a]
loop [a]
as [(t, a)]
ias' (t
nforall a. Num a => a -> a -> a
+t
1)
          Ordering
GT -> a
a  forall a. a -> [a] -> [a]
: [a] -> [(t, a)] -> t -> [a]
loop [a]
as [(t, a)]
ias  (t
nforall a. Num a => a -> a -> a
+t
1)
  vGenerate :: forall a. VecElem [] a => Int -> (Int -> a) -> [a]
vGenerate Int
n Int -> a
f = forall a b. (a -> b) -> [a] -> [b]
map Int -> a
f [Int
0 .. Int
nforall a. Num a => a -> a -> a
-Int
1]
  vAll :: forall a. VecElem [] a => (a -> Bool) -> [a] -> Bool
vAll = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
  vAny :: forall a. VecElem [] a => (a -> Bool) -> [a] -> Bool
vAny = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any

prettyShowL :: (Pretty a) => PrettyLevel -> a -> String
prettyShowL :: forall a. Pretty a => PrettyLevel -> a -> [Char]
prettyShowL PrettyLevel
l = Doc -> [Char]
render forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => PrettyLevel -> Rational -> a -> Doc
pPrintPrec PrettyLevel
l Rational
0

-- | The type /T/ is the internal type of arrays.  In general,
-- operations on /T/ do no sanity checking as that should be done
-- at the point of call.
--
-- To avoid manipulating the data the indexing into the vector containing
-- the data is somewhat complex.  To find where item /i/ of the outermost
-- dimension starts you calculate vector index @offset + i*strides[0]@.
-- To find where item /i,j/ of the two outermost dimensions is you
-- calculate vector index @offset + i*strides[0] + j*strides[1]@, etc.
data T v a = T
    { forall (v :: * -> *) a. T v a -> [Int]
strides :: ![Int]   -- length is tensor rank
    , forall (v :: * -> *) a. T v a -> Int
offset  :: !Int     -- offset into vector of values
    , forall (v :: * -> *) a. T v a -> v a
values  :: !(v a)   -- actual values
    }
    deriving (Int -> T v a -> ShowS
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
forall (v :: * -> *) a. Show (v a) => Int -> T v a -> ShowS
forall (v :: * -> *) a. Show (v a) => [T v a] -> ShowS
forall (v :: * -> *) a. Show (v a) => T v a -> [Char]
showList :: [T v a] -> ShowS
$cshowList :: forall (v :: * -> *) a. Show (v a) => [T v a] -> ShowS
show :: T v a -> [Char]
$cshow :: forall (v :: * -> *) a. Show (v a) => T v a -> [Char]
showsPrec :: Int -> T v a -> ShowS
$cshowsPrec :: forall (v :: * -> *) a. Show (v a) => Int -> T v a -> ShowS
Show, forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall (v :: * -> *) a x. Rep (T v a) x -> T v a
forall (v :: * -> *) a x. T v a -> Rep (T v a) x
$cto :: forall (v :: * -> *) a x. Rep (T v a) x -> T v a
$cfrom :: forall (v :: * -> *) a x. T v a -> Rep (T v a) x
Generic, T v a -> DataType
T v a -> Constr
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
forall {v :: * -> *} {a}.
(Typeable v, Typeable a, Data (v a)) =>
Typeable (T v a)
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> DataType
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> Constr
forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> T v a -> T v a
forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> T v a -> u
forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> T v a -> [u]
forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
forall (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
forall (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
gmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapMo :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapMp :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), MonadPlus m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
$cgmapM :: forall (v :: * -> *) a (m :: * -> *).
(Typeable v, Typeable a, Data (v a), Monad m) =>
(forall d. Data d => d -> m d) -> T v a -> m (T v a)
gmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> T v a -> u
$cgmapQi :: forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
Int -> (forall d. Data d => d -> u) -> T v a -> u
gmapQ :: forall u. (forall d. Data d => d -> u) -> T v a -> [u]
$cgmapQ :: forall (v :: * -> *) a u.
(Typeable v, Typeable a, Data (v a)) =>
(forall d. Data d => d -> u) -> T v a -> [u]
gmapQr :: forall r r'.
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
$cgmapQr :: forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
gmapQl :: forall r r'.
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
$cgmapQl :: forall (v :: * -> *) a r r'.
(Typeable v, Typeable a, Data (v a)) =>
(r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> T v a -> r
gmapT :: (forall b. Data b => b -> b) -> T v a -> T v a
$cgmapT :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
(forall b. Data b => b -> b) -> T v a -> T v a
dataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
$cdataCast2 :: forall (v :: * -> *) a (t :: * -> * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c (T v a))
dataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
$cdataCast1 :: forall (v :: * -> *) a (t :: * -> *) (c :: * -> *).
(Typeable v, Typeable a, Data (v a), Typeable t) =>
(forall d. Data d => c (t d)) -> Maybe (c (T v a))
dataTypeOf :: T v a -> DataType
$cdataTypeOf :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> DataType
toConstr :: T v a -> Constr
$ctoConstr :: forall (v :: * -> *) a.
(Typeable v, Typeable a, Data (v a)) =>
T v a -> Constr
gunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
$cgunfold :: forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c (T v a)
gfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
$cgfoldl :: forall (v :: * -> *) a (c :: * -> *).
(Typeable v, Typeable a, Data (v a)) =>
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> T v a -> c (T v a)
Data)

instance NFData (v a) => NFData (T v a)

-- | The shape of an array is a list of its dimensions.
type ShapeL = [Int]

badShape :: ShapeL -> Bool
badShape :: [Int] -> Bool
badShape = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Ord a => a -> a -> Bool
< Int
0)

-- When shapes match, we can be efficient and use loop-fused comparisons instead
-- of materializing a vector.
-- Note this assumes the shape is the same for both Vectors.
-- TODO(augustss): if the array is a small fraction of the vector this can be inefficient.
{-# INLINABLE equalT #-}
equalT :: (Vector v, VecElem v a, Eq a, Eq (v a))
                  => ShapeL -> T v a -> T v a -> Bool
equalT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a, Eq (v a)) =>
[Int] -> T v a -> T v a -> Bool
equalT [Int]
s T v a
x T v a
y | forall (v :: * -> *) a. T v a -> [Int]
strides T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a. T v a -> [Int]
strides T v a
y
               Bool -> Bool -> Bool
&& forall (v :: * -> *) a. T v a -> Int
offset T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a. T v a -> Int
offset T v a
y
               Bool -> Bool -> Bool
&& forall (v :: * -> *) a. T v a -> v a
values T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a. T v a -> v a
values T v a
y = Bool
True
             | Bool
otherwise = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
x forall a. Eq a => a -> a -> Bool
== forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
y

-- Note this assumes the shape is the same for both Vectors.
{-# INLINABLE compareT #-}
compareT :: (Vector v, VecElem v a, Ord a, Ord (v a))
            => ShapeL -> T v a -> T v a -> Ordering
compareT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a, Ord (v a)) =>
[Int] -> T v a -> T v a -> Ordering
compareT [Int]
s T v a
x T v a
y = forall a. Ord a => a -> a -> Ordering
compare (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
x) (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
s T v a
y)

-- Given the dimensions, return the stride in the underlying vector
-- for each dimension.  The first element of the list is the total length.
{-# INLINE getStridesT #-}
getStridesT :: ShapeL -> [Int]
getStridesT :: [Int] -> [Int]
getStridesT = forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr forall a. Num a => a -> a -> a
(*) Int
1

-- Convert an array to a list by indexing through all the elements.
-- The first argument is the array shape.
-- XXX Copy special cases from Tensor.
{-# INLINE toListT #-}
toListT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [a]
toListT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh a :: T v a
a@(T [Int]
ss0 Int
o0 v a
v)
  | forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> Bool
isCanonicalT ([Int] -> [Int]
getStridesT [Int]
sh) T v a
a = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> [a]
vToList v a
v
  | Bool
otherwise = forall a. (forall b. (a -> b -> b) -> b -> b) -> [a]
build forall a b. (a -> b) -> a -> b
$ \a -> b -> b
cons b
nil ->
      -- TODO: because unScalarT uses vIndex, this has unnecessary bounds
      -- checks.  We should expose an unchecked indexing function in the Vector
      -- class, add top-level bounds checks to cover the full range we'll
      -- access, and then do all accesses with the unchecked version.
      let go :: [Int] -> [Int] -> Int -> b -> b
go []     [Int]
ss Int
o b
rest = a -> b -> b
cons (forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT (forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
o v a
v)) b
rest
          go (Int
n:[Int]
ns) [Int]
ss Int
o b
rest = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr
            (\Int
i -> case forall (v :: * -> *) a. T v a -> Int -> T v a
indexT (forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
o v a
v) Int
i of T [Int]
ss' Int
o' v a
_ -> [Int] -> [Int] -> Int -> b -> b
go [Int]
ns [Int]
ss' Int
o')
            b
rest
            [Int
0..Int
nforall a. Num a => a -> a -> a
-Int
1]
      in  [Int] -> [Int] -> Int -> b -> b
go [Int]
sh [Int]
ss0 Int
o0 b
nil

-- | Check if the strides are canonical, i.e., if the vector have the natural layout.
-- XXX Copy special cases from Tensor.
{-# INLINE isCanonicalT #-}
isCanonicalT :: (Vector v, VecElem v a) => [Int] -> T v a -> Bool
isCanonicalT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> Bool
isCanonicalT (Int
n:[Int]
ss') (T [Int]
ss Int
o v a
v) =
    Int
o forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&&         -- Vector offset is 0
    [Int]
ss forall a. Eq a => a -> a -> Bool
== [Int]
ss' Bool -> Bool -> Bool
&&      -- All strides are normal
    forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v forall a. Eq a => a -> a -> Bool
== Int
n    -- The vector is the right size
isCanonicalT [Int]
_ T v a
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

-- Convert a value to a scalar array.
{-# INLINE scalarT #-}
scalarT :: (Vector v, VecElem v a) => a -> T v a
scalarT :: forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [] Int
0 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton

-- Convert a scalar array to the actual value.
{-# INLINE unScalarT #-}
unScalarT :: (Vector v, VecElem v a) => T v a -> a
unScalarT :: forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT (T [Int]
_ Int
o v a
v) = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
o

-- Make a constant array.
{-# INLINE constantT #-}
constantT :: (Vector v, VecElem v a) => ShapeL -> a -> T v a
constantT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> a -> T v a
constantT [Int]
sh a
x = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T (forall a b. (a -> b) -> [a] -> [b]
map (forall a b. a -> b -> a
const Int
0) [Int]
sh) Int
0 (forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton a
x)

-- TODO: change to return a list of vectors.
-- Convert an array to a vector in the natural order.
{-# INLINE toVectorT #-}
toVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
toVectorT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh a :: T v a
a@(T [Int]
ats Int
ao v a
v) =
  let Int
l : [Int]
ts' = [Int] -> [Int]
getStridesT [Int]
sh
      -- Are strides ok from this point?
      oks :: [Bool]
oks = forall a b. (a -> b -> b) -> b -> [a] -> [b]
scanr Bool -> Bool -> Bool
(&&) Bool
True (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(==) [Int]
ats [Int]
ts')
      loop :: [Bool] -> [Int] -> [Int] -> Int -> DList (v a)
loop [Bool]
_ [] [Int]
_ Int
o =
        forall a. a -> DList a
DL.singleton (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> Int -> v a -> v a
vSlice Int
o Int
1 v a
v)
      loop (Bool
b:[Bool]
bs) (Int
s:[Int]
ss) (Int
t:[Int]
ts) Int
o =
        if Bool
b then
          -- All strides normal from this point,
          -- so just take a slice of the underlying vector.
          forall a. a -> DList a
DL.singleton (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> Int -> v a -> v a
vSlice Int
o (Int
sforall a. Num a => a -> a -> a
*Int
t) v a
v)
        else
          -- Strides are not normal, collect slices.
          forall a. [DList a] -> DList a
DL.concat [ [Bool] -> [Int] -> [Int] -> Int -> DList (v a)
loop [Bool]
bs [Int]
ss [Int]
ts (Int
iforall a. Num a => a -> a -> a
*Int
t forall a. Num a => a -> a -> a
+ Int
o) | Int
i <- [Int
0 .. Int
sforall a. Num a => a -> a -> a
-Int
1] ]
      loop [Bool]
_ [Int]
_ [Int]
_ Int
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"
  in  if forall a. [a] -> a
head [Bool]
oks Bool -> Bool -> Bool
&& forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v forall a. Eq a => a -> a -> Bool
== Int
l then
        -- All strides are normal, return entire vector
        v a
v
      else if [Bool]
oks forall a. [a] -> Int -> a
!! forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
sh then  -- Special case for speed.
        -- Innermost dimension is normal, so slices are non-trivial.
        forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$ forall a. DList a -> [a]
DL.toList forall a b. (a -> b) -> a -> b
$ [Bool] -> [Int] -> [Int] -> Int -> DList (v a)
loop [Bool]
oks [Int]
sh [Int]
ats Int
ao
      else
        -- All slices would have length 1, going via a list is faster.
        forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
a

-- Convert to a vector containing the right elements,
-- but not necessarily in the right order.
-- This is used for reduction with commutative&associative operations.
{-# INLINE toUnorderedVectorT #-}
toUnorderedVectorT :: (Vector v, VecElem v a) => ShapeL -> T v a -> v a
toUnorderedVectorT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh a :: T v a
a@(T [Int]
ats Int
ao v a
v) =
  -- Figure out if the array maps onto some contiguous slice of the vector.
  -- Do this by checking if a transposition of the array corresponds to
  -- normal strides.
  -- First sort the strides in descending order, and rearrange the shape the same way.
  -- Then compute the strides from this rearranged shape; these will be the normal
  -- strides for this shape.  If these strides agree with the sorted actual strides
  -- it is a transposition, and we can just slice out the relevant piece of the vector.
  let
    ([Int]
ats', [Int]
sh') = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Ord a => a -> a -> Ordering
compare) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
ats [Int]
sh
    Int
l : [Int]
ts' = [Int] -> [Int]
getStridesT [Int]
sh'
  in
      if [Int]
ats' forall a. Eq a => a -> a -> Bool
== [Int]
ts' then
        forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> Int -> v a -> v a
vSlice Int
ao Int
l v a
v
      else
        forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
a

-- Convert from a vector.
{-# INLINE fromVectorT #-}
fromVectorT :: ShapeL -> v a -> T v a
fromVectorT :: forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T (forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
getStridesT [Int]
sh) Int
0

-- Convert from a list
{-# INLINE fromListT #-}
fromListT :: (Vector v, VecElem v a) => [Int] -> [a] -> T v a
fromListT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> [a] -> T v a
fromListT [Int]
sh = forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. (Vector v, VecElem v a) => [a] -> v a
vFromList

-- Index into the outermost dimension of an array.
{-# INLINE indexT #-}
indexT :: T v a -> Int -> T v a
indexT :: forall (v :: * -> *) a. T v a -> Int -> T v a
indexT (T (Int
s : [Int]
ss) Int
o v a
v) Int
i = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss (Int
o forall a. Num a => a -> a -> a
+ Int
i forall a. Num a => a -> a -> a
* Int
s) v a
v
indexT T v a
_ Int
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"

-- Stretch the given dimensions to have arbitrary size.
-- The stretched dimensions must have size 1, and stretching is
-- done by setting the stride to 0.
{-# INLINE stretchT #-}
stretchT :: [Bool] -> T v a -> T v a
stretchT :: forall (v :: * -> *) a. [Bool] -> T v a -> T v a
stretchT [Bool]
bs (T [Int]
ss Int
o v a
v) = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\ Bool
b Int
s -> if Bool
b then Int
0 else Int
s) [Bool]
bs [Int]
ss) Int
o v a
v

-- Map over the array elements.
{-# INLINE mapT #-}
mapT :: (Vector v, VecElem v a, VecElem v b) => ShapeL -> (a -> b) -> T v a -> T v b
mapT :: forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[Int] -> (a -> b) -> T v a -> T v b
mapT [Int]
sh a -> b
f (T [Int]
ss Int
o v a
v) | forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Int]
sh forall a. Ord a => a -> a -> Bool
>= forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
o (forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> v a -> v b
vMap a -> b
f v a
v)
mapT [Int]
sh a -> b
f T v a
t = forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
(a -> b) -> v a -> v b
vMap a -> b
f forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t

-- Zip two arrays with a function.
{-# INLINE zipWithT #-}
zipWithT :: (Vector v, VecElem v a, VecElem v b, VecElem v c) =>
            ShapeL -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT :: forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
[Int] -> (a -> b -> c) -> T v a -> T v b -> T v c
zipWithT [Int]
sh a -> b -> c
f t :: T v a
t@(T [Int]
ss Int
_ v a
v) t' :: T v b
t'@(T [Int]
_ Int
_ v b
v') =
  case (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v, forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v b
v') of
    (Int
1, Int
1) ->
      -- If both vectors have length 1, then it's a degenerate case and it's better
      -- to operate on the single element directly.
      forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton forall a b. (a -> b) -> a -> b
$ a -> b -> c
f (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
0) (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v b
v' Int
0)
    (Int
1, Int
_) ->
      -- First vector has length 1, so use a map instead.
      forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[Int] -> (a -> b) -> T v a -> T v b
mapT [Int]
sh (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
0 a -> b -> c
`f` ) T v b
t'
    (Int
_, Int
1) ->
      -- Second vector has length 1, so use a map instead.
      forall (v :: * -> *) a b.
(Vector v, VecElem v a, VecElem v b) =>
[Int] -> (a -> b) -> T v a -> T v b
mapT [Int]
sh (a -> b -> c
`f` forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v b
v' Int
0) T v a
t
    (Int
_, Int
_) ->
      let cv :: v a
cv  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
          cv' :: v b
cv' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
      in  forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c.
(Vector v, VecElem v a, VecElem v b, VecElem v c) =>
(a -> b -> c) -> v a -> v b -> v c
vZipWith a -> b -> c
f v a
cv v b
cv'

-- Zip three arrays with a function.
{-# INLINE zipWith3T #-}
zipWith3T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
             ShapeL -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T :: forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
[Int] -> (a -> b -> c -> d) -> T v a -> T v b -> T v c -> T v d
zipWith3T [Int]
_ a -> b -> c -> d
f (T [Int]
ss Int
_ v a
v) (T [Int]
_ Int
_ v b
v') (T [Int]
_ Int
_ v c
v'') |
  -- If all vectors have length 1, then it's a degenerate case and it's better
  -- to operate on the single element directly.
  forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v forall a. Eq a => a -> a -> Bool
== Int
1, forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v b
v' forall a. Eq a => a -> a -> Bool
== Int
1, forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v c
v'' forall a. Eq a => a -> a -> Bool
== Int
1 =
    forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> v a
vSingleton forall a b. (a -> b) -> a -> b
$ a -> b -> c -> d
f (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v Int
0) (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v b
v' Int
0) (forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v c
v'' Int
0)
zipWith3T [Int]
sh a -> b -> c -> d
f T v a
t T v b
t' T v c
t'' = forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c d.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d) =>
(a -> b -> c -> d) -> v a -> v b -> v c -> v d
vZipWith3 a -> b -> c -> d
f v a
v v b
v' v c
v''
  where v :: v a
v   = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        v' :: v b
v'  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
        v'' :: v c
v'' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v c
t''

-- Zip four arrays with a function.
{-# INLINE zipWith4T #-}
zipWith4T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e) => ShapeL -> (a -> b -> c -> d -> e) -> T v a -> T v b -> T v c -> T v d -> T v e
zipWith4T :: forall (v :: * -> *) a b c d e.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e) =>
[Int]
-> (a -> b -> c -> d -> e)
-> T v a
-> T v b
-> T v c
-> T v d
-> T v e
zipWith4T [Int]
sh a -> b -> c -> d -> e
f T v a
t T v b
t' T v c
t'' T v d
t''' = forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c d e.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e) =>
(a -> b -> c -> d -> e) -> v a -> v b -> v c -> v d -> v e
vZipWith4 a -> b -> c -> d -> e
f v a
v v b
v' v c
v'' v d
v'''
  where v :: v a
v   = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        v' :: v b
v'  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
        v'' :: v c
v'' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v c
t''
        v''' :: v d
v'''= forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v d
t'''

-- Zip five arrays with a function.
{-# INLINE zipWith5T #-}
zipWith5T :: (Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d, VecElem v e, VecElem v f) => ShapeL -> (a -> b -> c -> d -> e -> f) -> T v a -> T v b -> T v c -> T v d -> T v e -> T v f
zipWith5T :: forall (v :: * -> *) a b c d e f.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e, VecElem v f) =>
[Int]
-> (a -> b -> c -> d -> e -> f)
-> T v a
-> T v b
-> T v c
-> T v d
-> T v e
-> T v f
zipWith5T [Int]
sh a -> b -> c -> d -> e -> f
f T v a
t T v b
t' T v c
t'' T v d
t''' T v e
t'''' = forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
sh forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a b c d e f.
(Vector v, VecElem v a, VecElem v b, VecElem v c, VecElem v d,
 VecElem v e, VecElem v f) =>
(a -> b -> c -> d -> e -> f)
-> v a -> v b -> v c -> v d -> v e -> v f
vZipWith5 a -> b -> c -> d -> e -> f
f v a
v v b
v' v c
v'' v d
v''' v e
v''''
  where v :: v a
v   = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        v' :: v b
v'  = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v b
t'
        v'' :: v c
v'' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v c
t''
        v''' :: v d
v'''= forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v d
t'''
        v'''' :: v e
v''''= forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v e
t''''

-- Do an arbitrary transposition.  The first argument should be
-- a permutation of the dimension, i.e., the numbers [0..r-1] in some order
-- (where r is the rank of the array).
{-# INLINE transposeT #-}
transposeT :: [Int] -> T v a -> T v a
transposeT :: forall (v :: * -> *) a. [Int] -> T v a -> T v a
transposeT [Int]
is (T [Int]
ss Int
o v a
v) = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T (forall a. [Int] -> [a] -> [a]
permute [Int]
is [Int]
ss) Int
o v a
v

-- Return all subarrays n dimensions down.
-- The shape argument should be a prefix of the array shape.
{-# INLINE subArraysT #-}
subArraysT :: ShapeL -> T v a -> [T v a]
subArraysT :: forall (v :: * -> *) a. [Int] -> T v a -> [T v a]
subArraysT [Int]
sh T v a
ten = forall {v :: * -> *} {a}. [Int] -> T v a -> [T v a] -> [T v a]
sub [Int]
sh T v a
ten []
  where sub :: [Int] -> T v a -> [T v a] -> [T v a]
sub [] T v a
t = (T v a
t forall a. a -> [a] -> [a]
:)
        sub (Int
n:[Int]
ns) T v a
t = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall b c a. (b -> c) -> (a -> b) -> a -> c
(.) forall a. a -> a
id [[Int] -> T v a -> [T v a] -> [T v a]
sub [Int]
ns (forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T v a
t Int
i) | Int
i <- [Int
0..Int
nforall a. Num a => a -> a -> a
-Int
1]]

-- Reverse the given dimensions.
{-# INLINE reverseT #-}
reverseT :: [Int] -> ShapeL -> T v a -> T v a
reverseT :: forall (v :: * -> *) a. [Int] -> [Int] -> T v a -> T v a
reverseT [Int]
rs [Int]
sh (T [Int]
ats Int
ao v a
v) = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
rts Int
ro v a
v
  where (Int
ro, [Int]
rts) = Int -> [Int] -> [Int] -> (Int, [Int])
rev Int
0 [Int]
sh [Int]
ats
        rev :: Int -> [Int] -> [Int] -> (Int, [Int])
rev !Int
_ [] [] = (Int
ao, [])
        rev Int
r (Int
m:[Int]
ms) (Int
t:[Int]
ts) | Int
r forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
rs = (Int
o forall a. Num a => a -> a -> a
+ (Int
mforall a. Num a => a -> a -> a
-Int
1)forall a. Num a => a -> a -> a
*Int
t, -Int
t forall a. a -> [a] -> [a]
: [Int]
ts')
                            | Bool
otherwise   = (Int
o,            Int
t forall a. a -> [a] -> [a]
: [Int]
ts')
          where (Int
o, [Int]
ts') = Int -> [Int] -> [Int] -> (Int, [Int])
rev (Int
rforall a. Num a => a -> a -> a
+Int
1) [Int]
ms [Int]
ts
        rev Int
_ [Int]
_ [Int]
_ = forall a. HasCallStack => [Char] -> a
error [Char]
"reverseT: impossible"

-- Reduction of all array elements.
{-# INLINE reduceT #-}
reduceT :: (Vector v, VecElem v a) =>
           ShapeL -> (a -> a -> a) -> a -> T v a -> T v a
reduceT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> (a -> a -> a) -> a -> T v a -> T v a
reduceT [Int]
sh a -> a -> a
f a
z = forall (v :: * -> *) a. (Vector v, VecElem v a) => a -> T v a
scalarT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> a -> a) -> a -> v a -> a
vFold a -> a -> a
f a
z forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh

-- Right fold via toListT.
{-# INLINE foldrT #-}
foldrT
  :: (Vector v, VecElem v a) => ShapeL -> (a -> b -> b) -> b -> T v a -> b
foldrT :: forall (v :: * -> *) a b.
(Vector v, VecElem v a) =>
[Int] -> (a -> b -> b) -> b -> T v a -> b
foldrT [Int]
sh a -> b -> b
f b
z T v a
a = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr a -> b -> b
f b
z (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
a)

-- Traversal via toListT/fromListT.
{-# INLINE traverseT #-}
traverseT
  :: (Vector v, VecElem v a, VecElem v b, Applicative f)
  => ShapeL -> (a -> f b) -> T v a -> f (T v b)
traverseT :: forall (v :: * -> *) a b (f :: * -> *).
(Vector v, VecElem v a, VecElem v b, Applicative f) =>
[Int] -> (a -> f b) -> T v a -> f (T v b)
traverseT [Int]
sh a -> f b
f T v a
a = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> [a] -> T v a
fromListT [Int]
sh) (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
a))

-- Fast check if all elements are equal.
allSameT :: (Vector v, VecElem v a, Eq a) => ShapeL -> T v a -> Bool
allSameT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Eq a) =>
[Int] -> T v a -> Bool
allSameT [Int]
sh t :: T v a
t@(T [Int]
_ Int
_ v a
v)
  | forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int
vLength v a
v forall a. Ord a => a -> a -> Bool
<= Int
1 = Bool
True
  | Bool
otherwise =
    let !v' :: v a
v' = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t
        !x :: a
x = forall (v :: * -> *) a. (Vector v, VecElem v a) => v a -> Int -> a
vIndex v a
v' Int
0
    in  forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAll (a
x forall a. Eq a => a -> a -> Bool
==) v a
v'

newtype Rect = Rect { Rect -> [[Char]]
unRect :: [String] }  -- A rectangle of text

toRect :: String -> Rect
toRect :: [Char] -> Rect
toRect = [[Char]] -> Rect
Rect forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [[Char]]
lines

fromRect :: Rect -> String
fromRect :: Rect -> [Char]
fromRect (Rect [[Char]]
ls) = [[Char]] -> [Char]
unlines [[Char]]
ls

-- Make each Rect be of size h * w
rectPad :: Int -> Int -> Rect -> Rect
rectPad :: Int -> Int -> Rect -> Rect
rectPad Int
h Int
w (Rect [[Char]]
ls) = [[Char]] -> Rect
Rect forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ShowS
padL [[Char]]
ls forall a. [a] -> [a] -> [a]
++ forall a. Int -> a -> [a]
replicate (Int
h forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [[Char]]
ls) [Char]
mt
  where mt :: [Char]
mt = forall a. Int -> a -> [a]
replicate Int
w Char
' '
        padL :: ShowS
padL [Char]
s = forall a. Int -> a -> [a]
replicate (Int
w forall a. Num a => a -> a -> a
- forall (t :: * -> *) a. Foldable t => t a -> Int
length [Char]
s) Char
' ' forall a. [a] -> [a] -> [a]
++ [Char]
s

-- Horizontal catenation.  Assumes input rectangle are padded.
-- Adds empty space between Rects.
hcatRect :: Rect -> Rect -> Rect
hcatRect :: Rect -> Rect -> Rect
hcatRect (Rect [[Char]]
xs) (Rect [[Char]]
ys) = [[Char]] -> Rect
Rect forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\ [Char]
x [Char]
y -> [Char]
x forall a. [a] -> [a] -> [a]
++ [Char]
" " forall a. [a] -> [a] -> [a]
++ [Char]
y) [[Char]]
xs [[Char]]
ys

-- Vertical catenation.  Assumes input rectangle are padded.
-- Adds no space between Rects.
vcatRect :: Rect -> Rect -> Rect
vcatRect :: Rect -> Rect -> Rect
vcatRect (Rect [[Char]]
xs) (Rect [[Char]]
ys) = [[Char]] -> Rect
Rect forall a b. (a -> b) -> a -> b
$ [[Char]]
xs forall a. [a] -> [a] -> [a]
++ [[Char]]
ys

rectHeight :: Rect -> Int
rectHeight :: Rect -> Int
rectHeight = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rect -> [[Char]]
unRect

-- Widest line
rectWidth :: Rect -> Int
rectWidth :: Rect -> Int
rectWidth = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int
0forall a. a -> [a] -> [a]
:) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rect -> [[Char]]
unRect

ppT
  :: (Vector v, VecElem v a, Pretty a)
  => PrettyLevel -> Rational -> ShapeL -> T v a -> Doc
ppT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Pretty a) =>
PrettyLevel -> Rational -> [Int] -> T v a -> Doc
ppT PrettyLevel
l Rational
p [Int]
sh = Bool -> Doc -> Doc
maybeParens (Rational
p forall a. Ord a => a -> a -> Bool
> Rational
10) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Doc] -> Doc
vcat' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map [Char] -> Doc
text forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rect -> [[Char]]
unRect forall b c a. (b -> c) -> (a -> b) -> a -> c
. BoxMode -> Rect -> Rect
box BoxMode
boxMode forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> [Char]) -> [Int] -> T v a -> Rect
ppT_ (forall a. Pretty a => PrettyLevel -> a -> [Char]
prettyShowL PrettyLevel
l) [Int]
sh
  where boxMode :: BoxMode
boxMode | PrettyLevel
l forall a. Ord a => a -> a -> Bool
>= PrettyLevel
prettyNormal = Bool -> Bool -> Bool -> BoxMode
BoxMode Bool
True Bool
True Bool
True
                | Bool
otherwise = Bool -> Bool -> Bool -> BoxMode
BoxMode Bool
False Bool
False Bool
False
        vcat' :: [Doc] -> Doc
vcat' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Doc -> Doc -> Doc
($+$) Doc
empty

ppT_
  :: (Vector v, VecElem v a)
  => (a -> String) -> ShapeL -> T v a -> Rect
ppT_ :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> [Char]) -> [Int] -> T v a -> Rect
ppT_ a -> [Char]
show_ [Int]
sh T v a
t = [Int] -> T [] Rect -> Rect
showsT [Int]
sh T [] Rect
t'
  where ss :: [Rect]
ss = forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> Rect
toRect forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> [Char]
show_) forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
sh T v a
t
        maxH :: Int
maxH = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Rect -> Int
rectHeight [Rect]
ss
        maxW :: Int
maxW = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Rect -> Int
rectWidth [Rect]
ss
        ss' :: [Rect]
ss' = forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Rect -> Rect
rectPad Int
maxH Int
maxW) [Rect]
ss
        t' :: T [] Rect
        t' :: T [] Rect
t' = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T (forall a. [a] -> [a]
tail ([Int] -> [Int]
getStridesT [Int]
sh)) Int
0 [Rect]
ss'

showsT :: [Int] -> T [] Rect -> Rect
showsT :: [Int] -> T [] Rect -> Rect
showsT []     T [] Rect
t = forall (v :: * -> *) a. (Vector v, VecElem v a) => T v a -> a
unScalarT T [] Rect
t
showsT s :: [Int]
s@[Int
_]  T [] Rect
t = forall a. (a -> a -> a) -> [a] -> a
foldl1' Rect -> Rect -> Rect
hcatRect forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [a]
toListT [Int]
s T [] Rect
t
showsT (Int
n:[Int]
ns) T [] Rect
t = forall a. (a -> a -> a) -> [a] -> a
foldl1' Rect -> Rect -> Rect
vcat' [Rect]
rs
  where vcat' :: Rect -> Rect -> Rect
vcat' Rect
x Rect
y = Rect -> Rect -> Rect
vcatRect Rect
x (Rect -> Rect -> Rect
vcatRect Rect
spc Rect
y)
        spc :: Rect
spc = [[Char]] -> Rect
Rect forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
ns forall a. Num a => a -> a -> a
- Int
1) (forall a. Int -> a -> [a]
replicate (Rect -> Int
rectWidth (forall a. [a] -> a
head [Rect]
rs)) Char
' ')
        rs :: [Rect]
rs = [ [Int] -> T [] Rect -> Rect
showsT [Int]
ns (forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T [] Rect
t Int
i) | Int
i <- [Int
0..Int
nforall a. Num a => a -> a -> a
-Int
1] ]

data BoxMode = BoxMode { BoxMode -> Bool
_bmBars, BoxMode -> Bool
_bmUnicode, BoxMode -> Bool
_bmHeader :: Bool }

prettyBoxMode :: BoxMode
prettyBoxMode :: BoxMode
prettyBoxMode = Bool -> Bool -> Bool -> BoxMode
BoxMode Bool
False Bool
False Bool
False

-- Possibly draw a box around a (padded) rectangle.
box :: BoxMode -> Rect -> Rect
box :: BoxMode -> Rect -> Rect
box BoxMode{Bool
_bmHeader :: Bool
_bmUnicode :: Bool
_bmBars :: Bool
_bmHeader :: BoxMode -> Bool
_bmUnicode :: BoxMode -> Bool
_bmBars :: BoxMode -> Bool
..} (Rect [[Char]]
ls) =
  let bar :: Char
bar | Bool
_bmUnicode = Char
'\x2502'
          | Bool
otherwise = Char
'|'
      dash :: Char
dash | Bool
_bmUnicode = Char
'\x2500'
           | Bool
otherwise = Char
'-'
      ls' :: [[Char]]
ls' | Bool
_bmBars = forall a b. (a -> b) -> [a] -> [b]
map (\ [Char]
l -> if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Char]
l then [Char]
l else [Char
bar] forall a. [a] -> [a] -> [a]
++ [Char]
l forall a. [a] -> [a] -> [a]
++ [Char
bar]) [[Char]]
ls
          | Bool
otherwise = [[Char]]
ls
      h :: [Char]
h = forall a. Int -> a -> [a]
replicate (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. [a] -> a
head [[Char]]
ls)) Char
dash
      t :: [Char]
t | Bool
_bmUnicode = [Char]
"\x250c" forall a. [a] -> [a] -> [a]
++ [Char]
h forall a. [a] -> [a] -> [a]
++ [Char]
"\x2510"
        | Bool
otherwise = [Char]
"+" forall a. [a] -> [a] -> [a]
++ [Char]
h forall a. [a] -> [a] -> [a]
++ [Char]
"+"
      b :: [Char]
b | Bool
_bmUnicode = [Char]
"\x2514" forall a. [a] -> [a] -> [a]
++ [Char]
h forall a. [a] -> [a] -> [a]
++ [Char]
"\x2518"
        | Bool
otherwise = [Char]
t
      ls'' :: [[Char]]
ls'' | Bool
_bmHeader = [[Char]
t] forall a. [a] -> [a] -> [a]
++ [[Char]]
ls' forall a. [a] -> [a] -> [a]
++ [[Char]
b]
           | Bool
otherwise = [[Char]]
ls'
  in  [[Char]] -> Rect
Rect [[Char]]
ls''

zipWithLong2 :: (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 :: forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 a -> b -> b
f (a
a:[a]
as) (b
b:[b]
bs) = a -> b -> b
f a
a b
b forall a. a -> [a] -> [a]
: forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 a -> b -> b
f [a]
as [b]
bs
zipWithLong2 a -> b -> b
_     [a]
_     [b]
bs  = [b]
bs

padT :: forall v a . (Vector v, VecElem v a) => a -> [(Int, Int)] -> ShapeL -> T v a -> ([Int], T v a)
padT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
a -> [(Int, Int)] -> [Int] -> T v a -> ([Int], T v a)
padT a
v [(Int, Int)]
aps [Int]
ash T v a
at = ([Int]
ss, forall (v :: * -> *) a. [Int] -> v a -> T v a
fromVectorT [Int]
ss forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a. (Vector v, VecElem v a) => [v a] -> v a
vConcat forall a b. (a -> b) -> a -> b
$ [(Int, Int)] -> [Int] -> [Int] -> T v a -> [v a]
pad' [(Int, Int)]
aps [Int]
ash [Int]
st T v a
at)
  where pad' :: [(Int, Int)] -> ShapeL -> [Int] -> T v a -> [v a]
        pad' :: [(Int, Int)] -> [Int] -> [Int] -> T v a -> [v a]
pad' [] [Int]
sh [Int]
_ T v a
t = [forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t]
        pad' ((Int
l,Int
h):[(Int, Int)]
ps) (Int
s:[Int]
sh) (Int
n:[Int]
ns) T v a
t =
          [forall (v :: * -> *) a. (Vector v, VecElem v a) => Int -> a -> v a
vReplicate (Int
nforall a. Num a => a -> a -> a
*Int
l) a
v] forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([(Int, Int)] -> [Int] -> [Int] -> T v a -> [v a]
pad' [(Int, Int)]
ps [Int]
sh [Int]
ns forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a. T v a -> Int -> T v a
indexT T v a
t) [Int
0..Int
sforall a. Num a => a -> a -> a
-Int
1] forall a. [a] -> [a] -> [a]
++ [forall (v :: * -> *) a. (Vector v, VecElem v a) => Int -> a -> v a
vReplicate (Int
nforall a. Num a => a -> a -> a
*Int
h) a
v]
        pad' [(Int, Int)]
_ [Int]
_ [Int]
_ T v a
_ = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"pad: rank mismatch: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Int, Int)]
aps, forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
ash)
        Int
_ : [Int]
st = [Int] -> [Int]
getStridesT [Int]
ss
        ss :: [Int]
ss = forall a b. (a -> b -> b) -> [a] -> [b] -> [b]
zipWithLong2 (\ (Int
l,Int
h) Int
s -> Int
lforall a. Num a => a -> a -> a
+Int
sforall a. Num a => a -> a -> a
+Int
h) [(Int, Int)]
aps [Int]
ash

-- Check if a reshape is just adding/removing some dimensions of
-- size 1, in which case it can be done by just manipulating
-- the strides.  Given the old strides, the old shapes, and the
-- new shape it will return the possible new strides.
simpleReshape :: [Int] -> ShapeL -> ShapeL -> Maybe [Int]
simpleReshape :: [Int] -> [Int] -> [Int] -> Maybe [Int]
simpleReshape [Int]
osts [Int]
os [Int]
ns
  | forall a. (a -> Bool) -> [a] -> [a]
filter (Int
1 forall a. Eq a => a -> a -> Bool
/=) [Int]
os forall a. Eq a => a -> a -> Bool
== forall a. (a -> Bool) -> [a] -> [a]
filter (Int
1 forall a. Eq a => a -> a -> Bool
/=) [Int]
ns = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [Int]
loop [Int]
ns [Int]
sts'
    -- Old and new dimensions agree where they are not 1.
    where
      -- Get old strides for non-1 dimensions
      sts' :: [Int]
sts' = [ Int
st | (Int
st, Int
s) <- forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
osts [Int]
os, Int
s forall a. Eq a => a -> a -> Bool
/= Int
1 ]
      -- Insert stride 0 for all 1 dimensions in new shape.
      loop :: [Int] -> [Int] -> [Int]
loop [] [] = []
      loop (Int
1:[Int]
ss)     [Int]
sts  = Int
0  forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
loop [Int]
ss [Int]
sts
      loop (Int
_:[Int]
ss) (Int
st:[Int]
sts) = Int
st forall a. a -> [a] -> [a]
: [Int] -> [Int] -> [Int]
loop [Int]
ss [Int]
sts
      loop [Int]
_ [Int]
_ = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"simpleReshape: shouldn't happen: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show ([Int]
osts, [Int]
os, [Int]
ns)
simpleReshape [Int]
_ [Int]
_ [Int]
_ = forall a. Maybe a
Nothing

-- Note: assumes + is commutative&associative.
{-# INLINE sumT #-}
sumT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
sumT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
[Int] -> T v a -> a
sumT [Int]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Num a) => v a -> a
vSum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

-- Note: assumes * is commutative&associative.
{-# INLINE productT #-}
productT :: (Vector v, VecElem v a, Num a) => ShapeL -> T v a -> a
productT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Num a) =>
[Int] -> T v a -> a
productT [Int]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Num a) => v a -> a
vProduct forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

-- Note: assumes max is commutative&associative.
{-# INLINE maximumT #-}
maximumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
maximumT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
[Int] -> T v a -> a
maximumT [Int]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Ord a) => v a -> a
vMaximum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

-- Note: assumes min is commutative&associative.
{-# INLINE minimumT #-}
minimumT :: (Vector v, VecElem v a, Ord a) => ShapeL -> T v a -> a
minimumT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Ord a) =>
[Int] -> T v a -> a
minimumT [Int]
sh = forall (v :: * -> *) a. (Vector v, VecElem v a, Ord a) => v a -> a
vMinimum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE anyT #-}
anyT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
anyT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> (a -> Bool) -> T v a -> Bool
anyT [Int]
sh a -> Bool
p = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAny a -> Bool
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE allT #-}
allT :: (Vector v, VecElem v a) => ShapeL -> (a -> Bool) -> T v a -> Bool
allT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> (a -> Bool) -> T v a -> Bool
allT [Int]
sh a -> Bool
p = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
(a -> Bool) -> v a -> Bool
vAll a -> Bool
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toUnorderedVectorT [Int]
sh

{-# INLINE updateT #-}
updateT :: (Vector v, VecElem v a) => ShapeL -> T v a -> [([Int], a)] -> T v a
updateT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> [([Int], a)] -> T v a
updateT [Int]
sh T v a
t [([Int], a)]
us = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
v a -> [(Int, a)] -> v a
vUpdate (forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> T v a -> v a
toVectorT [Int]
sh T v a
t) forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map ([Int], a) -> (Int, a)
ix [([Int], a)]
us
  where Int
_ : [Int]
ss = [Int] -> [Int]
getStridesT [Int]
sh
        ix :: ([Int], a) -> (Int, a)
ix ([Int]
is, a
a) = (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Num a => a -> a -> a
(*) [Int]
is [Int]
ss, a
a)

{-# INLINE generateT #-}
generateT :: (Vector v, VecElem v a) => ShapeL -> ([Int] -> a) -> T v a
generateT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> ([Int] -> a) -> T v a
generateT [Int]
sh [Int] -> a
f = forall (v :: * -> *) a. [Int] -> Int -> v a -> T v a
T [Int]
ss Int
0 forall a b. (a -> b) -> a -> b
$ forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> (Int -> a) -> v a
vGenerate Int
s Int -> a
g
  where Int
s : [Int]
ss = [Int] -> [Int]
getStridesT [Int]
sh
        g :: Int -> a
g Int
i = [Int] -> a
f (forall {t}. Integral t => [t] -> t -> [t]
toIx [Int]
ss Int
i)
        toIx :: [t] -> t -> [t]
toIx [] t
_ = []
        toIx (t
n:[t]
ns) t
i = t
q forall a. a -> [a] -> [a]
: [t] -> t -> [t]
toIx [t]
ns t
r where (t
q, t
r) = forall a. Integral a => a -> a -> (a, a)
quotRem t
i t
n

{-# INLINE iterateNT #-}
iterateNT :: (Vector v, VecElem v a) => Int -> (a -> a) -> a -> T v a
iterateNT :: forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
Int -> (a -> a) -> a -> T v a
iterateNT Int
n a -> a
f a
x = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> [a] -> T v a
fromListT [Int
n] forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall a. (a -> a) -> a -> [a]
iterate a -> a
f a
x

{-# INLINE iotaT #-}
iotaT :: (Vector v, VecElem v a, Enum a, Num a) => Int -> T v a
iotaT :: forall (v :: * -> *) a.
(Vector v, VecElem v a, Enum a, Num a) =>
Int -> T v a
iotaT Int
n = forall (v :: * -> *) a.
(Vector v, VecElem v a) =>
[Int] -> [a] -> T v a
fromListT [Int
n] [a
0 .. forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n forall a. Num a => a -> a -> a
- a
1]    -- TODO: should use V.enumFromTo instead

-------

-- | Permute the elements of a list, the first argument is indices into the original list.
permute :: [Int] -> [a] -> [a]
permute :: forall a. [Int] -> [a] -> [a]
permute [Int]
is [a]
xs = forall a b. (a -> b) -> [a] -> [b]
map ([a]
xsforall a. [a] -> Int -> a
!!) [Int]
is

-- | Like 'dropWhile' but at the end of the list.
revDropWhile :: (a -> Bool) -> [a] -> [a]
revDropWhile :: forall a. (a -> Bool) -> [a] -> [a]
revDropWhile a -> Bool
p = forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
dropWhile a -> Bool
p forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse

allSame :: (Eq a) => [a] -> Bool
allSame :: forall a. Eq a => [a] -> Bool
allSame [] = Bool
True
allSame (a
x : [a]
xs) = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a
x forall a. Eq a => a -> a -> Bool
==) [a]
xs

-- | Get the value of a type level Nat.
-- Use with explicit type application, i.e., @valueOf \@42@
{-# INLINE valueOf #-}
valueOf :: forall n i . (KnownNat n, Num i) => i
valueOf :: forall (n :: Nat) i. (KnownNat n, Num i) => i
valueOf = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (forall {k} (t :: k). Proxy t
Proxy :: Proxy n)