module Camfort.Specification.Units.Solve where
import Data.Ratio
import Data.List
import qualified Data.Matrix as DM
import qualified Data.Vector as V
import Control.Exception
import System.IO.Unsafe
import qualified Debug.Trace as D
import Language.Fortran
import Camfort.Specification.Units.Environment
import Camfort.Specification.Units.SolveHMatrix
solveSystem :: (?solver :: Solver) => LinearSystem -> Consistency LinearSystem
solveSystem = case ?solver of
Custom -> solveSystemH
solveSystemC :: LinearSystem -> Consistency LinearSystem
solveSystemC system = solveSystem' system 1 1
solveSystem' :: LinearSystem -> Col -> Row -> Consistency LinearSystem
solveSystem' (matrix, vector) m k
| m > DM.ncols matrix = efmap (cutSystem k) $ checkSystem (matrix, vector) k
| otherwise = elimRow (matrix, vector) n m k
where n = find (\n -> matrix DM.! (n, m) /= 0) [k .. DM.nrows matrix]
cutSystem :: Int -> LinearSystem -> LinearSystem
cutSystem k (matrix, vector) = (matrix', vector')
where matrix' = DM.submatrix 1 (k 1) 1 (DM.ncols matrix) matrix
vector' = take (k 1) vector
checkSystem :: LinearSystem -> Row -> Consistency LinearSystem
checkSystem (matrix, vector) k
| k > DM.nrows matrix = Ok (matrix, vector)
| vector !! (k 1) /= Unitful [] = let vars = V.toList $ DM.getRow k matrix
bad = Bad (matrix, vector) k (vector !! (k 1), vars)
in bad
| otherwise = checkSystem (matrix, vector) (k + 1)
elimRow :: LinearSystem -> Maybe Row -> Col -> Row -> Consistency LinearSystem
elimRow system Nothing m k = solveSystem' system (m + 1) k
elimRow (matrix, vector) (Just n) m k =
solveSystem' system' (m + 1) (k + 1)
where matrix' = let s = matrix DM.! (n, m) in
(if (k == n) then id else DM.switchRows k n)
(if s == 1 then matrix else DM.scaleRow (recip $ s) n matrix)
vector' = switchScaleElems k n (fromRational $ recip $ matrix DM.! (n, m)) vector
system' = elimRow' (matrix', vector') k m
msteeper matrix k m = msteep matrix 1
where
r = DM.nrows matrix
msteep matrix n | n > r = matrix
| n == k = msteep matrix (n+1)
| otherwise = let s = ( matrix DM.! (n, m))
in if s == 0 then msteep matrix (n+1)
else msteep (DM.combineRows n s k matrix) (n+1)
elimRow' :: LinearSystem -> Row -> Col -> LinearSystem
elimRow' (matrix, vector) k m = (matrix', vector')
where mstep matrix n = let s = ( matrix DM.! (n, m)) in if s == 0 then matrix else DM.combineRows n s k matrix
matrix' = foldl mstep matrix $ [1 .. k 1] ++ [k + 1 .. DM.nrows matrix]
vector'' = [x fromRational (matrix DM.! (n, m)) * vector !! (k 1) | (n, x) <- zip [1..] vector]
(a, _ : b) = splitAt (k 1) vector''
vector' = a ++ vector !! (k 1) : b
switchScaleElems :: Num a => Int -> Int -> a -> [a] -> [a]
switchScaleElems i j factor list = a ++ factor * b : c
where (lj, b:rj) = splitAt (j 1) list
(a, _:c) = splitAt (i 1) (lj ++ list !! (i 1) : rj)
solveSystemH :: LinearSystem -> Consistency LinearSystem
solveSystemH system@(m,v) =
case convertToHMatrix system of
Left (n:_) -> Bad system (DM.nrows m) (v !! n, V.toList (DM.getRow n m))
Right (m', units) -> Ok sys'
where
m2 = rref m'
m3 = takeRows (rank m2) m2
sys' = convertFromHMatrix (m3, units)
solveSystemH_Either :: LinearSystem -> Either [Int] LinearSystem
solveSystemH_Either system@(m,v) =
case convertToHMatrix system of
Left ns -> Left ns
Right (m', units) -> Right sys'
where
m2 = rref m'
m3 = takeRows (rank m2) m2
sys' = convertFromHMatrix (m3, units)