{-
   Copyright 2017, Matthew Danish, Vilem Liepelt, Dominic Orchard, Andrew Rice, Mistral Contrastin

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.
-}

{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Camfort.Specification.Units.InferenceBackendSBV
  ( inconsistentConstraints, criticalVariables, inferVariables, genUnitAssignments )
where

import           Camfort.Specification.Units.BackendTypes
import           Camfort.Specification.Units.Environment
import qualified Camfort.Specification.Units.InferenceBackend as MatrixBackend
import           Control.Monad
import           Data.Function (on)
import           Data.List (partition, sortBy, groupBy, nub)
import qualified Data.Map.Strict as M
import           Data.Maybe (catMaybes, fromMaybe)
import           Data.Ord (comparing)
import           Data.SBV hiding (engine, name)
import           Data.SBV.Control
import           Prelude hiding (pred)
import           System.IO.Unsafe (unsafePerformIO)

--------------------------------------------------

-- | Identifies the variables that need to be annotated in order for
-- inference or checking to work.
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables :: Constraints -> [UnitInfo]
criticalVariables Constraints
cons = case Constraints -> EngineResult
engine Constraints
cons of
  Left ([String], [(String, AugConstraint)])
_ -> []
  Right (Sub
_, [UnitInfo]
suggests) -> [UnitInfo]
suggests

-- | Returns just the list of constraints that were identified as
-- being possible candidates for inconsistency, if there is a problem.
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints :: Constraints -> Maybe Constraints
inconsistentConstraints [] = Maybe Constraints
forall a. Maybe a
Nothing
inconsistentConstraints Constraints
cons = case Constraints -> EngineResult
engine Constraints
cons of
  -- assuming that SBV provides a list of label names in its unsat 'core'
  Left ([String]
core, [(String, AugConstraint)]
labeledCons) -> Constraints -> Maybe Constraints
forall a. a -> Maybe a
Just (Constraints -> Maybe Constraints)
-> ([String] -> Constraints) -> [String] -> Maybe Constraints
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraints -> Constraints
normalise (Constraints -> Constraints)
-> ([String] -> Constraints) -> [String] -> Constraints
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (AugConstraint -> Constraint) -> [AugConstraint] -> Constraints
forall a b. (a -> b) -> [a] -> [b]
map AugConstraint -> Constraint
forall a b. (a, b) -> a
fst ([AugConstraint] -> Constraints)
-> ([String] -> [AugConstraint]) -> [String] -> Constraints
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe AugConstraint] -> [AugConstraint]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe AugConstraint] -> [AugConstraint])
-> ([String] -> [Maybe AugConstraint])
-> [String]
-> [AugConstraint]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> Maybe AugConstraint)
-> [String] -> [Maybe AugConstraint]
forall a b. (a -> b) -> [a] -> [b]
map ((String -> [(String, AugConstraint)] -> Maybe AugConstraint)
-> [(String, AugConstraint)] -> String -> Maybe AugConstraint
forall a b c. (a -> b -> c) -> b -> a -> c
flip String -> [(String, AugConstraint)] -> Maybe AugConstraint
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(String, AugConstraint)]
labeledCons) ([String] -> Maybe Constraints) -> [String] -> Maybe Constraints
forall a b. (a -> b) -> a -> b
$ [String]
core
  Right (Sub
_, [UnitInfo]
_) -> Maybe Constraints
forall a. Maybe a
Nothing
  where
    normalise :: Constraints -> Constraints
normalise = (Constraint -> Constraint) -> Constraints -> Constraints
forall a b. (a -> b) -> [a] -> [b]
map (Dim -> Constraint
dimToConstraint (Dim -> Constraint)
-> (Constraint -> Dim) -> Constraint -> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Dim
constraintToDim)

-- | Returns list of formerly-undetermined variables and their units.
inferVariables :: Constraints -> [(VV, UnitInfo)]
inferVariables :: Constraints -> [(VV, UnitInfo)]
inferVariables Constraints
cons = [(VV, UnitInfo)]
solvedVars
  where
    -- We are only interested in reporting the solutions to variables.
    solvedVars :: [(VV, UnitInfo)]
solvedVars = [ (VV
vv, UnitInfo
unit) | (UnitVar VV
vv, UnitInfo
unit) <- [(UnitInfo, UnitInfo)]
unitAssignments ] [(VV, UnitInfo)] -> [(VV, UnitInfo)] -> [(VV, UnitInfo)]
forall a. [a] -> [a] -> [a]
++
                 [ (VV
vv, UnitInfo
unit) | (UnitParamVarAbs (VV
_, VV
vv), UnitInfo
unit) <- [(UnitInfo, UnitInfo)]
unitAssignments ]
    unitAssignments :: [(UnitInfo, UnitInfo)]
unitAssignments = Constraints -> [(UnitInfo, UnitInfo)]
genUnitAssignments Constraints
cons

--------------------------------------------------

-- FIXME
approxEq :: Double -> Double -> Bool
approxEq :: Double -> Double -> Bool
approxEq Double
a Double
b = Double -> Double
forall a. Num a => a -> a
abs (Double
b Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
a) Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
epsilon
epsilon :: Double
epsilon :: Double
epsilon = Double
0.001 -- arbitrary

--------------------------------------------------

negateCons :: [UnitInfo] -> [UnitInfo]
negateCons :: [UnitInfo] -> [UnitInfo]
negateCons = (UnitInfo -> UnitInfo) -> [UnitInfo] -> [UnitInfo]
forall a b. (a -> b) -> [a] -> [b]
map (\ (UnitPow UnitInfo
u Double
k) -> UnitInfo -> Double -> UnitInfo
UnitPow UnitInfo
u (-Double
k))

-- negatePosAbs :: UnitInfo -> UnitInfo
-- negatePosAbs (UnitPow (UnitParamPosAbs x) k) = UnitPow (UnitParamPosAbs x) (-k)
-- negatePosAbs u                               = u

--------------------------------------------------

-- Units that should appear on the right-hand-side of the equations during solving
isUnitRHS :: UnitInfo -> Bool
isUnitRHS :: UnitInfo -> Bool
isUnitRHS (UnitPow (UnitName String
_) Double
_)        = Bool
True
isUnitRHS (UnitPow (UnitParamEAPAbs VV
_) Double
_) = Bool
True
isUnitRHS UnitInfo
_                               = Bool
False

type ShiftedConstraint = ([UnitInfo], [UnitInfo])
type ShiftedConstraints = [ShiftedConstraint]

