{-
   Copyright 2016, Dominic Orchard, Andrew Rice, Mistral Contrastin, Matthew Danish

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
-}
{-# LANGUAGE ImplicitParams, BangPatterns #-}

module Camfort.Specification.Units.Solve where

import Data.Ratio
import Data.List
import qualified Data.Matrix as DM
import qualified Data.Vector as V
import Control.Exception
import System.IO.Unsafe
import qualified Debug.Trace as D


import Language.Fortran
import Camfort.Specification.Units.Environment
import Camfort.Specification.Units.SolveHMatrix

-- Top-level, select the solver
solveSystem :: (?solver :: Solver) => LinearSystem -> Consistency LinearSystem
solveSystem = case ?solver of
--                  LAPACK -> solveSystemL
--                  Custom -> solveSystemC
                  Custom -> solveSystemH
--------------------------------------------------
-- CUSTOM SOLVER
--------------------------------------------------

-- Top-level custom solver
solveSystemC :: LinearSystem -> Consistency LinearSystem
solveSystemC system = solveSystem' system 1 1

solveSystem' :: LinearSystem -> Col -> Row -> Consistency LinearSystem
solveSystem' (matrix, vector) m k
  | m > DM.ncols matrix = efmap (cutSystem k) $ checkSystem (matrix, vector) k
  | otherwise = elimRow (matrix, vector) n m k
                where n = find (\n -> matrix DM.! (n, m) /= 0) [k .. DM.nrows matrix]

cutSystem :: Int -> LinearSystem -> LinearSystem
cutSystem k (matrix, vector) = (matrix', vector')
  where matrix' = DM.submatrix 1 (k - 1) 1 (DM.ncols matrix) matrix
        vector' = take (k - 1) vector

checkSystem :: LinearSystem -> Row -> Consistency LinearSystem
checkSystem (matrix, vector) k
  | k > DM.nrows matrix = Ok (matrix, vector)
  | vector !! (k - 1) /= Unitful [] = let vars = V.toList $ DM.getRow k matrix
                                          bad = Bad (matrix, vector) k (vector !! (k - 1), vars)
                                      in bad
  | otherwise = checkSystem (matrix, vector) (k + 1)

elimRow :: LinearSystem -> Maybe Row -> Col -> Row -> Consistency LinearSystem
elimRow system Nothing m k = solveSystem' system (m + 1) k
elimRow (matrix, vector) (Just n) m k = -- (show (m, k)) `D.trace`
 solveSystem' system' (m + 1) (k + 1)
  where matrix' = let s = matrix DM.! (n, m) in
                    (if (k == n) then id else DM.switchRows k n)
                       (if s == 1 then matrix else DM.scaleRow (recip $ s) n matrix)
        vector' = switchScaleElems k n (fromRational $ recip $ matrix DM.! (n, m)) vector
        system' = elimRow' (matrix', vector') k m

msteeper matrix k m = msteep matrix 1
                       where
                         r = DM.nrows matrix
                         msteep matrix n | n > r = matrix
                                         | n == k = msteep matrix (n+1)
                                         | otherwise = let s = (- matrix DM.! (n, m))
                                                       in if s == 0 then msteep matrix (n+1)
                                                          else msteep (DM.combineRows n s k matrix) (n+1)

elimRow' :: LinearSystem -> Row -> Col -> LinearSystem
elimRow' (matrix, vector) k m = (matrix', vector')
  where mstep matrix n = let s = (- matrix DM.! (n, m)) in if s == 0 then matrix else DM.combineRows n s k matrix
        matrix' = foldl mstep matrix $ [1 .. k - 1] ++ [k + 1 .. DM.nrows matrix]
        --matrix' = msteeper matrix k m
        vector'' = [x - fromRational (matrix DM.! (n, m)) * vector !! (k - 1) | (n, x) <- zip [1..] vector]
        (a, _ : b) = splitAt (k - 1) vector''
        vector' = a ++ vector !! (k - 1) : b

switchScaleElems :: Num a => Int -> Int -> a -> [a] -> [a]
switchScaleElems i j factor list = a ++ factor * b : c
  where (lj, b:rj) = splitAt (j - 1) list
        (a, _:c) = splitAt (i - 1) (lj ++ list !! (i - 1) : rj)

--------------------------------------------------
-- Top-level custom solver based on HMatrix
solveSystemH :: LinearSystem -> Consistency LinearSystem
solveSystemH system@(m,v) =
  case convertToHMatrix system of
    Left  (n:_)       -> Bad system (DM.nrows m) (v !! n, V.toList (DM.getRow n m))
    Right (m', units) -> Ok sys'
      where
        m2   = rref m'
        m3   = takeRows (rank m2) m2
        sys' = convertFromHMatrix (m3, units)

--------------------------------------------------
-- Top-level custom solver based on HMatrix
-- This version uses "Either" result instead of "Consistency".
solveSystemH_Either :: LinearSystem -> Either [Int] LinearSystem
solveSystemH_Either system@(m,v) =
  case convertToHMatrix system of
    Left  ns          -> Left ns
    Right (m', units) -> Right sys'
      where
        m2   = rref m'
        m3   = takeRows (rank m2) m2
        sys' = convertFromHMatrix (m3, units)