module Camfort.Specification.Units.InferenceBackend
( inconsistentConstraints, criticalVariables, inferVariables
, shiftTerms, flattenConstraints, flattenUnits, constraintsToMatrix, rref, isInconsistentRREF )
import Data.Tuple (swap)
import Data.Maybe (maybeToList)
import Data.List ((\\), findIndex, partition, sortBy, group)
import Data.Generics.Uniplate.Operations (rewrite)
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.ST
import Control.Arrow (first, second)
import qualified Data.Map as M
import qualified Data.Array as A
import Camfort.Analysis.Annotations
import Camfort.Specification.Units.Environment
import Numeric.LinearAlgebra (
atIndex, (<>), (><), rank, (?), toLists, toList, fromLists, fromList, rows, cols,
takeRows, takeColumns, dropRows, dropColumns, subMatrix, diag, build, fromBlocks,
ident, flatten, lu, dispf
import qualified Numeric.LinearAlgebra as H
import Numeric.LinearAlgebra.Devel (
newMatrix, readMatrix, writeMatrix, runSTMatrix
import qualified Debug.Trace as D
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints [] = Nothing
inconsistentConstraints cons
| null inconsists = Nothing
| otherwise = Just [ con | (con, i) <- zip cons [0..], i `elem` inconsists ]
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables [] = []
criticalVariables cons = filter (not . isUnitName) $ map (colA A.!) criticalIndices
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
uncriticalIndices = concatMap (maybeToList . findIndex (/= 0)) $ H.toLists solvedM
criticalIndices = A.indices colA \\ uncriticalIndices
isUnitName (UnitName _) = True; isUnitName _ = False
inferVariables :: Constraints -> [(String, UnitInfo)]
inferVariables [] = []
inferVariables cons
| null inconsists = [ (var, if null units then UnitlessVar else foldl1 UnitMul units)
| ([UnitPow (UnitVar var) k], units) <- map (partition (not . isUnitName)) unitPows
, k `approxEq` 1 ]
| otherwise = []
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
cols = A.elems colA
unitPows = map (concatMap flattenUnits . zipWith UnitPow cols) (H.toLists solvedM)
isUnitName (UnitPow (UnitName _) _) = True; isUnitName _ = False
simplifyConstraints = map (\ (ConEq u1 u2) -> (flattenUnits u1, flattenUnits u2))
simplifyUnits :: UnitInfo -> UnitInfo
simplifyUnits = rewrite rw
rw (UnitMul (UnitMul u1 u2) u3) = Just $ UnitMul u1 (UnitMul u2 u3)
rw (UnitMul u1 u2) | u1 == u2 = Just $ UnitPow u1 2
rw (UnitPow (UnitPow u1 p1) p2) = Just $ UnitPow u1 (p1 * p2)
rw (UnitMul (UnitPow u1 p1) (UnitPow u2 p2)) | u1 == u2 = Just $ UnitPow u1 (p1 + p2)
rw (UnitPow _ p) | p `approxEq` 0 = Just UnitlessLit
rw (UnitMul UnitlessLit u) = Just u
rw (UnitMul u UnitlessLit) = Just u
rw u = Nothing
flattenUnits :: UnitInfo -> [UnitInfo]
flattenUnits = map (uncurry UnitPow) . M.toList
. M.filterWithKey (\ u _ -> u /= UnitlessLit)
. M.filter (not . approxEq 0)
. M.fromListWith (+)
. map (first simplifyUnits)
. flatten
flatten (UnitMul u1 u2) = flatten u1 ++ flatten u2
flatten (UnitPow u p) = map (second (p*)) $ flatten u
flatten u = [(u, 1)]
approxEq a b = abs (b a) < epsilon
epsilon = 0.001
constraintsToMatrix :: Constraints -> (H.Matrix Double, [Int], A.Array Int UnitInfo)
constraintsToMatrix cons = (augM, inconsists, A.listArray (0, length colElems 1) colElems)
consPairs = flattenConstraints cons
shiftedCons = map shiftTerms consPairs
lhs = map fst shiftedCons
rhs = map snd shiftedCons
(lhsM, lhsCols) = flattenedToMatrix lhs
(rhsM, rhsCols) = flattenedToMatrix rhs
colElems = A.elems lhsCols ++ A.elems rhsCols
augM = if rows rhsM == 0 || cols rhsM == 0 then lhsM else fromBlocks [[lhsM, rhsM]]
inconsists = findInconsistentRows lhsM augM
flattenedToMatrix :: [[UnitInfo]] -> (H.Matrix Double, A.Array Int UnitInfo)
flattenedToMatrix cons = (m, A.array (0, numCols 1) (map swap uniqUnits))
m = runSTMatrix $ do
m <- newMatrix 0 numRows numCols
forM_ (zip cons [0..]) $ \ (unitPows, row) -> do
forM_ unitPows $ \ (UnitPow u k) -> do
case M.lookup u colMap of
Just col -> readMatrix m row col >>= (writeMatrix m row col . (+k))
_ -> return ()
return m
uniqUnits = flip zip [0..] . map head . group . sortBy colSort $ [ u | UnitPow u _ <- concat cons ]
colMap = M.fromList uniqUnits
numRows = length cons
numCols = M.size colMap
negateCons = map (\ (UnitPow u k) -> UnitPow u (k))
colSort (UnitLiteral i) (UnitLiteral j) = compare i j
colSort (UnitLiteral _) _ = LT
colSort _ (UnitLiteral _) = GT
colSort x y = compare x y
flattenConstraints :: Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints = map (\ (ConEq u1 u2) -> (flattenUnits u1, flattenUnits u2))
shiftTerms :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms (lhs, rhs) = (lhsOk ++ negateCons rhsShift, rhsOk ++ negateCons lhsShift)
(lhsOk, lhsShift) = partition (not . isUnitName) lhs
(rhsOk, rhsShift) = partition isUnitName rhs
isUnitName (UnitPow (UnitName _) _) = True; isUnitName _ = False
isInconsistentRREF a = a @@> (rows a 1, cols a 1) == 1 && rank (takeColumns (cols a 1) (dropRows (rows a 1) a))== 0
rref :: H.Matrix Double -> H.Matrix Double
rref a = snd $ rrefMatrices' a 0 0 []
rrefMatrices :: H.Matrix Double -> [H.Matrix Double]
rrefMatrices a = fst $ rrefMatrices' a 0 0 []
rrefMatrix :: H.Matrix Double -> H.Matrix Double
rrefMatrix a = foldr (<>) (ident (rows a)) . fst $ rrefMatrices' a 0 0 []
rrefMatrices' a j k mats
| j k == n = (mats, a)
| j == m = (mats, a)
| a @@> (j k, j) == 0 = case findIndex (/= 0) below of
Nothing -> rrefMatrices' a (j + 1) (k + 1) mats
Just i' -> rrefMatrices' (swapMat <> a) j k (swapMat:mats)
where i = j k + i'
swapMat = elemRowSwap n i (j k)
| otherwise = rrefMatrices' a2 (j + 1) k mats2
n = rows a
m = cols a
below = getColumnBelow a (j k, j)
erm = elemRowMult n (j k) (recip (a @@> (j k, j)))
(a1, mats1) = if a @@> (j k, j) /= 1 then
(erm <> a, erm:mats)
else (a, mats)
findAdds i m ms = (new <> m, new:ms)
new = runSTMatrix $ do
new <- newMatrix 0 n n
sequence [ writeMatrix new i' i' 1 | i' <- [0 .. (n 1)] ]
let f i | i >= n = return ()
| i == j k = f (i + 1)
| a @@> (i, j) == 0 = f (i + 1)
| otherwise = writeMatrix new i (j k) ( (a @@> (i, j)))
>> f (i + 1)
f 0
return new
(a2, mats2) = findAdds 0 a1 mats1
getColumnBelow a (i, j) = concat . H.toLists $ subMatrix (i, j) (n i, 1) a
where n = rows a
elemRowMult :: Int -> Int -> Double -> H.Matrix Double
elemRowMult n i k = diag (H.fromList (replicate i 1.0 ++ [k] ++ replicate (n i 1) 1.0))
elemRowAdd :: Int -> Int -> Int -> Double -> H.Matrix Double
elemRowAdd n i j k = runSTMatrix $ do
m <- newMatrix 0 n n
sequence [ writeMatrix m i' i' 1 | i' <- [0 .. (n 1)] ]
writeMatrix m i j k
return m
elemRowSwap :: Int -> Int -> Int -> H.Matrix Double
elemRowSwap n i j
| i == j = ident n
| i > j = elemRowSwap n j i
| otherwise = extractRows ([0..i1] ++ [j] ++ [i+1..j1] ++ [i] ++ [j+1..n1]) $ ident n
toDouble :: Rational -> Double
toDouble = fromRational
fromDouble :: Double -> Rational
fromDouble = toRational
findInconsistentRows :: H.Matrix Double -> H.Matrix Double -> [Int]
findInconsistentRows coA augA = [0..(rows augA 1)] \\ consistent
consistent = head (filter (tryRows coA augA) (pset ( [0..(rows augA 1)])) ++ [[]])
tryRows coA augA ns = (rank coA' == rank augA')
coA' = extractRows ns coA
augA' = extractRows ns augA
pset = filterM (const [True, False])
extractRows = flip (?)
m @@> i = m `atIndex` i