-- | Shift UnitNames/EAPAbs poly units to the RHS, and all else to the LHS.
shiftTerms :: (UnitInfo -> Bool) -> ([UnitInfo], [UnitInfo]) -> ShiftedConstraint
shiftTerms :: (UnitInfo -> Bool)
-> ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms UnitInfo -> Bool
isUnitRHS' ([UnitInfo]
lhs, [UnitInfo]
rhs) = ([UnitInfo]
lhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
rhsShift, [UnitInfo]
rhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
lhsShift)
  where
    ([UnitInfo]
lhsOk, [UnitInfo]
lhsShift)           = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool -> Bool
not (Bool -> Bool) -> (UnitInfo -> Bool) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Bool
isUnitRHS') [UnitInfo]
lhs
    ([UnitInfo]
rhsOk, [UnitInfo]
rhsShift)           = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition UnitInfo -> Bool
isUnitRHS' [UnitInfo]
rhs

-- | Translate all constraints into a LHS, RHS side of units.
flattenConstraints :: Constraints         -> [([UnitInfo], [UnitInfo])]
flattenConstraints :: Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints = (Constraint -> ([UnitInfo], [UnitInfo]))
-> Constraints -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> [a] -> [b]
map (\ (ConEq UnitInfo
u1 UnitInfo
u2) -> (UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
u1, UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
u2))

--------------------------------------------------

-- type Z3 a   = Symbolic a
type Symbol = SInteger

-- type UnitZ3Map = M.Map (UnitInfo, UnitInfo) Symbol

type LhsUnit         = UnitInfo
type RhsUnit         = UnitInfo
type NameUnitInfoMap = M.Map String (LhsUnit, RhsUnit)
type NameSIntegerMap = M.Map String SInteger

gatherRhsUnitInfoNames :: [[UnitInfo]] -> [(String, RhsUnit)]
gatherRhsUnitInfoNames :: [[UnitInfo]] -> [(String, UnitInfo)]
gatherRhsUnitInfoNames [[UnitInfo]]
rhses
  | [(String, UnitInfo)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(String, UnitInfo)]
rhsNames = [(String
"bogus", String -> UnitInfo
UnitName String
"bogus")]
  | Bool
otherwise     = [(String, UnitInfo)]
rhsNames

  where
    rhsNames :: [(String, UnitInfo)]
rhsNames = ([UnitInfo] -> [(String, UnitInfo)])
-> [[UnitInfo]] -> [(String, UnitInfo)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [UnitInfo] -> [(String, UnitInfo)]
eachRow [[UnitInfo]]
rhses

    eachRow :: [UnitInfo] -> [(String, UnitInfo)]
eachRow  = (UnitInfo -> (String, UnitInfo))
-> [UnitInfo] -> [(String, UnitInfo)]
forall a b. (a -> b) -> [a] -> [b]
map UnitInfo -> (String, UnitInfo)
eachCol

    eachCol :: UnitInfo -> (String, UnitInfo)
eachCol (UnitPow UnitInfo
u Double
_) = (UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u, UnitInfo
u)
    eachCol UnitInfo
u             = (UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u, UnitInfo
u)

gatherLhsUnitInfoNames :: (String, RhsUnit) -> [[UnitInfo]] -> [(String, (LhsUnit, RhsUnit))]
gatherLhsUnitInfoNames :: (String, UnitInfo)
-> [[UnitInfo]] -> [(String, (UnitInfo, UnitInfo))]
gatherLhsUnitInfoNames (String
rhsName, UnitInfo
rhsUnit) = ([UnitInfo] -> [(String, (UnitInfo, UnitInfo))])
-> [[UnitInfo]] -> [(String, (UnitInfo, UnitInfo))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap [UnitInfo] -> [(String, (UnitInfo, UnitInfo))]
eachRow
  where
    eachRow :: [UnitInfo] -> [(String, (UnitInfo, UnitInfo))]
eachRow                               = (UnitInfo -> (String, (UnitInfo, UnitInfo)))
-> [UnitInfo] -> [(String, (UnitInfo, UnitInfo))]
forall a b. (a -> b) -> [a] -> [b]
map UnitInfo -> (String, (UnitInfo, UnitInfo))
eachCol

    eachCol :: UnitInfo -> (String, (UnitInfo, UnitInfo))
eachCol (UnitPow UnitInfo
u Double
_) = (UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
rhsName, (UnitInfo
u, UnitInfo
rhsUnit))
    eachCol UnitInfo
u             = (UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
rhsName, (UnitInfo
u, UnitInfo
rhsUnit))

gatherNameUnitInfoMap :: [([UnitInfo], [UnitInfo])] -> NameUnitInfoMap
gatherNameUnitInfoMap :: [([UnitInfo], [UnitInfo])] -> NameUnitInfoMap
gatherNameUnitInfoMap [([UnitInfo], [UnitInfo])]
shiftedCons = ((UnitInfo, UnitInfo)
 -> (UnitInfo, UnitInfo) -> (UnitInfo, UnitInfo))
-> [(String, (UnitInfo, UnitInfo))] -> NameUnitInfoMap
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith ((((UnitInfo, UnitInfo), (UnitInfo, UnitInfo))
 -> (UnitInfo, UnitInfo))
-> (UnitInfo, UnitInfo)
-> (UnitInfo, UnitInfo)
-> (UnitInfo, UnitInfo)
forall a b c. ((a, b) -> c) -> a -> b -> c
curry ((UnitInfo, UnitInfo), (UnitInfo, UnitInfo))
-> (UnitInfo, UnitInfo)
forall a b. (a, b) -> a
fst) [(String, (UnitInfo, UnitInfo))]
lhsNames
  where
    lhsNames :: [(String, (UnitInfo, UnitInfo))]
lhsNames                      = ((String, UnitInfo) -> [(String, (UnitInfo, UnitInfo))])
-> [(String, UnitInfo)] -> [(String, (UnitInfo, UnitInfo))]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (((String, UnitInfo)
 -> [[UnitInfo]] -> [(String, (UnitInfo, UnitInfo))])
-> [[UnitInfo]]
-> (String, UnitInfo)
-> [(String, (UnitInfo, UnitInfo))]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (String, UnitInfo)
-> [[UnitInfo]] -> [(String, (UnitInfo, UnitInfo))]
gatherLhsUnitInfoNames [[UnitInfo]]
lhsRows) [(String, UnitInfo)]
rhsNames
    lhsRows :: [[UnitInfo]]
lhsRows                       = (([UnitInfo], [UnitInfo]) -> [UnitInfo])
-> [([UnitInfo], [UnitInfo])] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> [UnitInfo]
forall a b. (a, b) -> a
fst [([UnitInfo], [UnitInfo])]
shiftedCons

    rhsNames :: [(String, UnitInfo)]
rhsNames = [[UnitInfo]] -> [(String, UnitInfo)]
gatherRhsUnitInfoNames [[UnitInfo]]
rhsRows
    rhsRows :: [[UnitInfo]]
rhsRows = (([UnitInfo], [UnitInfo]) -> [UnitInfo])
-> [([UnitInfo], [UnitInfo])] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> [UnitInfo]
forall a b. (a, b) -> b
snd [([UnitInfo], [UnitInfo])]
shiftedCons

-- | Map of RHS Names to initial powers (0). Forms the basis of the
-- solution for every unit variable.
type BasisMap = M.Map String Integer

genBasisMap :: ShiftedConstraints -> BasisMap
genBasisMap :: [([UnitInfo], [UnitInfo])] -> BasisMap
genBasisMap [([UnitInfo], [UnitInfo])]
shiftedCons = BasisMap
baseRhsMap
  where
    rhsNames :: [(String, UnitInfo)]
    rhsNames :: [(String, UnitInfo)]
rhsNames = [[UnitInfo]] -> [(String, UnitInfo)]
gatherRhsUnitInfoNames ((([UnitInfo], [UnitInfo]) -> [UnitInfo])
-> [([UnitInfo], [UnitInfo])] -> [[UnitInfo]]
forall a b. (a -> b) -> [a] -> [b]
map ([UnitInfo], [UnitInfo]) -> [UnitInfo]
forall a b. (a, b) -> b
snd [([UnitInfo], [UnitInfo])]
shiftedCons)

    -- start off with every RHS mapped to a power of zero.
    baseRhsMap :: BasisMap
baseRhsMap = [(String, Integer)] -> BasisMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (String
name, Integer
0) | (String
name, UnitInfo
_) <- [(String, UnitInfo)]
rhsNames ]

genUnitAssignments :: Constraints -> [(UnitInfo, UnitInfo)]
genUnitAssignments :: Constraints -> [(UnitInfo, UnitInfo)]
genUnitAssignments Constraints
cons = case Constraints -> EngineResult
engine Constraints
cons of
  Left ([String], [(String, AugConstraint)])
_ -> []
  Right (Sub
sub, [UnitInfo]
_) -> Sub -> [(UnitInfo, UnitInfo)]
subToList Sub
sub

basicOptimisations :: Constraints -> Constraints
basicOptimisations :: Constraints -> Constraints
basicOptimisations Constraints
cons = Constraints
cons'
  where
    cons' :: Constraints
cons' = (Constraint -> Bool) -> Constraints -> Constraints
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Constraint -> Bool) -> Constraint -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Bool
identicalSides) Constraints
cons
    identicalSides :: Constraint -> Bool
