{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE UndecidableInstances #-}
module Numeric.Matrix.Internal
( MatrixTranspose (..)
, SquareMatrix (..)
, MatrixDeterminant (..)
, MatrixInverse (..)
, MatrixLU (..), LUFact (..)
, Matrix
, HomTransform4 (..)
, Mat22f, Mat23f, Mat24f
, Mat32f, Mat33f, Mat34f
, Mat42f, Mat43f, Mat44f
, Mat22d, Mat23d, Mat24d
, Mat32d, Mat33d, Mat34d
, Mat42d, Mat43d, Mat44d
, mat22, mat33, mat44
, (%*)
, pivotMat, luSolve
) where
import Control.Monad (foldM)
import Control.Monad.ST
import Data.Foldable (foldl', forM_)
import Data.List (delete)
import GHC.Base
import Numeric.DataFrame.Contraction ((%*))
import Numeric.DataFrame.Internal.PrimArray
import Numeric.DataFrame.ST
import Numeric.DataFrame.SubSpace
import Numeric.DataFrame.Type
import Numeric.Dimensions
import Numeric.PrimBytes
import Numeric.Scalar.Internal
import Numeric.Vector.Internal
type Matrix (t :: l) (n :: k) (m :: k) = DataFrame t '[n,m]
class MatrixTranspose t (n :: k) (m :: k) where
transpose :: Matrix t n m -> Matrix t m n
class SquareMatrix t (n :: Nat) where
eye :: Matrix t n n
diag :: Scalar t -> Matrix t n n
trace :: Matrix t n n -> Scalar t
class MatrixDeterminant t (n :: Nat) where
det :: Matrix t n n -> Scalar t
class MatrixInverse t (n :: Nat) where
inverse :: Matrix t n n -> Matrix t n n
data LUFact t n
= LUFact
{ luLower :: Matrix t n n
, luUpper :: Matrix t n n
, luPerm :: Matrix t n n
, luPermSign :: Scalar t
}
deriving instance (Show t, PrimBytes t, KnownDim n) => Show (LUFact t n)
deriving instance (Eq (Matrix t n n), Eq t) => Eq (LUFact t n)
class MatrixLU t (n :: Nat) where
lu :: Matrix t n n -> LUFact t n
class HomTransform4 t where
translate4 :: Vector t 4 -> Matrix t 4 4
translate3 :: Vector t 3 -> Matrix t 4 4
rotateX :: t -> Matrix t 4 4
rotateY :: t -> Matrix t 4 4
rotateZ :: t -> Matrix t 4 4
rotate :: Vector t 3 -> t -> Matrix t 4 4
rotateEuler :: t
-> t
-> t
-> Matrix t 4 4
lookAt :: Vector t 3
-> Vector t 3
-> Vector t 3
-> Matrix t 4 4
perspective :: t
-> t
-> t
-> t
-> Matrix t 4 4
orthogonal :: t
-> t
-> t
-> t
-> Matrix t 4 4
toHomPoint :: Vector t 3 -> Vector t 4
toHomVector :: Vector t 3 -> Vector t 4
fromHom :: Vector t 4 -> Vector t 3
type Mat22f = Matrix Float 2 2
type Mat32f = Matrix Float 3 2
type Mat42f = Matrix Float 4 2
type Mat23f = Matrix Float 2 3
type Mat33f = Matrix Float 3 3
type Mat43f = Matrix Float 4 3
type Mat24f = Matrix Float 2 4
type Mat34f = Matrix Float 3 4
type Mat44f = Matrix Float 4 4
type Mat22d = Matrix Double 2 2
type Mat32d = Matrix Double 3 2
type Mat42d = Matrix Double 4 2
type Mat23d = Matrix Double 2 3
type Mat33d = Matrix Double 3 3
type Mat43d = Matrix Double 4 3
type Mat24d = Matrix Double 2 4
type Mat34d = Matrix Double 3 4
type Mat44d = Matrix Double 4 4
mat22 :: PrimBytes (t :: Type)
=> Vector t 2 -> Vector t 2 -> Matrix t 2 2
mat22 = packDF @_ @2 @'[2]
{-# INLINE mat22 #-}
mat33 :: PrimBytes (t :: Type)
=> Vector t 3 -> Vector t 3 -> Vector t 3 -> Matrix t 3 3
mat33 = packDF @_ @3 @'[3]
{-# INLINE mat33 #-}
mat44 :: PrimBytes (t :: Type)
=> Vector t 4
-> Vector t 4
-> Vector t 4
-> Vector t 4
-> Matrix t 4 4
mat44 = packDF @_ @4 @'[4]
{-# INLINE mat44 #-}
instance ( KnownDim n, KnownDim m
, PrimArray t (Matrix t n m)
, PrimArray t (Matrix t m n)
) => MatrixTranspose t (n :: Nat) (m :: Nat) where
transpose df = case uniqueOrCumulDims df of
Left a -> broadcast a
Right _
| wm <- dimVal' @m
, wn <- dimVal' @n
, m <- case wm of W# w -> word2Int# w
, n <- case wn of W# w -> word2Int# w
-> let f ( I# i, I# j )
| isTrue# (i ==# n) = f ( 0, I# (j +# 1#) )
| otherwise = (# ( I# (i +# 1#), I# j )
, ix# (i *# m +# j) df
#)
in case gen# (CumulDims [wm*wn, wn, 1]) f (0,0) of (# _, r #) -> r
instance MatrixTranspose (t :: Type) (xn :: XNat) (xm :: XNat) where
transpose (XFrame (df :: DataFrame t ns))
| ((D :: Dim n) :* (D :: Dim m) :* U) <- dims @ns
, Dict <- inferPrimElem df
= XFrame (transpose df :: Matrix t m n)
transpose _ = error "MatrixTranspose/transpose: impossible argument"
instance (KnownDim n, PrimArray t (Matrix t n n), Num t)
=> SquareMatrix t n where
eye
| n <- dimVal' @n
= let f 0 = (# n , 1 #)
f k = (# k - 1, 0 #)
in case gen# (CumulDims [n*n, n, 1]) f 0 of
(# _, r #) -> r
diag se
| n <- dimVal' @n
, e <- unScalar se
= let f 0 = (# n , e #)
f k = (# k - 1, 0 #)
in case gen# (CumulDims [n*n, n, 1]) f 0 of
(# _, r #) -> r
trace df
| I# n <- fromIntegral $ dimVal' @n
, n1 <- n +# 1#
= let f 0# = ix# 0# df
f k = ix# k df + f (k -# n1)
in scalar $ f (n *# n -# 1#)
instance ( KnownDim n, Ord t, Fractional t
, PrimBytes t
, PrimArray t (Matrix t n n)
, PrimArray t (Vector t n)
, PrimBytes (Vector t n)
, PrimBytes (Matrix t n n)
)
=> MatrixInverse t n where
inverse m = runST $ do
m' <- newDataFrame
indexWise_ @t @'[n] @'[n] @'[n, n] @_ @()
(\is ->
indexWise_ @t @'[n] @'[] @'[n] @_ @()
(\(j:*U) -> writeDataFrame m' (j:*is)) . luSolve luM
) eye
unsafeFreezeDataFrame m'
where
luM = lu m
instance ( KnownDim n, Ord t, Fractional t
, PrimBytes t, PrimArray t (Matrix t n n))
=> MatrixDeterminant t n where
det m = prodF (luUpper f) * prodF (luLower f) * luPermSign f
where
f = lu m
n = case dimVal' @n of W# w -> word2Int# w
n1 = n +# 1#
nn1 = n *# n -# 1#
prodF a = scalar $ prodF' nn1 a
prodF' 0# a = ix# 0# a
prodF' k a = ix# k a * prodF' (k -# n1) a
instance ( KnownDim n, Ord t, Fractional t
, PrimBytes t, PrimArray t (Matrix t n n))
=> MatrixLU t n where
lu m' = case runRW# go of
(# _, (# bu, bl #) #) -> LUFact
{ luLower = fromElems steps 0# bl
, luUpper = fromElems steps 0# bu
, luPerm = p
, luPermSign = si
}
where
dn = dimVal' @n
dnn = dn * dn
steps = CumulDims [dnn, dn, 1]
(m, p, si) = pivotMat m'
n = case dn of W# w -> word2Int# w
nn = case dnn of W# w -> word2Int# w
tbs = byteSize @t undefined
bsize = nn *# tbs
ixm i j = ix# (i *# n +# j) m
loop :: (Int# -> a -> State# s -> (# State# s, a #))
-> Int# -> Int# -> a -> State# s -> (# State# s, a #)
loop f i k x s
| isTrue# (i ==# k) = (# s, x #)
| otherwise = case f i x s of
(# s', y #) -> loop f ( i +# 1# ) k y s'
go s0
| (# s1, mbl #) <- newByteArray# bsize s0
, (# s2, mbu #) <- newByteArray# bsize s1
, s3 <- setByteArray# mbl 0# bsize 0# s2
, s4 <- setByteArray# mbu 0# bsize 0# s3
, readL <- \i j -> readArray @t mbl (i *# n +# j)
, readU <- \i j -> readArray @t mbu (i *# n +# j)
, writeL <- \i j -> writeArray @t mbl (i *# n +# j)
, writeU <- \i j -> writeArray @t mbu (i *# n +# j)
, computeU <- \i j ->
let f k x s
| (# s' , ukj #) <- readU k j s
, (# s'', lik #) <- readL i k s'
= (# s'', x - ukj * lik #)
in loop f 0# i (ixm i j)
, computeL' <- \i j ->
let f k x s
| (# s' , ukj #) <- readU k j s
, (# s'', lik #) <- readL i k s'
= (# s'', x - ukj * lik #)
in loop f 0# j (ixm i j)
, (# sr, () #) <-
loop
( \j _ sj -> case sj of
sj0
| sj1 <- writeL j j 1 sj0
, (# sj2, () #) <- loop
( \i _ sij0 -> case computeU i j sij0 of
(# sij1, uij #) -> (# writeU i j uij sij1, () #)
) 0# j () sj1
, (# sj3, ujj #) <- computeU j j sj2
, sj4 <- writeU j j ujj sj3
-> case ujj of
0 -> loop
( \i _ sij -> (# writeL i j 0 sij, () #)
) (j +# 1#) n () sj4
x -> loop
( \i _ sij0 -> case computeL' i j sij0 of
(# sij1, lij #) -> (# writeL i j (lij / x) sij1, () #)
) (j +# 1#) n () sj4
) 0# n () s4
, (# sf0, bl #) <- unsafeFreezeByteArray# mbl sr
, (# sf1, bu #) <- unsafeFreezeByteArray# mbu sf0
= (# sf1, (# bu, bl #) #)
luSolve :: forall (t :: Type) (n :: Nat)
. ( KnownDim n, Fractional t
, PrimArray t (Matrix t n n), PrimArray t (Vector t n))
=> LUFact t n -> Vector t n -> Vector t n
luSolve LUFact {..} b = x
where
pb = luPerm %* b
n# = case dimVal' @n of W# w -> word2Int# w
n = I# n#
y :: Vector t n
y = runST $ do
my <- newDataFrame
let ixA (I# i) (I# j) = scalar $ ix# (i *# n# +# j) luLower
ixB (I# i) = scalar $ ix# i pb
forM_ [0..n-1] $ \i -> do
v <- foldM ( \v j -> do
dj <- readDataFrameOff my j
return $ v - dj * ixA i j
) (ixB i) [0..i-1]
writeDataFrameOff my i v
unsafeFreezeDataFrame my
x = runST $ do
mx <- newDataFrame
let ixA (I# i) (I# j) = scalar $ ix# (i *# n# +# j) luUpper
ixB (I# i) = scalar $ ix# i y
forM_ [n-1, n-2 .. 0] $ \i -> do
v <- foldM ( \v j -> do
dj <- readDataFrameOff mx j
return $ v - dj * ixA i j
) (ixB i) [i+1..n-1]
writeDataFrameOff mx i (v / ixA i i)
unsafeFreezeDataFrame mx
pivotMat :: forall (t :: Type) (n :: Nat)
. (KnownDim n, PrimArray t (Matrix t n n), Ord t, Num t)
=> Matrix t n n -> ( Matrix t n n
, Matrix t n n
, Scalar t
)
pivotMat mat
= ( let f ( _, [] ) = undefined
f ( j, i:is )
| j == n = f (0, is)
| otherwise = (# (j+1, i:is), ix i j #)
in case gen# steps f (0,rowOrder) of
(# _, r #) -> r
, let f ( _, [] ) = undefined
f ( j, x:xs )
| j == n = f (0, xs)
| otherwise = (# ( j + 1, x:xs), if j == x then 1 else 0 #)
in case gen# steps f (0,rowOrder) of
(# _, r #) -> r
, if countMisordered rowOrder `rem` 2 == 1
then -1 else 1
)
where
dn = dimVal' @n
steps = CumulDims [dn * dn, dn, 1]
rowOrder = uncurry fillPass $ searchPass 0 [0..n-1]
n = fromIntegral dn
n# = case n of I# x -> x
countMisordered :: [Int] -> Int
countMisordered [] = 0
countMisordered (i:is) = foldl' (\c j -> if i > j then succ c else c) 0 is
+ countMisordered is
ix (I# i) (I# j) = ix# (i *# n# +# j) mat
findMax :: Int -> [Int] -> (t, Int)
findMax j = foldl' (\(ox, oi) i -> let x = abs (ix i j)
in if x > ox then (x, i)
else (ox, oi)
) (0, 0)
searchPass :: Int -> [Int] -> ([Int], [Maybe Int])
searchPass j is
| j == n = (is, [])
| otherwise = case findMax j is of
(0, _) -> (Nothing:) <$> searchPass (j+1) is
(_, i) -> (Just i:) <$> searchPass (j+1) (delete i is)
fillPass :: [Int] -> [Maybe Int] -> [Int]
fillPass _ [] = []
fillPass js (Just i : is) = i : fillPass js is
fillPass (j:js) (Nothing : is) = j : fillPass js is
fillPass [] (Nothing : is) = 0 : fillPass [] is