{-
   Copyright 2016, Dominic Orchard, Andrew Rice, Mistral Contrastin, Matthew Danish

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
-}

{-
  Units of measure extension to Fortran: backend
-}

{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Camfort.Specification.Units.InferenceBackend
  ( chooseImplicitNames
  , criticalVariables
  , inconsistentConstraints
  , inferVariables
  -- mainly for debugging and testing:
  , shiftTerms
  , flattenConstraints
  , flattenUnits
  , constraintsToMatrix
  , constraintsToMatrices
  , rref
  , genUnitAssignments
  , genUnitAssignments'
  , provenance
  , splitNormHNF
  ) where

import           Camfort.Specification.Units.Environment
import qualified Camfort.Specification.Units.InferenceBackendFlint as Flint
import           Control.Arrow (first, second, (***))
import           Control.Monad
import           Control.Monad.ST
import           Control.Parallel.Strategies
import qualified Data.Array as A
import           Data.Generics.Uniplate.Operations
  (transformBi, universeBi)
import           Data.Graph.Inductive hiding ((><))
import qualified Data.IntMap as IM
import qualified Data.IntSet as IS
import           Data.List
  ((\\), findIndex, inits, nub, partition, sort, sortBy, group, tails, foldl')
import qualified Data.Map.Strict as M
import           Data.Maybe (fromMaybe, mapMaybe)
import           Data.Ord
import           Data.Tuple (swap)
import           Numeric.LinearAlgebra
  ( atIndex, (<>)
  , rank, (?), (¿)
  , rows, cols
  , subMatrix, diag
  , fromBlocks, ident
  )
import qualified Numeric.LinearAlgebra as H
import           Numeric.LinearAlgebra.Devel
  ( newMatrix, readMatrix
  , writeMatrix, runSTMatrix
  , freezeMatrix, STMatrix
  )
import           Prelude hiding ((<>))


-- | Returns list of formerly-undetermined variables and their units.
inferVariables :: Constraints -> [(VV, UnitInfo)]
inferVariables :: Constraints -> [(VV, UnitInfo)]
inferVariables Constraints
cons = [(VV, UnitInfo)]
unitVarAssignments
  where
    unitAssignments :: [([UnitInfo], UnitInfo)]
unitAssignments = Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments Constraints
cons
    -- Find the rows corresponding to the distilled "unit :: var"
    -- information for ordinary (non-polymorphic) variables.
    unitVarAssignments :: [(VV, UnitInfo)]
unitVarAssignments            =
      [ (VV
var, UnitInfo
units) | ([UnitPow (UnitVar VV
var)                 Double
k], UnitInfo
units) <- [([UnitInfo], UnitInfo)]
unitAssignments, Double
k Double -> Double -> Bool
`approxEq` Double
1 ] [(VV, UnitInfo)] -> [(VV, UnitInfo)] -> [(VV, UnitInfo)]
forall a. [a] -> [a] -> [a]
++
      [ (VV
var, UnitInfo
units) | ([UnitPow (UnitParamVarAbs (VV
_, VV
var)) Double
k], UnitInfo
units)    <- [([UnitInfo], UnitInfo)]
unitAssignments, Double
k Double -> Double -> Bool
`approxEq` Double
1 ]

-- Detect inconsistency if concrete units are assigned an implicit
-- abstract unit variable with coefficients not equal, or there are
-- monomorphic literals being given parametric polymorphic units.
detectInconsistency :: [([UnitInfo], UnitInfo)] -> Constraints
detectInconsistency :: [([UnitInfo], UnitInfo)] -> Constraints
detectInconsistency [([UnitInfo], UnitInfo)]
unitAssignments = Constraints
inconsist
  where
    ua' :: [([UnitInfo], [UnitInfo])]
ua' = (([UnitInfo], UnitInfo) -> ([UnitInfo], [UnitInfo]))
-> [([UnitInfo], UnitInfo)] -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> [a] -> [b]
map (([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms (([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo]))
-> (([UnitInfo], UnitInfo) -> ([UnitInfo], [UnitInfo]))
-> ([UnitInfo], UnitInfo)
-> ([UnitInfo], [UnitInfo])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UnitInfo -> [UnitInfo])
-> ([UnitInfo], UnitInfo) -> ([UnitInfo], [UnitInfo])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap UnitInfo -> [UnitInfo]
flattenUnits) [([UnitInfo], UnitInfo)]
unitAssignments
    badImplicits :: [([UnitInfo], UnitInfo)]
badImplicits = [ ([UnitInfo] -> UnitInfo)
-> ([UnitInfo], [UnitInfo]) -> ([UnitInfo], UnitInfo)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo], [UnitInfo])
a | a :: ([UnitInfo], [UnitInfo])
a@([UnitPow (UnitParamImpAbs String
_) Double
k1], [UnitInfo]
rhs) <- [([UnitInfo], [UnitInfo])]
ua'
                                      , UnitPow UnitInfo
_ Double
k2 <- [UnitInfo]
rhs
                                      , Double
k1 Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
k2 ]
    inconsist :: Constraints
inconsist = [([UnitInfo], UnitInfo)] -> Constraints
unitAssignmentsToConstraints [([UnitInfo], UnitInfo)]
badImplicits Constraints -> Constraints -> Constraints
forall a. [a] -> [a] -> [a]
++ [([UnitInfo], UnitInfo)] -> Constraints
mustBeUnitless [([UnitInfo], UnitInfo)]
unitAssignments

-- Must be unitless: any assignments of parametric abstract units to
-- monomorphic literals.
mustBeUnitless :: [([UnitInfo], UnitInfo)] -> Constraints
mustBeUnitless :: [([UnitInfo], UnitInfo)] -> Constraints
mustBeUnitless [([UnitInfo], UnitInfo)]
unitAssignments = Constraints
mbu
  where
    mbu :: Constraints
mbu = [ UnitInfo -> UnitInfo -> Constraint
ConEq UnitInfo
UnitlessLit (UnitInfo -> Double -> UnitInfo
UnitPow (Int -> UnitInfo
UnitLiteral Int
l) Double
k)
          | (UnitPow (UnitLiteral Int
l) Double
k:[UnitInfo]
_, [UnitInfo]
rhs) <- [([UnitInfo], [UnitInfo])]
ua''
          , (UnitInfo -> Bool) -> [UnitInfo] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any UnitInfo -> Bool
isParametric ([UnitInfo] -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi [UnitInfo]
rhs :: [UnitInfo]) ]
    -- ua' = map (shiftTerms . fmap flattenUnits) unitAssignments
    ua'' :: [([UnitInfo], [UnitInfo])]
ua'' = (([UnitInfo], UnitInfo) -> ([UnitInfo], [UnitInfo]))
-> [([UnitInfo], UnitInfo)] -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> [a] -> [b]
map ((UnitInfo -> Bool)
-> ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTermsBy UnitInfo -> Bool
isLiteral (([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo]))
-> (([UnitInfo], UnitInfo) -> ([UnitInfo], [UnitInfo]))
-> ([UnitInfo], UnitInfo)
-> ([UnitInfo], [UnitInfo])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UnitInfo -> [UnitInfo])
-> ([UnitInfo], UnitInfo) -> ([UnitInfo], [UnitInfo])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap UnitInfo -> [UnitInfo]
flattenUnits) [([UnitInfo], UnitInfo)]
unitAssignments

    isLiteral :: UnitInfo -> Bool
isLiteral UnitLiteral{} = Bool
True
    isLiteral (UnitPow UnitLiteral{} Double
_) = Bool
True
    isLiteral UnitInfo
_ = Bool
False

    isParametric :: UnitInfo -> Bool
isParametric UnitParamVarAbs{} = Bool
True
    isParametric UnitParamPosAbs{} = Bool
True
    isParametric UnitParamEAPAbs{} = Bool
True
    isParametric UnitParamLitAbs{} = Bool
True
    isParametric UnitParamImpAbs{} = Bool
True
    isParametric (UnitPow UnitInfo
u Double
_)     = UnitInfo -> Bool
isParametric UnitInfo
u
    isParametric UnitInfo
_                 = Bool
False


-- convert the assignment format back into constraints
unitAssignmentsToConstraints :: [([UnitInfo], UnitInfo)] -> Constraints
unitAssignmentsToConstraints :: [([UnitInfo], UnitInfo)] -> Constraints
unitAssignmentsToConstraints = (([UnitInfo], UnitInfo) -> Constraint)
-> [([UnitInfo], UnitInfo)] -> Constraints
forall a b. (a -> b) -> [a] -> [b]
map ((UnitInfo -> UnitInfo -> Constraint)
-> (UnitInfo, UnitInfo) -> Constraint
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry UnitInfo -> UnitInfo -> Constraint
ConEq ((UnitInfo, UnitInfo) -> Constraint)
-> (([UnitInfo], UnitInfo) -> (UnitInfo, UnitInfo))
-> ([UnitInfo], UnitInfo)
-> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([UnitInfo] -> UnitInfo)
-> ([UnitInfo], UnitInfo) -> (UnitInfo, UnitInfo)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first [UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits)

-- | Raw units-assignment pairs.
genUnitAssignments :: Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments :: Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments Constraints
cons
  -- if the results include any mappings that must be forced to be unitless...
  | Constraints
mbu <- [([UnitInfo], UnitInfo)] -> Constraints
mustBeUnitless [([UnitInfo], UnitInfo)]
ua, Bool -> Bool
not (Constraints -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Constraints
mbu) = Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments (Constraints
mbu Constraints -> Constraints -> Constraints
forall a. [a] -> [a] -> [a]
++ [([UnitInfo], UnitInfo)] -> Constraints
unitAssignmentsToConstraints [([UnitInfo], UnitInfo)]
ua)
  | Constraints -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([([UnitInfo], UnitInfo)] -> Constraints
detectInconsistency [([UnitInfo], UnitInfo)]
ua)            = [([UnitInfo], UnitInfo)]
ua
  | Bool
otherwise                                = []
  where
    ua :: [([UnitInfo], UnitInfo)]
ua = SortFn -> Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments' SortFn
colSort Constraints
cons

-- | Break up the problem of solving normHNF on each group of related
-- columns, then bring it all back together.
splitNormHNF :: H.Matrix Double -> (H.Matrix Double, [Int])
splitNormHNF :: Matrix Double -> (Matrix Double, [Int])
splitNormHNF Matrix Double
unsolvedM = (Matrix Double
combinedMat, [Int]
allNewColIndices)
  where
    combinedMat :: Matrix Double
combinedMat      = [(Matrix Double, [Int])] -> Matrix Double
joinMat ((((Matrix Double, [Int]), [Int]) -> (Matrix Double, [Int]))
-> [((Matrix Double, [Int]), [Int])] -> [(Matrix Double, [Int])]
forall a b. (a -> b) -> [a] -> [b]
map (((Matrix Double, [Int]) -> Matrix Double)
-> ((Matrix Double, [Int]), [Int]) -> (Matrix Double, [Int])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (Matrix Double, [Int]) -> Matrix Double
forall a b. (a, b) -> a
fst) [((Matrix Double, [Int]), [Int])]
solvedMs)
    allNewColIndices :: [Int]
allNewColIndices = (((Matrix Double, [Int]), [Int]) -> [Int])
-> [((Matrix Double, [Int]), [Int])] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Matrix Double, [Int]) -> [Int]
forall a b. (a, b) -> b
snd ((Matrix Double, [Int]) -> [Int])
-> (((Matrix Double, [Int]), [Int]) -> (Matrix Double, [Int]))
-> ((Matrix Double, [Int]), [Int])
-> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Matrix Double, [Int]), [Int]) -> (Matrix Double, [Int])
forall a b. (a, b) -> a
fst) [((Matrix Double, [Int]), [Int])]
solvedMs

    inParallel :: ([a], b) -> ([a], b)