identicalSides (ConEq UnitInfo
lhs UnitInfo
rhs) = UnitInfo
lhs UnitInfo -> UnitInfo -> Bool
forall a. Eq a => a -> a -> Bool
== UnitInfo
rhs
    identicalSides Constraint
_               = Bool
False

type EngineResult = Either ([String], [(String, AugConstraint)]) (Sub, [UnitInfo])

-- main working function
engine :: Constraints -> EngineResult
engine :: Constraints -> EngineResult
engine Constraints
cons = IO EngineResult -> EngineResult
forall a. IO a -> a
unsafePerformIO (IO EngineResult -> EngineResult)
-> IO EngineResult -> EngineResult
forall a b. (a -> b) -> a -> b
$ do
  let shiftedCons :: ShiftedConstraints
      shiftedCons :: [([UnitInfo], [UnitInfo])]
shiftedCons = (([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo]))
-> [([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> [a] -> [b]
map ((UnitInfo -> Bool)
-> ([UnitInfo], [UnitInfo]) -> ([UnitInfo], [UnitInfo])
shiftTerms UnitInfo -> Bool
isUnitRHS) ([([UnitInfo], [UnitInfo])] -> [([UnitInfo], [UnitInfo])])
-> (Constraints -> [([UnitInfo], [UnitInfo])])
-> Constraints
-> [([UnitInfo], [UnitInfo])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraints -> [([UnitInfo], [UnitInfo])]
flattenConstraints (Constraints -> [([UnitInfo], [UnitInfo])])
-> Constraints -> [([UnitInfo], [UnitInfo])]
forall a b. (a -> b) -> a -> b
$ Constraints -> Constraints
basicOptimisations Constraints
cons

  let nameUIMap :: NameUnitInfoMap
nameUIMap = [([UnitInfo], [UnitInfo])] -> NameUnitInfoMap
gatherNameUnitInfoMap [([UnitInfo], [UnitInfo])]
shiftedCons

  let genVar :: String -> Symbolic (String, SInteger)
      genVar :: String -> Symbolic (String, SInteger)
genVar String
name = (String
name,) (SInteger -> (String, SInteger))
-> SymbolicT IO SInteger -> Symbolic (String, SInteger)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SymbolicT IO SInteger
sInteger String
name

  -- basis of the solution, a.k.a. the primitive units specified by the user
  let basisMap :: BasisMap
basisMap = [([UnitInfo], [UnitInfo])] -> BasisMap
genBasisMap [([UnitInfo], [UnitInfo])]
shiftedCons

  let pred :: Symbolic EngineResult
      pred :: Symbolic EngineResult
pred = do
        SMTOption -> SymbolicT IO ()
forall (m :: * -> *). SolverContext m => SMTOption -> m ()
setOption (SMTOption -> SymbolicT IO ()) -> SMTOption -> SymbolicT IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> SMTOption
ProduceUnsatCores Bool
True
        -- pregenerate all of the necessary existentials
        Map String SInteger
nameSIntMap <- [(String, SInteger)] -> Map String SInteger
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(String, SInteger)] -> Map String SInteger)
-> SymbolicT IO [(String, SInteger)]
-> SymbolicT IO (Map String SInteger)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> Symbolic (String, SInteger))
-> [String] -> SymbolicT IO [(String, SInteger)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM String -> Symbolic (String, SInteger)
genVar (NameUnitInfoMap -> [String]
forall k a. Map k a -> [k]
M.keys NameUnitInfoMap
nameUIMap)

        -- temporary arrangement for now to identify constraints
        let encCons :: [(SBool, AugConstraint)]
encCons = BasisMap
-> NameUnitInfoMap
-> Map String SInteger
-> [([UnitInfo], [UnitInfo])]
-> [(SBool, AugConstraint)]
encodeConstraints BasisMap
basisMap NameUnitInfoMap
nameUIMap Map String SInteger
nameSIntMap [([UnitInfo], [UnitInfo])]
shiftedCons
        [(String, AugConstraint)]
labeledCons <- [(Int, (SBool, AugConstraint))]
-> ((Int, (SBool, AugConstraint))
    -> SymbolicT IO (String, AugConstraint))
-> SymbolicT IO [(String, AugConstraint)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([Int]
-> [(SBool, AugConstraint)] -> [(Int, (SBool, AugConstraint))]
forall a b. [a] -> [b] -> [(a, b)]
zip [(Int
1::Int)..] [(SBool, AugConstraint)]
encCons) (((Int, (SBool, AugConstraint))
  -> SymbolicT IO (String, AugConstraint))
 -> SymbolicT IO [(String, AugConstraint)])
-> ((Int, (SBool, AugConstraint))
    -> SymbolicT IO (String, AugConstraint))
-> SymbolicT IO [(String, AugConstraint)]
forall a b. (a -> b) -> a -> b
$ \ (Int
i, (SBool
sbool, AugConstraint
augCon)) -> do
          String -> SBool -> SymbolicT IO ()
forall (m :: * -> *). SolverContext m => String -> SBool -> m ()
namedConstraint (String
"c"String -> String -> String
forall a. [a] -> [a] -> [a]
++Int -> String
forall a. Show a => a -> String
show Int
i) SBool
sbool
          (String, AugConstraint) -> SymbolicT IO (String, AugConstraint)
forall (m :: * -> *) a. Monad m => a -> m a
return (String
"c"String -> String -> String
forall a. [a] -> [a] -> [a]
++Int -> String
forall a. Show a => a -> String
show Int
i, AugConstraint
augCon)

        Query EngineResult -> Symbolic EngineResult
forall a. Query a -> Symbolic a
query (Query EngineResult -> Symbolic EngineResult)
-> Query EngineResult -> Symbolic EngineResult
forall a b. (a -> b) -> a -> b
$ do
          -- obtain at least 1 name, value mapping for each variable if consistent
          Either [String] NameValueInfoMap
e_nvMap <- Map String SInteger -> Query (Either [String] NameValueInfoMap)
computeInitialNVMap Map String SInteger
nameSIntMap
          case Either [String] NameValueInfoMap
e_nvMap of
            Left [String]
core -> EngineResult -> Query EngineResult
forall (m :: * -> *) a. Monad m => a -> m a
return (EngineResult -> Query EngineResult)
-> EngineResult -> Query EngineResult
forall a b. (a -> b) -> a -> b
$ ([String], [(String, AugConstraint)]) -> EngineResult
forall a b. a -> Either a b
Left ([String]
core, [(String, AugConstraint)]
labeledCons) -- inconsistent
            Right NameValueInfoMap
nvMap -> do
              -- interpret the suggested values as a list of substitutions
              Sub
assignSubs <- NameUnitInfoMap -> NameValueInfoMap -> Query Sub
interpret NameUnitInfoMap
nameUIMap NameValueInfoMap
nvMap

              -- convert to Dim format
              let dims :: [Dim]
dims = (([UnitInfo], [UnitInfo]) -> Dim)
-> [([UnitInfo], [UnitInfo])] -> [Dim]
forall a b. (a -> b) -> [a] -> [b]
map (\ ([UnitInfo]
lhs, [UnitInfo]
rhs) -> ([UnitInfo] -> Dim
dimFromUnitInfos ([UnitInfo]
lhs [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
rhs))) [([UnitInfo], [UnitInfo])]
shiftedCons

              -- apply known substitutions from solver
              let dims' :: [Dim]
dims' = (Dim -> Bool) -> [Dim] -> [Dim]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (Dim -> Bool) -> Dim -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Dim -> Bool
isIdentDim) ([Dim] -> [Dim]) -> [Dim] -> [Dim]
forall a b. (a -> b) -> a -> b
$ (Dim -> Dim) -> [Dim] -> [Dim]
forall a b. (a -> b) -> [a] -> [b]
map (Sub -> Dim -> Dim
applySub Sub
assignSubs) [Dim]
dims

              -- convert to Constraint format
              let polyCons :: Constraints
polyCons = (Dim -> Constraint) -> [Dim] -> Constraints
forall a b. (a -> b) -> [a] -> [b]
map Dim -> Constraint
dimToConstraint [Dim]
dims'

              -- feed back into old solver to figure out polymorphic equations
              let polyAssigns :: [([UnitInfo], UnitInfo)]
polyAssigns = Constraints -> [([UnitInfo], UnitInfo)]
MatrixBackend.genUnitAssignments Constraints
polyCons

              -- convert polymorphic assignments into substitution format
              let polySubs :: Sub
polySubs = [(UnitInfo, Dim)] -> Sub
subFromList [ (UnitInfo
u, UnitInfo -> Dim
dimFromUnitInfo UnitInfo
units)
                                         | ([UnitPow u :: UnitInfo
u@(UnitParamVarAbs (VV, VV)
_) Double
k], UnitInfo
units) <- [([UnitInfo], UnitInfo)]
polyAssigns
                                         , Double
k Double -> Double -> Bool
`approxEq` Double
1 ]

              let criticals :: [UnitInfo]
criticals = Constraints -> [UnitInfo]
MatrixBackend.criticalVariables Constraints
polyCons

              -- for now we'll suggest all underdetermined units but
              -- this should be cut down by considering the
              -- relationships between variables, much like we would
              -- do for polymorphic vars.
              let suggests :: [UnitInfo]
suggests = [ UnitInfo
v | v :: UnitInfo
v@(UnitVar {}) <- [UnitInfo]
criticals ] [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++
                             [ UnitInfo
v | v :: UnitInfo
v@(UnitParamVarUse {}) <- [UnitInfo]
criticals ]

              EngineResult -> Query EngineResult
forall (m :: * -> *) a. Monad m => a -> m a
return (EngineResult -> Query EngineResult)
-> (Sub -> EngineResult) -> Sub -> Query EngineResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Sub, [UnitInfo]) -> EngineResult
forall a b. b -> Either a b
Right ((Sub, [UnitInfo]) -> EngineResult)
-> (Sub -> (Sub, [UnitInfo])) -> Sub -> EngineResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,[UnitInfo]
suggests) (Sub -> Query EngineResult) -> Sub -> Query EngineResult
forall a b. (a -> b) -> a -> b
$ Sub -> Sub -> Sub
composeSubs Sub
polySubs Sub
assignSubs

  SMTConfig -> Symbolic EngineResult -> IO EngineResult
forall a. SMTConfig -> Symbolic a -> IO a
runSMTWith SMTConfig
z3 { transcript :: Maybe String
transcript = String -> Maybe String
forall a. a -> Maybe a
Just String
"backend-sbv.smt2" } -- SMT-LIB dump
             Symbolic EngineResult
pred

-- Assumes unitinfo was already simplified & flattened: extracts a
-- name and power
getUnitNamePow :: UnitInfo -> (String, Integer)
getUnitNamePow :: UnitInfo -> (String, Integer)
getUnitNamePow (UnitPow UnitInfo
u Double
p) = (String
uName, Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
p')
  where (String
uName, Integer
p') = UnitInfo -> (String, Integer)
getUnitNamePow UnitInfo
u
getUnitNamePow UnitInfo
u = (UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u, Integer
1)

-- augmented constraint also includes the "RHS name"
type AugConstraint = (Constraint, String)

encodeConstraints :: BasisMap -> NameUnitInfoMap -> NameSIntegerMap -> ShiftedConstraints -> [(SBool, AugConstraint)]
encodeConstraints :: BasisMap
-> NameUnitInfoMap
-> Map String SInteger
-> [([UnitInfo], [UnitInfo])]
-> [(SBool, AugConstraint)]
encodeConstraints BasisMap
basisMap NameUnitInfoMap
_ Map String SInteger
nameSIntMap [([UnitInfo], [UnitInfo])]
shiftedCons = do
  let getLhsSymbol :: String -> UnitInfo -> (Symbol, Integer)
      getLhsSymbol :: String -> UnitInfo -> (SInteger, Integer)
getLhsSymbol String
rhsName (UnitPow UnitInfo
u Double
p) = (SInteger
uSym, Double -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
floor Double
p Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
p')
        where (SInteger
uSym, Integer
p') = String -> UnitInfo -> (SInteger, Integer)
getLhsSymbol String
rhsName UnitInfo
u
      getLhsSymbol String
rhsName UnitInfo
u = (SInteger
s, Integer
1)
        where n :: String
n = UnitInfo -> String
forall a. Show a => a -> String
show UnitInfo
u String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
rhsName
              s :: SInteger
s = String -> SInteger
forall a. HasCallStack => String -> a
error (String
"missing variable for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
n) SInteger -> Maybe SInteger -> SInteger
forall a. a -> Maybe a -> a
`fromMaybe` String -> Map String SInteger -> Maybe SInteger
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
n Map String SInteger
nameSIntMap

  -- for each RHS name and corresponding power build an equation of the form:
  --   lhs1_RHS * pow1 + lhs2_RHS * pow2 + ... + lhsN_RHS powN = pow_RHS
  let eachRhs :: Constraint -> [UnitInfo] -> (String, Integer) -> Maybe (SBool, AugConstraint)
      eachRhs :: Constraint
-> [UnitInfo] -> (String, Integer) -> Maybe (SBool, AugConstraint)
eachRhs Constraint
con [UnitInfo]
lhs (String
rhsName, Integer
rhsPow)
        | [SInteger] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SInteger]
lhsTerms = (SBool, AugConstraint) -> Maybe (SBool, AugConstraint)
forall a. a -> Maybe a
Just (SInteger
0 SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== Integer -> SInteger
forall a. SymVal a => a -> SBV a
literal Integer
rhsPow, (Constraint
con, String
rhsName))
        | Bool
otherwise     = (SBool, AugConstraint) -> Maybe (SBool, AugConstraint)
forall a. a -> Maybe a
Just ([SInteger] -> SInteger
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [SInteger]
lhsTerms SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== Integer -> SInteger
forall a. SymVal a => a -> SBV a
literal Integer
rhsPow, (Constraint
con, String
rhsName))
        where
          -- lhsTerms = [lhs1_RHS * pow1, lhs2_RHS * pow2, ..., lhsN_RHS powN]
          lhsTerms :: [SInteger]
          lhsTerms :: [SInteger]
lhsTerms = [ SInteger
lhsSym SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* Integer -> SInteger
forall a. SymVal a => a -> SBV a
literal Integer
lhsPow | UnitInfo
lhs_i <- [UnitInfo]
lhs
                                               , let (SInteger
lhsSym, Integer
lhsPow) = String -> UnitInfo -> (SInteger, Integer)
getLhsSymbol String
rhsName UnitInfo
lhs_i ]
          -- msg = intercalate " + " [ lhsName ++ "(" ++ rhsName ++ ") * " ++ show lhsPow
          --                         | lhs_i <- lhs
          --                         , let (lhsName, lhsPow) = getUnitNamePow lhs_i ] ++
          --       " == " ++ rhsName ++ " * " ++ show rhsPow

  -- for each constraint having a set of LHS terms and a set of RHS terms:
  let eachConstraint :: ([UnitInfo], [UnitInfo]) -> [(SBool, AugConstraint)]
      eachConstraint :: ([UnitInfo], [UnitInfo]) -> [(SBool, AugConstraint)]
eachConstraint ([UnitInfo]
lhs, [UnitInfo]
rhs) = [(SBool, AugConstraint)]
res
        where
          con :: Constraint
con       = UnitInfo -> UnitInfo -> Constraint
ConEq ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits [UnitInfo]
lhs) ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits [UnitInfo]
rhs)
          -- msg       = "eachConstraint " ++ show (lhs, rhs) ++ " = " ++ show res
          res :: [(SBool, AugConstraint)]
res       = [Maybe (SBool, AugConstraint)] -> [(SBool, AugConstraint)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (SBool, AugConstraint)] -> [(SBool, AugConstraint)])
-> ([(String, Integer)] -> [Maybe (SBool, AugConstraint)])
-> [(String, Integer)]
-> [(SBool, AugConstraint)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Integer) -> Maybe (SBool, AugConstraint))
-> [(String, Integer)] -> [Maybe (SBool, AugConstraint)]
forall a b. (a -> b) -> [a] -> [b]
map (Constraint
-> [UnitInfo] -> (String, Integer) -> Maybe (SBool, AugConstraint)
eachRhs Constraint
con [UnitInfo]
lhs) ([(String, Integer)] -> [(SBool, AugConstraint)])
-> [(String, Integer)] -> [(SBool, AugConstraint)]
forall a b. (a -> b) -> a -> b
$ [(String, Integer)]
rhsPowers
          -- map every RHS to its corresponding power (including 0 for those not mentioned)
          rhsPowers :: [(String, Integer)]
rhsPowers = BasisMap -> [(String, Integer)]
forall k a. Map k a -> [(k, a)]
M.toList (BasisMap -> [(String, Integer)])
-> ([UnitInfo] -> BasisMap) -> [UnitInfo] -> [(String, Integer)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer -> Integer) -> BasisMap -> BasisMap -> BasisMap
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+) BasisMap
basisMap (BasisMap -> BasisMap)
-> ([UnitInfo] -> BasisMap) -> [UnitInfo] -> BasisMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> Integer -> Integer) -> [(String, Integer)] -> BasisMap
forall k a. Ord k => (a -> a -> a) -> [(k, a)] -> Map k a
M.fromListWith Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(+) ([(String, Integer)] -> BasisMap)
-> ([UnitInfo] -> [(String, Integer)]) -> [UnitInfo] -> BasisMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UnitInfo -> (String, Integer))
-> [UnitInfo] -> [(String, Integer)]
forall a b. (a -> b) -> [a] -> [b]
map UnitInfo -> (String, Integer)
getUnitNamePow ([UnitInfo] -> [(String, Integer)])
-> [UnitInfo] -> [(String, Integer)]
forall a b. (a -> b) -> a -> b
$ [UnitInfo]
rhs

  (([UnitInfo], [UnitInfo]) -> [(SBool, AugConstraint)])
