{-# LANGUAGE DataKinds #-} module Data.Array.RankedS.MatMul(matMul) where import Prelude hiding ((<>)) import GHC.Stack import Data.Array.RankedS import Numeric.LinearAlgebra as N matMul :: (HasCallStack, N.Numeric a) => Array 2 a -> Array 2 a -> Array 2 a matMul :: Array 2 a -> Array 2 a -> Array 2 a matMul Array 2 a x Array 2 a y = case (Array 2 a -> ShapeL forall (n :: Nat) a. Array n a -> ShapeL shapeL Array 2 a x, Array 2 a -> ShapeL forall (n :: Nat) a. Array n a -> ShapeL shapeL Array 2 a y) of ([Int m, Int n], [Int n', Int o]) | Int n Int -> Int -> Bool forall a. Eq a => a -> a -> Bool == Int n' -> let x' :: Matrix a x' = Int -> Vector a -> Matrix a forall t. Storable t => Int -> Vector t -> Matrix t N.reshape Int n (Vector a -> Matrix a) -> Vector a -> Matrix a forall a b. (a -> b) -> a -> b $ Array 2 a -> Vector a forall a (n :: Nat). Unbox a => Array n a -> Vector a toVector Array 2 a x y' :: Matrix a y' = Int -> Vector a -> Matrix a forall t. Storable t => Int -> Vector t -> Matrix t N.reshape Int o (Vector a -> Matrix a) -> Vector a -> Matrix a forall a b. (a -> b) -> a -> b $ Array 2 a -> Vector a forall a (n :: Nat). Unbox a => Array n a -> Vector a toVector Array 2 a y xy' :: Matrix a xy' = Matrix a x' Matrix a -> Matrix a -> Matrix a forall t. Numeric t => Matrix t -> Matrix t -> Matrix t N.<> Matrix a y' xy :: Array 2 a xy = ShapeL -> Vector a -> Array 2 a forall a (n :: Nat). (Unbox a, KnownNat n) => ShapeL -> Vector a -> Array n a fromVector [Int m, Int o] (Vector a -> Array 2 a) -> Vector a -> Array 2 a forall a b. (a -> b) -> a -> b $ Matrix a -> Vector a forall t. Element t => Matrix t -> Vector t N.flatten Matrix a xy' in Array 2 a xy (ShapeL, ShapeL) sz -> [Char] -> Array 2 a forall a. HasCallStack => [Char] -> a error ([Char] -> Array 2 a) -> [Char] -> Array 2 a forall a b. (a -> b) -> a -> b $ [Char] "matMul: expected two conforming two-dimensional arrays, got shape " [Char] -> [Char] -> [Char] forall a. [a] -> [a] -> [a] ++ (ShapeL, ShapeL) -> [Char] forall a. Show a => a -> [Char] show (ShapeL, ShapeL) sz