{-# LANGUAGE CPP #-} {-# LANGUAGE Safe #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE DefaultSignatures #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE RebindableSyntax #-} {-# LANGUAGE RankNTypes #-} module Data.Semimodule.Matrix ( type M22 , type M23 , type M24 , type M32 , type M33 , type M34 , type M42 , type M43 , type M44 , lensRep , grateRep , tran , row , rows , col , cols , (.#) , (.*) , (#.) , (*.) , (.#.) , (.*.) , outer , scale , dirac , identity , transpose , trace , diagonal , bdet2 , det2 , inv2 , bdet3 , det3 , inv3 , bdet4 , det4 , inv4 , m22 , m23 , m24 , m32 , m33 , m34 , m42 , m43 , m44 ) where import safe Data.Bool import safe Data.Distributive import safe Data.Functor.Compose import safe Data.Functor.Rep import safe Data.Semifield import safe Data.Semigroup.Additive import safe Data.Semigroup.Multiplicative import safe Data.Semimodule import safe Data.Semimodule.Transform import safe Data.Semimodule.Vector import safe Data.Semiring import safe Data.Tuple import safe Prelude hiding (Num(..), Fractional(..), sum, negate) -- All matrices use row-major representation. -- | A 2x2 matrix. type M22 a = V2 (V2 a) -- | A 2x3 matrix. type M23 a = V2 (V3 a) -- | A 2x4 matrix. type M24 a = V2 (V4 a) -- | A 3x2 matrix. type M32 a = V3 (V2 a) -- | A 3x3 matrix. type M33 a = V3 (V3 a) -- | A 3x4 matrix. type M34 a = V3 (V4 a) -- | A 4x2 matrix. type M42 a = V4 (V2 a) -- | A 4x3 matrix. type M43 a = V4 (V3 a) -- | A 4x4 matrix. type M44 a = V4 (V4 a) lensRep :: Eq (Rep f) => Representable f => Rep f -> forall g. Functor g => (a -> g a) -> f a -> g (f a) lensRep i f s = setter s <$> f (getter s) where getter = flip index i setter s' b = tabulate $ \j -> bool (index s' j) b (i == j) {-# INLINE lensRep #-} grateRep :: Representable f => forall g. Functor g => (Rep f -> g a -> b) -> g (f a) -> f b grateRep iab s = tabulate $ \i -> iab i (fmap (`index` i) s) {-# INLINE grateRep #-} -- @ ('.#') = 'app' . 'tran' @ tran :: Semiring a => Basis b f => Basis c g => Foldable g => f (g a) -> Tran a b c tran m = Tran $ \f -> index $ m .# (tabulate f) -- | Retrieve a row of a row-major matrix or element of a row vector. -- -- >>> row I21 (V2 1 2) -- 1 -- row :: Representable f => Rep f -> f a -> a row = flip index {-# INLINE row #-} -- | Retrieve a column of a row-major matrix. -- -- >>> row I22 . col I31 $ V2 (V3 1 2 3) (V3 4 5 6) -- 4 -- col :: Functor f => Representable g => Rep g -> f (g a) -> f a col j = flip index j . distribute {-# INLINE col #-} -- | Outer product of two vectors. -- -- >>> V2 1 1 `outer` V2 1 1 -- V2 (V2 1 1) (V2 1 1) -- outer :: Semiring a => Functor f => Functor g => f a -> g a -> f (g a) outer x y = fmap (\z-> fmap (*z) y) x infixl 7 #. -- | Multiply a matrix on the left by a row vector. -- -- >>> V2 1 2 #. m23 3 4 5 6 7 8 -- V3 15 18 21 -- -- >>> V2 1 2 #. m23 3 4 5 6 7 8 #. m32 1 0 0 0 0 0 -- V2 15 0 -- (#.) :: (Semiring a, Free f, Foldable f, Free g) => f a -> f (g a) -> g a x #. y = tabulate (\j -> x .*. col j y) {-# INLINE (#.) #-} infixr 7 .#, .#. -- | Multiply a matrix on the right by a column vector. -- -- @ ('.#') = 'app' . 'fromMatrix' @ -- -- >>> m23 1 2 3 4 5 6 .# V3 7 8 9 -- V2 50 122 -- -- >>> m22 1 0 0 0 .# m23 1 2 3 4 5 6 .# V3 7 8 9 -- V2 50 0 -- (.#) :: (Semiring a, Free f, Free g, Foldable g) => f (g a) -> g a -> f a x .# y = tabulate (\i -> row i x .*. y) {-# INLINE (.#) #-} -- | Multiply two matrices. -- -- >>> m22 1 2 3 4 .#. m22 1 2 3 4 :: M22 Int -- V2 (V2 7 10) (V2 15 22) -- -- >>> m23 1 2 3 4 5 6 .#. m32 1 2 3 4 4 5 :: M22 Int -- V2 (V2 19 25) (V2 43 58) -- (.#.) :: (Semiring a, Free f, Free g, Free h, Foldable g) => f (g a) -> g (h a) -> f (h a) (.#.) x y = getCompose $ tabulate (\(i,j) -> row i x .*. col j y) {-# INLINE (.#.) #-} -- | Obtain a diagonal matrix from a vector. -- -- >>> scale (V2 2 3) -- V2 (V2 2 0) (V2 0 3) -- scale :: (Additive-Monoid) a => Free f => f a -> f (f a) scale f = flip imapRep f $ \i x -> flip imapRep f (\j _ -> bool zero x $ i == j) {-# INLINE scale #-} -- | Identity matrix. -- -- >>> identity :: M44 Int -- V4 (V4 1 0 0 0) (V4 0 1 0 0) (V4 0 0 1 0) (V4 0 0 0 1) -- -- >>> identity :: V3 (V3 Int) -- V3 (V3 1 0 0) (V3 0 1 0) (V3 0 0 1) -- identity :: Semiring a => Free f => f (f a) identity = scale $ pureRep one {-# INLINE identity #-} -- | Compute the trace of a matrix. -- -- >>> trace (V2 (V2 a b) (V2 c d)) -- a <> d -- trace :: Semiring a => Free f => Foldable f => f (f a) -> a trace = sum . diagonal {-# INLINE trace #-} -- | Obtain the diagonal of a matrix as a vector. -- -- >>> diagonal (V2 (V2 a b) (V2 c d)) -- V2 a d -- diagonal :: Representable f => f (f a) -> f a diagonal = flip bindRep id {-# INLINE diagonal #-} ij :: Representable f => Representable g => Rep f -> Rep g -> f (g a) -> a ij i j = row i . col j -- | 2x2 matrix bdeterminant over a commutative semiring. -- -- >>> bdet2 $ m22 1 2 3 4 -- (4,6) -- bdet2 :: Semiring a => Basis I2 f => Basis I2 g => f (g a) -> (a, a) bdet2 m = (ij I21 I21 m * ij I22 I22 m, ij I21 I22 m * ij I22 I21 m) {-# INLINE bdet2 #-} -- | 2x2 matrix determinant over a commutative ring. -- -- @ -- 'det2' '==' 'uncurry' ('-') . 'bdet2' -- @ -- -- >>> det2 $ m22 1 2 3 4 :: Double -- -2.0 -- det2 :: Ring a => Basis I2 f => Basis I2 g => f (g a) -> a det2 = uncurry (-) . bdet2 {-# INLINE det2 #-} -- | 2x2 matrix inverse over a field. -- -- >>> inv2 $ m22 1 2 3 4 :: M22 Double -- V2 (V2 (-2.0) 1.0) (V2 1.5 (-0.5)) -- inv2 :: Field a => Basis I2 f => Basis I2 g => f (g a) -> g (f a) inv2 m = multl (recip $ det2 m) <$> m22 d (-b) (-c) a where a = ij I21 I21 m b = ij I21 I22 m c = ij I22 I21 m d = ij I22 I22 m {-# INLINE inv2 #-} -- | 3x3 matrix bdeterminant over a commutative semiring. -- -- >>> bdet3 (V3 (V3 1 2 3) (V3 4 5 6) (V3 7 8 9)) -- (225, 225) -- bdet3 :: Semiring a => Basis I3 f => Basis I3 g => f (g a) -> (a, a) bdet3 m = (evens, odds) where evens = a*e*i + g*b*f + d*h*c odds = a*h*f + d*b*i + g*e*c a = ij I31 I31 m b = ij I31 I32 m c = ij I31 I33 m d = ij I32 I31 m e = ij I32 I32 m f = ij I32 I33 m g = ij I33 I31 m h = ij I33 I32 m i = ij I33 I33 m {-# INLINE bdet3 #-} -- | 3x3 double-precision matrix determinant. -- -- @ -- 'det3' '==' 'uncurry' ('-') . 'bdet3' -- @ -- -- Implementation uses a cofactor expansion to avoid loss of precision. -- -- >>> det3 (V3 (V3 1 2 3) (V3 4 5 6) (V3 7 8 9)) -- 0 -- det3 :: Ring a => Basis I3 f => Basis I3 g => f (g a) -> a det3 m = a * (e*i-f*h) - d * (b*i-c*h) + g * (b*f-c*e) where a = ij I31 I31 m b = ij I31 I32 m c = ij I31 I33 m d = ij I32 I31 m e = ij I32 I32 m f = ij I32 I33 m g = ij I33 I31 m h = ij I33 I32 m i = ij I33 I33 m {-# INLINE det3 #-} -- | 3x3 matrix inverse. -- -- >>> inv3 $ m33 1 2 4 4 2 2 1 1 1 :: M33 Double -- V3 (V3 0.0 0.5 (-1.0)) (V3 (-0.5) (-0.75) 3.5) (V3 0.5 0.25 (-1.5)) -- inv3 :: forall a f g. Field a => Basis I3 f => Basis I3 g => f (g a) -> g (f a) inv3 m = multl (recip $ det3 m) <$> m33 a' b' c' d' e' f' g' h' i' where a = ij I31 I31 m b = ij I31 I32 m c = ij I31 I33 m d = ij I32 I31 m e = ij I32 I32 m f = ij I32 I33 m g = ij I33 I31 m h = ij I33 I32 m i = ij I33 I33 m a' = cofactor (e,f,h,i) b' = cofactor (c,b,i,h) c' = cofactor (b,c,e,f) d' = cofactor (f,d,i,g) e' = cofactor (a,c,g,i) f' = cofactor (c,a,f,d) g' = cofactor (d,e,g,h) h' = cofactor (b,a,h,g) i' = cofactor (a,b,d,e) cofactor (q,r,s,t) = det2 (m22 q r s t :: M22 a) {-# INLINE inv3 #-} -- | 4x4 matrix bdeterminant over a commutative semiring. -- -- >>> bdet4 (V4 (V4 1 2 3 4) (V4 5 6 7 8) (V4 9 10 11 12) (V4 13 14 15 16)) -- (27728,27728) -- bdet4 :: Semiring a => Basis I4 f => Basis I4 g => f (g a) -> (a, a) bdet4 x = (evens, odds) where evens = a * (f*k*p + g*l*n + h*j*o) + b * (g*i*p + e*l*o + h*k*m) + c * (e*j*p + f*l*m + h*i*n) + d * (f*i*o + e*k*n + g*j*m) odds = a * (g*j*p + f*l*o + h*k*n) + b * (e*k*p + g*l*m + h*i*o) + c * (f*i*p + e*l*n + h*j*m) + d * (e*j*o + f*k*m + g*i*n) a = ij I41 I41 x b = ij I41 I42 x c = ij I41 I43 x d = ij I41 I44 x e = ij I42 I41 x f = ij I42 I42 x g = ij I42 I43 x h = ij I42 I44 x i = ij I43 I41 x j = ij I43 I42 x k = ij I43 I43 x l = ij I43 I44 x m = ij I44 I41 x n = ij I44 I42 x o = ij I44 I43 x p = ij I44 I44 x {-# INLINE bdet4 #-} -- | 4x4 matrix determinant over a commutative ring. -- -- @ -- 'det4' '==' 'uncurry' ('-') . 'bdet4' -- @ -- -- This implementation uses a cofactor expansion to avoid loss of precision. -- -- >>> det4 (m44 1 0 3 2 2 0 2 1 0 0 0 1 0 3 4 0 :: M44 Rational) -- (-12) % 1 -- det4 :: Ring a => Basis I4 f => Basis I4 g => f (g a) -> a det4 x = s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0 where s0 = i00 * i11 - i10 * i01 s1 = i00 * i12 - i10 * i02 s2 = i00 * i13 - i10 * i03 s3 = i01 * i12 - i11 * i02 s4 = i01 * i13 - i11 * i03 s5 = i02 * i13 - i12 * i03 c5 = i22 * i33 - i32 * i23 c4 = i21 * i33 - i31 * i23 c3 = i21 * i32 - i31 * i22 c2 = i20 * i33 - i30 * i23 c1 = i20 * i32 - i30 * i22 c0 = i20 * i31 - i30 * i21 i00 = ij I41 I41 x i01 = ij I41 I42 x i02 = ij I41 I43 x i03 = ij I41 I44 x i10 = ij I42 I41 x i11 = ij I42 I42 x i12 = ij I42 I43 x i13 = ij I42 I44 x i20 = ij I43 I41 x i21 = ij I43 I42 x i22 = ij I43 I43 x i23 = ij I43 I44 x i30 = ij I44 I41 x i31 = ij I44 I42 x i32 = ij I44 I43 x i33 = ij I44 I44 x {-# INLINE det4 #-} -- | 4x4 matrix inverse. -- -- >>> row I41 $ inv4 (m44 1 0 3 2 2 0 2 1 0 0 0 1 0 3 4 0 :: M44 Rational) -- V4 (6 % (-12)) ((-9) % (-12)) ((-3) % (-12)) (0 % (-12)) -- inv4 :: forall a f g. Field a => Basis I4 f => Basis I4 g => f (g a) -> g (f a) inv4 x = multl (recip det) <$> x' where i00 = ij I41 I41 x i01 = ij I41 I42 x i02 = ij I41 I43 x i03 = ij I41 I44 x i10 = ij I42 I41 x i11 = ij I42 I42 x i12 = ij I42 I43 x i13 = ij I42 I44 x i20 = ij I43 I41 x i21 = ij I43 I42 x i22 = ij I43 I43 x i23 = ij I43 I44 x i30 = ij I44 I41 x i31 = ij I44 I42 x i32 = ij I44 I43 x i33 = ij I44 I44 x s0 = i00 * i11 - i10 * i01 s1 = i00 * i12 - i10 * i02 s2 = i00 * i13 - i10 * i03 s3 = i01 * i12 - i11 * i02 s4 = i01 * i13 - i11 * i03 s5 = i02 * i13 - i12 * i03 c5 = i22 * i33 - i32 * i23 c4 = i21 * i33 - i31 * i23 c3 = i21 * i32 - i31 * i22 c2 = i20 * i33 - i30 * i23 c1 = i20 * i32 - i30 * i22 c0 = i20 * i31 - i30 * i21 det = s0 * c5 - s1 * c4 + s2 * c3 + s3 * c2 - s4 * c1 + s5 * c0 x' = m44 (i11 * c5 - i12 * c4 + i13 * c3) (-i01 * c5 + i02 * c4 - i03 * c3) (i31 * s5 - i32 * s4 + i33 * s3) (-i21 * s5 + i22 * s4 - i23 * s3) (-i10 * c5 + i12 * c2 - i13 * c1) (i00 * c5 - i02 * c2 + i03 * c1) (-i30 * s5 + i32 * s2 - i33 * s1) (i20 * s5 - i22 * s2 + i23 * s1) (i10 * c4 - i11 * c2 + i13 * c0) (-i00 * c4 + i01 * c2 - i03 * c0) (i30 * s4 - i31 * s2 + i33 * s0) (-i20 * s4 + i21 * s2 - i23 * s0) (-i10 * c3 + i11 * c1 - i12 * c0) (i00 * c3 - i01 * c1 + i02 * c0) (-i30 * s3 + i31 * s1 - i32 * s0) (i20 * s3 - i21 * s1 + i22 * s0) {-# INLINE inv4 #-} -- | Construct a 2x2 matrix. -- -- Arguments are in row-major order. -- -- >>> m22 1 2 3 4 :: M22 Int -- V2 (V2 1 2) (V2 3 4) -- -- @ 'm22' :: a -> a -> a -> a -> 'M22' a @ -- m22 :: Basis I2 f => Basis I2 g => a -> a -> a -> a -> f (g a) m22 a b c d = fillI2 (fillI2 a b) (fillI2 c d) {-# INLINE m22 #-} -- | Construct a 2x3 matrix. -- -- Arguments are in row-major order. -- -- @ 'm23' :: a -> a -> a -> a -> a -> a -> 'M23' a @ -- m23 :: Basis I2 f => Basis I3 g => a -> a -> a -> a -> a -> a -> f (g a) m23 a b c d e f = fillI2 (fillI3 a b c) (fillI3 d e f) {-# INLINE m23 #-} -- | Construct a 2x4 matrix. -- -- Arguments are in row-major order. -- m24 :: Basis I2 f => Basis I4 g => a -> a -> a -> a -> a -> a -> a -> a -> f (g a) m24 a b c d e f g h = fillI2 (fillI4 a b c d) (fillI4 e f g h) {-# INLINE m24 #-} -- | Construct a 3x2 matrix. -- -- Arguments are in row-major order. -- m32 :: Basis I3 f => Basis I2 g => a -> a -> a -> a -> a -> a -> f (g a) m32 a b c d e f = fillI3 (fillI2 a b) (fillI2 c d) (fillI2 e f) {-# INLINE m32 #-} -- | Construct a 3x3 matrix. -- -- Arguments are in row-major order. -- m33 :: Basis I3 f => Basis I3 g => a -> a -> a -> a -> a -> a -> a -> a -> a -> f (g a) m33 a b c d e f g h i = fillI3 (fillI3 a b c) (fillI3 d e f) (fillI3 g h i) {-# INLINE m33 #-} -- | Construct a 3x4 matrix. -- -- Arguments are in row-major order. -- m34 :: Basis I3 f => Basis I4 g => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> f (g a) m34 a b c d e f g h i j k l = fillI3 (fillI4 a b c d) (fillI4 e f g h) (fillI4 i j k l) {-# INLINE m34 #-} -- | Construct a 4x2 matrix. -- -- Arguments are in row-major order. -- m42 :: Basis I4 f => Basis I2 g => a -> a -> a -> a -> a -> a -> a -> a -> f (g a) m42 a b c d e f g h = fillI4 (fillI2 a b) (fillI2 c d) (fillI2 e f) (fillI2 g h) {-# INLINE m42 #-} -- | Construct a 4x3 matrix. -- -- Arguments are in row-major order. -- m43 :: Basis I4 f => Basis I3 g => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> f (g a) m43 a b c d e f g h i j k l = fillI4 (fillI3 a b c) (fillI3 d e f) (fillI3 g h i) (fillI3 j k l) {-# INLINE m43 #-} -- | Construct a 4x4 matrix. -- -- Arguments are in row-major order. -- m44 :: Basis I4 f => Basis I4 g => a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> a -> f (g a) m44 a b c d e f g h i j k l m n o p = fillI4 (fillI4 a b c d) (fillI4 e f g h) (fillI4 i j k l) (fillI4 m n o p) {-# INLINE m44 #-}