-> [([UnitInfo], [UnitInfo])] -> [(SBool, AugConstraint)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ([UnitInfo], [UnitInfo]) -> [(SBool, AugConstraint)]
eachConstraint [([UnitInfo], [UnitInfo])]
shiftedCons

-- showConstraints :: BasisMap -> ShiftedConstraints -> [String]
-- showConstraints basisMap = map mkMsg
--   where
-- --    mkMsg ([], rhs) = ""
--     mkMsg (lhs, rhs) = intercalate "\n" . filter (not . null) $ map (perRhs lhs) rhsPowers
--       where
--         rhsPowers = M.toList . M.unionWith (+) basisMap . M.fromListWith (+) . map getUnitNamePow $ rhs

--     perRhs lhs (rhsName, rhsPow) = msg
--       where
--         msg = intercalate " + " [ lhsName ++ "(" ++ rhsName ++ ") * " ++ show lhsPow
--                                 | lhs_i <- lhs
--                                 , let (lhsName, lhsPow) = getUnitNamePow lhs_i ] ++
--               " == " ++ rhsName ++ " * " ++ show rhsPow


data ValueInfo
  = VISet [Integer]
  | VISuggest
  | VIParametric Integer
  deriving (Int -> ValueInfo -> String -> String
[ValueInfo] -> String -> String
ValueInfo -> String
(Int -> ValueInfo -> String -> String)
-> (ValueInfo -> String)
-> ([ValueInfo] -> String -> String)
-> Show ValueInfo
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [ValueInfo] -> String -> String
$cshowList :: [ValueInfo] -> String -> String
show :: ValueInfo -> String
$cshow :: ValueInfo -> String
showsPrec :: Int -> ValueInfo -> String -> String
$cshowsPrec :: Int -> ValueInfo -> String -> String
Show, ValueInfo -> ValueInfo -> Bool
(ValueInfo -> ValueInfo -> Bool)
-> (ValueInfo -> ValueInfo -> Bool) -> Eq ValueInfo
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ValueInfo -> ValueInfo -> Bool
$c/= :: ValueInfo -> ValueInfo -> Bool
== :: ValueInfo -> ValueInfo -> Bool
$c== :: ValueInfo -> ValueInfo -> Bool
Eq, Eq ValueInfo
Eq ValueInfo
-> (ValueInfo -> ValueInfo -> Ordering)
-> (ValueInfo -> ValueInfo -> Bool)
-> (ValueInfo -> ValueInfo -> Bool)
-> (ValueInfo -> ValueInfo -> Bool)
-> (ValueInfo -> ValueInfo -> Bool)
-> (ValueInfo -> ValueInfo -> ValueInfo)
-> (ValueInfo -> ValueInfo -> ValueInfo)
-> Ord ValueInfo
ValueInfo -> ValueInfo -> Bool
ValueInfo -> ValueInfo -> Ordering
ValueInfo -> ValueInfo -> ValueInfo
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ValueInfo -> ValueInfo -> ValueInfo
$cmin :: ValueInfo -> ValueInfo -> ValueInfo
max :: ValueInfo -> ValueInfo -> ValueInfo
$cmax :: ValueInfo -> ValueInfo -> ValueInfo
>= :: ValueInfo -> ValueInfo -> Bool
$c>= :: ValueInfo -> ValueInfo -> Bool
> :: ValueInfo -> ValueInfo -> Bool
$c> :: ValueInfo -> ValueInfo -> Bool
<= :: ValueInfo -> ValueInfo -> Bool
$c<= :: ValueInfo -> ValueInfo -> Bool
< :: ValueInfo -> ValueInfo -> Bool
$c< :: ValueInfo -> ValueInfo -> Bool
compare :: ValueInfo -> ValueInfo -> Ordering
$ccompare :: ValueInfo -> ValueInfo -> Ordering
$cp1Ord :: Eq ValueInfo
Ord)

type NameValueInfoMap = M.Map String ValueInfo

computeInitialNVMap :: NameSIntegerMap -> Query (Either [String] NameValueInfoMap)
computeInitialNVMap :: Map String SInteger -> Query (Either [String] NameValueInfoMap)
computeInitialNVMap Map String SInteger
nameSIntMap = do
  CheckSatResult
cs <- Query CheckSatResult
checkSat
  case CheckSatResult
cs of
    CheckSatResult
Unsat -> [String] -> Either [String] NameValueInfoMap
forall a b. a -> Either a b
Left ([String] -> Either [String] NameValueInfoMap)
-> QueryT IO [String] -> Query (Either [String] NameValueInfoMap)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> QueryT IO [String]
getUnsatCore
    CheckSatResult
Sat -> do
      NameValueInfoMap
nvMap <- Map String SInteger -> Query NameValueInfoMap
extractSIntValues Map String SInteger
nameSIntMap
      Int -> Query ()
push Int
1
      Map String SInteger -> NameValueInfoMap -> Query ()
disallowValues Map String SInteger
nameSIntMap NameValueInfoMap
nvMap
      CheckSatResult
cs' <- Query CheckSatResult
checkSat
      case CheckSatResult
cs' of
        CheckSatResult
Sat -> do
          NameValueInfoMap
nvMap' <- Map String SInteger -> Query NameValueInfoMap
extractSIntValues Map String SInteger
nameSIntMap
          let nvMap'' :: NameValueInfoMap
nvMap'' = (ValueInfo -> ValueInfo -> ValueInfo)
-> NameValueInfoMap -> NameValueInfoMap -> NameValueInfoMap
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith ValueInfo -> ValueInfo -> ValueInfo
nvUnion NameValueInfoMap
nvMap NameValueInfoMap
nvMap'
          Int -> Query ()
pop Int
1
          Either [String] NameValueInfoMap
-> Query (Either [String] NameValueInfoMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [String] NameValueInfoMap
 -> Query (Either [String] NameValueInfoMap))
-> Either [String] NameValueInfoMap
-> Query (Either [String] NameValueInfoMap)
forall a b. (a -> b) -> a -> b
$ NameValueInfoMap -> Either [String] NameValueInfoMap
forall a b. b -> Either a b
Right NameValueInfoMap
nvMap''
        CheckSatResult
_   -> do
          Int -> Query ()
pop Int
1
          Either [String] NameValueInfoMap
-> Query (Either [String] NameValueInfoMap)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either [String] NameValueInfoMap
 -> Query (Either [String] NameValueInfoMap))
