module Camfort.Specification.Units.InferenceBackend
( inconsistentConstraints, criticalVariables, inferVariables
, shiftTerms, flattenConstraints, flattenUnits, constraintsToMatrix, constraintsToMatrices
, rref, isInconsistentRREF, genUnitAssignments )
where
import Data.Tuple (swap)
import Data.Maybe (maybeToList)
import Data.List ((\\), findIndex, partition, sortBy, group, tails)
import Data.Generics.Uniplate.Operations (rewrite)
import Control.Monad
import Control.Monad.ST
import Control.Arrow (first, second)
import qualified Data.Map.Strict as M
import qualified Data.Array as A
import Camfort.Specification.Units.Environment
import Numeric.LinearAlgebra (
atIndex, (<>), rank, (?), rows, cols,
takeColumns, dropRows, subMatrix, diag, fromBlocks,
ident,
)
import qualified Numeric.LinearAlgebra as H
import Numeric.LinearAlgebra.Devel (
newMatrix, readMatrix, writeMatrix, runSTMatrix, freezeMatrix, STMatrix
)
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints [] = Nothing
inconsistentConstraints cons
| null inconsists = Nothing
| otherwise = Just [ con | (con, i) <- zip cons [0..], i `elem` inconsists ]
where
(_, _, inconsists, _, _) = constraintsToMatrices cons
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables [] = []
criticalVariables cons = filter (not . isUnitRHS) $ map (colA A.!) criticalIndices
where
(unsolvedM, _, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
uncriticalIndices = concatMap (maybeToList . findIndex (/= 0)) $ H.toLists solvedM
criticalIndices = A.indices colA \\ uncriticalIndices
isUnitRHS (UnitName _) = True; isUnitRHS _ = False
inferVariables :: Constraints -> [(VV, UnitInfo)]
inferVariables cons = unitVarAssignments
where
unitAssignments = genUnitAssignments cons
unitVarAssignments =
[ (var, units) | ([UnitPow (UnitVar var) k], units) <- unitAssignments, k `approxEq` 1 ] ++
[ (var, units) | ([UnitPow (UnitParamVarAbs (_, var)) k], units) <- unitAssignments, k `approxEq` 1 ]
genUnitAssignments :: [Constraint] -> [([UnitInfo], UnitInfo)]
genUnitAssignments cons
| null inconsists = unitAssignments
| otherwise = []
where
(unsolvedM, inconsists, colA) = constraintsToMatrix cons
solvedM = rref unsolvedM
cols = A.elems colA
unitPows = map (concatMap flattenUnits . zipWith UnitPow cols) (H.toLists solvedM)
unitAssignments = map (fmap (foldUnits . map negatePosAbs) . partition (not . isUnitRHS)) unitPows
isUnitRHS (UnitPow (UnitName _) _) = True
isUnitRHS (UnitPow (UnitParamEAPAbs _) _) = True
isUnitRHS (UnitPow (UnitParamPosAbs _) _) = True
isUnitRHS _ = False
foldUnits units
| null units = UnitlessVar
| otherwise = foldl1 UnitMul units
simplifyUnits :: UnitInfo -> UnitInfo
simplifyUnits = rewrite rw
where
rw (UnitMul (UnitMul u1 u2) u3) = Just $ UnitMul u1 (UnitMul u2 u3)
rw (UnitMul u1 u2) | u1 == u2 = Just $ UnitPow u1 2
rw (UnitPow (UnitPow u1 p1) p2) = Just $ UnitPow u1 (p1 * p2)
rw (UnitMul (UnitPow u1 p1) (UnitPow u2 p2)) | u1 == u2 = Just $ UnitPow u1 (p1 + p2)
rw (UnitPow _ p) | p `approxEq` 0 = Just UnitlessLit
rw (UnitMul UnitlessLit u) = Just u
rw (UnitMul u UnitlessLit) = Just u
rw _ = Nothing
flattenUnits :: UnitInfo -> [UnitInfo]
flattenUnits = map (uncurry UnitPow) . M.toList
. M.filterWithKey (\ u _ -> u /= UnitlessLit)
. M.filter (not . approxEq 0)
. M.fromListWith (+)
. map (first simplifyUnits)
. flatten
where
flatten (UnitMul u1 u2) = flatten u1 ++ flatten u2
flatten (UnitPow u p) = map (second (p*)) $ flatten u
flatten u = [(u, 1)]
approxEq a b = abs (b a) < epsilon
epsilon = 0.001
constraintsToMatrix :: Constraints -> (H.Matrix Double, [Int], A.Array Int UnitInfo)
constraintsToMatrix cons = (augM, inconsists, A.listArray (0, length colElems 1) colElems)
where
consPairs = flattenConstraints cons
shiftedCons = map shiftTerms consPairs
lhs = map fst shiftedCons
rhs = map snd shiftedCons
(lhsM, lhsCols) = flattenedToMatrix lhs
(rhsM, rhsCols) = flattenedToMatrix rhs
colElems = A.elems lhsCols ++ A.elems rhsCols
augM = if rows rhsM == 0 || cols rhsM == 0 then lhsM else fromBlocks [[lhsM, rhsM]]
inconsists = findInconsistentRows lhsM augM
constraintsToMatrices :: Constraints -> (H.Matrix Double, H.Matrix Double, [Int], A.Array Int UnitInfo, A.Array Int UnitInfo)
constraintsToMatrices cons = (lhsM, rhsM, inconsists, lhsCols, rhsCols)
where
consPairs = filter (uncurry (/=)) $ flattenConstraints cons
shiftedCons = map shiftTerms consPairs
lhs = map fst shiftedCons
rhs = map snd shiftedCons
(lhsM, lhsCols) = flattenedToMatrix lhs
(rhsM, rhsCols) = flattenedToMatrix rhs
augM = if rows rhsM == 0 || cols rhsM == 0 then lhsM else fromBlocks [[lhsM, rhsM]]
inconsists = findInconsistentRows lhsM augM
flattenedToMatrix :: [[UnitInfo]] -> (H.Matrix Double, A.Array Int UnitInfo)
flattenedToMatrix cons = (m, A.array (0, numCols 1) (map swap uniqUnits))
where
m = runSTMatrix $ do
m <- newMatrix 0 numRows numCols
forM_ (zip cons [0..]) $ \ (unitPows, row) -> do
forM_ unitPows $ \ (UnitPow u k) -> do
case M.lookup u colMap of
Just col -> readMatrix m row col >>= (writeMatrix m row col . (+k))
_ -> return ()
return m
uniqUnits = flip zip [0..] . map head . group . sortBy colSort $ [ u | UnitPow u _ <- concat cons ]
colMap = M.fromList uniqUnits
numRows = length cons
numCols = M.size colMap
negateCons = map (\ (UnitPow u k) -> UnitPow u (k))
negatePosAbs (UnitPow (UnitParamPosAbs x) k) = UnitPow (UnitParamPosAbs x) (k)
negatePosAbs u = u
colSort (UnitLiteral i) (UnitLiteral j) = compare i j
colSort (UnitLiteral _) _ = LT
colSort _ (UnitLiteral _) = GT
colSort (UnitParamPosAbs x) (UnitParamPosAbs y) = compare x y
colSort (UnitParamPosAbs _) _ = GT
colSort _ (UnitParamPosAbs _) = LT
colSort x y = compare x y
isUnitRHS (UnitPow (UnitName _) _) = True
isUnitRHS (UnitPow (UnitParamEAPAbs _) _) = True
isUnitRHS _ = False
shiftTerms :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms (lhs, rhs) = (lhsOk ++ negateCons rhsShift, rhsOk ++ negateCons lhsShift)
where
(lhsOk, lhsShift) = partition (not . isUnitRHS) lhs
(rhsOk, rhsShift) = partition isUnitRHS rhs
flattenConstraints :: Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints = map (\ (ConEq u1 u2) -> (flattenUnits u1, flattenUnits u2))
isInconsistentRREF a = a @@> (rows a 1, cols a 1) == 1 && rank (takeColumns (cols a 1) (dropRows (rows a 1) a))== 0
rref :: H.Matrix Double -> H.Matrix Double
rref a = snd $ rrefMatrices' a 0 0 []
rrefMatrices' a j k mats
| j k == n = (mats, a)
| j == m = (mats, a)
| a @@> (j k, j) == 0 = case findIndex (/= 0) below of
Nothing -> rrefMatrices' a (j + 1) (k + 1) mats
Just i' -> rrefMatrices' (swapMat <> a) j k (swapMat:mats)
where i = j k + i'
swapMat = elemRowSwap n i (j k)
| otherwise = rrefMatrices' a2 (j + 1) k mats2
where
n = rows a
m = cols a
below = getColumnBelow a (j k, j)
erm = elemRowMult n (j k) (recip (a @@> (j k, j)))
(a1, mats1) | a @@> (j k, j) /= 1 = (erm <> a, erm:mats)
| otherwise = (a, mats)
findAdds _ m ms
| isWritten = (new <> m, new:ms)
| otherwise = (m, ms)
where
(isWritten, new) = runST $ do
new <- newMatrix 0 n n :: ST s (STMatrix s Double)
sequence [ writeMatrix new i' i' 1 | i' <- [0 .. (n 1)] ]
let f w i | i >= n = return w
| i == j k = f w (i + 1)
| a @@> (i, j) == 0 = f w (i + 1)
| otherwise = writeMatrix new i (j k) ( (a @@> (i, j)))
>> f True (i + 1)
isWritten <- f False 0
(isWritten,) `fmap` freezeMatrix new
(a2, mats2) = findAdds 0 a1 mats1
getColumnBelow a (i, j) = concat . H.toLists $ subMatrix (i, j) (n i, 1) a
where n = rows a
elemRowMult :: Int -> Int -> Double -> H.Matrix Double
elemRowMult n i k = diag (H.fromList (replicate i 1.0 ++ [k] ++ replicate (n i 1) 1.0))
elemRowSwap :: Int -> Int -> Int -> H.Matrix Double
elemRowSwap n i j
| i == j = ident n
| i > j = elemRowSwap n j i
| otherwise = extractRows ([0..i1] ++ [j] ++ [i+1..j1] ++ [i] ++ [j+1..n1]) $ ident n
findInconsistentRows :: H.Matrix Double -> H.Matrix Double -> [Int]
findInconsistentRows coA augA = [0..(rows augA 1)] \\ consistent
where
consistent = head (filter (tryRows coA augA) (tails ( [0..(rows augA 1)])) ++ [[]])
tryRows coA augA ns = (rank coA' == rank augA')
where
coA' = extractRows ns coA
augA' = extractRows ns augA
extractRows = flip (?)
m @@> i = m `atIndex` i