{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NoStarIsType #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module LAoP.Matrix.Type
(
Matrix (..),
Countable,
CountableDims,
CountableN,
CountableDimsN,
FLN,
Liftable,
Trivial,
TrivialP,
Zero,
One,
one,
join,
fork,
I.FromNat,
I.Count,
I.Normalize,
I.FromLists,
fromLists,
toLists,
toList,
matrixBuilder',
matrixBuilder,
row,
col,
zeros,
ones,
bang,
point,
constant,
fmapM,
bimapM,
unitM,
multM,
selectM,
returnM,
bindM,
columns,
columns',
rows,
rows',
tr,
(.|),
(./),
cond,
abideJF,
abideFJ,
zipWithM,
(===),
p1,
p2,
(|||),
i1,
i2,
(-|-),
(><),
fstM,
sndM,
kr,
iden,
comp,
fromF',
fromF,
toRel,
pretty,
prettyPrint
)
where
import Data.Void
import Data.Proxy
import Data.Kind
import GHC.TypeLits
import Control.DeepSeq
import LAoP.Utils
import qualified LAoP.Matrix.Internal as I
import Prelude hiding (id, (.))
newtype Matrix e (cols :: Type) (rows :: Type) = M (I.Matrix e (I.Normalize cols) (I.Normalize rows))
deriving (Show, Num, Eq, Ord, NFData) via (I.Matrix e (I.Normalize cols) (I.Normalize rows))
type Countable a = KnownNat (I.Count a)
type CountableDims a b = (Countable a, Countable b)
type CountableN a = KnownNat (I.Count (I.Normalize a))
type CountableDimsN a b = (CountableN a, CountableN b)
type FLN a b = I.FromLists (I.Normalize a) (I.Normalize b)
type Liftable e a b = (Bounded a, Bounded b, Enum a, Enum b, Eq b, Num e, Ord e)
type Trivial a = I.Normalize (I.Normalize a) ~ I.Normalize (I.Normalize (I.Normalize a))
type Trivial2 a = I.Normalize a ~ I.Normalize (I.Normalize a)
type Trivial3 a = I.FromNat (I.Count (I.Normalize (I.Normalize a))) ~ I.Normalize (I.Normalize a)
type TrivialP a b = I.Normalize (a, b) ~ I.Normalize (I.Normalize a, I.Normalize b)
instance (Num e) => Category (Matrix e) where
type Object (Matrix e) a = (FLN a a, CountableN a)
id = iden
(.) = comp
bimapM ::
( Liftable e a b,
Liftable e c d,
CountableDimsN a c,
CountableDimsN b d,
FLN d c,
FLN b a
) => (a -> b) -> (c -> d) -> Matrix e a c -> Matrix e b d
bimapM f g m = fromF g . m . tr (fromF f)
type Zero = Void
type One = ()
one :: e -> Matrix e One One
one = M . I.One
join ::
Matrix e a rows ->
Matrix e b rows ->
Matrix e (Either a b) rows
join (M a) (M b) = M (I.Join a b)
infixl 3 |||
(|||) ::
Matrix e a rows ->
Matrix e b rows ->
Matrix e (Either a b) rows
(|||) = join
fork ::
Matrix e cols a ->
Matrix e cols b ->
Matrix e cols (Either a b)
fork (M a) (M b) = M (I.Fork a b)
infixl 2 ===
(===) ::
Matrix e cols a ->
Matrix e cols b ->
Matrix e cols (Either a b)
(===) = fork
fmapM ::
( Liftable e a b,
CountableDimsN a b,
FLN b a
)
=>
(a -> b) -> Matrix e c a -> Matrix e c b
fmapM f m = fromF f . m
unitM :: (Num e) => Matrix e () ()
unitM = one 1
multM ::
( CountableDimsN a b,
CountableN (a, b),
Num e,
FLN (a, b) a,
FLN (a, b) b,
TrivialP a b
) => Matrix e c a -> Matrix e c b -> Matrix e c (a, b)
multM = kr
returnM ::
forall e a .
( Num e,
Enum e,
Enum a,
FLN () a,
Countable a
) => a -> Matrix e One a
returnM a = col l
where
i = fromInteger $ natVal (Proxy :: Proxy (I.Count a))
x = fromEnum a
l = take x [0,0..] ++ [1] ++ take (i - x - 1) [0,0..]
bindM :: (Num e) => Matrix e a b -> Matrix e b c -> Matrix e a c
bindM = flip comp
fromLists :: (FLN cols rows) => [[e]] -> Matrix e cols rows
fromLists = M . I.fromLists
matrixBuilder' ::
(FLN cols rows, CountableDimsN cols rows )
=> ((Int, Int) -> e) -> Matrix e cols rows
matrixBuilder' = M . I.matrixBuilder'
matrixBuilder ::
( FLN a b,
Enum a,
Enum b,
Bounded a,
Bounded b,
Eq a,
CountableDimsN a b
) => ((a, b) -> e) -> Matrix e a b
matrixBuilder f = M (I.matrixBuilder f)
col :: (FLN () rows) => [e] -> Matrix e One rows
col = M . I.col
row :: (FLN cols ()) => [e] -> Matrix e cols One
row = M . I.row
fromF' ::
( Liftable e a b,
CountableDimsN cols rows,
FLN rows cols
) =>
(a -> b) -> Matrix e cols rows
fromF' = M . I.fromF'
fromF ::
( Liftable e a b,
CountableDimsN a b,
FLN b a
) =>
(a -> b) -> Matrix e a b
fromF = M . I.fromF
toRel ::
( Liftable (Natural 0 1) a b,
CountableDimsN a b,
FLN b a
) => (a -> b -> Bool) -> Matrix (Natural 0 1) a b
toRel = M . I.toRel
toLists :: Matrix e cols rows -> [[e]]
toLists (M m) = I.toLists m
toList :: Matrix e cols rows -> [e]
toList (M m) = I.toList m
zeros ::
(Num e, FLN cols rows, CountableDimsN cols rows)
=> Matrix e cols rows
zeros = M I.zeros
ones ::
(Num e, FLN cols rows, CountableDimsN cols rows)
=> Matrix e cols rows
ones = M I.ones
constant ::
(Num e, FLN cols rows, CountableDimsN cols rows)
=> e -> Matrix e cols rows
constant = M . I.constant
bang ::
forall e cols.
(Num e, Enum e, FLN cols (), CountableN cols) =>
Matrix e cols One
bang = M I.bang
point ::
( Bounded a,
Enum a,
Eq a,
Num e,
Ord e,
CountableN a,
FLN a One
) => a -> Matrix e One a
point = fromF . const
iden ::
(Num e, FLN a a, CountableN a) =>
Matrix e a a
iden = M I.iden
{-# NOINLINE iden #-}
comp :: (Num e) => Matrix e cr rows -> Matrix e cols cr -> Matrix e cols rows
comp (M a) (M b) = M (I.comp a b)
{-# NOINLINE comp #-}
{-# RULES
"comp/iden1" forall m. comp m iden = m ;
"comp/iden2" forall m. comp iden m = m
#-}
infixl 7 .|
(.|) :: Num e => e -> Matrix e cols rows -> Matrix e cols rows
(.|) e (M m) = M (e I..| m)
infixl 7 ./
(./) :: Fractional e => Matrix e cols rows -> e -> Matrix e cols rows
(./) (M m) e = M (m I../ e)
p1 ::
( Num e,
CountableDimsN n m,
FLN n m,
FLN m m
) =>
Matrix e (Either m n) m
p1 = M I.p1
p2 ::
( Num e,
CountableDimsN n m,
FLN m n,
FLN n n
) =>
Matrix e (Either m n) n
p2 = M I.p2
i1 ::
( Num e,
CountableDimsN n m,
FLN n m,
FLN m m
) =>
Matrix e m (Either m n)
i1 = tr p1
i2 ::
( Num e,
CountableDimsN n m,
FLN m n,
FLN n n
) =>
Matrix e n (Either m n)
i2 = tr p2
rows :: (CountableN rows) => Matrix e cols rows -> Int
rows (M m) = I.rows m
rows' :: Matrix e cols rows -> Int
rows' (M m) = I.rows' m
columns :: (CountableN cols) => Matrix e cols rows -> Int
columns (M m) = I.columns m
columns' :: Matrix e cols rows -> Int
columns' (M m) = I.columns' m
infixl 5 -|-
(-|-) ::
( Num e,
CountableDimsN j k,
FLN k k,
FLN j k,
FLN k j,
FLN j j
) =>
Matrix e n k ->
Matrix e m j ->
Matrix e (Either n m) (Either k j)
(-|-) (M a) (M b) = M ((I.-|-) a b)
fstM ::
forall e m k .
( Num e,
CountableDimsN m k,
CountableN (m, k),
FLN (m, k) m,
TrivialP m k
) => Matrix e (m, k) m
fstM = M (I.fstM @e @(I.Normalize m) @(I.Normalize k))
sndM ::
forall e m k.
( Num e,
CountableDimsN k m,
CountableN (m, k),
FLN (m, k) k,
TrivialP m k
) => Matrix e (m, k) k
sndM = M (I.sndM @e @(I.Normalize m) @(I.Normalize k))
kr ::
forall e cols a b.
( Num e,
CountableDimsN a b,
CountableN (a, b),
FLN (a, b) a,
FLN (a, b) b,
TrivialP a b
) => Matrix e cols a -> Matrix e cols b -> Matrix e cols (a, b)
kr a b =
let fstM' = fstM @e @a @b
sndM' = sndM @e @a @b
in comp (tr fstM') a * comp (tr sndM') b
infixl 4 ><
(><) ::
forall e m p n q.
( Num e,
CountableDimsN m n,
CountableDimsN p q,
CountableDimsN (m, n) (p, q),
FLN (m, n) m,
FLN (m, n) n,
FLN (p, q) p,
FLN (p, q) q,
TrivialP m n,
TrivialP p q
) => Matrix e m p -> Matrix e n q -> Matrix e (m, n) (p, q)
(><) a b =
let fstM' = fstM @e @m @n
sndM' = sndM @e @m @n
in kr (comp a fstM') (comp b sndM')
abideJF :: Matrix e cols rows -> Matrix e cols rows
abideJF (M m) = M (I.abideJF m)
abideFJ :: Matrix e cols rows -> Matrix e cols rows
abideFJ (M m) = M (I.abideFJ m)
tr :: Matrix e cols rows -> Matrix e rows cols
tr (M m) = M (I.tr m)
selectM ::
( Num e,
FLN b b,
CountableN b
) => Matrix e cols (Either a b) -> Matrix e a b -> Matrix e cols b
selectM (M m) (M y) = M (I.select m y)
cond ::
( Trivial a,
Trivial2 a,
Trivial3 a,
CountableN a,
FLN () a,
FLN a (),
FLN a a,
Liftable e a Bool
)
=>
(a -> Bool) -> Matrix e a b -> Matrix e a b -> Matrix e a b
cond p (M a) (M b) = M (I.cond p a b)
pretty :: (CountableDimsN cols rows, Show e) => Matrix e cols rows -> String
pretty (M m) = I.pretty m
prettyPrint :: (CountableDimsN cols rows, Show e) => Matrix e cols rows -> IO ()
prettyPrint (M m) = I.prettyPrint m
zipWithM :: (e -> f -> g) -> Matrix e a b -> Matrix f a b -> Matrix g a b
zipWithM f (M a) (M b) = M (I.zipWithM f a b)