{-# LANGUAGE DeriveFunctor #-}

module Jikka.Common.Matrix
  ( Matrix,
    unMatrix,
    makeMatrix,
    makeMatrix',
    matsize,
    matsize',
    matcheck,
    matzero,
    matone,
    matadd,
    matmul,
    matap,
    matscalar,
    matpow,
  )
where

import Control.Monad
import Control.Monad.ST
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV

-- | `Matrix` is data for matrices.
-- It is guaranteed that internal arrays are not jagged arrays.
newtype Matrix a = Matrix (V.Vector (V.Vector a))
  deriving (Matrix a -> Matrix a -> Bool
(Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Bool) -> Eq (Matrix a)
forall a. Eq a => Matrix a -> Matrix a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Matrix a -> Matrix a -> Bool
$c/= :: forall a. Eq a => Matrix a -> Matrix a -> Bool
== :: Matrix a -> Matrix a -> Bool
$c== :: forall a. Eq a => Matrix a -> Matrix a -> Bool
Eq, Eq (Matrix a)
Eq (Matrix a)
-> (Matrix a -> Matrix a -> Ordering)
-> (Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Bool)
-> (Matrix a -> Matrix a -> Matrix a)
-> (Matrix a -> Matrix a -> Matrix a)
-> Ord (Matrix a)
Matrix a -> Matrix a -> Bool
Matrix a -> Matrix a -> Ordering
Matrix a -> Matrix a -> Matrix a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall a. Ord a => Eq (Matrix a)
forall a. Ord a => Matrix a -> Matrix a -> Bool
forall a. Ord a => Matrix a -> Matrix a -> Ordering
forall a. Ord a => Matrix a -> Matrix a -> Matrix a
min :: Matrix a -> Matrix a -> Matrix a
$cmin :: forall a. Ord a => Matrix a -> Matrix a -> Matrix a
max :: Matrix a -> Matrix a -> Matrix a
$cmax :: forall a. Ord a => Matrix a -> Matrix a -> Matrix a
>= :: Matrix a -> Matrix a -> Bool
$c>= :: forall a. Ord a => Matrix a -> Matrix a -> Bool
> :: Matrix a -> Matrix a -> Bool
$c> :: forall a. Ord a => Matrix a -> Matrix a -> Bool
<= :: Matrix a -> Matrix a -> Bool
$c<= :: forall a. Ord a => Matrix a -> Matrix a -> Bool
< :: Matrix a -> Matrix a -> Bool
$c< :: forall a. Ord a => Matrix a -> Matrix a -> Bool
compare :: Matrix a -> Matrix a -> Ordering
$ccompare :: forall a. Ord a => Matrix a -> Matrix a -> Ordering
$cp1Ord :: forall a. Ord a => Eq (Matrix a)
Ord, Int -> Matrix a -> ShowS
[Matrix a] -> ShowS
Matrix a -> String
(Int -> Matrix a -> ShowS)
-> (Matrix a -> String) -> ([Matrix a] -> ShowS) -> Show (Matrix a)
forall a. Show a => Int -> Matrix a -> ShowS
forall a. Show a => [Matrix a] -> ShowS
forall a. Show a => Matrix a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Matrix a] -> ShowS
$cshowList :: forall a. Show a => [Matrix a] -> ShowS
show :: Matrix a -> String
$cshow :: forall a. Show a => Matrix a -> String
showsPrec :: Int -> Matrix a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Matrix a -> ShowS
Show, a -> Matrix b -> Matrix a
(a -> b) -> Matrix a -> Matrix b
(forall a b. (a -> b) -> Matrix a -> Matrix b)
-> (forall a b. a -> Matrix b -> Matrix a) -> Functor Matrix
forall a b. a -> Matrix b -> Matrix a
forall a b. (a -> b) -> Matrix a -> Matrix b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Matrix b -> Matrix a
$c<$ :: forall a b. a -> Matrix b -> Matrix a
fmap :: (a -> b) -> Matrix a -> Matrix b
$cfmap :: forall a b. (a -> b) -> Matrix a -> Matrix b
Functor)