-> Either [String] NameValueInfoMap
-> Query (Either [String] NameValueInfoMap)
forall a b. (a -> b) -> a -> b
$ NameValueInfoMap -> Either [String] NameValueInfoMap
forall a b. b -> Either a b
Right NameValueInfoMap
nvMap
    CheckSatResult
_ -> String -> Query (Either [String] NameValueInfoMap)
forall a. HasCallStack => String -> a
error String
"unknown"

-- identifyMultipleVISet :: NameUnitInfoMap -> NameValueInfoMap -> [UnitInfo]
-- identifyMultipleVISet nameUIMap = nub . map fst . catMaybes . map (`M.lookup` nameUIMap) . M.keys . M.filter isMultipleVISet

-- isMultipleVISet :: ValueInfo -> Bool
-- isMultipleVISet (VISet (_:_:_)) = True
-- isMultipleVISet _               = False

nvUnion :: ValueInfo -> ValueInfo -> ValueInfo
nvUnion :: ValueInfo -> ValueInfo -> ValueInfo
nvUnion (VISet [Integer]
xs) (VISet [Integer]
ys) = [Integer] -> ValueInfo
VISet ([Integer] -> ValueInfo)
-> ([Integer] -> [Integer]) -> [Integer] -> ValueInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Integer] -> [Integer]
forall a. Eq a => [a] -> [a]
nub ([Integer] -> ValueInfo) -> [Integer] -> ValueInfo
forall a b. (a -> b) -> a -> b
$ [Integer]
xs [Integer] -> [Integer] -> [Integer]
forall a. [a] -> [a] -> [a]
++ [Integer]
ys
nvUnion ValueInfo
x ValueInfo
y = String -> ValueInfo
forall a. HasCallStack => String -> a
error (String -> ValueInfo) -> String -> ValueInfo
forall a b. (a -> b) -> a -> b
$ String
"nvUnion on (" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ValueInfo -> String
forall a. Show a => a -> String
show ValueInfo
x String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ValueInfo -> String
forall a. Show a => a -> String
show ValueInfo
y String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
")"

