module Camfort.Specification.Units.SolveHMatrix
( rref, rrefMatrices, convertToHMatrix, convertFromHMatrix, isInconsistentRREF
, dispf, Units, lu, rank, takeRows )
where
import Data.Ratio
import Debug.Trace (trace)
import Numeric.LinearAlgebra (
atIndex, (<>), (><), rank, (?), toLists, toList, fromLists, fromList, rows, cols,
Matrix, takeRows, takeColumns, dropRows, dropColumns, subMatrix, diag, build, fromBlocks,
ident, flatten, lu, dispf
)
import Numeric.LinearAlgebra.Devel (
newMatrix, writeMatrix, runSTMatrix
)
import Control.Monad (filterM)
import Control.Monad.ST
import qualified Data.Matrix as Old (nrows, ncols, toList, Matrix, fromList)
import Foreign.Storable (Storable)
import Data.List (findIndex, nub, sort, (\\))
import Data.Maybe (fromMaybe)
import Camfort.Specification.Units.Environment (LinearSystem, UnitConstant(..))
import Language.Fortran (MeasureUnit)
isInconsistentRREF a = a @@> (rows a 1, cols a 1) == 1 && rank (takeColumns (cols a 1) (dropRows (rows a 1) a))== 0
rref :: Matrix Double -> Matrix Double
rref a = snd $ rrefMatrices' a 0 0 []
rrefMatrices :: Matrix Double -> [Matrix Double]
rrefMatrices a = fst $ rrefMatrices' a 0 0 []
rrefMatrix :: Matrix Double -> Matrix Double
rrefMatrix a = foldr (<>) (ident (rows a)) . fst $ rrefMatrices' a 0 0 []
rrefMatrices' a j k mats
| j k == n = (mats, a)
| j == m = (mats, a)
| a @@> (j k, j) == 0 = case findIndex (/= 0) below of
Nothing -> rrefMatrices' a (j + 1) (k + 1) mats
Just i' -> rrefMatrices' (swapMat <> a) j k (swapMat:mats)
where i = j k + i'
swapMat = elemRowSwap n i (j k)
| otherwise = rrefMatrices' a2 (j + 1) k mats2
where
n = rows a
m = cols a
below = getColumnBelow a (j k, j)
erm = elemRowMult n (j k) (recip (a @@> (j k, j)))
(a1, mats1) = if a @@> (j k, j) /= 1 then
(erm <> a, erm:mats)
else (a, mats)
findAdds i m ms = (new <> m, new:ms)
where
new = runSTMatrix $ do
new <- newMatrix 0 n n
sequence [ writeMatrix new i' i' 1 | i' <- [0 .. (n 1)] ]
let f i | i >= n = return ()
| i == j k = f (i + 1)
| a @@> (i, j) == 0 = f (i + 1)
| otherwise = writeMatrix new i (j k) ( (a @@> (i, j)))
>> f (i + 1)
f 0
return new
(a2, mats2) = findAdds 0 a1 mats1
getColumnBelow a (i, j) = concat . toLists $ subMatrix (i, j) (n i, 1) a
where n = rows a
elemRowMult n i k = diag (fromList (replicate i 1.0 ++ [k] ++ replicate (n i 1) 1.0))
elemRowAdd :: Int -> Int -> Int -> Double -> Matrix Double
elemRowAdd n i j k = runSTMatrix $ do
m <- newMatrix 0 n n
sequence [ writeMatrix m i' i' 1 | i' <- [0 .. (n 1)] ]
writeMatrix m i j k
return m
elemRowAdd_spec n i j k
| i < 0 || i >= n = undefined
| j < 0 || j >= n = undefined
| otherwise = build n n f
where
f (i', j') | i == i' && j == j' = k
| i' == j' = 1
| otherwise = 0
elemRowSwap n i j
| i == j = ident n
| i > j = elemRowSwap n j i
| otherwise = extractRows ([0..i1] ++ [j] ++ [i+1..j1] ++ [i] ++ [j+1..n1]) $ ident n
type Units = [MeasureUnit]
convertToHMatrix :: LinearSystem -> Either [Int] (Matrix Double, Units)
convertToHMatrix (a, ucs) = case findInconsistentRows a' augA of
[] -> Right (augA, units)
ns -> Left ns
where
a' = convertMatrixToHMatrix a
m = cols a'
units = ucsToUnits ucs
unitA = unitsToUnitA ucs units
augA = fromBlocks [[a', unitA]]
convertFromHMatrix :: (Matrix Double, [MeasureUnit]) -> LinearSystem
convertFromHMatrix (a, units) = (a', ucs')
where
ulen = length units
a' = convertHMatrixToMatrix (takeColumns (cols a ulen) a)
unitA = dropColumns (cols a ulen) a
ucs = unitAToUcs unitA units
ucs' = if null ucs then replicate (rows a) (Unitful []) else ucs
convertMatrixToHMatrix :: Old.Matrix Rational -> Matrix Double
convertMatrixToHMatrix a = (Old.nrows a >< Old.ncols a) . map toDouble $ Old.toList a
convertHMatrixToMatrix :: Matrix Double -> Old.Matrix Rational
convertHMatrixToMatrix a = Old.fromList (rows a) (cols a) . map fromDouble . toList $ flatten a
toDouble :: Rational -> Double
toDouble = fromRational
fromDouble :: Double -> Rational
fromDouble = toRational
unitsToUnitA :: [UnitConstant] -> Units -> Matrix Double
unitsToUnitA ucs units = unitA
where
unitA = fromLists . flip map ucs $ \ uc -> case uc of
Unitful us -> flip map units (toDouble . fromMaybe 0 . flip lookup us)
_ -> map (const 0) units
ucsToUnits :: [UnitConstant] -> Units
ucsToUnits ucs = sort . nub . (ucs >>=) $ \ uc -> case uc of
Unitful us -> map fst us
_ -> []
unitAToUcs :: Matrix Double -> Units -> [UnitConstant]
unitAToUcs unitA units =
flip map (toLists unitA) (Unitful . filter ((/= 0) . snd) . zip units . map fromDouble)
findInconsistentRows :: Matrix Double -> Matrix Double -> [Int]
findInconsistentRows coA augA = [0..(rows augA 1)] \\ consistent
where
consistent = head (filter (tryRows coA augA) (pset ( [0..(rows augA 1)])) ++ [[]])
tryRows coA augA ns = (rank coA' == rank augA')
where
coA' = extractRows ns coA
augA' = extractRows ns augA
pset = filterM (const [True, False])
extractRows = flip (?)
m @@> i = m `atIndex` i