{-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Numeric.Matrix.QR ( QR (..), LQ (..), MatrixQR (..) , detViaQR, inverseViaQR , qrSolveR, qrSolveL ) where import Control.Monad import Control.Monad.ST import Data.Kind import Numeric.Basics import Numeric.DataFrame.ST import Numeric.DataFrame.SubSpace import Numeric.DataFrame.Type import Numeric.Dimensions import Numeric.Matrix.Internal import Numeric.Scalar.Internal import Numeric.Subroutine.Householder import Numeric.Subroutine.SolveTriangular import Unsafe.Coerce (unsafeCoerce) -- | Result of QR factorization -- \( A = QR \). data QR (t :: Type) (n :: Nat) (m :: Nat) = QR { qrQ :: Matrix t n n -- ^ Orthogonal matrix \( Q \) , qrQDet :: Scalar t -- ^ A shortcut for evaluating a determinant of \( |Q| = \pm 1 \) , qrR :: Matrix t n m -- ^ Upper-triangular matrix \( R \) } -- | Result of LQ factorization -- \( A = LQ \). data LQ (t :: Type) (n :: Nat) (m :: Nat) = LQ { lqL :: Matrix t n m -- ^ Lower-triangular matrix \( L \) , lqQ :: Matrix t m m -- ^ Orthogonal matrix \( Q \) , lqQDet :: Scalar t -- ^ A shortcut for evaluating a determinant of \( |Q| = \pm 1 \) } deriving instance ( Eq t, PrimBytes t , KnownDim n, KnownDim m) => Eq (QR t n m) deriving instance ( Show t, PrimBytes t , KnownDim n, KnownDim m) => Show (QR t n m) deriving instance ( Eq t, PrimBytes t , KnownDim n, KnownDim m) => Eq (LQ t n m) deriving instance ( Show t, PrimBytes t , KnownDim n, KnownDim m) => Show (LQ t n m) class (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m) => MatrixQR t (n :: Nat) (m :: Nat) where -- | Compute QR factorization qr :: Matrix t n m -> QR t n m -- | Compute LQ factorization lq :: Matrix t n m -> LQ t n m instance (PrimBytes t, Ord t, Epsilon t, KnownDim n, KnownDim m) => MatrixQR t (n :: Nat) (m :: Nat) where qr a | lim == 0 = QR undefined undefined 1 | otherwise = runST $ do uPtr <- newDataFrame pPtr <- unsafeThawDataFrame eye rPtr <- thawDataFrame a detNegative <- let f x i = (x /=) <$> householderReflectionInplaceL uPtr pPtr rPtr (Idx i :* Idx i :* U) in foldM f False [0..lim - 1] qrR <- unsafeFreezeDataFrame rPtr qrQ <- unsafeFreezeDataFrame pPtr let qrQDet = if detNegative then -1 else 1 return QR {..} where n = dimVal' @n m = dimVal' @m lim = min n m lq a | lim == 0 = LQ undefined undefined 1 | otherwise = runST $ do uPtr <- newDataFrame pPtr <- unsafeThawDataFrame eye lPtr <- thawDataFrame a detNegative <- let f x i = (x /=) <$> householderReflectionInplaceR uPtr pPtr lPtr (Idx i :* Idx i :* U) in foldM f False [0..lim - 1] lqL <- unsafeFreezeDataFrame lPtr -- WARNING! todo: get rid of transpose. lqQ <- transpose <$> unsafeFreezeDataFrame pPtr let lqQDet = if detNegative then -1 else 1 return LQ {..} where n = dimVal' @n m = dimVal' @m lim = min n m -- | Calculate determinant of a matrix via QR decomposition detViaQR :: forall t n . MatrixQR t n n => Matrix t n n -> Scalar t detViaQR m = foldl (\x off -> scalar (ixOff off qrR) * x) qrQDet [0,n+1..n*n] where n = fromIntegral (dimVal' @n) :: Int QR {..} = qr m -- | Calculate inverse of a matrix via QR decomposition inverseViaQR :: forall t n . MatrixQR t n n => Matrix t n n -> Matrix t n n inverseViaQR = (`qrSolveR` eye) {- | Compute a QR or LQ decomposition of matrix \( A : n \times m \), and solve a system of linear equations \( Ax = b \). If \( n >= m \) QR decomposition is used; if \( n > m \) this function solves linear least squares problem. If \( n < m \) (underdetermined system) LQ decomposition is used to yield a minimum norm solution. -} qrSolveR :: forall t (n :: Nat) (m :: Nat) (ds :: [Nat]) . (MatrixQR t n m, Dimensions ds) => Matrix t n m -> DataFrame t (n :+ ds) -> DataFrame t (m :+ ds) qrSolveR a b = case compareDim dn dm of SEQ | Dict <- (unsafeCoerce (Dict @(m ~ m)) :: Dict (m ~ n)) -> runST $ do let QR {..} = qr a xPtr <- thawDataFrame (transpose qrQ %* b) -- NB: make a stateful product for transposed mat solveUpperTriangularR qrR xPtr unsafeFreezeDataFrame xPtr SGT | Dict <- unsafeCoerce (Dict @(m <= m)) :: Dict (m <= n) , Dict <- unsafeCoerce (Dict @(n ~ n)) :: Dict ((((n - m) + 1) + m) ~ (n + 1) ) , D <- minusDim dn dm `plusDim` D1 -> runST $ do let QR {..} = qr a i0 :: Idxs '[n - m + 1] i0 = Idx 0 :* U xPtr <- thawDataFrame ( slice i0 $ transpose qrQ %* b ) -- NB: make a stateful product for transposed mat solveUpperTriangularR qrR xPtr unsafeFreezeDataFrame xPtr SLT | Dict <- unsafeCoerce (Dict @(m <= m)) :: Dict (n <= m) , Dict <- unsafeCoerce (Dict @(n ~ n)) :: Dict ((((m - n) + 1) + n) ~ (m + 1)) , Dict <- unsafeCoerce (Dict @(n ~ n)) :: Dict (((n + 1) + (m - n)) ~ (m + 1)) , dd@D <- minusDim dm dn , D <- dd `plusDim` D1 , D <- dn `plusDim` D1 -> runST $ do let LQ {..} = lq a i0 :: Idxs '[m - n + 1] i0 = Idx 0 :* U iz :: Idxs '[n + 1] iz = maxBound :* U xPtr <- newDataFrame copyDataFrame i0 b xPtr copyDataFrame iz (0 :: DataFrame t ((m - n) ': ds)) xPtr solveLowerTriangularR lqL (subDataFrameView i0 xPtr) (transpose lqQ %*) <$> unsafeFreezeDataFrame xPtr -- NB: make a stateful product for transposed mat #if !MIN_VERSION_GLASGOW_HASKELL(8,10,0,0) _ -> error "Numeric.Matrix.QR.qrSolveR: impossible pattern" #endif where dn = dim @n dm = dim @m {- | Compute a QR or LQ decomposition of matrix \( A : n \times m \), and solve a system of linear equations \( xA = b \). If \( n <= m \) LQ decomposition is used; if \( n < m \) this function solves linear least squares problem. If \( n > m \) (underdetermined system) QR decomposition is used to yield a minimum norm solution. -} qrSolveL :: forall t (n :: Nat) (m :: Nat) (ds :: [Nat]) . (MatrixQR t n m, Dimensions ds) => Matrix t n m -> DataFrame t (ds +: m) -> DataFrame t (ds +: n) qrSolveL a b | Dims <- Snoc (dims @ds) dn , Dims <- Snoc (dims @ds) dm , Dict <- Dict @(SnocList ds n _) -- remind GHC that we have (ds +: n) , Dict <- Dict @(SnocList ds m _) -- remind GHC that we have (ds +: m) = case compareDim dn dm of SEQ | Dict <- (unsafeCoerce (Dict @(m ~ m)) :: Dict (m ~ n)) -> runST $ do let LQ {..} = lq a xPtr <- thawDataFrame (b %* transpose lqQ) -- NB: make a stateful product for transposed mat solveLowerTriangularL xPtr lqL unsafeFreezeDataFrame xPtr SGT | Dict <- unsafeCoerce (Dict @(m <= m)) :: Dict (m <= n) , Dict <- unsafeCoerce (Dict @(n ~ n)) :: Dict ((m + (n - m)) ~ n) , D <- minusDim dn dm -> runST $ do let QR {..} = qr a xPtr <- thawDataFrame b solveUpperTriangularL xPtr qrR ewmap @t @ds @'[n] (\x -> qrQ %* appendDF x (0 :: DataFrame t '[n - m])) <$> unsafeFreezeDataFrame xPtr SLT | Dict <- unsafeCoerce (Dict @(m <= m)) :: Dict (n <= m) , Dict <- unsafeCoerce (Dict @(n ~ n)) :: Dict ((((m - n) + 1) + n) ~ (m + 1)) , D <- minusDim dm dn `plusDim` D1 -> runST $ do let LQ {..} = lq a i0 :: Idx (m - n + 1) i0 = Idx 0 xPtr <- thawDataFrame (b %* transpose lqQ) -- NB: make a stateful product for transposed mat solveLowerTriangularL xPtr lqL ewmap @t @ds @'[n] (sslice i0) <$> unsafeFreezeDataFrame xPtr #if !MIN_VERSION_GLASGOW_HASKELL(8,10,0,0) _ -> error "Numeric.Matrix.QR.qrSolveL/compareDim: impossible pattern" #endif | otherwise = error "Numeric.Matrix.QR.qrSolveL: impossible pattern" where dn = dim @n dm = dim @m