{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
{-# OPTIONS_GHC -fno-solve-constant-dicts #-}
module Data.Matrix.Static (
Matrix
, MatrixConstructor
, IsMatrix
, matrix
, identity
, Identity
, row
, Row
, getRowElems
, GetRowElems
, setRowElems
, SetRowElems
, mapRowElems
, MapRowElems
, col
, Col
, getColElems
, GetColElems
, setColElems
, SetColElems
, mapColElems
, MapColElems
, MatrixMultDims
, MatrixMult(..)
, transpose
, Transpose
, minorMatrix
, MinorMatrix
, Determinant(..)
, minor
, Minor
, cofactor
, Cofactor
, cofactorMatrix
, CofactorMatrix
, adjugateMatrix
, AdjugateMatrix
, inverse
, Inverse
, genMatrixInstance
) where
import Control.Lens (Lens', (^.))
import Data.Kind (Constraint)
import Data.Proxy (Proxy(..))
import Data.Singletons (type (~>))
import Data.Singletons.TH (genDefunSymbols)
import Data.Tensor.Static ( IsTensor(..), Tensor, TensorConstructor, NormalizeDims
, generate, Generate
, subtensor, SubtensorCtx, getSubtensorElems, GetSubtensorElems, setSubtensorElems, SetSubtensorElems
, mapSubtensorElems, MapSubtensorElems
, slice, Slice, getSliceElems, GetSliceElems, setSliceElems, SetSliceElems
, mapSliceElems, MapSliceElems
, tensorElem, TensorElem
, withTensor
, NatsFromTo
, scale, Scale)
import Data.Tensor.Static.TH (genTensorInstance)
import Data.Vector.Static (Vector)
import GHC.TypeLits (Nat, type (<=), type (<=?), type (-), type (+), TypeError, ErrorMessage(..))
import Language.Haskell.TH (Q, Name, Dec)
import Type.List (DemoteWith(..))
import qualified Data.List.NonEmpty as N
import qualified Data.List.Unrolled as U
type Matrix m n e = Tensor '[m, n] e
type MatrixConstructor m n e = TensorConstructor '[m, n] e
type IsMatrix m n e = IsTensor '[m, n] e
matrix :: forall m n e. (IsMatrix m n e) => MatrixConstructor m n e
matrix = tensor @'[m, n] @e
{-# INLINE matrix #-}
identity :: forall m e.
( IsMatrix m m e
, Generate '[m, m] e ([Nat] -> Constraint) (IdentityWrk e)
, Num e
)
=> Matrix m m e
identity = generate @'[m, m] @e @([Nat] -> Constraint) @(IdentityWrk e) go
where
go :: forall (index :: [Nat]).
(IdentityWrk e index) =>
Proxy index -> e
go _ = identityWrk @e @index
{-# INLINE identity #-}
type Identity m e =
( IsMatrix m m e
, Generate '[m, m] e ([Nat] -> Constraint) (IdentityWrk e)
, Num e
)
class IdentityWrk e (index :: [Nat]) where
identityWrk :: e
instance {-# OVERLAPPABLE #-} (Num e) => IdentityWrk e '[i, j] where
identityWrk = 0
{-# INLINE identityWrk #-}
instance {-# OVERLAPPING #-} (Num e) => IdentityWrk e '[i, i] where
identityWrk = 1
{-# INLINE identityWrk #-}
row :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.
(Row r m n e)
=> Lens' (Matrix m n e) (Vector n e)
row = subtensor @'[r] @'[m, n] @e
{-# INLINE row #-}
type Row (r :: Nat) (m :: Nat) (n :: Nat) e =
( SubtensorCtx '[r] '[m, n] e
, r <= m - 1
, NormalizeDims '[n] ~ '[n]
)
getRowElems :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.
(GetRowElems r m n e)
=> Matrix m n e
-> [e]
getRowElems = getSubtensorElems @'[r] @'[m, n] @e
{-# INLINE getRowElems #-}
type GetRowElems (r :: Nat) (m :: Nat) (n :: Nat) e =
GetSubtensorElems '[r] '[m, n] e
setRowElems :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.
(SetRowElems r m n e)
=> Matrix m n e
-> [e]
-> Maybe (Matrix m n e)
setRowElems = setSubtensorElems @'[r] @'[m, n] @e
{-# INLINE setRowElems #-}
type SetRowElems (r :: Nat) (m :: Nat) (n :: Nat) e =
SetSubtensorElems '[r] '[m, n] e
mapRowElems :: forall (r :: Nat) (m :: Nat) (n :: Nat) e.
(MapRowElems r m n e)
=> Matrix m n e
-> (e -> e)
-> Matrix m n e
mapRowElems = mapSubtensorElems @'[r] @'[m, n] @e
{-# INLINE mapRowElems #-}
type MapRowElems (r :: Nat) (m :: Nat) (n :: Nat) e =
MapSubtensorElems '[r] '[m, n] e
col :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.
(Col c m n e)
=> Lens' (Matrix m n e) (Vector m e)
col = slice @'[0, c] @'[m, 1] @'[m, n] @e
{-# INLINE col #-}
type Col (c :: Nat) (m :: Nat) (n :: Nat) e =
( Slice '[0, c] '[m, 1] '[m, n] e
, NormalizeDims '[m, 1] ~ '[m]
)
getColElems :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.
(GetColElems c m n e)
=> Matrix m n e
-> [e]
getColElems = getSliceElems @'[0, c] @'[m, 1] @'[m, n] @e
{-# INLINE getColElems #-}
type GetColElems (c :: Nat) (m :: Nat) (n :: Nat) e =
GetSliceElems '[0, c] '[m, 1] '[m, n] e
setColElems :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.
(SetColElems c m n e)
=> Matrix m n e
-> [e]
-> Maybe (Matrix m n e)
setColElems = setSliceElems @'[0, c] @'[m, 1] @'[m, n] @e
{-# INLINE setColElems #-}
type SetColElems (c :: Nat) (m :: Nat) (n :: Nat) e =
SetSliceElems '[0, c] '[m, 1] '[m, n] e
mapColElems :: forall (c :: Nat) (m :: Nat) (n :: Nat) e.
(MapColElems c m n e)
=> Matrix m n e
-> (e -> e)
-> Matrix m n e
mapColElems = mapSliceElems @'[0, c] @'[m, 1] @'[m, n] @e
{-# INLINE mapColElems #-}
type MapColElems (c :: Nat) (m :: Nat) (n :: Nat) e =
MapSliceElems '[0, c] '[m, 1] '[m, n] e
type family ReverseIndex (index :: [Nat]) :: [Nat] where
ReverseIndex '[i, j] = '[j, i]
type TransposeGo m n e index = GetSliceElems (ReverseIndex index) [1, 1] [m, n] e
$(genDefunSymbols [''TransposeGo])
transpose :: forall m n e.
(Transpose m n e)
=> Matrix m n e
-> Matrix n m e
transpose m = generate @'[n, m] @e @([Nat] ~> Constraint) @(TransposeGoSym3 m n e) go
where
go :: forall (index :: [Nat]).
(TransposeGo m n e index)
=> Proxy index -> e
go _ = head $ getSliceElems @(ReverseIndex index) @[1, 1] m
{-# INLINE transpose #-}
type Transpose m n e =
( IsMatrix m n e
, IsMatrix n m e
, Generate '[n, m] e ([Nat] ~> Constraint) (TransposeGoSym3 m n e)
)
type family MatrixMultDims (dims0 :: [Nat]) (dims1 :: [Nat]) :: [Nat] where
MatrixMultDims '[m, n] '[n, o] = '[m, o]
MatrixMultDims '[n ] '[n, o] = '[o ]
MatrixMultDims '[m, n] '[n ] = '[m ]
MatrixMultDims a b =
TypeError (
'Text "Matrices of shapes "
':<>: 'ShowType a
':<>: 'Text " and "
':<>: 'ShowType b
':<>: 'Text " are incompatible for multiplication.")
class MatrixMult (dims0 :: [Nat]) (dims1 :: [Nat]) e where
mult ::
( IsTensor dims0 e
, IsTensor dims1 e
, IsTensor (MatrixMultDims dims0 dims1) e
)
=> Tensor dims0 e
-> Tensor dims1 e
-> Tensor (MatrixMultDims dims0 dims1) e
type family Index0 (index :: [Nat]) :: Nat where
Index0 (i ': _) = i
type family Index1 (index :: [Nat]) :: Nat where
Index1 (_ ': j ': _ ) = j
type MultMatMatGo (m :: Nat) (n :: Nat) (o :: Nat) e (index :: [Nat]) =
( GetRowElems (Index0 index) m n e
, GetColElems (Index1 index) n o e
, U.Sum n e
, U.ZipWith n
)
$(genDefunSymbols [''MultMatMatGo])
instance ( Num e
, Generate (MatrixMultDims '[m, n] '[n, o]) e ([Nat] ~> Constraint) (MultMatMatGoSym4 m n o e)
) =>
MatrixMult '[m, n] '[n, o] e where
mult m0 m1 = generate @(MatrixMultDims '[m, n] '[n, o]) @e @([Nat] ~> Constraint) @(MultMatMatGoSym4 m n o e) go
where
go :: forall (index :: [Nat]).
( GetRowElems (Index0 index) m n e
, GetColElems (Index1 index) n o e
, U.Sum n e
, U.ZipWith n
) =>
Proxy index -> e
go _ = go' @(Index0 index) @(Index1 index)
go' :: forall (i :: Nat) (j :: Nat).
( GetRowElems i m n e
, GetColElems j n o e
, U.Sum n e
, U.ZipWith n
) =>
e
go' = U.sum @n (U.zipWith @n (*) irow jcol)
where
irow = getRowElems @i m0
jcol = getColElems @j m1
{-# INLINE mult #-}
type MultVecMatGo (m :: Nat) (n :: Nat) (o :: Nat) e (index :: [Nat]) =
( GetColElems (Index0 index) n o e
, U.Sum n e
, U.ZipWith n
)
$(genDefunSymbols [''MultVecMatGo])
instance ( Num e
, Generate (MatrixMultDims '[n] '[n, o]) e ([Nat] ~> Constraint) (MultVecMatGoSym4 m n o e)
) =>
MatrixMult '[n] '[n, o] e where
mult v m = generate @(MatrixMultDims '[n] '[n, o]) @e @([Nat] ~> Constraint) @(MultVecMatGoSym4 m n o e) go
where
go :: forall (index :: [Nat]).
( GetColElems (Index0 index) n o e
, U.Sum n e
, U.ZipWith n
) =>
Proxy index -> e
go _ = go' @(Index0 index)
go' :: forall (c :: Nat).
( GetColElems c n o e
, U.Sum n e
, U.ZipWith n
) =>
e
go' = U.sum @n (U.zipWith @n (*) irow jcol)
where
irow = toList v
jcol = getColElems @c m
{-# INLINE mult #-}
type MultMatVecGo (m :: Nat) (n :: Nat) (o :: Nat) e (index :: [Nat]) =
( GetRowElems (Index0 index) m n e
, U.Sum n e
, U.ZipWith n
)
$(genDefunSymbols [''MultMatVecGo])
instance ( Num e
, Generate (MatrixMultDims '[m, n] '[n]) e ([Nat] ~> Constraint) (MultMatVecGoSym4 m n o e)
) =>
MatrixMult '[m, n] '[n] e where
mult m v = generate @(MatrixMultDims '[m, n] '[n]) @e @([Nat] ~> Constraint) @(MultMatVecGoSym4 m n o e) go
where
go :: forall (index :: [Nat]).
( GetRowElems (Index0 index) m n e
, U.Sum n e
, U.ZipWith n
) =>
Proxy index -> e
go _ = go' @(Index0 index)
go' :: forall (r :: Nat).
( GetRowElems r m n e
, U.Sum n e
, U.ZipWith n
) =>
e
go' = U.sum @n (U.zipWith @n (*) irow jcol)
where
irow = getRowElems @r m
jcol = toList v
{-# INLINE mult #-}
type family MinorMatrixNewIndex (cutIndex :: Nat) (index :: Nat) :: Nat where
MinorMatrixNewIndex 0 i = i + 1
MinorMatrixNewIndex ci i = MinorMatrixNewIndex' ci i (i <=? ci - 1)
type family MinorMatrixNewIndex' (cutIndex :: Nat) (index :: Nat) (indexLTcutIndex :: Bool) :: Nat where
MinorMatrixNewIndex' ci i 'True = i
MinorMatrixNewIndex' ci i 'False = i + 1
type MinorMatrixGo (i :: Nat) (j :: Nat) (n :: Nat) e (index :: [Nat]) =
(GetSliceElems [ (MinorMatrixNewIndex i (Index0 index))
, (MinorMatrixNewIndex j (Index1 index))
]
[1, 1]
[n, n]
e
)
$(genDefunSymbols [''MinorMatrixGo])
minorMatrix :: forall (i :: Nat) (j :: Nat) (n :: Nat) e.
(Generate ([n - 1, n - 1]) e ([Nat] ~> Constraint) (MinorMatrixGoSym4 i j n e))
=> Matrix n n e
-> Matrix (n - 1) (n - 1) e
minorMatrix m = generate @([n - 1, n - 1]) @e @([Nat] ~> Constraint) @(MinorMatrixGoSym4 i j n e) go
where
go :: forall (index :: [Nat]).
(MinorMatrixGo i j n e index) =>
Proxy index -> e
go _ = go' @(MinorMatrixNewIndex i (Index0 index)) @(MinorMatrixNewIndex j (Index1 index))
go' :: forall (r :: Nat) (c :: Nat). (GetSliceElems [r, c] [1, 1] [n, n] e) => e
go' = head $ getSliceElems @[r, c] @[1, 1] @[n, n] @e m
{-# INLINE minorMatrix #-}
type MinorMatrix (i :: Nat) (j :: Nat) (n :: Nat) e =
Generate ([n - 1, n - 1]) e ([Nat] ~> Constraint) (MinorMatrixGoSym4 i j n e)
class Determinant (n :: Nat) e where
determinant :: (Num e) => Matrix n n e -> e
instance {-# OVERLAPPING #-}
(Num e, IsMatrix 2 2 e)
=> Determinant 2 e
where
determinant m =
withTensor m $ \a b c d -> a * d - b * c
{-# INLINE determinant #-}
instance {-# OVERLAPPING #-}
(Num e, IsMatrix 3 3 e)
=> Determinant 3 e
where
determinant m =
withTensor m $ \a b c d e f g h i ->
a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g)
{-# INLINE determinant #-}
class Sign (n :: Nat) where
sign :: (Num a) => a
instance {-# OVERLAPPING #-} Sign 0 where
sign = 1
{-# INLINE sign #-}
instance {-# OVERLAPPABLE #-} (Sign (n - 1)) => Sign n where
sign = (-1) * sign @(n - 1)
{-# INLINE sign #-}
type DeterminantGo (n :: Nat) e (j :: Nat) =
( Determinant (n - 1) e
, TensorElem [0, j] [n, n] e
, MinorMatrix 0 j n e
, Sign j
)
$(genDefunSymbols [''DeterminantGo])
instance {-# OVERLAPPABLE #-}
( Num e
, IsMatrix n n e
, DemoteWith Nat (Nat ~> Constraint) (DeterminantGoSym2 n e) (NatsFromTo 0 (n - 1))
, U.Sum n e
)
=> Determinant n e
where
determinant m = U.sum @n $ demoteWith @Nat @(Nat ~> Constraint) @(DeterminantGoSym2 n e) @(NatsFromTo 0 (n - 1)) go
where
go :: forall (j :: Nat).
(DeterminantGo n e j)
=> Proxy j -> e
go _ = sign @j * el * determinant (minorMatrix @0 @j @n @e m)
where
el = m ^. tensorElem @[0, j]
{-# INLINE determinant #-}
minor :: forall (i :: Nat) (j :: Nat) (n :: Nat) e.
(Minor i j n e)
=> Matrix n n e
-> e
minor = determinant @(n - 1) @e . minorMatrix @i @j @n @e
{-# INLINE minor #-}
type Minor (i :: Nat) (j :: Nat) (n :: Nat) e =
( MinorMatrix i j n e
, Determinant (n - 1) e
, Num e
)
cofactor :: forall (i :: Nat) (j :: Nat) (n :: Nat) e.
(Cofactor i j n e)
=> Matrix n n e
-> e
cofactor m = sign @(i + j) * minor @i @j @n @e m
{-# INLINE cofactor #-}
type Cofactor (i :: Nat) (j :: Nat) (n :: Nat) e =
( Minor i j n e
, Sign (i + j)
)
type CofactorMatrixGo (n :: Nat) e (index :: [Nat]) =
(Cofactor (Index0 index) (Index1 index) n e)
$(genDefunSymbols [''CofactorMatrixGo])
cofactorMatrix :: forall (n :: Nat) e.
(CofactorMatrix n e)
=> Matrix n n e
-> Matrix n n e
cofactorMatrix m = generate @([n, n]) @e @([Nat] ~> Constraint) @(CofactorMatrixGoSym2 n e) go
where
go :: forall (index :: [Nat]).
(Cofactor (Index0 index) (Index1 index) n e) =>
Proxy index -> e
go _ = go' @(Index0 index) @(Index1 index)
go' :: forall (i :: Nat) (j :: Nat).
(Cofactor i j n e) => e
go' = cofactor @i @j @n @e m
{-# INLINE cofactorMatrix #-}
type CofactorMatrix (n :: Nat) e =
Generate [n, n] e ([Nat] ~> Constraint) (CofactorMatrixGoSym2 n e)
adjugateMatrix :: forall (n :: Nat) e.
(AdjugateMatrix n e)
=> Matrix n n e
-> Matrix n n e
adjugateMatrix = transpose . cofactorMatrix
{-# INLINE adjugateMatrix #-}
type AdjugateMatrix (n :: Nat) e =
(CofactorMatrix n e, Transpose n n e)
inverse :: forall (n :: Nat) e.
(Inverse n e)
=> Matrix n n e
-> Matrix n n e
inverse m = scale (adjugateMatrix m) (1 / determinant m)
{-# INLINE inverse #-}
type Inverse (n :: Nat) e =
(AdjugateMatrix n e, Determinant n e, Fractional e, Scale '[n, n] e)
genMatrixInstance :: Int
-> Int
-> Name
-> Q [Dec]
genMatrixInstance m n elemTypeName = genTensorInstance (N.fromList [m, n]) elemTypeName