{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RebindableSyntax #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StrictData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE NoStarIsType #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
{-# OPTIONS_GHC -fno-warn-incomplete-uni-patterns #-}
module NumHask.Array.Dynamic
(
Array (..),
fromFlatList,
toFlatList,
index,
tabulate,
reshape,
transpose,
diag,
ident,
singleton,
selects,
selectsExcept,
folds,
extracts,
extractsExcept,
joins,
maps,
concatenate,
insert,
append,
reorder,
expand,
apply,
contract,
dot,
mult,
slice,
squeeze,
fromScalar,
toScalar,
col,
row,
mmult,
)
where
import Data.List (intercalate)
import qualified Data.Vector as V
import GHC.Show (Show (..))
import NumHask.Array.Shape
import NumHask.Prelude as P hiding (product)
data Array a = Array {Array a -> [Int]
shape :: [Int], Array a -> Vector a
unArray :: V.Vector a}
deriving (Array a -> Array a -> Bool
(Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool) -> Eq (Array a)
forall a. Eq a => Array a -> Array a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Array a -> Array a -> Bool
$c/= :: forall a. Eq a => Array a -> Array a -> Bool
== :: Array a -> Array a -> Bool
$c== :: forall a. Eq a => Array a -> Array a -> Bool
Eq, Eq (Array a)
Eq (Array a)
-> (Array a -> Array a -> Ordering)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Bool)
-> (Array a -> Array a -> Array a)
-> (Array a -> Array a -> Array a)
-> Ord (Array a)
Array a -> Array a -> Bool
Array a -> Array a -> Ordering
Array a -> Array a -> Array a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Array a)
forall a. Ord a => Array a -> Array a -> Bool
forall a. Ord a => Array a -> Array a -> Ordering
forall a. Ord a => Array a -> Array a -> Array a
min :: Array a -> Array a -> Array a
$cmin :: forall a. Ord a => Array a -> Array a -> Array a
max :: Array a -> Array a -> Array a
$cmax :: forall a. Ord a => Array a -> Array a -> Array a
>= :: Array a -> Array a -> Bool
$c>= :: forall a. Ord a => Array a -> Array a -> Bool
> :: Array a -> Array a -> Bool
$c> :: forall a. Ord a => Array a -> Array a -> Bool
<= :: Array a -> Array a -> Bool
$c<= :: forall a. Ord a => Array a -> Array a -> Bool
< :: Array a -> Array a -> Bool
$c< :: forall a. Ord a => Array a -> Array a -> Bool
compare :: Array a -> Array a -> Ordering
$ccompare :: forall a. Ord a => Array a -> Array a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Array a)
Ord, (forall x. Array a -> Rep (Array a) x)
-> (forall x. Rep (Array a) x -> Array a) -> Generic (Array a)
forall x. Rep (Array a) x -> Array a
forall x. Array a -> Rep (Array a) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall a x. Rep (Array a) x -> Array a
forall a x. Array a -> Rep (Array a) x
$cto :: forall a x. Rep (Array a) x -> Array a
$cfrom :: forall a x. Array a -> Rep (Array a) x
Generic)
instance Functor Array where
fmap :: (a -> b) -> Array a -> Array b
fmap a -> b
f (Array [Int]
s Vector a
a) = [Int] -> Vector b -> Array b
forall a. [Int] -> Vector a -> Array a
Array [Int]
s ((a -> b) -> Vector a -> Vector b
forall a b. (a -> b) -> Vector a -> Vector b
V.map a -> b
f Vector a
a)
instance Foldable Array where
foldr :: (a -> b -> b) -> b -> Array a -> b
foldr a -> b -> b
x b
a (Array [Int]
_ Vector a
v) = (a -> b -> b) -> b -> Vector a -> b
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr a -> b -> b
x b
a Vector a
v
instance Traversable Array where
traverse :: (a -> f b) -> Array a -> f (Array b)
traverse a -> f b
f (Array [Int]
s Vector a
v) =
[Int] -> [b] -> Array b
forall a. [Int] -> [a] -> Array a
fromFlatList [Int]
s ([b] -> Array b) -> f [b] -> f (Array b)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f b) -> [a] -> f [b]
forall (t :: Type -> Type) (f :: Type -> Type) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f (Vector a -> [a]
forall (t :: Type -> Type) a. Foldable t => t a -> [a]
toList Vector a
v)
instance (Show a) => Show (Array a) where
show :: Array a -> String
show a :: Array a
a@(Array [Int]
l Vector a
_) = Int -> Array a -> String
forall a. Show a => Int -> Array a -> String
go ([Int] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Int]
l) Array a
a
where
go :: Int -> Array a -> String
go Int
n a' :: Array a
a'@(Array [Int]
l' Vector a
m) =
case [Int] -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length [Int]
l' of
Int
0 -> a -> String
forall a. Show a => a -> String
GHC.Show.show (Vector a -> a
forall a. Vector a -> a
V.head Vector a
m)
Int
1 -> String
"[" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " (a -> String
forall a. Show a => a -> String
GHC.Show.show (a -> String) -> [a] -> [String]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
m) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
Int
x ->
String
"["
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate
(String
",\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
n Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
x Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
1) Char
' ')
(Int -> Array a -> String
go Int
n (Array a -> String) -> [Array a] -> [String]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Array (Array a) -> [Array a]
forall a. Array a -> [a]
toFlatList ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int
0] Array a
a'))
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"]"
fromFlatList :: [Int] -> [a] -> Array a
fromFlatList :: [Int] -> [a] -> Array a
fromFlatList [Int]
ds [a]
l = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int]
ds (Vector a -> Array a) -> Vector a -> Array a
forall a b. (a -> b) -> a -> b
$ [a] -> Vector a
forall a. [a] -> Vector a
V.fromList ([a] -> Vector a) -> [a] -> Vector a
forall a b. (a -> b) -> a -> b
$ Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take ([Int] -> Int
size [Int]
ds) [a]
l
toFlatList :: Array a -> [a]
toFlatList :: Array a -> [a]
toFlatList (Array [Int]
_ Vector a
v) = Vector a -> [a]
forall a. Vector a -> [a]
V.toList Vector a
v
index :: () => Array a -> [Int] -> a
index :: Array a -> [Int] -> a
index (Array [Int]
s Vector a
v) [Int]
i = Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
v ([Int] -> [Int] -> Int
flatten [Int]
s [Int]
i)
tabulate :: () => [Int] -> ([Int] -> a) -> Array a
tabulate :: [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds [Int] -> a
f = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int]
ds (Vector a -> Array a)
-> ((Int -> a) -> Vector a) -> (Int -> a) -> Array a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate ([Int] -> Int
size [Int]
ds) ((Int -> a) -> Array a) -> (Int -> a) -> Array a
forall a b. (a -> b) -> a -> b
$ ([Int] -> a
f ([Int] -> a) -> (Int -> [Int]) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen [Int]
ds)
reshape ::
[Int] ->
Array a ->
Array a
reshape :: [Int] -> Array a -> Array a
reshape [Int]
s Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
s (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int -> [Int]
shapen (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Int -> [Int]) -> ([Int] -> Int) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int] -> Int
flatten [Int]
s)
transpose :: Array a -> Array a
transpose :: Array a -> Array a
transpose Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int]
forall a. [a] -> [a]
reverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> a) -> ([Int] -> [Int]) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. [a] -> [a]
reverse)
ident :: (Additive a, Multiplicative a) => [Int] -> Array a
ident :: [Int] -> Array a
ident [Int]
ds = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool a
forall a. Additive a => a
zero a
forall a. Multiplicative a => a
one (Bool -> a) -> ([Int] -> Bool) -> [Int] -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Bool
forall a. Eq a => [a] -> Bool
isDiag)
where
isDiag :: [a] -> Bool
isDiag [] = Bool
True
isDiag [a
_] = Bool
True
isDiag [a
x, a
y] = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y
isDiag (a
x : a
y : [a]
xs) = a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
y Bool -> Bool -> Bool
&& [a] -> Bool
isDiag (a
y a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
xs)
diag ::
Array a ->
Array a
diag :: Array a -> Array a
diag Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [[Int] -> Int
NumHask.Array.Shape.minimum (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)] [Int] -> a
go
where
go :: [Int] -> a
go [] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Rank Underflow")
go (Int
s' : [Int]
_) = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate ([Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) Int
s')
singleton :: [Int] -> a -> Array a
singleton :: [Int] -> a -> Array a
singleton [Int]
ds a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int]
ds (a -> [Int] -> a
forall a b. a -> b -> a
const a
a)
selects ::
[Int] ->
[Int] ->
Array a ->
Array a
selects :: [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
i Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
dropIndexes (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
s [Int]
ds [Int]
i)
selectsExcept ::
[Int] ->
[Int] ->
Array a ->
Array a
selectsExcept :: [Int] -> [Int] -> Array a -> Array a
selectsExcept [Int]
ds [Int]
i Array a
a = [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
selects (Int -> [Int] -> [Int]
exclude ([Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int]
ds) [Int]
i Array a
a
folds ::
(Array a -> b) ->
[Int] ->
Array a ->
Array b
folds :: (Array a -> b) -> [Int] -> Array a -> Array b
folds Array a -> b
f [Int]
ds Array a
a = [Int] -> ([Int] -> b) -> Array b
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
takeIndexes (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> b
go
where
go :: [Int] -> b
go [Int]
s = Array a -> b
f ([Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
s Array a
a)
extracts ::
[Int] ->
Array a ->
Array (Array a)
[Int]
ds Array a
a = [Int] -> ([Int] -> Array a) -> Array (Array a)
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
takeIndexes (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> Array a
go
where
go :: [Int] -> Array a
go [Int]
s = [Int] -> [Int] -> Array a -> Array a
forall a. [Int] -> [Int] -> Array a -> Array a
selects [Int]
ds [Int]
s Array a
a
extractsExcept ::
[Int] ->
Array a ->
Array (Array a)
[Int]
ds Array a
a = [Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts (Int -> [Int] -> [Int]
exclude ([Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int]
ds) Array a
a
joins ::
[Int] ->
Array (Array a) ->
Array a
joins :: [Int] -> Array (Array a) -> Array a
joins [Int]
ds Array (Array a)
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [Int]
si [Int]
ds [Int]
so) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index (Array (Array a) -> [Int] -> Array a
forall a. Array a -> [Int] -> a
index Array (Array a)
a ([Int] -> [Int] -> [Int]
takeIndexes [Int]
s [Int]
ds)) ([Int] -> [Int] -> [Int]
dropIndexes [Int]
s [Int]
ds)
so :: [Int]
so = Array (Array a) -> [Int]
forall a. Array a -> [Int]
shape Array (Array a)
a
si :: [Int]
si = Array a -> [Int]
forall a. Array a -> [Int]
shape (Array (Array a) -> [Int] -> Array a
forall a. Array a -> [Int] -> a
index Array (Array a)
a (Int -> Int -> [Int]
forall a. Int -> a -> [a]
replicate ([Int] -> Int
forall a. [a] -> Int
rank [Int]
so) Int
0))
maps ::
(Array a -> Array b) ->
[Int] ->
Array a ->
Array b
maps :: (Array a -> Array b) -> [Int] -> Array a -> Array b
maps Array a -> Array b
f [Int]
ds Array a
a = [Int] -> Array (Array b) -> Array b
forall a. [Int] -> Array (Array a) -> Array a
joins [Int]
ds ((Array a -> Array b) -> Array (Array a) -> Array (Array b)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap Array a -> Array b
f ([Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extracts [Int]
ds Array a
a))
concatenate ::
Int ->
Array a ->
Array a ->
Array a
concatenate :: Int -> Array a -> Array a -> Array a
concatenate Int
d Array a
a0 Array a
a1 = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int] -> [Int]
concatenate' Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a0) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a1)) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s =
a -> a -> Bool -> a
forall a. a -> a -> Bool -> a
bool
(Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a0 [Int]
s)
( Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index
Array a
a1
( [Int] -> Int -> Int -> [Int]
addIndex
([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
Int
d
(([Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d) Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- ([Int]
ds0 [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d))
)
)
(([Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ([Int]
ds0 [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d))
ds0 :: [Int]
ds0 = Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a0
insert ::
Int ->
Int ->
Array a ->
Array a ->
Array a
insert :: Int -> Int -> Array a -> Array a -> Array a
insert Int
d Int
i Array a
a Array a
b = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate (Int -> [Int] -> [Int]
incAt Int
d (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s
| [Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
b ([Int] -> Int -> [Int]
dropIndex [Int]
s Int
d)
| [Int]
s [Int] -> Int -> Int
forall a. [a] -> Int -> a
!! Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a [Int]
s
| Bool
otherwise = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
decAt Int
d [Int]
s)
append ::
Int ->
Array a ->
Array a ->
Array a
append :: Int -> Array a -> Array a -> Array a
append Int
d Array a
a Array a
b = Int -> Int -> Array a -> Array a -> Array a
forall a. Int -> Int -> Array a -> Array a -> Array a
insert Int
d ([Int] -> Int -> Int
dimension (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) Int
d) Array a
a Array a
b
reorder ::
[Int] ->
Array a ->
Array a
reorder :: [Int] -> Array a -> Array a
reorder [Int]
ds Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
reorder' (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) [Int]
ds) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([Int] -> [Int] -> [Int] -> [Int]
addIndexes [] [Int]
ds [Int]
s)
expand ::
(a -> b -> c) ->
Array a ->
Array b ->
Array c
expand :: (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
f Array a
a Array b
b = [Int] -> ([Int] -> c) -> Array c
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
(++) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a) (Array b -> [Int]
forall a. Array a -> [Int]
shape Array b
b)) (\[Int]
i -> a -> b -> c
f (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
i)) (Array b -> [Int] -> b
forall a. Array a -> [Int] -> a
index Array b
b (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
r [Int]
i)))
where
r :: Int
r = [Int] -> Int
forall a. [a] -> Int
rank (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)
apply ::
Array (a -> b) ->
Array a ->
Array b
apply :: Array (a -> b) -> Array a -> Array b
apply Array (a -> b)
f Array a
a = [Int] -> ([Int] -> b) -> Array b
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
(++) (Array (a -> b) -> [Int]
forall a. Array a -> [Int]
shape Array (a -> b)
f) (Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a)) (\[Int]
i -> Array (a -> b) -> [Int] -> a -> b
forall a. Array a -> [Int] -> a
index Array (a -> b)
f (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
i) (Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
r [Int]
i)))
where
r :: Int
r = [Int] -> Int
forall a. [a] -> Int
rank (Array (a -> b) -> [Int]
forall a. Array a -> [Int]
shape Array (a -> b)
f)
contract ::
(Array a -> b) ->
[Int] ->
Array a ->
Array b
contract :: (Array a -> b) -> [Int] -> Array a -> Array b
contract Array a -> b
f [Int]
xs Array a
a = Array a -> b
f (Array a -> b) -> (Array a -> Array a) -> Array a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Array a -> Array a
forall a. Array a -> Array a
diag (Array a -> b) -> Array (Array a) -> Array b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> Array a -> Array (Array a)
forall a. [Int] -> Array a -> Array (Array a)
extractsExcept [Int]
xs Array a
a
dot ::
(Array c -> d) ->
(a -> b -> c) ->
Array a ->
Array b ->
Array d
dot :: (Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array c -> d
f a -> b -> c
g Array a
a Array b
b = (Array c -> d) -> [Int] -> Array c -> Array d
forall a b. (Array a -> b) -> [Int] -> Array a -> Array b
contract Array c -> d
f [[Int] -> Int
forall a. [a] -> Int
rank [Int]
sa Int -> Int -> Int
forall a. Subtractive a => a -> a -> a
- Int
1, [Int] -> Int
forall a. [a] -> Int
rank [Int]
sa] ((a -> b -> c) -> Array a -> Array b -> Array c
forall a b c. (a -> b -> c) -> Array a -> Array b -> Array c
expand a -> b -> c
g Array a
a Array b
b)
where
sa :: [Int]
sa = Array a -> [Int]
forall a. Array a -> [Int]
shape Array a
a
mult ::
( Additive a,
Multiplicative a
) =>
Array a ->
Array a ->
Array a
mult :: Array a -> Array a -> Array a
mult = (Array a -> a) -> (a -> a -> a) -> Array a -> Array a -> Array a
forall c d a b.
(Array c -> d) -> (a -> b -> c) -> Array a -> Array b -> Array d
dot Array a -> a
forall a (f :: Type -> Type). (Additive a, Foldable f) => f a -> a
sum a -> a -> a
forall a. Multiplicative a => a -> a -> a
(*)
slice ::
[[Int]] ->
Array a ->
Array a
slice :: [[Int]] -> Array a -> Array a
slice [[Int]]
pss Array a
a = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate ([[Int]] -> [Int]
forall a. [[a]] -> [Int]
ranks [[Int]]
pss) [Int] -> a
go
where
go :: [Int] -> a
go [Int]
s = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a (([Int] -> Int -> Int) -> [[Int]] -> [Int] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [Int] -> Int -> Int
forall a. [a] -> Int -> a
(!!) [[Int]]
pss [Int]
s)
squeeze ::
Array a ->
Array a
squeeze :: Array a -> Array a
squeeze (Array [Int]
s Vector a
x) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array ([Int] -> [Int]
forall a. (Eq a, Multiplicative a) => [a] -> [a]
squeeze' [Int]
s) Vector a
x
fromScalar :: Array a -> a
fromScalar :: Array a -> a
fromScalar Array a
a = Array a -> [Int] -> a
forall a. Array a -> [Int] -> a
index Array a
a ([] :: [Int])
toScalar :: a -> Array a
toScalar :: a -> Array a
toScalar a
a = [Int] -> [a] -> Array a
forall a. [Int] -> [a] -> Array a
fromFlatList [] [a
a]
row :: Int -> Array a -> Array a
row :: Int -> Array a -> Array a
row Int
i (Array [Int]
s Vector a
a) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int
n] (Vector a -> Array a) -> Vector a -> Array a
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int
i Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n) Int
n Vector a
a
where
(Int
_ : Int
n : [Int]
_) = [Int]
s
col :: Int -> Array a -> Array a
col :: Int -> Array a -> Array a
col Int
i (Array [Int]
s Vector a
a) = [Int] -> Vector a -> Array a
forall a. [Int] -> Vector a -> Array a
Array [Int
m] (Vector a -> Array a) -> Vector a -> Array a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
m (\Int
x -> Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.unsafeIndex Vector a
a (Int
i Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
x Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n))
where
(Int
m : Int
n : [Int]
_) = [Int]
s
mmult ::
(Ring a) =>
Array a ->
Array a ->
Array a
mmult :: Array a -> Array a -> Array a
mmult (Array [Int]
sx Vector a
x) (Array [Int]
sy Vector a
y) = [Int] -> ([Int] -> a) -> Array a
forall a. [Int] -> ([Int] -> a) -> Array a
tabulate [Int
m, Int
n] [Int] -> a
go
where
go :: [Int] -> a
go [] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
go [Int
_] = NumHaskException -> a
forall a e. Exception e => e -> a
throw (String -> NumHaskException
NumHaskException String
"Needs two dimensions")
go (Int
i : Int
j : [Int]
_) = Vector a -> a
forall a (f :: Type -> Type). (Additive a, Foldable f) => f a -> a
sum (Vector a -> a) -> Vector a -> a
forall a b. (a -> b) -> a -> b
$ (a -> a -> a) -> Vector a -> Vector a -> Vector a
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith a -> a -> a
forall a. Multiplicative a => a -> a -> a
(*) (Int -> Int -> Vector a -> Vector a
forall a. Int -> Int -> Vector a -> Vector a
V.slice (Int -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral Int
i Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
k) Int
k Vector a
x) (Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
k (\Int
x' -> Vector a
y Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! (Int -> Int
forall a b. FromIntegral a b => b -> a
fromIntegral Int
j Int -> Int -> Int
forall a. Additive a => a -> a -> a
+ Int
x' Int -> Int -> Int
forall a. Multiplicative a => a -> a -> a
* Int
n)))
(Int
m : Int
k : [Int]
_) = [Int]
sx
(Int
_ : Int
n : [Int]
_) = [Int]
sy
{-# INLINE mmult #-}