unMatrix :: Matrix a -> V.Vector (V.Vector a)
unMatrix :: Matrix a -> Vector (Vector a)
unMatrix (Matrix Vector (Vector a)
a) = Vector (Vector a)
a

-- | `matsize` computes the size of a matrix.
matsize :: Matrix a -> (Int, Int)
matsize :: Matrix a -> (Int, Int)
matsize (Matrix Vector (Vector a)
a) = Vector (Vector a) -> (Int, Int)
forall a. Vector (Vector a) -> (Int, Int)
matsize' Vector (Vector a)
a

-- | `matsize'` computes the size of a matrix.
-- This assumes inputs are matrices (`matcheck`).
matsize' :: V.Vector (V.Vector a) -> (Int, Int)
matsize' :: Vector (Vector a) -> (Int, Int)
matsize' Vector (Vector a)
a =
  if Vector (Vector a) -> Bool
forall a. Vector a -> Bool
V.null Vector (Vector a)
a
    then (Int
0, Int
0)
    else (Vector (Vector a) -> Int
forall a. Vector a -> Int
V.length Vector (Vector a)
a, Vector a -> Int
forall a. Vector a -> Int
V.length (Vector (Vector a)
a Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.! Int
0))

-- | `matcheck` checks a given vector of vectors is a matrix.
-- That is, this returns `False` for jagged arrays.
matcheck :: V.Vector (V.Vector a) -> Bool
matcheck :: Vector (Vector a) -> Bool
matcheck Vector (Vector a)
a =
  let (Int
_, Int
w) = Vector (Vector a) -> (Int, Int)
forall a. Vector (Vector a) -> (Int, Int)
matsize' Vector (Vector a)
a
   in (Vector a -> Bool) -> [Vector a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\Vector a
row -> Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
row Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
w) (Vector (Vector a) -> [Vector a]
forall a. Vector a -> [a]
V.toList Vector (Vector a)
a)

makeMatrix :: V.Vector (V.Vector a) -> Maybe (Matrix a)
makeMatrix :: Vector (Vector a) -> Maybe (Matrix a)
makeMatrix Vector (Vector a)
a = if Vector (Vector a) -> Bool
forall a. Vector (Vector a) -> Bool
matcheck Vector (Vector a)
a then Matrix a -> Maybe (Matrix a)
forall a. a -> Maybe a
Just (Vector (Vector a) -> Matrix a
forall a. Vector (Vector a) -> Matrix a
Matrix Vector (Vector a)
a) else Maybe (Matrix a)
forall a. Maybe a
Nothing

makeMatrix' :: V.Vector (V.Vector a) -> Matrix a
makeMatrix' :: Vector (Vector a) -> Matrix a
makeMatrix' Vector (Vector a)
a = case Vector (Vector a) -> Maybe (Matrix a)
forall a. Vector (Vector a) -> Maybe (Matrix a)
makeMatrix Vector (Vector a)
a of
  Maybe (Matrix a)
Nothing -> String -> Matrix a
forall a. HasCallStack => String -> a
error String
"Jikka.Common.Matrix.makeMatrix': the input is not a matrix"
  Just Matrix a
a -> Matrix a
a

