{- |
Module     : Persistence.Matrix
Copyright  : (c) Eben Cowley, 2018
License    : BSD 3 Clause
Maintainer : eben.cowley42@gmail.com
Stability  : experimental

This module contains a variety of matrix utility functions, used in the computation of Betti numbers and simplicial homology groups.

Most importantly, it includes functions for computing the rank, normal form, and kernel of matrices. For the computation of homology groups and Betti numbers, one must perform column operations on one matrix to get it into column echelon form and find its kernel while also performing the inverse row operations on the next matrix to be operated on.

Bool is an instance of Num here (instance given in Util) so that functions can be somewhat generalized to act on both integers and integers modulo 2.

-}

module Matrix
  ( BMatrix
  , getDiagonal
  , getUnsignedDiagonal
  , transposeMat
  , transposePar
  , multiply
  , multiplyPar
  , rankBool
  , kernelBool
  , imgInKerBool
  ) where

{--FOR DEVS---------------------------------------------------------------

Matrices are transformed by iterating through each row and selecting a pivot. Zero rows are skipped for finding column eschelon form but a row operation is performed (if possible) if there is a zero row for Smith normal form.

To get the smith normal form, the entire pivot row and column is eliminated before continuing. Also, the pivot is always a diagonal element.

To get column eschelon form, every element in the pivot row after the pivot is eliminated. To get the kernel, all column operations to get the matrix to this form are also performed on the identiy matrix. To get the image of one matrix inside the kernel of the one being put into column eschelon form, perform the inverse row operations on the matrix whose image is needed. See stanford paper or the blog post on simplicial homology.

To get the rank of a matrix, look at the number of non-zero columns in the column eschelon form. To get the kernel, look at the columns of the identity (after all of the same column operations have been performed on it) which correspond to zero columns of the column eschelon form.

Eliminating elements is a slighltly more complicated process since only integer operations are allowed. First, every element that must be eliminated is made divisible by the pivot by using the Bezout coefficients from the extended Euclidean algorithm. Once this is done, integer division and subtraction can be used to eliminate the elements.

Boolean matrices are much easier to work with, they are regular matrices with elements modulo 2. Bool is an instance of Num here and the instance is given in Util.

--}

import Util

import Data.List as L
import Data.Vector as V
import Control.Parallel.Strategies

--BASIC STUFF-------------------------------------------------------------

-- | Matrix of integers.
type IMatrix = Vector (Vector Int)

-- | Matrix of integers modulo 2. Alternatively, matrix over the field wit h2 elements.
type BMatrix = Vector (Vector Bool)

-- | Take the transpose a matrix (no fancy optimizations, yet).
transposeMat :: Vector (Vector a) -> Vector (Vector a)
transposeMat mat =
  V.map (\i -> V.map (\row -> row ! i) mat) $ 0 `range` ((V.length $ V.head mat) - 1)

-- | Take the transpose of a matrix using parallel evaluation of rows.
transposePar :: Vector (Vector a) -> Vector (Vector a)
transposePar mat =
  parMapVec (\i -> V.map (\row -> row ! i) mat) $ 0 `range` ((V.length $ V.head mat) - 1)

-- | Multiply two matrices
multiply :: Num a => Vector (Vector a) -> Vector (Vector a) -> Vector (Vector a)
multiply mat1 mat2 =
  let t = transposeMat mat2
  in V.map (\row -> V.map (dotProduct row) t) mat1

-- | Multiply matrices, evaluate rows in parallel if processors are available
multiplyPar :: Num a => Vector (Vector a) -> Vector (Vector a) -> Vector (Vector a)
multiplyPar mat1 mat2 = runEval $ do
  let t = transposeMat mat2
  rseq t
  return $ parMapVec (\row -> V.map (dotProduct row) t) mat1

-- | Get the diagonal elements.
getDiagonal :: Vector (Vector a) -> [a]
getDiagonal matrix =
  if V.null matrix then []
  else L.map (\i -> matrix ! i ! i) [0..(min (V.length matrix) (V.length $ V.head matrix)) - 1]

-- | Get the absolute value of each of the diagonal elements in a list.
getUnsignedDiagonal :: Num a => Vector (Vector a) -> [a]
getUnsignedDiagonal matrix =
  if V.null matrix then []
  else L.map (\i -> abs $ matrix ! i ! i) [0..(min (V.length matrix) (V.length $ V.head matrix)) - 1]

--assumes index1 < index2
colOperation :: Int -> Int -> (Int, Int, Int, Int) -> IMatrix -> IMatrix
colOperation index1 index2 (c11, c12, c21, c22) matrix =
  let calc row =
        let elem1  = row ! index1
            elem2  = row ! index2
            first  = V.take index1 row
            second = V.drop (index1 + 1) (V.take index2 row)
            third  = V.drop (index2 + 1) row
        in first V.++ (cons (c11*elem1 + c12*elem2) second) V.++ (cons (c22*elem2 + c21*elem1) third)
  in V.map calc matrix