inParallel = (([a], b) -> Strategy ([a], b) -> ([a], b)
forall a. a -> Strategy a -> a
`using` Strategy [a] -> Strategy b -> Strategy ([a], b)
forall a b. Strategy a -> Strategy b -> Strategy (a, b)
parTuple2 (Strategy a -> Strategy [a]
forall a. Strategy a -> Strategy [a]
parList Strategy a
forall a. Strategy a
rseq) Strategy b
forall a. Strategy a
rseq)
    ([((Matrix Double, [Int]), [Int])]
solvedMs, Int
_) = ([((Matrix Double, [Int]), [Int])], Int)
-> ([((Matrix Double, [Int]), [Int])], Int)
forall a b. ([a], b) -> ([a], b)
inParallel (([((Matrix Double, [Int]), [Int])], Int)
 -> ([((Matrix Double, [Int]), [Int])], Int))
-> ([((Matrix Double, [Int]), [Int])]
    -> ([((Matrix Double, [Int]), [Int])], Int))
-> [((Matrix Double, [Int]), [Int])]
-> ([((Matrix Double, [Int]), [Int])], Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (([((Matrix Double, [Int]), [Int])], Int)
 -> ((Matrix Double, [Int]), [Int])
 -> ([((Matrix Double, [Int]), [Int])], Int))
-> ([((Matrix Double, [Int]), [Int])], Int)
-> [((Matrix Double, [Int]), [Int])]
-> ([((Matrix Double, [Int]), [Int])], Int)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ([((Matrix Double, [Int]), [Int])], Int)
-> ((Matrix Double, [Int]), [Int])
-> ([((Matrix Double, [Int]), [Int])], Int)
forall a.
([((a, [Int]), [Int])], Int)
-> ((a, [Int]), [Int]) -> ([((a, [Int]), [Int])], Int)
eachResult ([], Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
unsolvedM) ([((Matrix Double, [Int]), [Int])]
 -> ([((Matrix Double, [Int]), [Int])], Int))
-> [((Matrix Double, [Int]), [Int])]
-> ([((Matrix Double, [Int]), [Int])], Int)
forall a b. (a -> b) -> a -> b
$ ((Matrix Double, [Int]) -> ((Matrix Double, [Int]), [Int]))
-> [(Matrix Double, [Int])] -> [((Matrix Double, [Int]), [Int])]
forall a b. (a -> b) -> [a] -> [b]
map ((Matrix Double -> (Matrix Double, [Int]))
-> (Matrix Double, [Int]) -> ((Matrix Double, [Int]), [Int])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Matrix Double -> (Matrix Double, [Int])
Flint.normHNF) (Matrix Double -> [(Matrix Double, [Int])]
splitMat Matrix Double
unsolvedM)

    -- for each result re-number the generated columns & add mappings for each.
    eachResult :: ([((a, [Int]), [Int])], Int)
-> ((a, [Int]), [Int]) -> ([((a, [Int]), [Int])], Int)
eachResult ([((a, [Int]), [Int])]
ms, Int
startI) ((a
m, [Int]
newColIndices), [Int]
origCols) = (((a
m, [Int]
newColIndices'), [Int]
origCols')((a, [Int]), [Int])
-> [((a, [Int]), [Int])] -> [((a, [Int]), [Int])]
forall a. a -> [a] -> [a]
:[((a, [Int]), [Int])]
ms, Int
endI)
      where
        -- produce (length newColIndices) number of mappings
        endI :: Int
endI           = Int
startI Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
newColIndices
        -- re-number the newColIndices according to the lookup list
        newColIndices' :: [Int]
newColIndices' = (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Int]
origCols [Int] -> Int -> Int
forall a. [a] -> Int -> a
!!) [Int]
newColIndices
        -- add columns in the (combined) matrix for the newly
        -- generated columns from running normHNF on m.
        origCols' :: [Int]
origCols'      = [Int]
origCols [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
startI .. Int
endIInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]

genUnitAssignments' :: SortFn -> Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments' :: SortFn -> Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments' SortFn
_ [] = []
genUnitAssignments' SortFn
sortfn Constraints
cons
  | [(Double, UnitInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Double, UnitInfo)]
colList                                      = []
  | [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
inconsists                                   = [([UnitInfo], UnitInfo)]
unitAssignments
  | Bool
otherwise                                         = []
  where
    (Matrix Double
lhsM, Matrix Double
rhsM, [Int]
inconsists, Array Int UnitInfo
lhsColA, Array Int UnitInfo
rhsColA) = SortFn
-> Constraints
-> (Matrix Double, Matrix Double, [Int], Array Int UnitInfo,
    Array Int UnitInfo)
constraintsToMatrices' SortFn
sortfn Constraints
cons
    unsolvedM :: Matrix Double
unsolvedM | Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
rhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
rhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Matrix Double
lhsM
              | Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
lhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
lhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Matrix Double
rhsM
              | Bool
otherwise                        = [[Matrix Double]] -> Matrix Double
forall t. Element t => [[Matrix t]] -> Matrix t
fromBlocks [[Matrix Double
lhsM, Matrix Double
rhsM]]
    (Matrix Double
solvedM, [Int]
newColIndices)      = Matrix Double -> (Matrix Double, [Int])
splitNormHNF Matrix Double
unsolvedM
    -- solvedM can have additional columns and rows from normHNF;
    -- cosolvedM corresponds to the original lhsM.
    -- cosolvedM                     = subMatrix (0, 0) (rows solvedM, cols lhsM) solvedM
    -- cosolvedMrhs                  = subMatrix (0, cols lhsM) (rows solvedM, cols solvedM - cols lhsM) solvedM

    -- generate a colList with both the original columns and new ones generated
    -- if a new column generated was derived from the right-hand side then negate it
    numLhsCols :: Int
numLhsCols                    = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int, Int) -> Int
forall a b. (a, b) -> b
snd (Array Int UnitInfo -> (Int, Int)
forall i e. Array i e -> (i, i)
A.bounds Array Int UnitInfo
lhsColA)
    colList :: [(Double, UnitInfo)]
colList                       = (UnitInfo -> (Double, UnitInfo))
-> [UnitInfo] -> [(Double, UnitInfo)]
forall a b. (a -> b) -> [a] -> [b]
map (Double
1,) (Array Int UnitInfo -> [UnitInfo]
forall i e. Array i e -> [e]
A.elems Array Int UnitInfo
lhsColA [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ Array Int UnitInfo -> [UnitInfo]
forall i e. Array i e -> [e]
A.elems Array Int UnitInfo
rhsColA) [(Double, UnitInfo)]
-> [(Double, UnitInfo)] -> [(Double, UnitInfo)]
forall a. [a] -> [a] -> [a]
++ (Int -> (Double, UnitInfo)) -> [Int] -> [(Double, UnitInfo)]
forall a b. (a -> b) -> [a] -> [b]
map Int -> (Double, UnitInfo)
genC [Int]
newColIndices
    genC :: Int -> (Double, UnitInfo)
genC Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
numLhsCols      = (-Double
k, String -> UnitInfo
UnitParamImpAbs (UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u))
           | Bool
otherwise            = (Double
k, String -> UnitInfo
UnitParamImpAbs (UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u))
      where (Double
k, UnitInfo
u) = [(Double, UnitInfo)]
colList [(Double, UnitInfo)] -> Int -> (Double, UnitInfo)
forall a. [a] -> Int -> a
!! Int
n
    -- Convert the rows of the solved matrix into flattened unit
    -- expressions in the form of "unit ** k".
    unitPow :: (Double, UnitInfo) -> Double -> UnitInfo
unitPow (Double
k, UnitInfo
u) Double
x              = UnitInfo -> Double -> UnitInfo
UnitPow UnitInfo
u (Double
k Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
x)
    unitPows :: [[UnitInfo]]
unitPows                      = ([Double] -> [UnitInfo]) -> [[Double]] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ((UnitInfo -> [UnitInfo]) -> [UnitInfo] -> [UnitInfo]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap UnitInfo -> [UnitInfo]
flattenUnits ([UnitInfo] -> [UnitInfo])
-> ([Double] -> [UnitInfo]) -> [Double] -> [UnitInfo]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Double, UnitInfo) -> Double -> UnitInfo)
-> [(Double, UnitInfo)] -> [Double] -> [UnitInfo]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Double, UnitInfo) -> Double -> UnitInfo
unitPow [(Double, UnitInfo)]
colList) (Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
H.toLists Matrix Double
solvedM)

    -- Variables to the left, unit names to the right side of the equation.
    unitAssignments :: [([UnitInfo], UnitInfo)]
unitAssignments               = ([UnitInfo] -> ([UnitInfo], UnitInfo))
-> [[UnitInfo]] -> [([UnitInfo], UnitInfo)]
forall a b. (a -> b) -> [a] -> [b]
map (([UnitInfo] -> UnitInfo)
-> ([UnitInfo], [UnitInfo]) -> ([UnitInfo], UnitInfo)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo] -> UnitInfo)
-> ([UnitInfo] -> [UnitInfo]) -> [UnitInfo] -> UnitInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UnitInfo -> UnitInfo) -> [UnitInfo] -> [UnitInfo]
forall a b. (a -> b) -> [a] -> [b]
map UnitInfo -> UnitInfo
negatePosAbs) (([UnitInfo], [UnitInfo]) -> ([UnitInfo], UnitInfo))
-> ([UnitInfo] -> ([UnitInfo], [UnitInfo]))
-> [UnitInfo]
-> ([UnitInfo], UnitInfo)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
checkSanity (([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo]))
-> ([UnitInfo] -> ([UnitInfo], [UnitInfo]))
-> [UnitInfo]
-> ([UnitInfo], [UnitInfo])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool -> Bool
not (Bool -> Bool) -> (UnitInfo -> Bool) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Bool
isUnitRHS')) [[UnitInfo]]
unitPows
    isUnitRHS' :: UnitInfo -> Bool
isUnitRHS' (UnitPow (UnitName String
_) Double
_)        = Bool
True
    isUnitRHS' (UnitPow (UnitParamEAPAbs VV
_) Double
_) = Bool
True
    -- Because this version of isUnitRHS different from
    -- constraintsToMatrix interpretation, we need to ensure that any
    -- moved ParamPosAbs units are negated, because they are
    -- effectively being shifted across the equal-sign:
    isUnitRHS' (UnitPow (UnitParamImpAbs String
_) Double
_) = Bool
True
    isUnitRHS' (UnitPow (UnitParamPosAbs (VV