extractSIntValues :: NameSIntegerMap -> Query NameValueInfoMap
extractSIntValues :: Map String SInteger -> Query NameValueInfoMap
extractSIntValues = ([(String, ValueInfo)] -> NameValueInfoMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(String, ValueInfo)] -> NameValueInfoMap)
-> QueryT IO [(String, ValueInfo)] -> Query NameValueInfoMap
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>) (QueryT IO [(String, ValueInfo)] -> Query NameValueInfoMap)
-> (Map String SInteger -> QueryT IO [(String, ValueInfo)])
-> Map String SInteger
-> Query NameValueInfoMap
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, SInteger) -> QueryT IO (String, ValueInfo))
-> [(String, SInteger)] -> QueryT IO [(String, ValueInfo)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String, SInteger) -> QueryT IO (String, ValueInfo)
forall t. (t, SInteger) -> QueryT IO (t, ValueInfo)
convert ([(String, SInteger)] -> QueryT IO [(String, ValueInfo)])
-> (Map String SInteger -> [(String, SInteger)])
-> Map String SInteger
-> QueryT IO [(String, ValueInfo)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map String SInteger -> [(String, SInteger)]
forall k a. Map k a -> [(k, a)]
M.toList
  where convert :: (t, SInteger) -> QueryT IO (t, ValueInfo)
convert (t
name, SInteger
sInt) = ((t
name,) (ValueInfo -> (t, ValueInfo))
-> (Integer -> ValueInfo) -> Integer -> (t, ValueInfo)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Integer] -> ValueInfo
VISet ([Integer] -> ValueInfo)
-> (Integer -> [Integer]) -> Integer -> ValueInfo
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> [Integer] -> [Integer]
forall a. a -> [a] -> [a]
:[])) (Integer -> (t, ValueInfo))
-> QueryT IO Integer -> QueryT IO (t, ValueInfo)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SInteger -> QueryT IO Integer
forall a. SymVal a => SBV a -> Query a
getValue SInteger
sInt

