module Camfort.Specification.Units.InferenceBackend
( inconsistentConstraints, criticalVariables, inferVariables
, shiftTerms, flattenConstraints, flattenUnits, constraintsToMatrix, rref, isInconsistentRREF )
where
import Data.Tuple (swap)
import Data.Maybe (maybeToList)
import Data.List ((\\), findIndex, partition, sortBy, group)
import Data.Generics.Uniplate.Operations (rewrite)
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.ST
import Control.Arrow (first, second)
import qualified Data.Map.Strict as M
import qualified Data.Array as A
import Camfort.Analysis.Annotations
import Camfort.Specification.Units.Environment
import Numeric.LinearAlgebra (
atIndex, (<>), (><), rank, (?), toLists, toList, fromLists, fromList, rows, cols,
takeRows, takeColumns, dropRows, dropColumns, subMatrix, diag, build, fromBlocks,
ident, flatten, lu, dispf
)
import qualified Numeric.LinearAlgebra as H
import Numeric.LinearAlgebra.Devel (
newMatrix, readMatrix, writeMatrix, runSTMatrix
)
import qualified Debug.Trace as D
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints [] = Nothing
inconsistentConstraints cons
| null inconsists = Nothing
| otherwise = Just [ con | (con, i) <- zip cons [0..], i `elem` inconsists ]
where
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables [] = []
criticalVariables cons = filter (not . isUnitName) $ map (colA A.!) criticalIndices
where
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
uncriticalIndices = concatMap (maybeToList . findIndex (/= 0)) $ H.toLists solvedM
criticalIndices = A.indices colA \\ uncriticalIndices
isUnitName (UnitName _) = True; isUnitName _ = False
inferVariables :: Constraints -> [(VV, UnitInfo)]
inferVariables [] = []
inferVariables cons
| null inconsists = [ (var, if null units then UnitlessVar else foldl1 UnitMul units)
| ([UnitPow (UnitVar var) k], units) <- map (partition (not . isUnitName)) unitPows
, k `approxEq` 1 ]
| otherwise = []
where
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
cols = A.elems colA
unitPows = map (concatMap flattenUnits . zipWith UnitPow cols) (H.toLists solvedM)
isUnitName (UnitPow (UnitName _) _) = True; isUnitName _ = False
simplifyConstraints = map (\ (ConEq u1 u2) -> (flattenUnits u1, flattenUnits u2))
simplifyUnits :: UnitInfo -> UnitInfo
simplifyUnits = rewrite rw
where
rw (UnitMul (UnitMul u1 u2) u3) = Just $ UnitMul u1 (UnitMul u2 u3)
rw (UnitMul u1 u2) | u1 == u2 = Just $ UnitPow u1 2
rw (UnitPow (UnitPow u1 p1) p2) = Just $ UnitPow u1 (p1 * p2)
rw (UnitMul (UnitPow u1 p1) (UnitPow u2 p2)) | u1 == u2 = Just $ UnitPow u1 (p1 + p2)
rw (UnitPow _ p) | p `approxEq` 0 = Just UnitlessLit
rw (UnitMul UnitlessLit u) = Just u
rw (UnitMul u UnitlessLit) = Just u
rw u = Nothing
flattenUnits :: UnitInfo -> [UnitInfo]
flattenUnits = map (uncurry UnitPow) . M.toList
. M.filterWithKey (\ u _ -> u /= UnitlessLit)
. M.filter (not . approxEq 0)
. M.fromListWith (+)
. map (first simplifyUnits)
. flatten
where
flatten (UnitMul u1 u2) = flatten u1 ++ flatten u2
flatten (UnitPow u p) = map (second (p*)) $ flatten u
flatten u = [(u, 1)]
approxEq a b = abs (b a) < epsilon
epsilon = 0.001
constraintsToMatrix :: Constraints -> (H.Matrix Double, [Int], A.Array Int UnitInfo)
constraintsToMatrix cons = (augM, inconsists, A.listArray (0, length colElems 1) colElems)
where
consPairs = flattenConstraints cons
shiftedCons = map shiftTerms consPairs
lhs = map fst shiftedCons
rhs = map snd shiftedCons
(lhsM, lhsCols) = flattenedToMatrix lhs
(rhsM, rhsCols) = flattenedToMatrix rhs
colElems = A.elems lhsCols ++ A.elems rhsCols
augM = if rows rhsM == 0 || cols rhsM == 0 then lhsM else fromBlocks [[lhsM, rhsM]]
inconsists = findInconsistentRows lhsM augM
flattenedToMatrix :: [[UnitInfo]] -> (H.Matrix Double, A.Array Int UnitInfo)
flattenedToMatrix cons = (m, A.array (0, numCols 1) (map swap uniqUnits))
where
m = runSTMatrix $ do
m <- newMatrix 0 numRows numCols
forM_ (zip cons [0..]) $ \ (unitPows, row) -> do
forM_ unitPows $ \ (UnitPow u k) -> do
case M.lookup u colMap of
Just col -> readMatrix m row col >>= (writeMatrix m row col . (+k))
_ -> return ()
return m
uniqUnits = flip zip [0..] . map head . group . sortBy colSort $ [ u | UnitPow u _ <- concat cons ]
colMap = M.fromList uniqUnits
numRows = length cons
numCols = M.size colMap
negateCons = map (\ (UnitPow u k) -> UnitPow u (k))
colSort (UnitLiteral i) (UnitLiteral j) = compare i j
colSort (UnitLiteral _) _ = LT
colSort _ (UnitLiteral _) = GT
colSort x y = compare x y
flattenConstraints :: Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints = map (\ (ConEq u1 u2) -> (flattenUnits u1, flattenUnits u2))
shiftTerms :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms (lhs, rhs) = (lhsOk ++ negateCons rhsShift, rhsOk ++ negateCons lhsShift)
where
(lhsOk, lhsShift) = partition (not . isUnitName) lhs
(rhsOk, rhsShift) = partition isUnitName rhs
isUnitName (UnitPow (UnitName _) _) = True; isUnitName _ = False
isInconsistentRREF a = a @@> (rows a 1, cols a 1) == 1 && rank (takeColumns (cols a 1) (dropRows (rows a 1) a))== 0
rref :: H.Matrix Double -> H.Matrix Double
rref a = snd $ rrefMatrices' a 0 0 []
rrefMatrices :: H.Matrix Double -> [H.Matrix Double]
rrefMatrices a = fst $ rrefMatrices' a 0 0 []
rrefMatrix :: H.Matrix Double -> H.Matrix Double
rrefMatrix a = foldr (<>) (ident (rows a)) . fst $ rrefMatrices' a 0 0 []
rrefMatrices' a j k mats
| j k == n = (mats, a)
| j == m = (mats, a)
| a @@> (j k, j) == 0 = case findIndex (/= 0) below of
Nothing -> rrefMatrices' a (j + 1) (k + 1) mats
Just i' -> rrefMatrices' (swapMat <> a) j k (swapMat:mats)
where i = j k + i'
swapMat = elemRowSwap n i (j k)
| otherwise = rrefMatrices' a2 (j + 1) k mats2
where
n = rows a
m = cols a
below = getColumnBelow a (j k, j)
erm = elemRowMult n (j k) (recip (a @@> (j k, j)))
(a1, mats1) = if a @@> (j k, j) /= 1 then
(erm <> a, erm:mats)
else (a, mats)
findAdds i m ms = (new <> m, new:ms)
where
new = runSTMatrix $ do
new <- newMatrix 0 n n
sequence [ writeMatrix new i' i' 1 | i' <- [0 .. (n 1)] ]
let f i | i >= n = return ()
| i == j k = f (i + 1)
| a @@> (i, j) == 0 = f (i + 1)
| otherwise = writeMatrix new i (j k) ( (a @@> (i, j)))
>> f (i + 1)
f 0
return new
(a2, mats2) = findAdds 0 a1 mats1
getColumnBelow a (i, j) = concat . H.toLists $ subMatrix (i, j) (n i, 1) a
where n = rows a
elemRowMult :: Int -> Int -> Double -> H.Matrix Double
elemRowMult n i k = diag (H.fromList (replicate i 1.0 ++ [k] ++ replicate (n i 1) 1.0))
elemRowAdd :: Int -> Int -> Int -> Double -> H.Matrix Double
elemRowAdd n i j k = runSTMatrix $ do
m <- newMatrix 0 n n
sequence [ writeMatrix m i' i' 1 | i' <- [0 .. (n 1)] ]
writeMatrix m i j k
return m
elemRowSwap :: Int -> Int -> Int -> H.Matrix Double
elemRowSwap n i j
| i == j = ident n
| i > j = elemRowSwap n j i
| otherwise = extractRows ([0..i1] ++ [j] ++ [i+1..j1] ++ [i] ++ [j+1..n1]) $ ident n
toDouble :: Rational -> Double
toDouble = fromRational
fromDouble :: Double -> Rational
fromDouble = toRational
findInconsistentRows :: H.Matrix Double -> H.Matrix Double -> [Int]
findInconsistentRows coA augA = [0..(rows augA 1)] \\ consistent
where
consistent = head (filter (tryRows coA augA) (pset ( [0..(rows augA 1)])) ++ [[]])
tryRows coA augA ns = (rank coA' == rank augA')
where
coA' = extractRows ns coA
augA' = extractRows ns augA
pset = filterM (const [True, False])
extractRows = flip (?)
m @@> i = m `atIndex` i