_, Int
0)) Double
_) = Bool
False
    isUnitRHS' (UnitPow (UnitParamPosAbs (VV, Int)
_) Double
_) = Bool
True
    isUnitRHS' UnitInfo
_                               = Bool
False

checkSanity :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
checkSanity :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
checkSanity (u1 :: [UnitInfo]
u1@[UnitPow (UnitVar VV
_) Double
_], [UnitInfo]
u2)
  | [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ [ Bool
True | UnitParamPosAbs (VV
_, Int
_) <- [UnitInfo] -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi [UnitInfo]
u2 ]
      [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ [ Bool
True | UnitParamImpAbs String
_      <- [UnitInfo] -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi [UnitInfo]
u2 ] = ([UnitInfo]
u1[UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++[UnitInfo]
u2,[])
checkSanity (u1 :: [UnitInfo]
u1@[UnitPow (UnitParamVarAbs (VV
f, VV
_)) Double
_], [UnitInfo]
u2)
  | [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [ Bool
True | UnitParamPosAbs (VV
f', Int
_) <- [UnitInfo] -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi [UnitInfo]
u2, VV
f' VV -> VV -> Bool
forall a. Eq a => a -> a -> Bool
/= VV
f ] = ([UnitInfo]
u1[UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++[UnitInfo]
u2,[])
checkSanity ([UnitInfo], [UnitInfo])
c = ([UnitInfo], [UnitInfo])
c

--------------------------------------------------

-- FIXME: you know better...
approxEq :: Double -> Double -> Bool
approxEq :: Double -> Double -> Bool
approxEq Double
a Double
b = Double -> Double
forall a. Num a => a -> a
abs (Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
a) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
epsilon
notApproxEq :: Double -> Double -> Bool
notApproxEq :: Double -> Double -> Bool
notApproxEq Double
a Double
b = Bool -> Bool
not (Double -> Double -> Bool
approxEq Double
a Double
b)
epsilon :: Double
epsilon :: Double
epsilon = Double
0.001 -- arbitrary

--------------------------------------------------

type RowNum = Int               -- ^ 'row number' of matrix
type ColNum = Int               -- ^ 'column number' of matrix
-- | Represents a subproblem of AX=B where the row numbers and column
-- numbers help you re-map back to the original problem.
type Subproblem = ([RowNum], (H.Matrix Double, H.Matrix Double), [ColNum])

-- | Divide up the AX=B problem into smaller problems based on the
-- 'related columns' and their corresponding rows in the
-- right-hand-side of the equation. Where lhsM = A and rhsM = B.  The
-- resulting list of subproblems contains the new, smaller As and Bs
-- as well as a list of original row numbers and column numbers to
-- aide re-mapping back to the original lhsM and rhsM.
splitMatWithRHS :: H.Matrix Double -> H.Matrix Double -> [Subproblem]
splitMatWithRHS :: Matrix Double -> Matrix Double -> [Subproblem]
splitMatWithRHS Matrix Double
lhsM Matrix Double
rhsM | Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
lhsM Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = ([Int] -> Subproblem) -> [[Int]] -> [Subproblem]
forall a b. (a -> b) -> [a] -> [b]
map ([Int] -> Subproblem
eachComponent ([Int] -> Subproblem) -> ([Int] -> [Int]) -> [Int] -> Subproblem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort) ([[Int]] -> [Subproblem]) -> [[Int]] -> [Subproblem]
forall a b. (a -> b) -> a -> b
$ Gr () () -> [[Int]]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [[Int]]
scc (Matrix Double -> Gr () ()
relatedColumnsGraph Matrix Double
lhsM)
                          | Bool
otherwise     = []
  where
    -- Gets called on every strongly-connected component / related set of columns.
    eachComponent :: [Int] -> Subproblem
eachComponent [Int]
cs = ([Int]
rs, (Matrix Double, Matrix Double)
mats, [Int]
cs)
      where
        -- Selected columns
        lhsSelCols :: H.Matrix Double
        lhsSelCols :: Matrix Double
lhsSelCols = Matrix Double
lhsM Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
¿ [Int]
cs

        csLen :: Int
csLen = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
lhsSelCols

        -- Find the row numbers of the 'all zero' rows in lhsM.
        lhsAllZeroRows :: [RowNum]
        lhsAllZeroRows :: [Int]
lhsAllZeroRows = ((Int, [Double]) -> Int) -> [(Int, [Double])] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, [Double]) -> Int
forall a b. (a, b) -> a
fst ([(Int, [Double])] -> [Int])
-> ([[Double]] -> [(Int, [Double])]) -> [[Double]] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, [Double]) -> Bool) -> [(Int, [Double])] -> [(Int, [Double])]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Bool) -> [Double] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Double -> Double -> Bool
approxEq Double
0) ([Double] -> Bool)
-> ((Int, [Double]) -> [Double]) -> (Int, [Double]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, [Double]) -> [Double]
forall a b. (a, b) -> b
snd) ([(Int, [Double])] -> [(Int, [Double])])
-> ([[Double]] -> [(Int, [Double])])
-> [[Double]]
-> [(Int, [Double])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [[Double]] -> [(Int, [Double])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ([[Double]] -> [Int]) -> [[Double]] -> [Int]
forall a b. (a -> b) -> a -> b
$ Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
H.toLists Matrix Double
lhsM

        -- Find the row numbers that correspond to the non-zero co-efficients in the selected columns.
        lhsNonZeroColRows :: [(RowNum, [Double])]
        lhsNonZeroColRows :: [(Int, [Double])]
lhsNonZeroColRows = ((Int, [Double]) -> Bool) -> [(Int, [Double])] -> [(Int, [Double])]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Bool) -> [Double] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Double -> Double -> Bool
notApproxEq Double
0) ([Double] -> Bool)
-> ((Int, [Double]) -> [Double]) -> (Int, [Double]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, [Double]) -> [Double]
forall a b. (a, b) -> b
snd) ([(Int, [Double])] -> [(Int, [Double])])
-> (Matrix Double -> [(Int, [Double])])
-> Matrix Double
-> [(Int, [Double])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [[Double]] -> [(Int, [Double])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ([[Double]] -> [(Int, [Double])])
-> (Matrix Double -> [[Double]])
-> Matrix Double
-> [(Int, [Double])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
H.toLists (Matrix Double -> [(Int, [Double])])
-> Matrix Double -> [(Int, [Double])]
forall a b. (a -> b) -> a -> b
$ Matrix Double
lhsSelCols

        -- List of all the row numbers and row values combined from the two above variables.
        lhsNumberedRows :: [(RowNum, [Double])]
        lhsNumberedRows :: [(Int, [Double])]
lhsNumberedRows = ((Int, [Double]) -> (Int, [Double]) -> Ordering)
-> [(Int, [Double])] -> [(Int, [Double])]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Int, [Double]) -> Int)
-> (Int, [Double]) -> (Int, [Double]) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Int, [Double]) -> Int
forall a b. (a, b) -> a
fst) ([(Int, [Double])] -> [(Int, [Double])])
-> [(Int, [Double])] -> [(Int, [Double])]
forall a b. (a -> b) -> a -> b
$ [(Int, [Double])]
lhsNonZeroColRows [(Int, [Double])] -> [(Int, [Double])] -> [(Int, [Double])]
forall a. [a] -> [a] -> [a]
++ [Int] -> [[Double]] -> [(Int, [Double])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
lhsAllZeroRows ([Double] -> [[Double]]
forall a. a -> [a]
repeat (Int -> Double -> [Double]
forall a. Int -> a -> [a]
replicate Int
csLen Double
0))

        -- For each of the above LHS rows find a corresponding RHS row.
        rhsSelRows :: [[Double]]
        rhsSelRows :: [[Double]]
rhsSelRows | Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
rhsM Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
H.toLists (Matrix Double
rhsM Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
? ((Int, [Double]) -> Int) -> [(Int, [Double])] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, [Double]) -> Int
forall a b. (a, b) -> a
fst [(Int, [Double])]
lhsNumberedRows)
                   | Bool
otherwise     = []

        reassoc :: (a, a) -> b -> (a, (a, b))
reassoc (a
a, a
b) b
c = (a
a, (a
b, b
c))

        notAllZero :: (a, ([Double], [Double])) -> Bool