disallowValues :: NameSIntegerMap -> NameValueInfoMap -> Query ()
disallowValues :: Map String SInteger -> NameValueInfoMap -> Query ()
disallowValues Map String SInteger
nameSIntMap NameValueInfoMap
nvMap = SBool -> Query ()
forall (m :: * -> *). SolverContext m => SBool -> m ()
constrain (SBool -> Query ())
-> ([Maybe SBool] -> SBool) -> [Maybe SBool] -> Query ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SBool] -> SBool
sOr ([SBool] -> SBool)
-> ([Maybe SBool] -> [SBool]) -> [Maybe SBool] -> SBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe SBool] -> [SBool]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe SBool] -> Query ()) -> [Maybe SBool] -> Query ()
forall a b. (a -> b) -> a -> b
$ ((String, ValueInfo) -> Maybe SBool)
-> [(String, ValueInfo)] -> [Maybe SBool]
forall a b. (a -> b) -> [a] -> [b]
map (String, ValueInfo) -> Maybe SBool
mkNotEq (NameValueInfoMap -> [(String, ValueInfo)]
forall k a. Map k a -> [(k, a)]
M.toList NameValueInfoMap
nvMap)
  where
    mkNotEq :: (String, ValueInfo) -> Maybe SBool
mkNotEq (String
name, VISet vs :: [Integer]
vs@(Integer
_:[Integer]
_))
      | Just SInteger
sInt <- String -> Map String SInteger -> Maybe SInteger
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
name Map String SInteger
nameSIntMap = SBool -> Maybe SBool
forall a. a -> Maybe a
Just (SBool -> Maybe SBool)
-> ([SBool] -> SBool) -> [SBool] -> Maybe SBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SBool] -> SBool
sAnd ([SBool] -> Maybe SBool) -> [SBool] -> Maybe SBool
forall a b. (a -> b) -> a -> b
$ (Integer -> SBool) -> [Integer] -> [SBool]
forall a b. (a -> b) -> [a] -> [b]
map ((SInteger
sInt SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
./=) (SInteger -> SBool) -> (Integer -> SInteger) -> Integer -> SBool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> SInteger
forall a. SymVal a => a -> SBV a
literal) [Integer]
vs
    mkNotEq (String, ValueInfo)
_                                = Maybe SBool
forall a. Maybe a
Nothing

-- disallowCurrentValues :: NameSIntegerMap -> Query ()
-- disallowCurrentValues nameSIntMap = extractSIntValues nameSIntMap >>= disallowValues nameSIntMap

-- Interpret results.

-- The nameUIMap stores the mapping between each SInteger name and
-- its corresponding (lhsU, rhsU). Therefore we sort and group each
-- entry by its lhsU, and then check the solved integer value of the
-- SInteger name. That solved integer value corresponds to rhsU raised
-- to that power. Take all of the rhsUs, raised to their respective
-- powers, and combine them into a single UnitMul for each lhsU.

interpret :: NameUnitInfoMap -> NameValueInfoMap -> Query Sub
interpret :: NameUnitInfoMap -> NameValueInfoMap -> Query Sub
interpret NameUnitInfoMap
nameUIMap NameValueInfoMap
nvMap = do
  let getLhsU :: (a, (c, b)) -> c
getLhsU = (c, b) -> c
forall a b. (a, b) -> a
fst ((c, b) -> c) -> ((a, (c, b)) -> (c, b)) -> (a, (c, b)) -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, (c, b)) -> (c, b)
forall a b. (a, b) -> b
snd
  let unitGroups :: [[(String, (UnitInfo, UnitInfo))]]
unitGroups = ((String, (UnitInfo, UnitInfo))
 -> (String, (UnitInfo, UnitInfo)) -> Bool)
