```{-
Copyright 2016, Dominic Orchard, Andrew Rice, Mistral Contrastin, Matthew Danish

you may not use this file except in compliance with the License.
You may obtain a copy of the License at

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-}
{-# LANGUAGE ImplicitParams, BangPatterns #-}

module Camfort.Specification.Units.Solve where

import Data.Ratio
import Data.List
import qualified Data.Matrix as DM
import qualified Data.Vector as V
import Control.Exception
import System.IO.Unsafe
import qualified Debug.Trace as D

import Language.Fortran
import Camfort.Specification.Units.Environment
import Camfort.Specification.Units.SolveHMatrix

-- Top-level, select the solver
solveSystem :: (?solver :: Solver) => LinearSystem -> Consistency LinearSystem
solveSystem = case ?solver of
--                  LAPACK -> solveSystemL
--                  Custom -> solveSystemC
Custom -> solveSystemH
--------------------------------------------------
-- CUSTOM SOLVER
--------------------------------------------------

-- Top-level custom solver
solveSystemC :: LinearSystem -> Consistency LinearSystem
solveSystemC system = solveSystem' system 1 1

solveSystem' :: LinearSystem -> Col -> Row -> Consistency LinearSystem
solveSystem' (matrix, vector) m k
| m > DM.ncols matrix = efmap (cutSystem k) \$ checkSystem (matrix, vector) k
| otherwise = elimRow (matrix, vector) n m k
where n = find (\n -> matrix DM.! (n, m) /= 0) [k .. DM.nrows matrix]

cutSystem :: Int -> LinearSystem -> LinearSystem
cutSystem k (matrix, vector) = (matrix', vector')
where matrix' = DM.submatrix 1 (k - 1) 1 (DM.ncols matrix) matrix
vector' = take (k - 1) vector

checkSystem :: LinearSystem -> Row -> Consistency LinearSystem
checkSystem (matrix, vector) k
| k > DM.nrows matrix = Ok (matrix, vector)
| vector !! (k - 1) /= Unitful [] = let vars = V.toList \$ DM.getRow k matrix
bad = Bad (matrix, vector) k (vector !! (k - 1), vars)
| otherwise = checkSystem (matrix, vector) (k + 1)

elimRow :: LinearSystem -> Maybe Row -> Col -> Row -> Consistency LinearSystem
elimRow system Nothing m k = solveSystem' system (m + 1) k
elimRow (matrix, vector) (Just n) m k = -- (show (m, k)) `D.trace`
solveSystem' system' (m + 1) (k + 1)
where matrix' = let s = matrix DM.! (n, m) in
(if (k == n) then id else DM.switchRows k n)
(if s == 1 then matrix else DM.scaleRow (recip \$ s) n matrix)
vector' = switchScaleElems k n (fromRational \$ recip \$ matrix DM.! (n, m)) vector
system' = elimRow' (matrix', vector') k m

msteeper matrix k m = msteep matrix 1
where
r = DM.nrows matrix
msteep matrix n | n > r = matrix
| n == k = msteep matrix (n+1)
| otherwise = let s = (- matrix DM.! (n, m))
in if s == 0 then msteep matrix (n+1)
else msteep (DM.combineRows n s k matrix) (n+1)

elimRow' :: LinearSystem -> Row -> Col -> LinearSystem
elimRow' (matrix, vector) k m = (matrix', vector')
where mstep matrix n = let s = (- matrix DM.! (n, m)) in if s == 0 then matrix else DM.combineRows n s k matrix
matrix' = foldl mstep matrix \$ [1 .. k - 1] ++ [k + 1 .. DM.nrows matrix]
--matrix' = msteeper matrix k m
vector'' = [x - fromRational (matrix DM.! (n, m)) * vector !! (k - 1) | (n, x) <- zip [1..] vector]
(a, _ : b) = splitAt (k - 1) vector''
vector' = a ++ vector !! (k - 1) : b

switchScaleElems :: Num a => Int -> Int -> a -> [a] -> [a]
switchScaleElems i j factor list = a ++ factor * b : c
where (lj, b:rj) = splitAt (j - 1) list
(a, _:c) = splitAt (i - 1) (lj ++ list !! (i - 1) : rj)

--------------------------------------------------
-- Top-level custom solver based on HMatrix
solveSystemH :: LinearSystem -> Consistency LinearSystem
solveSystemH system@(m,v) =
case convertToHMatrix system of
Left  (n:_)       -> Bad system (DM.nrows m) (v !! n, V.toList (DM.getRow n m))
Right (m', units) -> Ok sys'
where
m2   = rref m'
m3   = takeRows (rank m2) m2
sys' = convertFromHMatrix (m3, units)

--------------------------------------------------
-- Top-level custom solver based on HMatrix
-- This version uses "Either" result instead of "Consistency".
solveSystemH_Either :: LinearSystem -> Either [Int] LinearSystem
solveSystemH_Either system@(m,v) =
case convertToHMatrix system of
Left  ns          -> Left ns
Right (m', units) -> Right sys'
where
m2   = rref m'
m3   = takeRows (rank m2) m2
sys' = convertFromHMatrix (m3, units)
```