notAllZero (a
_, ([Double]
lhs, [Double]
rhs)) = (Double -> Bool) -> [Double] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Double -> Double -> Bool
notApproxEq Double
0) ([Double]
lhs [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ [Double]
rhs)

        -- Zip the selected LHS, RHS rows together, filter out any that are all zeroes.
        numberedRows :: ([RowNum], [([Double], [Double])])
        numberedRows :: ([Int], [([Double], [Double])])
numberedRows = [(Int, ([Double], [Double]))] -> ([Int], [([Double], [Double])])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, ([Double], [Double]))] -> ([Int], [([Double], [Double])]))
-> ([(Int, ([Double], [Double]))] -> [(Int, ([Double], [Double]))])
-> [(Int, ([Double], [Double]))]
-> ([Int], [([Double], [Double])])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, ([Double], [Double])) -> Bool)
-> [(Int, ([Double], [Double]))] -> [(Int, ([Double], [Double]))]
forall a. (a -> Bool) -> [a] -> [a]
filter (Int, ([Double], [Double])) -> Bool
forall a. (a, ([Double], [Double])) -> Bool
notAllZero ([(Int, ([Double], [Double]))] -> ([Int], [([Double], [Double])]))
-> [(Int, ([Double], [Double]))] -> ([Int], [([Double], [Double])])
forall a b. (a -> b) -> a -> b
$ ((Int, [Double]) -> [Double] -> (Int, ([Double], [Double])))
-> [(Int, [Double])] -> [[Double]] -> [(Int, ([Double], [Double]))]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int, [Double]) -> [Double] -> (Int, ([Double], [Double]))
forall a a b. (a, a) -> b -> (a, (a, b))
reassoc [(Int, [Double])]
lhsNumberedRows [[Double]]
rhsSelRows

        rs :: [RowNum]          -- list of row numbers in the subproblem
        mats :: (H.Matrix Double, H.Matrix Double) -- LHS/RHS subproblem matrices
        ([Int]
rs, (Matrix Double, Matrix Double)
mats) = ([([Double], [Double])] -> (Matrix Double, Matrix Double))
-> ([Int], [([Double], [Double])])
-> ([Int], (Matrix Double, Matrix Double))
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (([[Double]] -> Matrix Double
forall t. Element t => [[t]] -> Matrix t
H.fromLists ([[Double]] -> Matrix Double)
-> ([[Double]] -> Matrix Double)
-> ([[Double]], [[Double]])
-> (Matrix Double, Matrix Double)
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [[Double]] -> Matrix Double
forall t. Element t => [[t]] -> Matrix t
H.fromLists) (([[Double]], [[Double]]) -> (Matrix Double, Matrix Double))
-> ([([Double], [Double])] -> ([[Double]], [[Double]]))
-> [([Double], [Double])]
-> (Matrix Double, Matrix Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [([Double], [Double])] -> ([[Double]], [[Double]])
forall a b. [(a, b)] -> ([a], [b])
unzip) ([Int], [([Double], [Double])])
numberedRows

-- | Split the lhsM/rhsM problem into subproblems and then look for
-- inconsistent rows in each subproblem, concatenating all of the
-- inconsistent row numbers found (in terms of the rows of the
-- original lhsM).
splitFindInconsistentRows :: H.Matrix Double -> H.Matrix Double -> [RowNum]
splitFindInconsistentRows :: Matrix Double -> Matrix Double -> [Int]
splitFindInconsistentRows Matrix Double
lhsMat Matrix Double
rhsMat = (Subproblem -> [Int]) -> [Subproblem] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Subproblem -> [Int]
forall b c. ([b], (Matrix Double, Matrix Double), c) -> [b]
eachComponent ([Subproblem] -> [Int]) -> [Subproblem] -> [Int]
forall a b. (a -> b) -> a -> b
$ Matrix Double -> Matrix Double -> [Subproblem]
splitMatWithRHS Matrix Double
lhsMat Matrix Double
rhsMat
  where
    eachComponent :: ([b], (Matrix Double, Matrix Double), c) -> [b]
eachComponent ([b]
rs, (Matrix Double
lhsM, Matrix Double
rhsM), c
_) = (Int -> b) -> [Int] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map ([b]
rs [b] -> Int -> b
forall a. [a] -> Int -> a
!!) ([Int] -> [b]) -> [Int] -> [b]
forall a b. (a -> b) -> a -> b
$ Matrix Double -> Matrix Double -> [Int]
findInconsistentRows Matrix Double
lhsM Matrix Double
augM
      where
        -- Augmented matrix is defined as the combined LHS/RHS matrices.
        augM :: Matrix Double
augM
          | Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
rhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
rhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Matrix Double
lhsM
          | Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
lhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
lhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Matrix Double
rhsM
          | Bool
otherwise = [[Matrix Double]] -> Matrix Double
forall t. Element t => [[Matrix t]] -> Matrix t
fromBlocks [[Matrix Double
lhsM, Matrix Double
rhsM]]

-- | Break out the 'unrelated' columns in a single matrix into
-- separate matrices, along with a list of their original column
-- positions.
splitMat :: H.Matrix Double -> [(H.Matrix Double, [ColNum])]
splitMat :: Matrix Double -> [(Matrix Double, [Int])]
splitMat Matrix Double
m = ([Int] -> (Matrix Double, [Int]))
-> [[Int]] -> [(Matrix Double, [Int])]
forall a b. (a -> b) -> [a] -> [b]
map ([Int] -> (Matrix Double, [Int])
eachComponent ([Int] -> (Matrix Double, [Int]))
-> ([Int] -> [Int]) -> [Int] -> (Matrix Double, [Int])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort) ([[Int]] -> [(Matrix Double, [Int])])
-> [[Int]] -> [(Matrix Double, [Int])]
forall a b. (a -> b) -> a -> b
$ Gr () () -> [[Int]]
forall (gr :: * -> * -> *) a b. Graph gr => gr a b -> [[Int]]
scc (Matrix Double -> Gr () ()
relatedColumnsGraph Matrix Double
m)
  where
    eachComponent :: [Int] -> (Matrix Double, [Int])
eachComponent [Int]
cs = ([[Double]] -> Matrix Double
forall t. Element t => [[t]] -> Matrix t
H.fromLists ([[Double]] -> Matrix Double)
-> (Matrix Double -> [[Double]]) -> Matrix Double -> Matrix Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Double] -> Bool) -> [[Double]] -> [[Double]]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Double -> Bool) -> [Double] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0)) ([[Double]] -> [[Double]])
-> (Matrix Double -> [[Double]]) -> Matrix Double -> [[Double]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
H.toLists (Matrix Double -> Matrix Double) -> Matrix Double -> Matrix Double
forall a b. (a -> b) -> a -> b
$ Matrix Double
m Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
¿ [Int]
cs, [Int]
cs)

-- | Bring together the split matrices and put the columns back in
-- their original order. Rows may not be in the same order as the
-- original, but the constraints should be equivalent.
joinMat :: [(H.Matrix Double, [Int])] -> H.Matrix Double
joinMat :: [(Matrix Double, [Int])] -> Matrix Double
joinMat [(Matrix Double, [Int])]
ms = Matrix Double
sortedM
  where
    disorderedM :: Matrix Double
disorderedM = [Matrix Double] -> Matrix Double
forall t. (Element t, Num t) => [Matrix t] -> Matrix t
H.diagBlock (((Matrix Double, [Int]) -> Matrix Double)
-> [(Matrix Double, [Int])] -> [Matrix Double]
forall a b. (a -> b) -> [a] -> [b]
map (Matrix Double, [Int]) -> Matrix Double
forall a b. (a, b) -> a
fst [(Matrix Double, [Int])]
ms)
    colsWithIdx :: [(Int, Vector Double)]
colsWithIdx = [Int] -> [Vector Double] -> [(Int, Vector Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Matrix Double, [Int]) -> [Int])
-> [(Matrix Double, [Int])] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Matrix Double, [Int]) -> [Int]
forall a b. (a, b) -> b
snd [(Matrix Double, [Int])]
ms) ([Vector Double] -> [(Int, Vector Double)])
-> (Matrix Double -> [Vector Double])
-> Matrix Double
-> [(Int, Vector Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix Double -> [Vector Double]
forall t. Element t => Matrix t -> [Vector t]
H.toColumns (Matrix Double -> [(Int, Vector Double)])
-> Matrix Double -> [(Int, Vector Double)]
forall a b. (a -> b) -> a -> b
$ Matrix Double
disorderedM
    sortedM :: Matrix Double
sortedM     = [Vector Double] -> Matrix Double
forall t. Element t => [Vector t] -> Matrix t
H.fromColumns ([Vector Double] -> Matrix Double)
-> ([(Int, Vector Double)] -> [Vector Double])
-> [(Int, Vector Double)]
-> Matrix Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Vector Double) -> Vector Double)
-> [(Int, Vector Double)] -> [Vector Double]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Vector Double) -> Vector Double
forall a b. (a, b) -> b
snd ([(Int, Vector Double)] -> [Vector Double])
-> ([(Int, Vector Double)] -> [(Int, Vector Double)])
-> [(Int, Vector Double)]
-> [Vector Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Vector Double) -> (Int, Vector Double) -> Ordering)
-> [(Int, Vector Double)] -> [(Int, Vector Double)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((Int, Vector Double) -> Int)
-> (Int, Vector Double) -> (Int, Vector Double) -> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Int, Vector Double) -> Int
forall a b. (a, b) -> a
fst) ([(Int, Vector Double)] -> Matrix Double)
-> [(Int, Vector Double)] -> Matrix Double
forall a b. (a -> b) -> a -> b
$ [(Int, Vector Double)]
colsWithIdx

-- | Turn a matrix into a graph where each node represents a column
-- and each edge represents two columns that have non-zero
-- co-efficients in some row. Basically, 'related columns'. Also
-- includes self-refs for each node..
relatedColumnsGraph :: H.Matrix Double -> Gr () ()
relatedColumnsGraph :: Matrix Double -> Gr () ()
relatedColumnsGraph Matrix Double
m = [LNode ()] -> [LEdge ()] -> Gr () ()
forall (gr :: * -> * -> *) a b.
Graph gr =>
[LNode a] -> [LEdge b] -> gr a b
mkGraph ((Int -> LNode ()) -> [Int] -> [LNode ()]
forall a b. (a -> b) -> [a] -> [b]
map (,()) [Int]
ns) (((Int, Int) -> LEdge ()) -> [(Int, Int)] -> [LEdge ()]
forall a b. (a -> b) -> [a] -> [b]
map (\ (Int
a,Int
b) -> (Int
a,Int
b,())) [(Int, Int)]
es)
  where
    nonZeroCols :: [[Int]]
nonZeroCols = [ [ Int
j | Int
j <- [Int
0..Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1], Bool -> Bool
not (Matrix Double
m Matrix Double -> IndexOf Matrix -> Double
forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`atIndex` (Int
i, Int
j) Double -> Double -> Bool
`approxEq` Double
0) ] | Int
i <- [Int
0..Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
m Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] ]
    ns :: [Int]
ns          = [Int] -> [Int]
forall a. Eq a => [a] -> [a]
nub ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [[Int]] -> [Int]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Int]]
nonZeroCols
    es :: [(Int, Int)]
es          = [ (Int
i, Int
j) | [Int]
cs <- [[Int]]
nonZeroCols, [Int
i, Int
j] <- [[Int]] -> [[Int]]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [[Int]
cs, [Int]
cs] ]

-- Convert a set of constraints into a matrix of co-efficients, and a
-- reverse mapping of column numbers to units.
constraintsToMatrix :: Constraints -> (H.Matrix Double, [Int], A.Array Int UnitInfo)
constraintsToMatrix :: Constraints -> (Matrix Double, [Int], Array Int UnitInfo)
constraintsToMatrix Constraints
cons
  | ([UnitInfo] -> Bool) -> [[UnitInfo]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all [UnitInfo] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[UnitInfo]]
lhs = (Int -> Matrix Double
forall a. (Num a, Element a) => Int -> Matrix a
H.ident Int
0, [], (Int, Int) -> [UnitInfo] -> Array Int UnitInfo
forall i e. Ix i => (i, i) -> [e] -> Array i e
A.listArray (Int
0, -Int
1) [])
  | Bool
