{-# LANGUAGE DataKinds #-}
module Data.Array.RankedS.MatMul(matMul) where
import Prelude hiding ((<>))
import GHC.Stack
import Data.Array.RankedS
import Numeric.LinearAlgebra as N

matMul :: (HasCallStack, N.Numeric a) => Array 2 a -> Array 2 a -> Array 2 a
matMul :: Array 2 a -> Array 2 a -> Array 2 a
matMul Array 2 a
x Array 2 a
y =
  case (Array 2 a -> ShapeL
forall (n :: Nat) a. Array n a -> ShapeL
shapeL Array 2 a
x, Array 2 a -> ShapeL
forall (n :: Nat) a. Array n a -> ShapeL
shapeL Array 2 a
y) of
    ([Int
m, Int
n], [Int
n', Int
o]) | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n' ->
      let x' :: Matrix a
x' = Int -> Vector a -> Matrix a
forall t. Storable t => Int -> Vector t -> Matrix t
N.reshape Int
n (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Array 2 a -> Vector a
forall a (n :: Nat). Unbox a => Array n a -> Vector a
toVector Array 2 a
x
          y' :: Matrix a
y' = Int -> Vector a -> Matrix a
forall t. Storable t => Int -> Vector t -> Matrix t
N.reshape Int
o (Vector a -> Matrix a) -> Vector a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Array 2 a -> Vector a
forall a (n :: Nat). Unbox a => Array n a -> Vector a
toVector Array 2 a
y
          xy' :: Matrix a
xy' = Matrix a
x' Matrix a -> Matrix a -> Matrix a
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
N.<> Matrix a
y'
          xy :: Array 2 a
xy = ShapeL -> Vector a -> Array 2 a
forall a (n :: Nat).
(Unbox a, KnownNat n) =>
ShapeL -> Vector a -> Array n a
fromVector [Int
m, Int
o] (Vector a -> Array 2 a) -> Vector a -> Array 2 a
forall a b. (a -> b) -> a -> b
$ Matrix a -> Vector a
forall t. Element t => Matrix t -> Vector t
N.flatten Matrix a
xy'
      in  Array 2 a
xy
    (ShapeL, ShapeL)
sz -> [Char] -> Array 2 a
forall a. HasCallStack => [Char] -> a
error ([Char] -> Array 2 a) -> [Char] -> Array 2 a
forall a b. (a -> b) -> a -> b
$ [Char]
"matMul: expected two conforming two-dimensional arrays, got shape " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (ShapeL, ShapeL) -> [Char]
forall a. Show a => a -> [Char]
show (ShapeL, ShapeL)
sz