module Camfort.Specification.Units.InferenceBackend
( inconsistentConstraints, criticalVariables, inferVariables
, shiftTerms, flattenConstraints, flattenUnits, constraintsToMatrix, constraintsToMatrices
, rref, isInconsistentRREF, genUnitAssignments )
where
import Data.Tuple (swap)
import Data.Maybe (maybeToList)
import Data.List ((\\), findIndex, partition, sortBy, group, intercalate, tails, sort)
import Data.Generics.Uniplate.Operations (rewrite, universeBi)
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.ST
import Control.Arrow (first, second)
import qualified Data.Map.Strict as M
import qualified Data.Array as A
import Camfort.Analysis.Annotations
import Camfort.Specification.Units.Environment
import Numeric.LinearAlgebra (
atIndex, (<>), (><), rank, (?), toLists, toList, fromLists, fromList, rows, cols,
takeRows, takeColumns, dropRows, dropColumns, subMatrix, diag, build, fromBlocks,
ident, flatten, lu, dispf
)
import qualified Numeric.LinearAlgebra as H
import Numeric.LinearAlgebra.Devel (
newMatrix, readMatrix, writeMatrix, runSTMatrix, freezeMatrix, STMatrix
)
import qualified Debug.Trace as D
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints [] = Nothing
inconsistentConstraints cons
| null inconsists = Nothing
| otherwise = Just [ con | (con, i) <- zip cons [0..], i `elem` inconsists ]
where
(_, _, inconsists, _, _) = constraintsToMatrices cons
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables [] = []
criticalVariables cons = filter (not . isUnitRHS) $ map (colA A.!) criticalIndices
where
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
uncriticalIndices = concatMap (maybeToList . findIndex (/= 0)) $ H.toLists solvedM
criticalIndices = A.indices colA \\ uncriticalIndices
isUnitRHS (UnitName _) = True; isUnitRHS _ = False
inferVariables :: Constraints -> [(VV, UnitInfo)]
inferVariables cons = unitVarAssignments
where
unitAssignments = genUnitAssignments cons
unitVarAssignments =
[ (var, units) | ([UnitPow (UnitVar var) k], units) <- unitAssignments, k `approxEq` 1 ] ++
[ (var, units) | ([UnitPow (UnitParamVarAbs (_, var)) k], units) <- unitAssignments, k `approxEq` 1 ]
genUnitAssignments :: [Constraint] -> [([UnitInfo], UnitInfo)]
genUnitAssignments cons
| null inconsists = unitAssignments
| otherwise = []
where
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
cols = A.elems colA
unitPows = map (concatMap flattenUnits . zipWith UnitPow cols) (H.toLists solvedM)
unitAssignments = map (fmap (foldUnits . map negatePosAbs) . partition (not . isUnitRHS)) unitPows
isUnitRHS (UnitPow (UnitName _) _) = True
isUnitRHS (UnitPow (UnitParamEAPAbs _) _) = True
isUnitRHS (UnitPow (UnitParamPosAbs _) _) = True
isUnitRHS _ = False
foldUnits units
| null units = UnitlessVar
| otherwise = foldl1 UnitMul units
simplifyConstraints = map (\ (ConEq u1 u2) -> (flattenUnits u1, flattenUnits u2))
simplifyUnits :: UnitInfo -> UnitInfo
simplifyUnits = rewrite rw
where
rw (UnitMul (UnitMul u1 u2) u3) = Just $ UnitMul u1 (UnitMul u2 u3)
rw (UnitMul u1 u2) | u1 == u2 = Just $ UnitPow u1 2
rw (UnitPow (UnitPow u1 p1) p2) = Just $ UnitPow u1 (p1 * p2)
rw (UnitMul (UnitPow u1 p1) (UnitPow u2 p2)) | u1 == u2 = Just $ UnitPow u1 (p1 + p2)
rw (UnitPow _ p) | p `approxEq` 0 = Just UnitlessLit
rw (UnitMul UnitlessLit u) = Just u
rw (UnitMul u UnitlessLit) = Just u
rw u = Nothing
flattenUnits :: UnitInfo -> [UnitInfo]
flattenUnits = map (uncurry UnitPow) . M.toList
. M.filterWithKey (\ u _ -> u /= UnitlessLit)
. M.filter (not . approxEq 0)
. M.fromListWith (+)
. map (first simplifyUnits)
. flatten
where
flatten (UnitMul u1 u2) = flatten u1 ++ flatten u2
flatten (UnitPow u p) = map (second (p*)) $ flatten u
flatten u = [(u, 1)]
approxEq a b = abs (b a) < epsilon
epsilon = 0.001
constraintsToMatrix :: Constraints -> (H.Matrix Double, [Int], A.Array Int UnitInfo)
constraintsToMatrix cons = (augM, inconsists, A.listArray (0, length colElems 1) colElems)
where
consPairs = flattenConstraints cons
shiftedCons = map shiftTerms consPairs
lhs = map fst shiftedCons
rhs = map snd shiftedCons
(lhsM, lhsCols) = flattenedToMatrix lhs
(rhsM, rhsCols) = flattenedToMatrix rhs
colElems = A.elems lhsCols ++ A.elems rhsCols
augM = if rows rhsM == 0 || cols rhsM == 0 then lhsM else fromBlocks [[lhsM, rhsM]]
inconsists = findInconsistentRows lhsM augM
constraintsToMatrices :: Constraints -> (H.Matrix Double, H.Matrix Double, [Int], A.Array Int UnitInfo, A.Array Int UnitInfo)
constraintsToMatrices cons = (lhsM, rhsM, inconsists, lhsCols, rhsCols)
where
consPairs = filter (uncurry (/=)) $ flattenConstraints cons
shiftedCons = map shiftTerms consPairs
lhs = map fst shiftedCons
rhs = map snd shiftedCons
(lhsM, lhsCols) = flattenedToMatrix lhs
(rhsM, rhsCols) = flattenedToMatrix rhs
colElems = A.elems lhsCols ++ A.elems rhsCols
augM = if rows rhsM == 0 || cols rhsM == 0 then lhsM else fromBlocks [[lhsM, rhsM]]
inconsists = findInconsistentRows lhsM augM
flattenedToMatrix :: [[UnitInfo]] -> (H.Matrix Double, A.Array Int UnitInfo)
flattenedToMatrix cons = (m, A.array (0, numCols 1) (map swap uniqUnits))
where
m = runSTMatrix $ do
m <- newMatrix 0 numRows numCols
forM_ (zip cons [0..]) $ \ (unitPows, row) -> do
forM_ unitPows $ \ (UnitPow u k) -> do
case M.lookup u colMap of
Just col -> readMatrix m row col >>= (writeMatrix m row col . (+k))
_ -> return ()
return m
uniqUnits = flip zip [0..] . map head . group . sortBy colSort $ [ u | UnitPow u _ <- concat cons ]
colMap = M.fromList uniqUnits
numRows = length cons
numCols = M.size colMap
negateCons = map (\ (UnitPow u k) -> UnitPow u (k))
negatePosAbs (UnitPow (UnitParamPosAbs x) k) = UnitPow (UnitParamPosAbs x) (k)
negatePosAbs u = u
colSort (UnitLiteral i) (UnitLiteral j) = compare i j
colSort (UnitLiteral _) _ = LT
colSort _ (UnitLiteral _) = GT
colSort (UnitParamPosAbs x) (UnitParamPosAbs y) = compare x y
colSort (UnitParamPosAbs _) _ = GT
colSort _ (UnitParamPosAbs _) = LT
colSort x y = compare x y
isUnitRHS (UnitPow (UnitName _) _) = True
isUnitRHS (UnitPow (UnitParamEAPAbs _) _) = True
isUnitRHS _ = False
shiftTerms :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms (lhs, rhs) = (lhsOk ++ negateCons rhsShift, rhsOk ++ negateCons lhsShift)
where
(lhsOk, lhsShift) = partition (not . isUnitRHS) lhs
(rhsOk, rhsShift) = partition isUnitRHS rhs
flattenConstraints :: Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints = map (\ (ConEq u1 u2) -> (flattenUnits u1, flattenUnits u2))
isInconsistentRREF a = a @@> (rows a 1, cols a 1) == 1 && rank (takeColumns (cols a 1) (dropRows (rows a 1) a))== 0
rref :: H.Matrix Double -> H.Matrix Double
rref a = snd $ rrefMatrices' a 0 0 []
rrefMatrices :: H.Matrix Double -> [H.Matrix Double]
rrefMatrices a = fst $ rrefMatrices' a 0 0 []
rrefMatrix :: H.Matrix Double -> H.Matrix Double
rrefMatrix a = foldr (<>) (ident (rows a)) . fst $ rrefMatrices' a 0 0 []
rrefMatrices' a j k mats
| j k == n = (mats, a)
| j == m = (mats, a)
| a @@> (j k, j) == 0 = case findIndex (/= 0) below of
Nothing -> rrefMatrices' a (j + 1) (k + 1) mats
Just i' -> rrefMatrices' (swapMat <> a) j k (swapMat:mats)
where i = j k + i'
swapMat = elemRowSwap n i (j k)
| otherwise = rrefMatrices' a2 (j + 1) k mats2
where
n = rows a
m = cols a
below = getColumnBelow a (j k, j)
erm = elemRowMult n (j k) (recip (a @@> (j k, j)))
(a1, mats1) | a @@> (j k, j) /= 1 = (erm <> a, erm:mats)
| otherwise = (a, mats)
findAdds i m ms
| isWritten = (new <> m, new:ms)
| otherwise = (m, ms)
where
(isWritten, new) = runST $ do
new <- newMatrix 0 n n :: ST s (STMatrix s Double)
sequence [ writeMatrix new i' i' 1 | i' <- [0 .. (n 1)] ]
let f w i | i >= n = return w
| i == j k = f w (i + 1)
| a @@> (i, j) == 0 = f w (i + 1)
| otherwise = writeMatrix new i (j k) ( (a @@> (i, j)))
>> f True (i + 1)
isWritten <- f False 0
(isWritten,) `fmap` freezeMatrix new
(a2, mats2) = findAdds 0 a1 mats1
getColumnBelow a (i, j) = concat . H.toLists $ subMatrix (i, j) (n i, 1) a
where n = rows a
elemRowMult :: Int -> Int -> Double -> H.Matrix Double
elemRowMult n i k = diag (H.fromList (replicate i 1.0 ++ [k] ++ replicate (n i 1) 1.0))
elemRowAdd :: Int -> Int -> Int -> Double -> H.Matrix Double
elemRowAdd n i j k = runSTMatrix $ do
m <- newMatrix 0 n n
sequence [ writeMatrix m i' i' 1 | i' <- [0 .. (n 1)] ]
writeMatrix m i j k
return m
elemRowSwap :: Int -> Int -> Int -> H.Matrix Double
elemRowSwap n i j
| i == j = ident n
| i > j = elemRowSwap n j i
| otherwise = extractRows ([0..i1] ++ [j] ++ [i+1..j1] ++ [i] ++ [j+1..n1]) $ ident n
toDouble :: Rational -> Double
toDouble = fromRational
fromDouble :: Double -> Rational
fromDouble = toRational
findInconsistentRows :: H.Matrix Double -> H.Matrix Double -> [Int]
findInconsistentRows coA augA = [0..(rows augA 1)] \\ consistent
where
consistent = head (filter (tryRows coA augA) (tails ( [0..(rows augA 1)])) ++ [[]])
tryRows coA augA ns = (rank coA' == rank augA')
where
coA' = extractRows ns coA
augA' = extractRows ns augA
pset = filterM (const [True, False])
extractRows = flip (?)
m @@> i = m `atIndex` i
showCons str = unlines . ([replicate 50 '-', str ++ ":"]++) . (++[replicate 50 '^']) . map f
where
f (ConEq u1 u2) = show (flattenUnits u1) ++ " === " ++ show (flattenUnits u2)
f (ConConj cons) = intercalate " && " (map f cons)