matzero :: Num a => Int -> Matrix a
matzero :: Int -> Matrix a
matzero Int
n = Vector (Vector a) -> Matrix a
forall a. Vector (Vector a) -> Matrix a
Matrix (Vector (Vector a) -> Matrix a) -> Vector (Vector a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> Vector a -> Vector (Vector a)
forall a. Int -> a -> Vector a
V.replicate Int
n (Int -> a -> Vector a
forall a. Int -> a -> Vector a
V.replicate Int
n a
0)

matone :: Num a => Int -> Matrix a
matone :: Int -> Matrix a
matone Int
n = Vector (Vector a) -> Matrix a
forall a. Vector (Vector a) -> Matrix a
Matrix (Vector (Vector a) -> Matrix a) -> Vector (Vector a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Vector a) -> Vector (Vector a)
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
n (\Int
y -> Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
n (\Int
x -> if Int
y Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
x then a
1 else a
0))

-- | `matadd` calculates the addition \(A + B\) of two matrices \(A, B\).
-- This assumes sizes of inputs match.
matadd :: Num a => Matrix a -> Matrix a -> Matrix a
matadd :: Matrix a -> Matrix a -> Matrix a
matadd (Matrix Vector (Vector a)
a) (Matrix Vector (Vector a)
b) =
  let (Int
h, Int
w) = Vector (Vector a) -> (Int, Int)
forall a. Vector (Vector a) -> (Int, Int)
matsize' Vector (Vector a)
a
   in Vector (Vector a) -> Matrix a
forall a. Vector (Vector a) -> Matrix a
Matrix (Vector (Vector a) -> Matrix a) -> Vector (Vector a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ Int -> (Int -> Vector a) -> Vector (Vector a)
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
h (\Int
y -> Int -> (Int -> a) -> Vector a
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
w (\Int
x -> (Vector (Vector a)
a Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.! Int
y Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
x) a -> a -> a
forall a. Num a => a -> a -> a
+ (Vector (Vector a)
b Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.! Int
y Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
x)))

-- | `matmul` calculates the multiplication \(A B\)of two matrices \(A, B\).
-- This assumes sizes of inputs match.
matmul :: Num a => Matrix a -> Matrix a -> Matrix a
matmul :: Matrix a -> Matrix a -> Matrix a
matmul (Matrix Vector (Vector a)
a) (Matrix Vector (Vector a)
b) = (forall s. ST s (Matrix a)) -> Matrix a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Matrix a)) -> Matrix a)
-> (forall s. ST s (Matrix a)) -> Matrix a
forall a b. (a -> b) -> a -> b
$ do
  let (Int
h, Int
n) = Vector (Vector a) -> (Int, Int)
forall a. Vector (Vector a) -> (Int, Int)
matsize' Vector (Vector a)
a
  let (Int
_, Int
w) = Vector (Vector a) -> (Int, Int)
forall a. Vector (Vector a) -> (Int, Int)
matsize' Vector (Vector a)
b
  MVector s (MVector s a)
c <- Int
-> ST s (MVector s a)
-> ST s (MVector (PrimState (ST s)) (MVector s a))
forall (m :: * -> *) a.
PrimMonad m =>
Int -> m a -> m (MVector (PrimState m) a)
MV.replicateM Int
h (Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
w a
0)
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
y -> do
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
z -> do
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
        let delta :: a
delta = (Vector (Vector a)
a Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.! Int
y Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
z) a -> a -> a
forall a. Num a => a -> a -> a
* (Vector (Vector a)
b Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.! Int
z Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
x)
        MVector s a
row <- MVector (PrimState (ST s)) (MVector s a)
-> Int -> ST s (MVector s a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MV.read MVector s (MVector s a)
MVector (PrimState (ST s)) (MVector s a)
c Int
y
        MVector (PrimState (ST s)) a -> (a -> a) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s a
MVector (PrimState (ST s)) a
row (a -> a -> a
forall a. Num a => a -> a -> a
+ a
delta) Int
x
  Vector (Vector a) -> Matrix a
forall a. Vector (Vector a) -> Matrix a
Matrix (Vector (Vector a) -> Matrix a)
-> ([Vector a] -> Vector (Vector a)) -> [Vector a] -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Vector a] -> Vector (Vector a)
forall a. [a] -> Vector a
V.fromList ([Vector a] -> Matrix a) -> ST s [Vector a] -> ST s (Matrix a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (MVector s a -> [Vector a] -> ST s [Vector a])
-> [Vector a]
-> MVector (PrimState (ST s)) (MVector s a)
-> ST s [Vector a]
forall (m :: * -> *) a b.
PrimMonad m =>
(a -> b -> m b) -> b -> MVector (PrimState m) a -> m b
MV.foldrM' (\MVector s a
row [Vector a]
c' -> (Vector a -> [Vector a] -> [Vector a]
forall a. a -> [a] -> [a]
: [Vector a]
c') (Vector a -> [Vector a]) -> ST s (Vector a) -> ST s [Vector a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector s a
MVector (PrimState (ST s)) a
row) [] MVector s (MVector s a)
MVector (PrimState (ST s)) (MVector s a)
c