-> [(String, (UnitInfo, UnitInfo))]
-> [[(String, (UnitInfo, UnitInfo))]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy (UnitInfo -> UnitInfo -> Bool
forall a. Eq a => a -> a -> Bool
(==) (UnitInfo -> UnitInfo -> Bool)
-> ((String, (UnitInfo, UnitInfo)) -> UnitInfo)
-> (String, (UnitInfo, UnitInfo))
-> (String, (UnitInfo, UnitInfo))
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (String, (UnitInfo, UnitInfo)) -> UnitInfo
forall a c b. (a, (c, b)) -> c
getLhsU) ([(String, (UnitInfo, UnitInfo))]
 -> [[(String, (UnitInfo, UnitInfo))]])
-> ([(String, (UnitInfo, UnitInfo))]
    -> [(String, (UnitInfo, UnitInfo))])
-> [(String, (UnitInfo, UnitInfo))]
-> [[(String, (UnitInfo, UnitInfo))]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, (UnitInfo, UnitInfo))
 -> (String, (UnitInfo, UnitInfo)) -> Ordering)
-> [(String, (UnitInfo, UnitInfo))]
-> [(String, (UnitInfo, UnitInfo))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (((String, (UnitInfo, UnitInfo)) -> UnitInfo)
-> (String, (UnitInfo, UnitInfo))
-> (String, (UnitInfo, UnitInfo))
-> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (String, (UnitInfo, UnitInfo)) -> UnitInfo
forall a c b. (a, (c, b)) -> c
getLhsU) ([(String, (UnitInfo, UnitInfo))]
 -> [[(String, (UnitInfo, UnitInfo))]])
-> [(String, (UnitInfo, UnitInfo))]
-> [[(String, (UnitInfo, UnitInfo))]]
forall a b. (a -> b) -> a -> b
$ NameUnitInfoMap -> [(String, (UnitInfo, UnitInfo))]
forall k a. Map k a -> [(k, a)]
M.toList NameUnitInfoMap
nameUIMap

  -- unitGroups =
  --   [ [(name1_1, (lhs1, rhs1)), (name1_2, (lhs1, rhs2)), ...]
  --   , [(name2_1, (lhs2, rhs1)), (name2_2, (lhs2, rhs2)), ...]
  --   , ...]

  let eachName :: (String, (LhsUnit, RhsUnit)) -> Query (Maybe UnitInfo)
      eachName :: (String, (UnitInfo, UnitInfo)) -> Query (Maybe UnitInfo)
eachName (String
lhsName, (UnitInfo
_, UnitInfo
rhsU)) = do
        case String -> NameValueInfoMap -> Maybe ValueInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
lhsName NameValueInfoMap
nvMap of
          Just (VISet [Integer
0]) -> Maybe UnitInfo -> Query (Maybe UnitInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe UnitInfo -> Query (Maybe UnitInfo))
-> (UnitInfo -> Maybe UnitInfo)
-> UnitInfo
-> Query (Maybe UnitInfo)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Maybe UnitInfo
forall a. a -> Maybe a
Just (UnitInfo -> Query (Maybe UnitInfo))
-> UnitInfo -> Query (Maybe UnitInfo)
forall a b. (a -> b) -> a -> b
$ UnitInfo
UnitlessVar
          Just (VISet [Integer
x]) -> Maybe UnitInfo -> Query (Maybe UnitInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe UnitInfo -> Query (Maybe UnitInfo))
-> (UnitInfo -> Maybe UnitInfo)
-> UnitInfo
-> Query (Maybe UnitInfo)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Maybe UnitInfo
forall a. a -> Maybe a
Just (UnitInfo -> Query (Maybe UnitInfo))
-> UnitInfo -> Query (Maybe UnitInfo)
forall a b. (a -> b) -> a -> b
$ UnitInfo -> Double -> UnitInfo
UnitPow UnitInfo
rhsU (Integer -> Double
forall a. Num a => Integer -> a
fromInteger Integer
x)
          Maybe ValueInfo
_                -> Maybe UnitInfo -> Query (Maybe UnitInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe UnitInfo
forall a. Maybe a
Nothing

  -- each group corresponds to a LHS variable
  let eachGroup :: [(String, (LhsUnit, RhsUnit))] -> Query (Maybe (LhsUnit, Dim))
      eachGroup :: [(String, (UnitInfo, UnitInfo))] -> Query (Maybe (UnitInfo, Dim))
eachGroup [(String, (UnitInfo, UnitInfo))]
unitGroup = do
        let (String
_, (UnitInfo
lhsU, UnitInfo
_)):[(String, (UnitInfo, UnitInfo))]
_ = [(String, (UnitInfo, UnitInfo))]
unitGroup -- grouped by lhsU, so pick out one of them
        [UnitInfo]
rawUnits <- [Maybe UnitInfo] -> [UnitInfo]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe UnitInfo] -> [UnitInfo])
-> QueryT IO [Maybe UnitInfo] -> QueryT IO [UnitInfo]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((String, (UnitInfo, UnitInfo)) -> Query (Maybe UnitInfo))
-> [(String, (UnitInfo, UnitInfo))] -> QueryT IO [Maybe UnitInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String, (UnitInfo, UnitInfo)) -> Query (Maybe UnitInfo)
eachName [(String, (UnitInfo, UnitInfo))]
unitGroup
        case [UnitInfo]
rawUnits of
          [] -> Maybe (UnitInfo, Dim) -> Query (Maybe (UnitInfo, Dim))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (UnitInfo, Dim)
forall a. Maybe a
Nothing
          [UnitInfo]
_  -> Maybe (UnitInfo, Dim) -> Query (Maybe (UnitInfo, Dim))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (UnitInfo, Dim) -> Query (Maybe (UnitInfo, Dim)))
-> Maybe (UnitInfo, Dim) -> Query (Maybe (UnitInfo, Dim))
forall a b. (a -> b) -> a -> b
$ (UnitInfo, Dim) -> Maybe (UnitInfo, Dim)
forall a. a -> Maybe a
Just (UnitInfo
lhsU, [UnitInfo] -> Dim
dimFromUnitInfos [UnitInfo]
rawUnits)

  ([(UnitInfo, Dim)] -> Sub
subFromList ([(UnitInfo, Dim)] -> Sub)
-> ([Maybe (UnitInfo, Dim)] -> [(UnitInfo, Dim)])
-> [Maybe (UnitInfo, Dim)]
-> Sub
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (UnitInfo, Dim)] -> [(UnitInfo, Dim)]
forall a. [Maybe a] -> [a]
catMaybes) ([Maybe (UnitInfo, Dim)] -> Sub)
-> QueryT IO [Maybe (UnitInfo, Dim)] -> Query Sub
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(String, (UnitInfo, UnitInfo))] -> Query (Maybe (UnitInfo, Dim)))
-> [[(String, (UnitInfo, UnitInfo))]]
-> QueryT IO [Maybe (UnitInfo, Dim)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [(String, (UnitInfo, UnitInfo))] -> Query (Maybe (UnitInfo, Dim))
eachGroup [[(String, (UnitInfo, UnitInfo))]]
unitGroups