colOperationPar :: Int -> Int -> (Int, Int, Int, Int) -> IMatrix -> IMatrix
colOperationPar index1 index2 (c11, c12, c21, c22) matrix =
  let calc row =
        let elem1  = row ! index1
            elem2  = row ! index2
            first  = V.take index1 row
            second = V.drop (index1 + 1) (V.take index2 row)
            third  = V.drop (index2 + 1) row
        in first V.++ (cons (c11*elem1 + c12*elem2) second) V.++ (cons (c22*elem2 + c21*elem1) third)
  in parMapVec calc matrix

--assumes index1 < index2
rowOperation :: Int -> Int -> (Int, Int, Int, Int) -> IMatrix -> IMatrix
rowOperation index1 index2 (c11, c12, c21, c22) matrix =
  let row1   = matrix ! index1
      row2 = matrix ! index2
      first  = V.take index1 matrix
      second = V.drop (index1 + 1) $ V.take index2 matrix
      third  = V.drop (index2 + 1) matrix
  in first V.++ (cons ((c11 `mul` row1) `add` (c12 `mul` row2)) second)
    V.++ (cons ((c22 `mul` row2) `add` (c21 `mul` row1)) third)

rowOperationPar :: Int -> Int -> (Int, Int, Int, Int) -> IMatrix -> IMatrix
rowOperationPar index1 index2 (c11, c12, c21, c22) matrix =
  let row1   = matrix ! index1
      row2 = matrix ! index2
      first  = V.take index1 matrix
      second = V.drop (index1 + 1) (V.take index2 matrix)
      third  = V.drop (index2 + 1) matrix
  in runEval $ do
     a <- rpar $ (c11 `mul` row1) `add` (c12 `mul` row2)
     b <- rpar $ (c21 `mul` row1) `add` (c22 `mul` row2)
     rseq (a,b)
     return $ first V.++ (a `cons` second) V.++ (b `cons` third)

--BOOLEAN MATRICES--------------------------------------------------------

--RANK--------------------------------------------------------------------

--given the index of the pivot row and the matrix
--determines whether there is a non-zero element in the row, does necessary rearranging
--and returns the column operation that was performed if there was one
--returns Nothing if the entire row is zero
chooseGaussPivotBool :: (Int, Int) -> BMatrix -> Maybe (Bool, BMatrix, Maybe (Int, Int))
chooseGaussPivotBool (rowIndex, colIndex) mat =
  let row     = mat ! rowIndex --the following variable should be useful for quickly determining whether or not there are more entries to eleiminate
      indices = V.filter (\index -> index > colIndex) $ V.findIndices id row --but that method is not working for some reason
  in
    if not $ row ! colIndex then
      case indices of
        v | V.null v -> Nothing
        v            ->
          let j = V.head v
          in Just (V.length v > 0, V.map (switchElems colIndex j) mat, Just (colIndex, j))
    else Just (V.length indices > 0, mat, Nothing)

--eliminates pivot row of a boolean matrix
elimRowBool :: (Int, Int) -> Int -> BMatrix -> BMatrix
elimRowBool (rowIndex, colIndex) numCols elems =
  let row = elems ! rowIndex
      elim i mat
        | i == numCols  = mat
        | not $ row ! i = elim (i + 1) mat
        | otherwise     = elim (i + 1) $ V.map (\row -> replaceElem i ((row ! i) + (row ! colIndex)) row) mat
  in elim (colIndex + 1) elems

-- | Find the rank of a mod 2 matrix (number of linearly independent columns).
rankBool :: BMatrix -> Int
rankBool matrix =
  let rows  = V.length matrix
      cols  = V.length $ V.head matrix
      cols1 = cols - 1

      doColOps (rowIndex, colIndex) mat =
        if rowIndex == rows || colIndex == cols then mat else
          case chooseGaussPivotBool (rowIndex, colIndex) mat of
            Just (True, mx, _)  -> doColOps (rowIndex + 1, colIndex + 1) $ elimRowBool (rowIndex, colIndex) cols mx
            Just (False, mx, _) -> doColOps (rowIndex + 1, colIndex + 1) mat
            Nothing             -> doColOps (rowIndex + 1, colIndex) mat

      countNonZeroCols mat =
        V.sum $ V.map (\i ->
           if existsVec (\j -> mat ! j ! i /= 0) (0 `range` (rows - 1)) then 1 else 0) $ 0 `range` cols1
  in countNonZeroCols $ doColOps (0, 0) matrix