-- | `matap` calculates the multiplication \(A x\) of a matrix \(A\) and a vector \(x\).
-- This assumes sizes of inputs match.
matap :: Num a => Matrix a -> V.Vector a -> V.Vector a
matap :: Matrix a -> Vector a -> Vector a
matap (Matrix Vector (Vector a)
a) Vector a
b = (forall s. ST s (Vector a)) -> Vector a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Vector a)) -> Vector a)
-> (forall s. ST s (Vector a)) -> Vector a
forall a b. (a -> b) -> a -> b
$ do
  let (Int
h, Int
w) = Vector (Vector a) -> (Int, Int)
forall a. Vector (Vector a) -> (Int, Int)
matsize' Vector (Vector a)
a
  MVector s a
c <- Int -> a -> ST s (MVector (PrimState (ST s)) a)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate Int
h a
0
  [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
y -> do
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> do
      let delta :: a
delta = (Vector (Vector a)
a Vector (Vector a) -> Int -> Vector a
forall a. Vector a -> Int -> a
V.! Int
y Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
x) a -> a -> a
forall a. Num a => a -> a -> a
* (Vector a
b Vector a -> Int -> a
forall a. Vector a -> Int -> a
V.! Int
x)
      MVector (PrimState (ST s)) a -> (a -> a) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s a
MVector (PrimState (ST s)) a
c (a -> a -> a
forall a. Num a => a -> a -> a
+ a
delta) Int
y
  MVector (PrimState (ST s)) a -> ST s (Vector a)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector s a
MVector (PrimState (ST s)) a
c

matscalar :: Num a => a -> Matrix a -> Matrix a
matscalar :: a -> Matrix a -> Matrix a
matscalar a
a (Matrix Vector (Vector a)
b) = Vector (Vector a) -> Matrix a
forall a. Vector (Vector a) -> Matrix a
Matrix (Vector (Vector a) -> Matrix a) -> Vector (Vector a) -> Matrix a
forall a b. (a -> b) -> a -> b
$ (Vector a -> Vector a) -> Vector (Vector a) -> Vector (Vector a)
forall a b. (a -> b) -> Vector a -> Vector b
V.map ((a -> a) -> Vector a -> Vector a
forall a b. (a -> b) -> Vector a -> Vector b
V.map (a
a a -> a -> a
forall a. Num a => a -> a -> a
*)) Vector (Vector a)
b

-- | `matpow` calculates the power \(A^k\) of a matrix \(A\) and a natural number \(k\).
-- This assumes inputs are square matrices.
-- This fails for \(k \lt 0\).
matpow :: (Show a, Num a) => Matrix a -> Integer -> Matrix a
matpow :: Matrix a -> Integer -> Matrix a
matpow Matrix a
_ Integer
k | Integer
k Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = String -> Matrix a
forall a. HasCallStack => String -> a
error String
"cannot calculate a negative power for a monoid"
matpow Matrix a
x Integer
k = Matrix a -> Matrix a -> Integer -> Matrix a
forall t a.
(Integral t, Num a) =>
Matrix a -> Matrix a -> t -> Matrix a
go Matrix a
unit Matrix a
x Integer
k
  where
    unit :: Matrix a
unit = let (Int
h, Int
_) = Matrix a -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix a
x in Int -> Matrix a
forall a. Num a => Int -> Matrix a
matone Int
h
    go :: Matrix a -> Matrix a -> t -> Matrix a
go Matrix a
y Matrix a
_ t
0 = Matrix a
y
    go Matrix a
y Matrix a
x t
k = Matrix a -> Matrix a -> t -> Matrix a
go (if t
k t -> t -> t
forall a. Integral a => a -> a -> a
`mod` t
2 t -> t -> Bool
forall a. Eq a => a -> a -> Bool
== t
1 then Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
matmul Matrix a
y Matrix a
x else Matrix a
y) (Matrix a -> Matrix a -> Matrix a
forall a. Num a => Matrix a -> Matrix a -> Matrix a
matmul Matrix a
x Matrix a
x) (t
k t -> t -> t
forall a. Integral a => a -> a -> a
`div` t
2)