module Data.Array.DynamicS.MatMul(matMul) where
import Prelude hiding ((<>))
import GHC.Stack
import Data.Array.DynamicS
import Numeric.LinearAlgebra as N

matMul :: (HasCallStack, N.Numeric a) => Array a -> Array a -> Array a
matMul :: Array a -> Array a -> Array a
matMul Array a
x Array a
y =
  case (Array a -> ShapeL
forall a. Array a -> ShapeL
shapeL Array a
x, Array a -> ShapeL
forall a. Array a -> ShapeL
shapeL Array a
y) of
    ([Int
m, Int
n], [Int
n', Int
o]) | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n' ->
      let x' :: Matrix a
x' = Int -> Vector a -> Matrix a
forall t. Storable t => Int -> Vector t -> Matrix t
N.reshape Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Array a -> Vector a
forall a. (HasCallStack, Unbox a) => Array a -> Vector a
toVector Array a
x
          y' :: Matrix a
y' = Int -> Vector a -> Matrix a
forall t. Storable t => Int -> Vector t -> Matrix t
N.reshape Int
o (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Array a -> Vector a
forall a. (HasCallStack, Unbox a) => Array a -> Vector a
toVector Array a
y
          xy' :: Matrix a
xy' = Matrix a
x' Matrix a -> Matrix a -> Matrix a
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
N.<> Matrix a
y'
          xy :: Array a
xy = ShapeL -> Vector a -> Array a
forall a. (HasCallStack, Unbox a) => ShapeL -> Vector a -> Array a
fromVector [Int
m, Int
o] (Vector a -> Array a) -> Vector a -> Array a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Vector a
forall t. Element t => Matrix t -> Vector t
N.flatten Matrix a
xy'
      in  Array a
xy
    (ShapeL, ShapeL)
sz -> [Char] -> Array a
forall a. HasCallStack => [Char] -> a
error ([Char] -> Array a) -> [Char] -> Array a
forall a b. (a -> b) -> a -> b
$ [Char]
"matMul: expected two conforming two-dimensional arrays, got shape " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL) -> [Char]
forall a. Show a => a -> [Char]
show (ShapeL, ShapeL)
sz