--KERNEL------------------------------------------------------------------

--eliminates all the entries in the pivot row that come after the pivot, after the matrix has been improved
--returns the new matrix (fst) paired with the identity with whatever column operations were performed (snd)
elimRowBoolWithId :: (Int, Int) -> Int -> BMatrix -> BMatrix -> (BMatrix, BMatrix)
elimRowBoolWithId (rowIndex, colIndex) numCols elems identity =
  let row = elems ! rowIndex
      elim i mat ide
        | i == numCols  = (mat, ide)
        | not $ row ! i = elim (i + 1) mat ide
        | otherwise     =
          let transform = V.map (\row -> replaceElem i ((row ! i) + (row ! colIndex)) row)
          in elim (i + 1) (transform mat) (transform ide)
  in elim (colIndex + 1) elems identity

-- | Finds the basis of the kernel of a matrix, arranges the basis vectors into the rows of a matrix.
kernelBool :: BMatrix -> BMatrix
kernelBool matrix =
  let rows     = V.length matrix
      cols     = V.length $ V.head matrix
      cols1    = cols - 1
      identity = V.map (\i -> (V.replicate i False) V.++ (cons True (V.replicate (cols1 - i) False))) $ 0 `range` cols1

      doColOps (rowIndex, colIndex) (ker, ide) =
        if rowIndex == rows || colIndex == cols then (ker, ide)
        else
          case chooseGaussPivotBool (rowIndex, colIndex) ker of
            Just (True, _, Nothing)      ->
              doColOps (rowIndex + 1, colIndex + 1) $
                elimRowBoolWithId (rowIndex, colIndex) cols ker ide
            Just (True, mx, Just (i, j)) ->
              doColOps (rowIndex + 1, colIndex + 1) $
                elimRowBoolWithId (rowIndex, colIndex) cols mx $ V.map (switchElems i j) ide
            Just (False, _, Just (i, j)) -> doColOps (rowIndex + 1, colIndex + 1) (ker, V.map (switchElems i j) ide)
            Just (False, _, _)           -> doColOps (rowIndex + 1, colIndex + 1) (ker, ide)
            Nothing                      -> doColOps (rowIndex + 1, colIndex) (ker, ide)

      result = doColOps (0, 0) (matrix, identity)
      ker    = fst result
      img    = snd result
  in V.map (\i -> img ! i) $ V.filter (\i -> forallVec (\row -> not $ row ! i) ker) $ 0 `range` cols1

--IMAGE IN BASIS OF KERNEL------------------------------------------------

elimRowBoolWithInv :: (Int, Int) -> Int -> BMatrix -> BMatrix -> (BMatrix, BMatrix)
elimRowBoolWithInv (rowIndex, colIndex) numCols toColEch toImage =
  let row = toColEch ! rowIndex
      elim i ech img
          | i == numCols  = (ech, img)
          | not $ row ! i = elim (i + 1) ech img
          | otherwise     =
            let transform1 = V.map (\r -> replaceElem i ((r ! i) + (r ! colIndex)) r)
                transform2 = \mat -> replaceElem colIndex ((mat ! i) `add` (mat ! colIndex)) mat
            in elim (i + 1) (transform1 ech) (transform2 img)
  in elim (colIndex + 1) toColEch toImage

-- | Calculates the image of the second matrix represented in the basis of the kernel of the first matrix.
imgInKerBool :: BMatrix -> BMatrix -> BMatrix
imgInKerBool toColEch toImage =
  let rows  = V.length toColEch
      cols  = V.length $ V.head toColEch
      cols1 = cols - 1

      doColOps (rowIndex, colIndex) (ech, img) =
        if rowIndex == rows || colIndex == cols then (ech, img)
        else
          case chooseGaussPivotBool (rowIndex, colIndex) ech of
            Just (True, _, Nothing)       ->
              doColOps (rowIndex + 1, colIndex + 1) $
                elimRowBoolWithInv (rowIndex, colIndex) cols ech img
            Just (True, mx, Just (i, j))  ->
              doColOps (rowIndex + 1, colIndex + 1) $
                elimRowBoolWithInv (rowIndex, colIndex) cols mx $ switchElems i j img
            Just (False, mx, Just (i, j)) -> doColOps (rowIndex + 1, colIndex + 1) (mx, switchElems i j img)
            Just (False, _, _)            -> doColOps (rowIndex + 1, colIndex + 1) (ech, img)
            Nothing                       -> doColOps (rowIndex + 1, colIndex) (ech, img)

      result = doColOps (0, 0) (toColEch, toImage)
      ker    = fst result
      img    = snd result
  in V.map (\i -> img ! i) $ V.filter (\i -> forallVec (\row -> not $ row ! i) ker) $ 0 `range` cols1