{-# OPTIONS -fno-warn-incomplete-patterns #-}
{-# LANGUAGE PackageImports #-}
module Data.Array.Repa.Algorithms.Matrix
(
row
, col
, mmultP, mmultS
, transpose2P, transpose2S
, trace2P, trace2S)
where
import Data.Array.Repa as R
import Data.Array.Repa.Eval as R
import Data.Array.Repa.Unsafe as R
import Control.Monad.ST.Strict
row :: DIM2 -> Int
row :: DIM2 -> Int
row (DIM0
Z :. Int
r :. Int
_) = Int
r
{-# INLINE row #-}
col :: DIM2 -> Int
col :: DIM2 -> Int
col (DIM0
Z :. Int
_ :. Int
c) = Int
c
{-# INLINE col #-}
mmultP :: Monad m
=> Array U DIM2 Double
-> Array U DIM2 Double
-> m (Array U DIM2 Double)
mmultP :: Array U DIM2 Double
-> Array U DIM2 Double -> m (Array U DIM2 Double)
mmultP Array U DIM2 Double
arr Array U DIM2 Double
brr
= [Array U DIM2 Double
arr, Array U DIM2 Double
brr] [Array U DIM2 Double]
-> m (Array U DIM2 Double) -> m (Array U DIM2 Double)
forall sh r e b. (Shape sh, Source r e) => [Array r sh e] -> b -> b
`deepSeqArrays`
do Array U DIM2 Double
trr <- Array U DIM2 Double -> m (Array U DIM2 Double)
forall (m :: * -> *).
Monad m =>
Array U DIM2 Double -> m (Array U DIM2 Double)
transpose2P Array U DIM2 Double
brr
let (DIM0
Z :. Int
h1 :. Int
_) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr
let (DIM0
Z :. Int
_ :. Int
w2) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
brr
Array U DIM2 Double
trr Array U DIM2 Double
-> (Array D DIM2 Double -> m (Array U DIM2 Double))
-> Array D DIM2 Double
-> m (Array U DIM2 Double)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` Array D DIM2 Double -> m (Array U DIM2 Double)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP
(Array D DIM2 Double -> m (Array U DIM2 Double))
-> Array D DIM2 Double -> m (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ DIM2 -> (DIM2 -> Double) -> Array D DIM2 Double
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
h1 (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
w2)
((DIM2 -> Double) -> Array D DIM2 Double)
-> (DIM2 -> Double) -> Array D DIM2 Double
forall a b. (a -> b) -> a -> b
$ \DIM2
ix -> Array D (DIM0 :. Int) Double -> Double
forall sh r a. (Shape sh, Source r a, Num a) => Array r sh a -> a
R.sumAllS
(Array D (DIM0 :. Int) Double -> Double)
-> Array D (DIM0 :. Int) Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*)
(Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
arr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
row DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All))
(Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
trr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
col DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All))
{-# NOINLINE mmultP #-}
mmultS :: Array U DIM2 Double
-> Array U DIM2 Double
-> Array U DIM2 Double
mmultS :: Array U DIM2 Double -> Array U DIM2 Double -> Array U DIM2 Double
mmultS Array U DIM2 Double
arr Array U DIM2 Double
brr
= [Array U DIM2 Double
arr, Array U DIM2 Double
brr] [Array U DIM2 Double] -> Array U DIM2 Double -> Array U DIM2 Double
forall sh r e b. (Shape sh, Source r e) => [Array r sh e] -> b -> b
`deepSeqArrays` ((forall s. ST s (Array U DIM2 Double)) -> Array U DIM2 Double
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Array U DIM2 Double)) -> Array U DIM2 Double)
-> (forall s. ST s (Array U DIM2 Double)) -> Array U DIM2 Double
forall a b. (a -> b) -> a -> b
$
do Array U DIM2 Double
trr <- Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall sh r e (m :: * -> *).
(Shape sh, Source r e, Monad m) =>
Array r sh e -> m (Array r sh e)
R.now (Array U DIM2 Double -> ST s (Array U DIM2 Double))
-> Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ Array U DIM2 Double -> Array U DIM2 Double
transpose2S Array U DIM2 Double
brr
let (DIM0
Z :. Int
h1 :. Int
_) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr
let (DIM0
Z :. Int
_ :. Int
w2) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
brr
Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall (m :: * -> *) a. Monad m => a -> m a
return (Array U DIM2 Double -> ST s (Array U DIM2 Double))
-> Array U DIM2 Double -> ST s (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ Array D DIM2 Double -> Array U DIM2 Double
forall r1 sh e r2.
(Load r1 sh e, Target r2 e) =>
Array r1 sh e -> Array r2 sh e
computeS
(Array D DIM2 Double -> Array U DIM2 Double)
-> Array D DIM2 Double -> Array U DIM2 Double
forall a b. (a -> b) -> a -> b
$ DIM2 -> (DIM2 -> Double) -> Array D DIM2 Double
forall sh a. sh -> (sh -> a) -> Array D sh a
fromFunction (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
h1 (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
w2)
((DIM2 -> Double) -> Array D DIM2 Double)
-> (DIM2 -> Double) -> Array D DIM2 Double
forall a b. (a -> b) -> a -> b
$ \DIM2
ix -> Array D (DIM0 :. Int) Double -> Double
forall sh r a. (Shape sh, Source r a, Num a) => Array r sh a -> a
R.sumAllS
(Array D (DIM0 :. Int) Double -> Double)
-> Array D (DIM0 :. Int) Double -> Double
forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double)
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
-> Array D (DIM0 :. Int) Double
forall sh r1 a r2 b c.
(Shape sh, Source r1 a, Source r2 b) =>
(a -> b -> c) -> Array r1 sh a -> Array r2 sh b -> Array D sh c
R.zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(*)
(Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
arr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
row DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All))
(Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
-> ((Any DIM0 :. Int) :. All)
-> Array D (SliceShape ((Any DIM0 :. Int) :. All)) Double
forall sl r e.
(Slice sl, Shape (FullShape sl), Source r e) =>
Array r (FullShape sl) e -> sl -> Array D (SliceShape sl) e
unsafeSlice Array U (FullShape ((Any DIM0 :. Int) :. All)) Double
Array U DIM2 Double
trr (Any DIM0
forall sh. Any sh
Any Any DIM0 -> Int -> Any DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (DIM2 -> Int
col DIM2
ix) (Any DIM0 :. Int) -> All -> (Any DIM0 :. Int) :. All
forall tail head. tail -> head -> tail :. head
:. All
All)))
{-# NOINLINE mmultS #-}
transpose2P
:: Monad m
=> Array U DIM2 Double
-> m (Array U DIM2 Double)
transpose2P :: Array U DIM2 Double -> m (Array U DIM2 Double)
transpose2P Array U DIM2 Double
arr
= Array U DIM2 Double
arr Array U DIM2 Double
-> m (Array U DIM2 Double) -> m (Array U DIM2 Double)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
do Array D DIM2 Double -> m (Array U DIM2 Double)
forall r1 sh e (m :: * -> *).
(Load r1 sh e, Monad m, Unbox e) =>
Array r1 sh e -> m (Array U sh e)
computeUnboxedP
(Array D DIM2 Double -> m (Array U DIM2 Double))
-> Array D DIM2 Double -> m (Array U DIM2 Double)
forall a b. (a -> b) -> a -> b
$ DIM2
-> (DIM2 -> DIM2) -> Array U DIM2 Double -> Array D DIM2 Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute DIM2
new_extent DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap Array U DIM2 Double
arr
where swap :: ((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (DIM0
Z :. head
i :. head
j) = DIM0
Z DIM0 -> head -> DIM0 :. head
forall tail head. tail -> head -> tail :. head
:. head
j (DIM0 :. head) -> head -> (DIM0 :. head) :. head
forall tail head. tail -> head -> tail :. head
:. head
i
new_extent :: DIM2
new_extent = DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr)
{-# NOINLINE transpose2P #-}
transpose2S
:: Array U DIM2 Double
-> Array U DIM2 Double
transpose2S :: Array U DIM2 Double -> Array U DIM2 Double
transpose2S Array U DIM2 Double
arr
= Array U DIM2 Double
arr Array U DIM2 Double -> Array U DIM2 Double -> Array U DIM2 Double
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray`
do Array D DIM2 Double -> Array U DIM2 Double
forall r1 sh e.
(Load r1 sh e, Unbox e) =>
Array r1 sh e -> Array U sh e
computeUnboxedS
(Array D DIM2 Double -> Array U DIM2 Double)
-> Array D DIM2 Double -> Array U DIM2 Double
forall a b. (a -> b) -> a -> b
$ DIM2
-> (DIM2 -> DIM2) -> Array U DIM2 Double -> Array D DIM2 Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute DIM2
new_extent DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap Array U DIM2 Double
arr
where swap :: ((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (DIM0
Z :. head
i :. head
j) = DIM0
Z DIM0 -> head -> DIM0 :. head
forall tail head. tail -> head -> tail :. head
:. head
j (DIM0 :. head) -> head -> (DIM0 :. head) :. head
forall tail head. tail -> head -> tail :. head
:. head
i
new_extent :: DIM2
new_extent = DIM2 -> DIM2
forall head head.
((DIM0 :. head) :. head) -> (DIM0 :. head) :. head
swap (Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
arr)
{-# NOINLINE transpose2S #-}
trace2P :: Monad m => Array U DIM2 Double -> m Double
trace2P :: Array U DIM2 Double -> m Double
trace2P Array U DIM2 Double
x
= Array D (DIM0 :. Int) Double -> m Double
forall sh r a (m :: * -> *).
(Shape sh, Source r a, Unbox a, Num a, Monad m) =>
Array r sh a -> m a
sumAllP (Array D (DIM0 :. Int) Double -> m Double)
-> Array D (DIM0 :. Int) Double -> m Double
forall a b. (a -> b) -> a -> b
$ (DIM0 :. Int)
-> ((DIM0 :. Int) -> DIM2)
-> Array U DIM2 Double
-> Array D (DIM0 :. Int) Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
nRows Int
nColumns)) (\(DIM0
Z :. Int
i) -> (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
i)) Array U DIM2 Double
x
where
(DIM0
Z :. Int
nRows :. Int
nColumns) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
x
trace2S :: Array U DIM2 Double -> Double
trace2S :: Array U DIM2 Double -> Double
trace2S Array U DIM2 Double
x
= Array D (DIM0 :. Int) Double -> Double
forall sh r a. (Shape sh, Source r a, Num a) => Array r sh a -> a
sumAllS (Array D (DIM0 :. Int) Double -> Double)
-> Array D (DIM0 :. Int) Double -> Double
forall a b. (a -> b) -> a -> b
$ (DIM0 :. Int)
-> ((DIM0 :. Int) -> DIM2)
-> Array U DIM2 Double
-> Array D (DIM0 :. Int) Double
forall r sh1 sh2 e.
(Shape sh1, Source r e) =>
sh2 -> (sh2 -> sh1) -> Array r sh1 e -> Array D sh2 e
unsafeBackpermute (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
nRows Int
nColumns)) (\(DIM0
Z :. Int
i) -> (DIM0
Z DIM0 -> Int -> DIM0 :. Int
forall tail head. tail -> head -> tail :. head
:. Int
i (DIM0 :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
i)) Array U DIM2 Double
x
where
(DIM0
Z :. Int
nRows :. Int
nColumns) = Array U DIM2 Double -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 Double
x