otherwise = (Matrix Double
augM, [Int]
inconsists, (Int, Int) -> [UnitInfo] -> Array Int UnitInfo
forall i e. Ix i => (i, i) -> [e] -> Array i e
A.listArray (Int
0, [UnitInfo] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [UnitInfo]
colElems Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [UnitInfo]
colElems)
  where
    -- convert each constraint into the form (lhs, rhs)
    consPairs :: [([UnitInfo], [UnitInfo])]
consPairs       = (([UnitInfo], [UnitInfo]) -> Bool)
-> [([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])]
forall a. (a -> Bool) -> [a] -> [a]
filter (([UnitInfo] -> [UnitInfo] -> Bool)
-> ([UnitInfo], [UnitInfo]) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [UnitInfo] -> [UnitInfo] -> Bool
forall a. Eq a => a -> a -> Bool
(/=)) ([([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])])
-> [([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> a -> b
$ Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints Constraints
cons
    -- ensure terms are on the correct side of the equal sign
    shiftedCons :: [([UnitInfo], [UnitInfo])]
shiftedCons     = (([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo]))
-> [([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms [([UnitInfo], [UnitInfo])]
consPairs
    lhs :: [[UnitInfo]]
lhs             = (([UnitInfo], [UnitInfo]) -> [UnitInfo])
-> [([UnitInfo], [UnitInfo])] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> [UnitInfo]
forall a b. (a, b) -> a
fst [([UnitInfo], [UnitInfo])]
shiftedCons
    rhs :: [[UnitInfo]]
rhs             = (([UnitInfo], [UnitInfo]) -> [UnitInfo])
-> [([UnitInfo], [UnitInfo])] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> [UnitInfo]
forall a b. (a, b) -> b
snd [([UnitInfo], [UnitInfo])]
shiftedCons
    (Matrix Double
lhsM, Array Int UnitInfo
lhsCols) = SortFn -> [[UnitInfo]] -> (Matrix Double, Array Int UnitInfo)
flattenedToMatrix SortFn
colSort [[UnitInfo]]
lhs
    (Matrix Double
rhsM, Array Int UnitInfo
rhsCols) = SortFn -> [[UnitInfo]] -> (Matrix Double, Array Int UnitInfo)
flattenedToMatrix SortFn
colSort [[UnitInfo]]
rhs
    colElems :: [UnitInfo]
colElems        = Array Int UnitInfo -> [UnitInfo]
forall i e. Array i e -> [e]
A.elems Array Int UnitInfo
lhsCols [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ Array Int UnitInfo -> [UnitInfo]
forall i e. Array i e -> [e]
A.elems Array Int UnitInfo
rhsCols
    augM :: Matrix Double
augM            = if Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
rhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
rhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Matrix Double
lhsM else if Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
lhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
lhsM Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Matrix Double
rhsM else [[Matrix Double]] -> Matrix Double
forall t. Element t => [[Matrix t]] -> Matrix t
fromBlocks [[Matrix Double
lhsM, Matrix Double
rhsM]]
    inconsists :: [Int]
inconsists      = Matrix Double -> Matrix Double -> [Int]
splitFindInconsistentRows Matrix Double
lhsM Matrix Double
rhsM

constraintsToMatrices :: Constraints -> (H.Matrix Double, H.Matrix Double, [Int], A.Array Int UnitInfo, A.Array Int UnitInfo)
constraintsToMatrices :: Constraints
-> (Matrix Double, Matrix Double, [Int], Array Int UnitInfo,
    Array Int UnitInfo)
constraintsToMatrices Constraints
cons = SortFn
-> Constraints
-> (Matrix Double, Matrix Double, [Int], Array Int UnitInfo,
    Array Int UnitInfo)
constraintsToMatrices' SortFn
colSort Constraints
cons

constraintsToMatrices' :: SortFn -> Constraints -> (H.Matrix Double, H.Matrix Double, [Int], A.Array Int UnitInfo, A.Array Int UnitInfo)
constraintsToMatrices' :: SortFn
-> Constraints
-> (Matrix Double, Matrix Double, [Int], Array Int UnitInfo,
    Array Int UnitInfo)
constraintsToMatrices' SortFn
sortfn Constraints
cons
  | ([UnitInfo] -> Bool) -> [[UnitInfo]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all [UnitInfo] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [[UnitInfo]]
lhs = (Int -> Matrix Double
forall a. (Num a, Element a) => Int -> Matrix a
H.ident Int
0, Int -> Matrix Double
forall a. (Num a, Element a) => Int -> Matrix a
H.ident Int
0, [], (Int, Int) -> [UnitInfo] -> Array Int UnitInfo
forall i e. Ix i => (i, i) -> [e] -> Array i e
A.listArray (Int
0, -Int
1) [], (Int, Int) -> [UnitInfo] -> Array Int UnitInfo
forall i e. Ix i => (i, i) -> [e] -> Array i e
A.listArray (Int
0, -Int
1) [])
  | Bool
otherwise = (Matrix Double
lhsM, Matrix Double
rhsM, [Int]
inconsists, Array Int UnitInfo
lhsCols, Array Int UnitInfo
rhsCols)
  where
    -- convert each constraint into the form (lhs, rhs)
    consPairs :: [([UnitInfo], [UnitInfo])]
consPairs       = (([UnitInfo], [UnitInfo]) -> Bool)
-> [([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])]
forall a. (a -> Bool) -> [a] -> [a]
filter (([UnitInfo] -> [UnitInfo] -> Bool)
-> ([UnitInfo], [UnitInfo]) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [UnitInfo] -> [UnitInfo] -> Bool
forall a. Eq a => a -> a -> Bool
(/=)) ([([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])])
-> [([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> a -> b
$ Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints Constraints
cons
    -- ensure terms are on the correct side of the equal sign
    shiftedCons :: [([UnitInfo], [UnitInfo])]
shiftedCons     = (([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo]))
-> [([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms [([UnitInfo], [UnitInfo])]
consPairs
    lhs :: [[UnitInfo]]
lhs             = (([UnitInfo], [UnitInfo]) -> [UnitInfo])
-> [([UnitInfo], [UnitInfo])] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> [UnitInfo]
forall a b. (a, b) -> a
fst [([UnitInfo], [UnitInfo])]
shiftedCons
    rhs :: [[UnitInfo]]
rhs             = (([UnitInfo], [UnitInfo]) -> [UnitInfo])
-> [([UnitInfo], [UnitInfo])] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> [UnitInfo]
forall a b. (a, b) -> b
snd [([UnitInfo], [UnitInfo])]
shiftedCons
    (Matrix Double
lhsM, Array Int UnitInfo
lhsCols) = SortFn -> [[UnitInfo]] -> (Matrix Double, Array Int UnitInfo)
flattenedToMatrix SortFn
sortfn [[UnitInfo]]
lhs
    (Matrix Double
rhsM, Array Int UnitInfo
rhsCols) = SortFn -> [[UnitInfo]] -> (Matrix Double, Array Int UnitInfo)
flattenedToMatrix SortFn
sortfn [[UnitInfo]]
rhs
    inconsists :: [Int]
inconsists      = Matrix Double -> Matrix Double -> [Int]
splitFindInconsistentRows Matrix Double
lhsM Matrix Double
rhsM

-- [[UnitInfo]] is a list of flattened constraints
flattenedToMatrix :: SortFn -> [[UnitInfo]] -> (H.Matrix Double, A.Array Int UnitInfo)
flattenedToMatrix :: SortFn -> [[UnitInfo]] -> (Matrix Double, Array Int UnitInfo)
flattenedToMatrix SortFn
sortfn [[UnitInfo]]
cons = (Matrix Double
m, (Int, Int) -> [(Int, UnitInfo)] -> Array Int UnitInfo
forall i e. Ix i => (i, i) -> [(i, e)] -> Array i e
A.array (Int
0, Int
numCols Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (((UnitInfo, Int) -> (Int, UnitInfo))
-> [(UnitInfo, Int)] -> [(Int, UnitInfo)]
forall a b. (a -> b) -> [a] -> [b]
map (UnitInfo, Int) -> (Int, UnitInfo)
forall a b. (a, b) -> (b, a)
swap [(UnitInfo, Int)]
uniqUnits))
  where
    m :: Matrix Double
m = (forall s. ST s (STMatrix s Double)) -> Matrix Double
forall t. Storable t => (forall s. ST s (STMatrix s t)) -> Matrix t
runSTMatrix ((forall s. ST s (STMatrix s Double)) -> Matrix Double)
-> (forall s. ST s (STMatrix s Double)) -> Matrix Double
forall a b. (a -> b) -> a -> b
$ do
          STMatrix s Double
newM <- Double -> Int -> Int -> ST s (STMatrix s Double)
forall t s. Storable t => t -> Int -> Int -> ST s (STMatrix s t)
newMatrix Double
0 Int
numRows Int
numCols
          -- loop through all constraints
          [([UnitInfo], Int)] -> (([UnitInfo], Int) -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([[UnitInfo]] -> [Int] -> [([UnitInfo], Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[UnitInfo]]
cons [Int
0..]) ((([UnitInfo], Int) -> ST s ()) -> ST s ())
-> (([UnitInfo], Int) -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ ([UnitInfo]
unitPows, Int
row) -> do
            -- write co-efficients for the lhs of the constraint
            [UnitInfo] -> (UnitInfo -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [UnitInfo]
unitPows ((UnitInfo -> ST s ()) -> ST s ())
-> (UnitInfo -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \ (UnitPow UnitInfo
u Double
k) -> do
              case UnitInfo -> Map UnitInfo Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup UnitInfo
u Map UnitInfo Int
colMap of
                Just Int
col -> STMatrix s Double -> Int -> Int -> ST s Double
forall t s. Storable t => STMatrix s t -> Int -> Int -> ST s t
readMatrix STMatrix s Double
newM Int
row Int
col ST s Double -> (Double -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (STMatrix s Double -> Int -> Int -> Double -> ST s ()
forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix STMatrix s Double
newM Int
row Int
col (Double -> ST s ()) -> (Double -> Double) -> Double -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Double
forall a. Num a => a -> a -> a
+Double
k))
                Maybe Int
_        -> () -> ST s ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          STMatrix s Double -> ST s (STMatrix s Double)
forall (m :: * -> *) a. Monad m => a -> m a
return STMatrix s Double
newM
    -- identify and enumerate every unit uniquely
    uniqUnits :: [(UnitInfo, Int)]
uniqUnits = ([UnitInfo] -> [Int] -> [(UnitInfo, Int)])
-> [Int] -> [UnitInfo] -> [(UnitInfo, Int)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip [UnitInfo] -> [Int] -> [(UnitInfo, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] ([UnitInfo] -> [(UnitInfo, Int)])
-> ([UnitInfo] -> [UnitInfo]) -> [UnitInfo] -> [(UnitInfo, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([UnitInfo] -> UnitInfo) -> [[UnitInfo]] -> [UnitInfo]
forall a b. (a -> b) -> [a] -> [b]
map [UnitInfo] -> UnitInfo
forall a. [a] -> a
head ([[UnitInfo]] -> [UnitInfo])
-> ([UnitInfo] -> [[UnitInfo]]) -> [UnitInfo] -> [UnitInfo]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [UnitInfo] -> [[UnitInfo]]
forall a. Eq a => [a] -> [[a]]
group ([UnitInfo] -> [[UnitInfo]])
-> ([UnitInfo] -> [UnitInfo]) -> [UnitInfo] -> [[UnitInfo]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SortFn -> [UnitInfo] -> [UnitInfo]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy SortFn
sortfn ([UnitInfo] -> [(UnitInfo, Int)])
-> [UnitInfo] -> [(UnitInfo, Int)]
forall a b. (a -> b) -> a -> b
$ [ UnitInfo
u | UnitPow UnitInfo
u Double
_ <- [[UnitInfo]] -> [UnitInfo]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[UnitInfo]]
cons ]
    -- map units to their unique column number
    colMap :: Map UnitInfo Int
colMap    = [(UnitInfo, Int)] -> Map UnitInfo Int
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(UnitInfo, Int)]
uniqUnits
    numRows :: Int
numRows   = [[UnitInfo]] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[UnitInfo]]
cons
    numCols :: Int
numCols   = Map UnitInfo Int -> Int
forall k a. Map k a -> Int
M.size Map UnitInfo Int
colMap

negateCons :: [UnitInfo] -> [UnitInfo]
negateCons :: [UnitInfo] -> [UnitInfo]
negateCons = (UnitInfo -> UnitInfo) -> [UnitInfo] -> [UnitInfo]
forall a b. (a -> b) -> [a] -> [b]
map (\ (UnitPow UnitInfo
u Double
k) -> UnitInfo -> Double -> UnitInfo
UnitPow UnitInfo
u (-Double
k))

negatePosAbs :: UnitInfo -> UnitInfo
negatePosAbs :: UnitInfo -> UnitInfo
negatePosAbs (UnitPow (UnitParamPosAbs (VV, Int)
x) Double
k) = UnitInfo -> Double -> UnitInfo
UnitPow ((VV, Int) -> UnitInfo
UnitParamPosAbs (VV, Int)
x) (-Double
k)
negatePosAbs (UnitPow (UnitParamImpAbs String
v) Double
k) = UnitInfo -> Double -> UnitInfo
UnitPow (String -> UnitInfo
UnitParamImpAbs String
v) (-Double
k)
negatePosAbs UnitInfo
u                               = UnitInfo
u

--------------------------------------------------

-- Units that should appear on the right-hand-side of the matrix during solving
isUnitRHS :: UnitInfo -> Bool
isUnitRHS :: UnitInfo -> Bool
isUnitRHS (UnitPow (UnitName String
_) Double
_)        = Bool
True
isUnitRHS (UnitPow (UnitParamEAPAbs VV
_) Double
_) = Bool
True
isUnitRHS UnitInfo
_                               = Bool
False

-- | Shift UnitNames/EAPAbs poly units to the RHS, and all else to the LHS.
shiftTerms :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms :: ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms ([UnitInfo]
lhs, [UnitInfo]
rhs) = ([UnitInfo]
lhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
rhsShift, [UnitInfo]
rhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
lhsShift)
  where
    ([UnitInfo]
lhsOk, [UnitInfo]
lhsShift) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool -> Bool
not (Bool -> Bool) -> (UnitInfo -> Bool) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Bool
isUnitRHS) [UnitInfo]
lhs
    ([UnitInfo]
rhsOk, [UnitInfo]
rhsShift) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition UnitInfo -> Bool
isUnitRHS [UnitInfo]
rhs

-- | Shift terms based on function f (<- True, False ->).
shiftTermsBy :: (UnitInfo -> Bool) -> ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTermsBy :: (UnitInfo -> Bool)
-> ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTermsBy UnitInfo -> Bool
f ([UnitInfo]
lhs, [UnitInfo]
rhs) = ([UnitInfo]
lhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
rhsShift, [UnitInfo]
rhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
lhsShift)
  where
    ([UnitInfo]
lhsOk, [UnitInfo]
lhsShift) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition UnitInfo -> Bool
f [UnitInfo]
lhs
    ([UnitInfo]
rhsOk, [UnitInfo]
rhsShift) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool -> Bool
not (Bool -> Bool) -> (UnitInfo -> Bool) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Bool
f) [UnitInfo]
rhs


-- | Translate all constraints into a LHS, RHS side of units.
flattenConstraints :: Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints :: Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints = (Constraint -> ([UnitInfo], [UnitInfo]))
-> Constraints -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> [a] -> [b]
map (\ (ConEq UnitInfo
u1 UnitInfo
u2) -> (UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
u1, UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
u2))

--------------------------------------------------
-- Matrix solving functions based on HMatrix

-- | Returns given matrix transformed into Reduced Row Echelon Form
rref :: H.Matrix Double -> H.Matrix Double
rref :: Matrix Double -> Matrix Double
rref Matrix Double
a = ([(Matrix Double, RRefOp)], Matrix Double) -> Matrix Double
forall a b. (a, b) -> b
snd (([(Matrix Double, RRefOp)], Matrix Double) -> Matrix Double)
-> ([(Matrix Double, RRefOp)], Matrix Double) -> Matrix Double
forall a b. (a -> b) -> a -> b
$ Matrix Double
-> Int
-> Int
-> [(Matrix Double, RRefOp)]
-> ([(Matrix Double, RRefOp)], Matrix Double)
rrefMatrices' Matrix Double
a Int
0 Int
0 []
  where
    -- (a', den, r) = Flint.rref a

-- Provenance of matrices.
data RRefOp
  = ElemRowSwap Int Int         -- ^ swapped row with row
  | ElemRowMult Int Double      -- ^ scaled row by constant
  | ElemRowAdds [(Int, Int)]    -- ^ set of added row onto row ops
  deriving (Int -> RRefOp -> ShowS
[RRefOp] -> ShowS
RRefOp -> String
(Int -> RRefOp -> ShowS)
-> (RRefOp -> String) -> ([RRefOp] -> ShowS) -> Show RRefOp
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RRefOp] -> ShowS
$cshowList :: [RRefOp] -> ShowS
show :: RRefOp -> String
$cshow :: RRefOp -> String
showsPrec :: Int -> RRefOp -> ShowS
$cshowsPrec :: Int -> RRefOp -> ShowS
Show, RRefOp -> RRefOp -> Bool
(RRefOp -> RRefOp -> Bool)
-> (RRefOp -> RRefOp -> Bool) -> Eq RRefOp
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RRefOp -> RRefOp -> Bool
$c/= :: RRefOp -> RRefOp -> Bool
== :: RRefOp -> RRefOp -> Bool
$c== :: RRefOp -> RRefOp -> Bool
Eq, Eq RRefOp
Eq RRefOp
-> (RRefOp -> RRefOp -> Ordering)
-> (RRefOp -> RRefOp -> Bool)
-> (RRefOp -> RRefOp -> Bool)
-> (RRefOp -> RRefOp -> Bool)
-> (RRefOp -> RRefOp -> Bool)
-> (RRefOp -> RRefOp -> RRefOp)
-> (RRefOp -> RRefOp -> RRefOp)
-> Ord RRefOp
RRefOp -> RRefOp -> Bool
RRefOp -> RRefOp -> Ordering
RRefOp -> RRefOp -> RRefOp
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RRefOp -> RRefOp -> RRefOp
$cmin :: RRefOp -> RRefOp -> RRefOp
max :: RRefOp -> RRefOp -> RRefOp
$cmax :: RRefOp -> RRefOp -> RRefOp
>= :: RRefOp -> RRefOp -> Bool
$c>= :: RRefOp -> RRefOp -> Bool
> :: RRefOp -> RRefOp -> Bool
$c> :: RRefOp -> RRefOp -> Bool
<= :: RRefOp -> RRefOp -> Bool
$c<= :: RRefOp -> RRefOp -> Bool
< :: RRefOp -> RRefOp -> Bool
$c< :: RRefOp -> RRefOp -> Bool
compare :: RRefOp -> RRefOp -> Ordering
$ccompare :: RRefOp -> RRefOp -> Ordering
$cp1Ord :: Eq RRefOp
Ord)

-- worker function
-- invariant: the matrix a is in rref except within the submatrix (j-k,j) to (n,n)
rrefMatrices' :: H.Matrix Double -> Int -> Int -> [(H.Matrix Double, RRefOp)] ->
                 ([(H.Matrix Double, RRefOp)], H.Matrix Double)
rrefMatrices' :: Matrix Double
-> Int
-> Int
-> [(Matrix Double, RRefOp)]
-> ([(Matrix Double, RRefOp)], Matrix Double)
rrefMatrices' Matrix Double
a Int
j Int
k [(Matrix Double, RRefOp)]
mats
  -- Base cases:
  | Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n            = ([(Matrix Double, RRefOp)]
mats, Matrix Double
a)
  | Int
j     Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m            = ([(Matrix Double, RRefOp)]
mats, Matrix Double
a)

  -- When we haven't yet found the first non-zero number in the row, but we really need one:
  | Matrix Double
a Matrix Double -> IndexOf Matrix -> Double
forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`atIndex` (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k, Int
j) Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 = case (Double -> Bool) -> [Double] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0) [Double]
below of
    -- this column is all 0s below current row, must move onto the next column
    Maybe Int
Nothing -> Matrix Double
-> Int
-> Int
-> [(Matrix Double, RRefOp)]
-> ([(Matrix Double, RRefOp)], Matrix Double)
rrefMatrices' Matrix Double
a (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [(Matrix Double, RRefOp)]
mats
    -- we've found a row that has a non-zero element that can be swapped into this row
    Just Int
i' -> Matrix Double
-> Int
-> Int
-> [(Matrix Double, RRefOp)]
-> ([(Matrix Double, RRefOp)], Matrix Double)
rrefMatrices' (Matrix Double
swapMat Matrix Double -> Matrix Double -> Matrix Double
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
<> Matrix Double
a) Int
j Int
k ((Matrix Double
swapMat, Int -> Int -> RRefOp
ElemRowSwap Int
i (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k))(Matrix Double, RRefOp)
-> [(Matrix Double, RRefOp)] -> [(Matrix Double, RRefOp)]
forall a. a -> [a] -> [a]
:[(Matrix Double, RRefOp)]
mats)
      where i :: Int
i       = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i'
            swapMat :: Matrix Double
swapMat = Int -> Int -> Int -> Matrix Double
elemRowSwap Int
n Int
i (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k)

  -- We have found a non-zero cell at (j - k, j), so transform it into
  -- a 1 if needed using elemRowMult, and then clear out any lingering
  -- non-zero values that might appear in the same column, using
  -- elemRowAdd:
  | Bool
otherwise             = Matrix Double
-> Int
-> Int
-> [(Matrix Double, RRefOp)]
-> ([(Matrix Double, RRefOp)], Matrix Double)
rrefMatrices' Matrix Double
a2 (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
k [(Matrix Double, RRefOp)]
mats2
  where
    n :: Int
n     = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
a
    m :: Int
m     = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
a
    below :: [Double]
below = Matrix Double -> (Int, Int) -> [Double]
getColumnBelow Matrix Double
a (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k, Int
j)
    scale :: Double
scale = Double -> Double
forall a. Fractional a => a -> a
recip (Matrix Double
a Matrix Double -> IndexOf Matrix -> Double
forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`atIndex` (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k, Int
j))
    erm :: Matrix Double
erm   = Int -> Int -> Double -> Matrix Double
elemRowMult Int
n (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) Double
scale

    -- scale the row if the cell is not already equal to 1
    (Matrix Double
a1, [(Matrix Double, RRefOp)]
mats1) | Matrix Double
a Matrix Double -> IndexOf Matrix -> Double
forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`atIndex` (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k, Int
j) Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
1 = (Matrix Double
erm Matrix Double -> Matrix Double -> Matrix Double
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
<> Matrix Double
a, (Matrix Double
erm, Int -> Double -> RRefOp
ElemRowMult (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) Double
scale)(Matrix Double, RRefOp)
-> [(Matrix Double, RRefOp)] -> [(Matrix Double, RRefOp)]
forall a. a -> [a] -> [a]
:[(Matrix Double, RRefOp)]
mats)
                | Bool
otherwise                   = (Matrix Double
a, [(Matrix Double, RRefOp)]
mats)

    -- Locate any non-zero values in the same column as (j - k, j) and
    -- cancel them out. Optimisation: instead of constructing a
    -- separate elemRowAdd matrix for each cancellation that are then
    -- multiplied together, simply build a single matrix that cancels
    -- all of them out at the same time, using the ST Monad.
    findAdds :: p
-> Matrix Double
-> [(Matrix Double, RRefOp)]
-> (Matrix Double, [(Matrix Double, RRefOp)])
findAdds p
_ Matrix Double
curM [(Matrix Double, RRefOp)]
ms
      | Bool
isWritten = (Matrix Double
newMat Matrix Double -> Matrix Double -> Matrix Double
forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
<> Matrix Double
curM, (Matrix Double
newMat, [(Int, Int)] -> RRefOp
ElemRowAdds [(Int, Int)]
matOps)(Matrix Double, RRefOp)
-> [(Matrix Double, RRefOp)] -> [(Matrix Double, RRefOp)]
forall a. a -> [a] -> [a]
:[(Matrix Double, RRefOp)]
ms)
      | Bool
otherwise = (Matrix Double
curM, [(Matrix Double, RRefOp)]
ms)
      where
        (Bool
isWritten, [(Int, Int)]
matOps, Matrix Double
newMat) = (forall s. ST s (Bool, [(Int, Int)], Matrix Double))
-> (Bool, [(Int, Int)], Matrix Double)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Bool, [(Int, Int)], Matrix Double))
 -> (Bool, [(Int, Int)], Matrix Double))
-> (forall s. ST s (Bool, [(Int, Int)], Matrix Double))
-> (Bool, [(Int, Int)], Matrix Double)
forall a b. (a -> b) -> a -> b
$ do
          STMatrix s Double
newM <- Double -> Int -> Int -> ST s (STMatrix s Double)
forall t s. Storable t => t -> Int -> Int -> ST s (STMatrix s t)
newMatrix Double
0 Int
n Int
n :: ST s (STMatrix s Double)
          [ST s ()] -> ST s ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ [ STMatrix s Double -> Int -> Int -> Double -> ST s ()
forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix STMatrix s Double
newM Int
i' Int
i' Double
1 | Int
i' <- [Int
0 .. (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] ]
          let f :: Bool -> [(Int, Int)] -> Int -> ST s (Bool, [(Int, Int)])
f Bool
w [(Int, Int)]
o Int
i | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n                  = (Bool, [(Int, Int)]) -> ST s (Bool, [(Int, Int)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
w, [(Int, Int)]
o)
                      | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k              = Bool -> [(Int, Int)] -> Int -> ST s (Bool, [(Int, Int)])
f Bool
w [(Int, Int)]
o (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                      | Matrix Double
a Matrix Double -> IndexOf Matrix -> Double
forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`atIndex` (Int
i, Int
j) Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
0 = Bool -> [(Int, Int)] -> Int -> ST s (Bool, [(Int, Int)])
f Bool
w [(Int, Int)]
o (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                      | Bool
otherwise               = STMatrix s Double -> Int -> Int -> Double -> ST s ()
forall t s.
Storable t =>
STMatrix s t -> Int -> Int -> t -> ST s ()
writeMatrix STMatrix s Double
newM Int
i (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) (- (Matrix Double
a Matrix Double -> IndexOf Matrix -> Double
forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`atIndex` (Int
i, Int
j)))
                                                  ST s () -> ST s (Bool, [(Int, Int)]) -> ST s (Bool, [(Int, Int)])
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Bool -> [(Int, Int)] -> Int -> ST s (Bool, [(Int, Int)])
f Bool
True ((Int
i, Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k)(Int, Int) -> [(Int, Int)] -> [(Int, Int)]
forall a. a -> [a] -> [a]
:[(Int, Int)]
o) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
          (Bool
isW, [(Int, Int)]
ops) <- Bool -> [(Int, Int)] -> Int -> ST s (Bool, [(Int, Int)])
f Bool
False [] Int
0
          (Bool
isW, [(Int, Int)]
ops,) (Matrix Double -> (Bool, [(Int, Int)], Matrix Double))
-> ST s (Matrix Double) -> ST s (Bool, [(Int, Int)], Matrix Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` STMatrix s Double -> ST s (Matrix Double)
forall t s. Element t => STMatrix s t -> ST s (Matrix t)
freezeMatrix STMatrix s Double
newM

    (Matrix Double
a2, [(Matrix Double, RRefOp)]
mats2) = Int
-> Matrix Double
-> [(Matrix Double, RRefOp)]
-> (Matrix Double, [(Matrix Double, RRefOp)])
forall p.
p
-> Matrix Double
-> [(Matrix Double, RRefOp)]
-> (Matrix Double, [(Matrix Double, RRefOp)])
findAdds (Int
0::Int) Matrix Double
a1 [(Matrix Double, RRefOp)]
mats1

-- Get a list of values that occur below (i, j) in the matrix a.
getColumnBelow :: H.Matrix Double -> (Int, Int) -> [Double]
getColumnBelow :: Matrix Double -> (Int, Int) -> [Double]
getColumnBelow Matrix Double
a (Int
i, Int
j) = [[Double]] -> [Double]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Double]] -> [Double])
-> (Matrix Double -> [[Double]]) -> Matrix Double -> [Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
H.toLists (Matrix Double -> [Double]) -> Matrix Double -> [Double]
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> (Int, Int) -> Matrix Double -> Matrix Double
forall a.
Element a =>
(Int, Int) -> (Int, Int) -> Matrix a -> Matrix a
subMatrix (Int
i, Int
j) (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i, Int
1) Matrix Double
a
  where n :: Int
n = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
a

-- 'Elementary row operation' matrices
elemRowMult :: Int -> Int -> Double -> H.Matrix Double
elemRowMult :: Int -> Int -> Double -> Matrix Double
elemRowMult Int
n Int
i Double
k = Vector Double -> Matrix Double
forall a. (Num a, Element a) => Vector a -> Matrix a
diag ([Double] -> Vector Double
forall a. Storable a => [a] -> Vector a
H.fromList (Int -> Double -> [Double]
forall a. Int -> a -> [a]
replicate Int
i Double
1.0 [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ [Double
k] [Double] -> [Double] -> [Double]
forall a. [a] -> [a] -> [a]
++ Int -> Double -> [Double]
forall a. Int -> a -> [a]
replicate (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Double
1.0))

elemRowSwap :: Int -> Int -> Int -> H.Matrix Double
elemRowSwap :: Int -> Int -> Int -> Matrix Double
elemRowSwap Int
n Int
i Int
j
  | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j          = Int -> Matrix Double
forall a. (Num a, Element a) => Int -> Matrix a
ident Int
n
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
j           = Int -> Int -> Int -> Matrix Double
elemRowSwap Int
n Int
j Int
i
  | Bool
otherwise       = Int -> Matrix Double
forall a. (Num a, Element a) => Int -> Matrix a
ident Int
n Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
? ([Int
0..Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
j] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
i] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
jInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1])

--------------------------------------------------

type GraphCol = IM.IntMap IS.IntSet   -- graph from origin to dest.
type Provenance = IM.IntMap IS.IntSet -- graph from dest. to origin

opToGraphCol :: RRefOp -> GraphCol
opToGraphCol :: RRefOp -> GraphCol
opToGraphCol ElemRowMult{} = GraphCol
forall a. IntMap a
IM.empty
opToGraphCol (ElemRowSwap Int
i Int
j) = [(Int, IntSet)] -> GraphCol
forall a. [(Int, a)] -> IntMap a
IM.fromList [ (Int
i, Int -> IntSet
IS.singleton Int
j), (Int
j, Int -> IntSet
IS.singleton Int
i) ]
opToGraphCol (ElemRowAdds [(Int, Int)]
l)   = [(Int, IntSet)] -> GraphCol
forall a. [(Int, a)] -> IntMap a
IM.fromList ([(Int, IntSet)] -> GraphCol) -> [(Int, IntSet)] -> GraphCol
forall a b. (a -> b) -> a -> b
$ [[(Int, IntSet)]] -> [(Int, IntSet)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ [(Int
i, [Int] -> IntSet
IS.fromList [Int
i,Int
j]), (Int
j, Int -> IntSet
IS.singleton Int
j)]  | (Int
i, Int
j) <- [(Int, Int)]
l ]

graphColCombine :: GraphCol -> GraphCol -> GraphCol
graphColCombine :: GraphCol -> GraphCol -> GraphCol
graphColCombine GraphCol
g1 GraphCol
g2 = (IntSet -> IntSet -> IntSet) -> GraphCol -> GraphCol -> GraphCol
forall a. (a -> a -> a) -> IntMap a -> IntMap a -> IntMap a
IM.unionWith (((IntSet, IntSet) -> IntSet) -> IntSet -> IntSet -> IntSet
forall a b c. ((a, b) -> c) -> a -> b -> c
curry (IntSet, IntSet) -> IntSet
forall a b. (a, b) -> b
snd) GraphCol
g1 (GraphCol -> GraphCol) -> GraphCol -> GraphCol
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet) -> GraphCol -> GraphCol
forall a b. (a -> b) -> IntMap a -> IntMap b
IM.map ([Int] -> IntSet
IS.fromList ([Int] -> IntSet) -> (IntSet -> [Int]) -> IntSet -> IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
trans ([Int] -> [Int]) -> (IntSet -> [Int]) -> IntSet -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntSet -> [Int]
IS.toList) GraphCol
g2
  where
    trans :: [Int] -> [Int]
trans = (Int -> [Int]) -> [Int] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\ Int
i -> [Int
i] [Int] -> Maybe [Int] -> [Int]
forall a. a -> Maybe a -> a
`fromMaybe` (IntSet -> [Int]
IS.toList (IntSet -> [Int]) -> Maybe IntSet -> Maybe [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> GraphCol -> Maybe IntSet
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
i GraphCol
g1))

invertGraphCol :: GraphCol -> GraphCol
invertGraphCol :: GraphCol -> GraphCol
invertGraphCol GraphCol
g = (IntSet -> IntSet -> IntSet) -> [(Int, IntSet)] -> GraphCol
forall a. (a -> a -> a) -> [(Int, a)] -> IntMap a
IM.fromListWith IntSet -> IntSet -> IntSet
IS.union [ (Int
i, Int -> IntSet
IS.singleton Int
j) | (Int
j, IntSet
jset) <- GraphCol -> [(Int, IntSet)]
forall a. IntMap a -> [(Int, a)]
IM.toList GraphCol
g, Int
i <- IntSet -> [Int]
IS.toList IntSet
jset ]

provenance :: H.Matrix Double -> (H.Matrix Double, Provenance)
provenance :: Matrix Double -> (Matrix Double, GraphCol)
provenance Matrix Double
m = (Matrix Double
m', GraphCol
p)
  where
    ([(Matrix Double, RRefOp)]
matOps, Matrix Double
m') = Matrix Double
-> Int
-> Int
-> [(Matrix Double, RRefOp)]
-> ([(Matrix Double, RRefOp)], Matrix Double)
rrefMatrices' Matrix Double
m Int
0 Int
0 []
    p :: GraphCol
p = GraphCol -> GraphCol
invertGraphCol (GraphCol -> GraphCol)
-> ([RRefOp] -> GraphCol) -> [RRefOp] -> GraphCol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (GraphCol -> GraphCol -> GraphCol)
-> GraphCol -> [GraphCol] -> GraphCol
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' GraphCol -> GraphCol -> GraphCol
graphColCombine GraphCol
forall a. IntMap a
IM.empty ([GraphCol] -> GraphCol)
-> ([RRefOp] -> [GraphCol]) -> [RRefOp] -> GraphCol
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (RRefOp -> GraphCol) -> [RRefOp] -> [GraphCol]
forall a b. (a -> b) -> [a] -> [b]
map RRefOp -> GraphCol
opToGraphCol ([RRefOp] -> GraphCol) -> [RRefOp] -> GraphCol
forall a b. (a -> b) -> a -> b
$ ((Matrix Double, RRefOp) -> RRefOp)
-> [(Matrix Double, RRefOp)] -> [RRefOp]
forall a b. (a -> b) -> [a] -> [b]
map (Matrix Double, RRefOp) -> RRefOp
forall a b. (a, b) -> b
snd [(Matrix Double, RRefOp)]
matOps

-- Worker functions:

findInconsistentRows :: H.Matrix Double -> H.Matrix Double -> [Int]
findInconsistentRows :: Matrix Double -> Matrix Double -> [Int]
findInconsistentRows Matrix Double
coA Matrix Double
augA | Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
augA Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
2 = []
                              | Bool
otherwise     = [Int]
inconsistent
  where
    inconsistent :: [Int]
inconsistent = [Int
0..(Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
augA Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)] [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Int]
consistent

    consistent :: [Int]
consistent
      -- if the space is relatively small, try it all
      | Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
augA Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
16 = [[Int]] -> [Int]
forall a. [a] -> a
head (([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
filter [Int] -> Bool
tryRows ([Int] -> [[Int]]
forall b. [b] -> [[b]]
powerset ([Int] -> [[Int]]) -> [Int] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
forall a. [a] -> [a]
reverse [Int
0..(Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
augA Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)]))
      | Bool
otherwise = [[Int]] -> [Int]
forall a. [a] -> a
head (([Int] -> Bool) -> [[Int]] -> [[Int]]
forall a. (a -> Bool) -> [a] -> [a]
filter [Int] -> Bool
tryRows ([Int] -> [[Int]]
forall b. [b] -> [[b]]
tails ( [Int
0..(Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
augA Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)])) [[Int]] -> [[Int]] -> [[Int]]
forall a. [a] -> [a] -> [a]
++ [[]])

    powerset :: [b] -> [[b]]
powerset = (b -> [Bool]) -> [b] -> [[b]]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM ([Bool] -> b -> [Bool]
forall a b. a -> b -> a
const [Bool
True, Bool
False])

    -- Rouché–Capelli theorem is that if the rank of the coefficient
    -- matrix is not equal to the rank of the augmented matrix then
    -- the system of linear equations is inconsistent.
    tryRows :: [Int] -> Bool
tryRows [] = Bool
True
    tryRows [Int]
ns = (Matrix Double -> Int
forall t. Field t => Matrix t -> Int
rank Matrix Double
coA' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Matrix Double -> Int
forall t. Field t => Matrix t -> Int
rank Matrix Double
augA')
      where
        coA' :: Matrix Double
coA'  = Matrix Double
coA Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
? [Int]
ns
        augA' :: Matrix Double
augA' = Matrix Double
augA Matrix Double -> [Int] -> Matrix Double
forall t. Element t => Matrix t -> [Int] -> Matrix t
? [Int]
ns

-- | Create unique names for all of the inferred implicit polymorphic
-- unit variables.
chooseImplicitNames :: [(VV, UnitInfo)] -> [(VV, UnitInfo)]
chooseImplicitNames :: [(VV, UnitInfo)] -> [(VV, UnitInfo)]
chooseImplicitNames [(VV, UnitInfo)]
vars = Map UnitInfo UnitInfo -> [(VV, UnitInfo)] -> [(VV, UnitInfo)]
forall a. Data a => Map UnitInfo UnitInfo -> a -> a
replaceImplicitNames ([(VV, UnitInfo)] -> Map UnitInfo UnitInfo
forall a. Data a => a -> Map UnitInfo UnitInfo
genImplicitNamesMap [(VV, UnitInfo)]
vars) [(VV, UnitInfo)]
vars

genImplicitNamesMap :: Data a => a -> M.Map UnitInfo UnitInfo
genImplicitNamesMap :: a -> Map UnitInfo UnitInfo
genImplicitNamesMap a
x = [(UnitInfo, UnitInfo)] -> Map UnitInfo UnitInfo
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (UnitInfo
absU, VV -> UnitInfo
UnitParamEAPAbs (String
newN, String
newN)) | (UnitInfo
absU, String
newN) <- [UnitInfo] -> [String] -> [(UnitInfo, String)]
forall a b. [a] -> [b] -> [(a, b)]
zip [UnitInfo]
absUnits [String]
newNames ]
  where
    absUnits :: [UnitInfo]
absUnits = [UnitInfo] -> [UnitInfo]
forall a. Eq a => [a] -> [a]
nub [ UnitInfo
u | u :: UnitInfo
u@(UnitParamPosAbs (VV, Int)
_)             <- a -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi a
x ] [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++
               [UnitInfo] -> [UnitInfo]
forall a. Eq a => [a] -> [a]
nub [ UnitInfo
u | u :: UnitInfo
u@(UnitParamImpAbs String
_)             <- a -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi a
x ]
    eapNames :: [String]
eapNames = [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ [ String
n | (UnitParamEAPAbs (String
_, String
n))      <- a -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi a
x ] [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++
                     [ String
n | (UnitParamEAPUse ((String
_, String
n), Int
_)) <- a -> [UnitInfo]
forall from to. Biplate from to => from -> [to]
universeBi a
x ]
    newNames :: [String]
newNames = (String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [String]
eapNames) ([String] -> [String])
-> ([String] -> [String]) -> [String] -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShowS -> [String] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (Char
'\''Char -> ShowS
forall a. a -> [a] -> [a]
:) ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ [String]
nameGen
    nameGen :: [String]
nameGen  = ([String] -> [String]) -> [[String]] -> [String]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [String] -> [String]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([[String]] -> [String])
-> ([String] -> [[String]]) -> [String] -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[String]] -> [[String]]
forall a. [a] -> [a]
tail ([[String]] -> [[String]])
-> ([String] -> [[String]]) -> [String] -> [[String]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> [[String]]
forall b. [b] -> [[b]]
inits ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ String -> [String]
forall a. a -> [a]
repeat [Char
'a'..Char
'z']

replaceImplicitNames :: Data a => M.Map UnitInfo UnitInfo -> a -> a
replaceImplicitNames :: Map UnitInfo UnitInfo -> a -> a
replaceImplicitNames Map UnitInfo UnitInfo
implicitMap = (UnitInfo -> UnitInfo) -> a -> a
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi UnitInfo -> UnitInfo
replace
  where
    replace :: UnitInfo -> UnitInfo
replace u :: UnitInfo
u@(UnitParamPosAbs (VV, Int)
_) = UnitInfo -> Maybe UnitInfo -> UnitInfo
forall a. a -> Maybe a -> a
fromMaybe UnitInfo
u (Maybe UnitInfo -> UnitInfo) -> Maybe UnitInfo -> UnitInfo
forall a b. (a -> b) -> a -> b
$ UnitInfo -> Map UnitInfo UnitInfo -> Maybe UnitInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup UnitInfo
u Map UnitInfo UnitInfo
implicitMap
    replace u :: UnitInfo
u@(UnitParamImpAbs String
_) = UnitInfo -> Maybe UnitInfo -> UnitInfo
forall a. a -> Maybe a -> a
fromMaybe UnitInfo
u (Maybe UnitInfo -> UnitInfo) -> Maybe UnitInfo -> UnitInfo
forall a b. (a -> b) -> a -> b
$ UnitInfo -> Map UnitInfo UnitInfo -> Maybe UnitInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup UnitInfo
u Map UnitInfo UnitInfo
implicitMap
    replace UnitInfo
u                     = UnitInfo
u

-- | Identifies the variables that need to be annotated in order for
-- inference or checking to work.
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables [] = []
criticalVariables Constraints
cons = (UnitInfo -> Bool) -> [UnitInfo] -> [UnitInfo]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (UnitInfo -> Bool) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Bool
isUnitRHS') ([UnitInfo] -> [UnitInfo]) -> [UnitInfo] -> [UnitInfo]
forall a b. (a -> b) -> a -> b
$ (Int -> UnitInfo) -> [Int] -> [UnitInfo]
forall a b. (a -> b) -> [a] -> [b]
map (Array Int UnitInfo
colA Array Int UnitInfo -> Int -> UnitInfo
forall i e. Ix i => Array i e -> i -> e
A.!) [Int]
criticalIndices
  where
    (Matrix Double
unsolvedM, [Int]
_, Array Int UnitInfo
colA)          = Constraints -> (Matrix Double, [Int], Array Int UnitInfo)
constraintsToMatrix Constraints
cons
    solvedM :: Matrix Double
solvedM                       = Matrix Double -> Matrix Double
rref Matrix Double
unsolvedM
    uncriticalIndices :: [Int]
uncriticalIndices             = ([Double] -> Maybe Int) -> [[Double]] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Double -> Bool) -> [Double] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0)) ([[Double]] -> [Int]) -> [[Double]] -> [Int]
forall a b. (a -> b) -> a -> b
$ Matrix Double -> [[Double]]
forall t. Element t => Matrix t -> [[t]]
H.toLists Matrix Double
solvedM
    criticalIndices :: [Int]
criticalIndices               = Array Int UnitInfo -> [Int]
forall i e. Ix i => Array i e -> [i]
A.indices Array Int UnitInfo
colA [Int] -> [Int] -> [Int]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Int]
uncriticalIndices
    isUnitRHS' :: UnitInfo -> Bool
isUnitRHS' (UnitName String
_)       = Bool
True; isUnitRHS' UnitInfo
_ = Bool
False

-- | Returns just the list of constraints that were identified as
-- being possible candidates for inconsistency, if there is a problem.
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints [] = Maybe Constraints
forall a. Maybe a
Nothing
inconsistentConstraints Constraints
cons
  | Bool -> Bool
not (Constraints -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Constraints
direct) = Constraints -> Maybe Constraints
forall a. a -> Maybe a
Just Constraints
direct
  | [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
inconsists   = Maybe Constraints
forall a. Maybe a
Nothing
  | Bool
otherwise         = Constraints -> Maybe Constraints
forall a. a -> Maybe a
Just [ Constraint
con | (Constraint
con, Int
i) <- Constraints -> [Int] -> [(Constraint, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip Constraints
cons [Int
0..], Int
i Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
inconsists ]
  where
    (Matrix Double
_, Matrix Double
_, [Int]
inconsists, Array Int UnitInfo
_, Array Int UnitInfo
_) = Constraints
-> (Matrix Double, Matrix Double, [Int], Array Int UnitInfo,
    Array Int UnitInfo)
constraintsToMatrices Constraints
cons
    direct :: Constraints
direct = [([UnitInfo], UnitInfo)] -> Constraints
detectInconsistency ([([UnitInfo], UnitInfo)] -> Constraints)
-> [([UnitInfo], UnitInfo)] -> Constraints
forall a b. (a -> b) -> a -> b
$ SortFn -> Constraints -> [([UnitInfo], UnitInfo)]
genUnitAssignments' SortFn
colSort Constraints
cons