{- |
Module      :  Camfort.Specification.Units.Analysis.Consistent
Description :  Analysis to verify units are consistent.
Copyright   :  (c) 2017, Dominic Orchard, Andrew Rice, Mistral Contrastin, Matthew Danish
License     :  Apache-2.0

Maintainer  :  dom.orchard@gmail.com
Stability   :  experimental
-}

{-# LANGUAGE ExistentialQuantification #-}

module Camfort.Specification.Units.Analysis.Consistent
  ( ConsistencyError
  , ConsistencyReport(Consistent, Inconsistent)
  , checkUnits
  ) where

import           Camfort.Analysis (ExitCodeOfReport(..), Describe(..))
import           Camfort.Specification.Units.Analysis (UnitAnalysis, runInference)
import qualified Camfort.Specification.Units.Annotation as UA
import qualified Camfort.Specification.Units.BackendTypes as B
import           Camfort.Specification.Units.Environment
import           Camfort.Specification.Units.InferenceBackend (inconsistentConstraints)
import           Camfort.Specification.Units.Monad
import           Control.DeepSeq
import           Control.Monad.Reader (asks)
import           Control.Monad.State (get)
import           Data.Generics.Uniplate.Operations
import           Data.List (partition, find, group, sort)
import qualified Data.Map.Strict as M
import           Data.Maybe (maybeToList, maybe)
import qualified Language.Fortran.AST as F
import qualified Language.Fortran.Util.Position as FU

-- | A report that summarises unit consistency.
data ConsistencyReport
    -- | All units were consistent.
  = forall a. Consistent (F.ProgramFile a) Int
    -- | An inconsistency was found in units of the program.
  | Inconsistent ConsistencyError
instance NFData ConsistencyReport where
  rnf :: ConsistencyReport -> ()
rnf ConsistencyReport
_ = ()
instance Show ConsistencyReport where
  show :: ConsistencyReport -> String
show (Consistent ProgramFile a
pf Int
nVars) = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [String
"\n", String
fname, String
": Consistent ", Int -> String
forall a. Show a => a -> String
show Int
nVars, String
" variables checked."]
    where fname :: String
fname = ProgramFile a -> String
forall a. ProgramFile a -> String
F.pfGetFilename ProgramFile a
pf
  show (Inconsistent ConsistencyError
e) = ConsistencyError -> String
forall a. Show a => a -> String
show ConsistencyError
e

instance ExitCodeOfReport ConsistencyReport where
  exitCodeOf :: ConsistencyReport -> Int
exitCodeOf (Consistent {}) = Int
0
  exitCodeOf (Inconsistent ConsistencyError
_) = Int
1

instance Describe ConsistencyReport

data ConsistencyError =
  Inconsistency (F.ProgramFile UA) Constraints

instance Show ConsistencyError where
  show :: ConsistencyError -> String
show (Inconsistency ProgramFile UA
pf Constraints
cons) = [String] -> String
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [ String
"\n", String
fname, String
": Inconsistent:\n", String
reportErrors ]
    where
      fname :: String
fname = ProgramFile UA -> String
forall a. ProgramFile a -> String
F.pfGetFilename ProgramFile UA
pf
      reportErrors :: String
reportErrors = [String] -> String
unlines [ String -> (SrcSpan -> String) -> Maybe SrcSpan -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" SrcSpan -> String
showSS Maybe SrcSpan
ss String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
str | (Maybe SrcSpan
ss, String
str) <- [(Maybe SrcSpan, String)]
reports ]
        where
          reports :: [(Maybe SrcSpan, String)]
reports = ([(Maybe SrcSpan, String)] -> (Maybe SrcSpan, String))
-> [[(Maybe SrcSpan, String)]] -> [(Maybe SrcSpan, String)]
forall a b. (a -> b) -> [a] -> [b]
map [(Maybe SrcSpan, String)] -> (Maybe SrcSpan, String)
forall a. [a] -> a
head ([[(Maybe SrcSpan, String)]] -> [(Maybe SrcSpan, String)])
-> (Constraints -> [[(Maybe SrcSpan, String)]])
-> Constraints
-> [(Maybe SrcSpan, String)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Maybe SrcSpan, String)] -> [[(Maybe SrcSpan, String)]]
forall a. Eq a => [a] -> [[a]]
group ([(Maybe SrcSpan, String)] -> [[(Maybe SrcSpan, String)]])
-> (Constraints -> [(Maybe SrcSpan, String)])
-> Constraints
-> [[(Maybe SrcSpan, String)]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Maybe SrcSpan, String)] -> [(Maybe SrcSpan, String)]
forall a. Ord a => [a] -> [a]
sort ([(Maybe SrcSpan, String)] -> [(Maybe SrcSpan, String)])
-> (Constraints -> [(Maybe SrcSpan, String)])
-> Constraints
-> [(Maybe SrcSpan, String)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Constraint -> (Maybe SrcSpan, String))
-> Constraints -> [(Maybe SrcSpan, String)]
forall a b. (a -> b) -> [a] -> [b]
map Constraint -> (Maybe SrcSpan, String)
reportError (Constraints -> [(Maybe SrcSpan, String)])
-> (Constraints -> Constraints)
-> Constraints
-> [(Maybe SrcSpan, String)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Constraint -> Bool) -> Constraints -> Constraints
forall a. (a -> Bool) -> [a] -> [a]
filter Constraint -> Bool
relevantConstraints (Constraints -> [(Maybe SrcSpan, String)])
-> Constraints -> [(Maybe SrcSpan, String)]
forall a b. (a -> b) -> a -> b
$ Constraints
cons
          showSS :: SrcSpan -> String
showSS  = (String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
": ") ShowS -> (SrcSpan -> String) -> SrcSpan -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
" - at "String -> ShowS
forall a. [a] -> [a] -> [a]
++) ShowS -> (SrcSpan -> String) -> SrcSpan -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SrcSpan -> String
showSpanStart

          relevantConstraints :: Constraint -> Bool
relevantConstraints Constraint
c = Bool -> Bool
not (Constraint -> Bool
isPolymorphic0 Constraint
c) Bool -> Bool -> Bool
&& Bool -> Bool
not (Constraint -> Bool
isReflexive Constraint
c)

          isPolymorphic0 :: Constraint -> Bool
isPolymorphic0 (ConEq UnitParamLitAbs{} UnitInfo
_) = Bool
True
          isPolymorphic0 (ConEq UnitInfo
_ UnitParamLitAbs{}) = Bool
True
          isPolymorphic0 Constraint
_                         = Bool
False

          isReflexive :: Constraint -> Bool
isReflexive (ConEq UnitInfo
u1 UnitInfo
u2) = UnitInfo
u1 UnitInfo -> UnitInfo -> Bool
forall a. Eq a => a -> a -> Bool
== UnitInfo
u2
          isReflexive Constraint
_ = String -> Bool
forall a. HasCallStack => String -> a
error String
"isReflexive without ConEq"

      reportError :: Constraint -> (Maybe SrcSpan, String)
reportError Constraint
con = (Maybe SrcSpan
errSpan, Constraint -> String
pprintConstr (Constraint -> String)
-> (Constraint -> Constraint) -> Constraint -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Constraint
orient (Constraint -> Constraint)
-> (Constraint -> Constraint) -> Constraint -> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Constraint
forall a. Data a => a -> a
unrename (Constraint -> Constraint)
-> (Constraint -> Constraint) -> Constraint -> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Constraint
shift (Constraint -> Constraint)
-> (Constraint -> Constraint) -> Constraint -> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Constraint
simplify (Constraint -> String) -> Constraint -> String
forall a b. (a -> b) -> a -> b
$ Constraint
con)
        where
          errSpan :: Maybe SrcSpan
errSpan = Constraint -> Maybe SrcSpan
findCon Constraint
con
          orient :: Constraint -> Constraint
orient (ConEq UnitInfo
u UnitInfo
v) | Double
0 Double -> [Double] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [UnitInfo -> Double
unitPower UnitInfo
u, UnitInfo -> Double
unitPower UnitInfo
v] = (UnitInfo -> Bool) -> Constraint -> Constraint
balanceConEq ((Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0) (Double -> Bool) -> (UnitInfo -> Double) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Double
unitPower) (UnitInfo -> UnitInfo -> Constraint
ConEq UnitInfo
u UnitInfo
v)
          orient (ConEq UnitInfo
u (UnitVar VV
v)) = UnitInfo -> UnitInfo -> Constraint
ConEq (VV -> UnitInfo
UnitVar VV
v) UnitInfo
u
          orient (ConEq UnitInfo
u (UnitParamVarUse (VV, VV, Int)
v)) = UnitInfo -> UnitInfo -> Constraint
ConEq ((VV, VV, Int) -> UnitInfo
UnitParamVarUse (VV, VV, Int)
v) UnitInfo
u
          orient (ConEq UnitInfo
u UnitInfo
v)
            | (UnitInfo -> Bool) -> [UnitInfo] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0) (Double -> Bool) -> (UnitInfo -> Double) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Double
unitPower) ([UnitInfo] -> Bool) -> [UnitInfo] -> Bool
forall a b. (a -> b) -> a -> b
$ [UnitInfo]
lhs [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo]
rhs = Constraint -> Constraint
orient (Constraint -> Constraint) -> Constraint -> Constraint
forall a b. (a -> b) -> a -> b
$ UnitInfo -> UnitInfo -> Constraint
ConEq ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
lhs)) ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
rhs))
            where
              lhs :: [UnitInfo]
lhs = UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
u
              rhs :: [UnitInfo]
rhs = UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
v
          orient Constraint
c = Constraint
c

          -- partitionUnits f u = (foldUnits a, foldUnits b)
          --   where (a, b) = partition f (flattenUnits u)
          unitPower :: UnitInfo -> Double
unitPower (UnitPow UnitInfo
u Double
k) = UnitInfo -> Double
unitPower UnitInfo
u Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
k
          unitPower UnitInfo
UnitlessLit = Double
0
          unitPower UnitInfo
UnitlessVar = Double
0
          unitPower UnitInfo
_ = Double
1

          -- When reporting inconsistent constraints, shift all the
          -- UnitNames (e.g. m, kg) and Polymorphic Units (e.g. 'a,
          -- 'b) to the right-hand-side, and other things to the left.
          shift :: Constraint -> Constraint
shift = (UnitInfo -> Bool) -> Constraint -> Constraint
shiftConEq UnitInfo -> Bool
isUnitRHS

          -- Units that should appear on the right-hand-side of the matrix during solving
          isUnitRHS :: UnitInfo -> Bool
isUnitRHS (UnitPow (UnitName String
_) Double
_)        = Bool
True
          isUnitRHS (UnitPow (UnitParamEAPAbs VV
_) Double
_) = Bool
True
          isUnitRHS UnitInfo
_                               = Bool
False

          simplify :: Constraint -> Constraint
simplify = Dim -> Constraint
B.dimToConstraint (Dim -> Constraint)
-> (Constraint -> Dim) -> Constraint -> Constraint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Constraint -> Dim
B.constraintToDim

      findCon :: Constraint -> Maybe FU.SrcSpan
      findCon :: Constraint -> Maybe SrcSpan
findCon Constraint
con = (Constraint -> Bool) -> [(Constraint, SrcSpan)] -> Maybe SrcSpan
forall a b. (a -> Bool) -> [(a, b)] -> Maybe b
lookupWith (Constraint -> Constraint -> Bool
forall from. Data from => Constraint -> from -> Bool
eq Constraint
con) [(Constraint, SrcSpan)]
constraints
        where -- constraintToDim normalises as it builds the Dim, so we can use dimParamEq directly.
              eq :: Constraint -> from -> Bool
eq Constraint
c1 from
c2 = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or [ Constraint -> Dim
B.constraintToDim Constraint
c1 Dim -> Dim -> Bool
`B.dimParamEq` Constraint -> Dim
B.constraintToDim Constraint
c2' | Constraint
c2' <- from -> Constraints
forall from to. Biplate from to => from -> [to]
universeBi from
c2 ]
      constraints :: [(Constraint, SrcSpan)]
constraints = [ (Constraint
c, SrcSpan
srcSpan)
                    | Expression UA
x <- ProgramFile UA -> [Expression UA]
forall from to. Biplate from to => from -> [to]
universeBi ProgramFile UA
pf :: [F.Expression UA]
                    , let srcSpan :: SrcSpan
srcSpan = Expression UA -> SrcSpan
forall a. Spanned a => a -> SrcSpan
FU.getSpan Expression UA
x
                    , Constraint
c <- Maybe Constraint -> Constraints
forall a. Maybe a -> [a]
maybeToList (Expression UA -> Maybe Constraint
forall (f :: * -> *). Annotated f => f UA -> Maybe Constraint
UA.getConstraint Expression UA
x)
                    ] [(Constraint, SrcSpan)]
-> [(Constraint, SrcSpan)] -> [(Constraint, SrcSpan)]
forall a. [a] -> [a] -> [a]
++

                    [ (Constraint
c, SrcSpan
srcSpan)
                    | Statement UA
x <- ProgramFile UA -> [Statement UA]
forall from to. Biplate from to => from -> [to]
universeBi ProgramFile UA
pf :: [F.Statement UA]
                    , let srcSpan :: SrcSpan
srcSpan = Statement UA -> SrcSpan
forall a. Spanned a => a -> SrcSpan
FU.getSpan Statement UA
x
                    , Constraint
c <- Maybe Constraint -> Constraints
forall a. Maybe a -> [a]
maybeToList (Statement UA -> Maybe Constraint
forall (f :: * -> *). Annotated f => f UA -> Maybe Constraint
UA.getConstraint Statement UA
x)
                    ] [(Constraint, SrcSpan)]
-> [(Constraint, SrcSpan)] -> [(Constraint, SrcSpan)]
forall a. [a] -> [a] -> [a]
++

                    [ (Constraint
c, SrcSpan
srcSpan)
                    | Argument UA
x <- ProgramFile UA -> [Argument UA]
forall from to. Biplate from to => from -> [to]
universeBi ProgramFile UA
pf :: [F.Argument UA]
                    , let srcSpan :: SrcSpan
srcSpan = Argument UA -> SrcSpan
forall a. Spanned a => a -> SrcSpan
FU.getSpan Argument UA
x
                    , Constraint
c <- Maybe Constraint -> Constraints
forall a. Maybe a -> [a]
maybeToList (Argument UA -> Maybe Constraint
forall (f :: * -> *). Annotated f => f UA -> Maybe Constraint
UA.getConstraint Argument UA
x)
                    ] [(Constraint, SrcSpan)]
-> [(Constraint, SrcSpan)] -> [(Constraint, SrcSpan)]
forall a. [a] -> [a] -> [a]
++

                    [ (Constraint
c, SrcSpan
srcSpan)
                    | Declarator UA
x <- ProgramFile UA -> [Declarator UA]
forall from to. Biplate from to => from -> [to]
universeBi ProgramFile UA
pf :: [F.Declarator UA]
                    , let srcSpan :: SrcSpan
srcSpan = Declarator UA -> SrcSpan
forall a. Spanned a => a -> SrcSpan
FU.getSpan Declarator UA
x
                    , Constraint
c <- Maybe Constraint -> Constraints
forall a. Maybe a -> [a]
maybeToList (Declarator UA -> Maybe Constraint
forall (f :: * -> *). Annotated f => f UA -> Maybe Constraint
UA.getConstraint Declarator UA
x)
                    ] [(Constraint, SrcSpan)]
-> [(Constraint, SrcSpan)] -> [(Constraint, SrcSpan)]
forall a. [a] -> [a] -> [a]
++

                    -- Why reverse? So that PUFunction and PUSubroutine appear
                    -- first in the list, before PUModule.
                    [(Constraint, SrcSpan)] -> [(Constraint, SrcSpan)]
forall a. [a] -> [a]
reverse [ (Constraint
c, SrcSpan
srcSpan)
                    | ProgramUnit UA
x <- ProgramFile UA -> [ProgramUnit UA]
forall from to. Biplate from to => from -> [to]
universeBi ProgramFile UA
pf :: [F.ProgramUnit UA]
                    , let srcSpan :: SrcSpan
srcSpan = ProgramUnit UA -> SrcSpan
forall a. Spanned a => a -> SrcSpan
FU.getSpan ProgramUnit UA
x
                    , Constraint
c <- Maybe Constraint -> Constraints
forall a. Maybe a -> [a]
maybeToList (ProgramUnit UA -> Maybe Constraint
forall (f :: * -> *). Annotated f => f UA -> Maybe Constraint
UA.getConstraint ProgramUnit UA
x)
                    ]

instance Describe ConsistencyError

{-| Check units-of-measure for a program -}
checkUnits :: UnitAnalysis ConsistencyReport
checkUnits :: UnitAnalysis ConsistencyReport
checkUnits = do
  ProgramFile Annotation
pf <- (UnitEnv -> ProgramFile Annotation)
-> ReaderT UnitEnv (AnalysisT () () IO) (ProgramFile Annotation)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks UnitEnv -> ProgramFile Annotation
unitProgramFile
  (Maybe Constraints
eCons, UnitState
state) <- UnitSolver (Maybe Constraints)
-> UnitAnalysis (Maybe Constraints, UnitState)
forall a. UnitSolver a -> UnitAnalysis (a, UnitState)
runInference UnitSolver (Maybe Constraints)
runInconsistentConstraints
    -- number of 'real' variables checked, e.g. not parametric
  let
    nVars :: Int
nVars = Map VV UnitInfo -> Int
forall k a. Map k a -> Int
M.size (Map VV UnitInfo -> Int)
-> (Map VV UnitInfo -> Map VV UnitInfo) -> Map VV UnitInfo -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UnitInfo -> Bool) -> Map VV UnitInfo -> Map VV UnitInfo
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Bool -> Bool
not (Bool -> Bool) -> (UnitInfo -> Bool) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Bool
isParametricUnit) (Map VV UnitInfo -> Int) -> Map VV UnitInfo -> Int
forall a b. (a -> b) -> a -> b
$ UnitState -> Map VV UnitInfo
usVarUnitMap UnitState
state
    pfUA :: F.ProgramFile UA
    pfUA :: ProgramFile UA
pfUA = UnitState -> ProgramFile UA
usProgramFile UnitState
state -- the program file after units analysis is done

  ConsistencyReport -> UnitAnalysis ConsistencyReport
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ConsistencyReport -> UnitAnalysis ConsistencyReport)
-> ConsistencyReport -> UnitAnalysis ConsistencyReport
forall a b. NFData a => (a -> b) -> a -> b
$!! case Maybe Constraints
eCons of
             Maybe Constraints
Nothing     -> ProgramFile Annotation -> Int -> ConsistencyReport
forall a. ProgramFile a -> Int -> ConsistencyReport
Consistent ProgramFile Annotation
pf Int
nVars
             (Just Constraints
cons) -> ConsistencyError -> ConsistencyReport
Inconsistent (ConsistencyError -> ConsistencyReport)
-> ConsistencyError -> ConsistencyReport
forall a b. (a -> b) -> a -> b
$ ProgramFile UA -> Constraints -> ConsistencyError
Inconsistency ProgramFile UA
pfUA Constraints
cons
  where
    isParametricUnit :: UnitInfo -> Bool
isParametricUnit UnitInfo
u = case UnitInfo
u of UnitParamPosAbs {} -> Bool
True; UnitParamPosUse {} -> Bool
True
                                   UnitParamVarAbs {} -> Bool
True; UnitParamVarUse {} -> Bool
True
                                   UnitInfo
_ -> Bool
False

lookupWith :: (a -> Bool) -> [(a,b)] -> Maybe b
lookupWith :: (a -> Bool) -> [(a, b)] -> Maybe b
lookupWith a -> Bool
f = ((a, b) -> b) -> Maybe (a, b) -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, b) -> b
forall a b. (a, b) -> b
snd (Maybe (a, b) -> Maybe b)
-> ([(a, b)] -> Maybe (a, b)) -> [(a, b)] -> Maybe b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, b) -> Bool) -> [(a, b)] -> Maybe (a, b)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (a -> Bool
f (a -> Bool) -> ((a, b) -> a) -> (a, b) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, b) -> a
forall a b. (a, b) -> a
fst)

-- | Return a possible list of unsolvable constraints.
runInconsistentConstraints :: UnitSolver (Maybe Constraints)
runInconsistentConstraints :: UnitSolver (Maybe Constraints)
runInconsistentConstraints = do
  Constraints
cons <- UnitState -> Constraints
usConstraints (UnitState -> Constraints)
-> StateT UnitState UnitAnalysis UnitState
-> StateT UnitState UnitAnalysis Constraints
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` StateT UnitState UnitAnalysis UnitState
forall s (m :: * -> *). MonadState s m => m s
get
  Maybe Constraints -> UnitSolver (Maybe Constraints)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Constraints -> UnitSolver (Maybe Constraints))
-> Maybe Constraints -> UnitSolver (Maybe Constraints)
forall a b. (a -> b) -> a -> b
$ Constraints -> Maybe Constraints
inconsistentConstraints Constraints
cons

-- clear out the unique names in the UnitInfos.
unrename :: Data a => a -> a
unrename :: a -> a
unrename = (UnitInfo -> UnitInfo) -> a -> a
forall from to. Biplate from to => (to -> to) -> from -> from
transformBi ((UnitInfo -> UnitInfo) -> a -> a)
-> (UnitInfo -> UnitInfo) -> a -> a
forall a b. (a -> b) -> a -> b
$ \ UnitInfo
x -> case UnitInfo
x of
  UnitVar (String
_, String
s)                      -> VV -> UnitInfo
UnitVar (String
s, String
s)
  UnitParamVarAbs ((String
_, String
f), (String
_, String
s))    -> (VV, VV) -> UnitInfo
UnitParamVarAbs ((String
f, String
f), (String
s, String
s))
  UnitParamVarUse ((String
_, String
f), (String
_, String
s), Int
i) -> (VV, VV, Int) -> UnitInfo
UnitParamVarUse ((String
f, String
f), (String
s, String
s), Int
i)
  UnitParamEAPAbs (String
_, String
s)              -> VV -> UnitInfo
UnitParamEAPAbs (String
s, String
s)
  UnitParamEAPUse ((String
_, String
s), Int
i)         -> (VV, Int) -> UnitInfo
UnitParamEAPUse ((String
s, String
s), Int
i)
  UnitInfo
u                                   -> UnitInfo
u

-- | Show only the start position of the 'SrcSpan'.
showSpanStart :: FU.SrcSpan -> String
showSpanStart :: SrcSpan -> String
showSpanStart (FU.SrcSpan Position
l Position
_) = Position -> String
forall a. Show a => a -> String
show Position
l

-- | Shift terms to the right if predicate f is satisfied and to the left otherwise.
shiftConEq :: (UnitInfo -> Bool) -> Constraint -> Constraint
shiftConEq :: (UnitInfo -> Bool) -> Constraint -> Constraint
shiftConEq UnitInfo -> Bool
f (ConEq UnitInfo
l UnitInfo
r) = UnitInfo -> UnitInfo -> Constraint
ConEq ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo]
lhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
rhsShift)) ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo]
rhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
lhsShift))
  where
    ([UnitInfo]
lhsOk, [UnitInfo]
lhsShift) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (Bool -> Bool
not (Bool -> Bool) -> (UnitInfo -> Bool) -> UnitInfo -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnitInfo -> Bool
f) (UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
l)
    ([UnitInfo]
rhsOk, [UnitInfo]
rhsShift) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition UnitInfo -> Bool
f (UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
r)
shiftConEq UnitInfo -> Bool
f (ConConj Constraints
cs) = Constraints -> Constraint
ConConj (Constraints -> Constraint) -> Constraints -> Constraint
forall a b. (a -> b) -> a -> b
$ (Constraint -> Constraint) -> Constraints -> Constraints
forall a b. (a -> b) -> [a] -> [b]
map ((UnitInfo -> Bool) -> Constraint -> Constraint
shiftConEq UnitInfo -> Bool
f) Constraints
cs

-- | Balance equations by shifting terms that satisfy predicate f
balanceConEq :: (UnitInfo -> Bool) -> Constraint -> Constraint
balanceConEq :: (UnitInfo -> Bool) -> Constraint -> Constraint
balanceConEq UnitInfo -> Bool
f (ConEq UnitInfo
l UnitInfo
r) = UnitInfo -> UnitInfo -> Constraint
ConEq ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo]
lhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
rhsShift)) ([UnitInfo] -> UnitInfo
forall (t :: * -> *). Foldable t => t UnitInfo -> UnitInfo
foldUnits ([UnitInfo]
rhsOk [UnitInfo] -> [UnitInfo] -> [UnitInfo]
forall a. [a] -> [a] -> [a]
++ [UnitInfo] -> [UnitInfo]
negateCons [UnitInfo]
lhsShift))
  where
    ([UnitInfo]
lhsShift, [UnitInfo]
lhsOk) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition UnitInfo -> Bool
f (UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
l)
    ([UnitInfo]
rhsShift, [UnitInfo]
rhsOk) = (UnitInfo -> Bool) -> [UnitInfo] -> ([UnitInfo], [UnitInfo])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition UnitInfo -> Bool
f (UnitInfo -> [UnitInfo]
flattenUnits UnitInfo
r)
balanceConEq UnitInfo -> Bool
f (ConConj Constraints
cs) = Constraints -> Constraint
ConConj (Constraints -> Constraint) -> Constraints -> Constraint
forall a b. (a -> b) -> a -> b
$ (Constraint -> Constraint) -> Constraints -> Constraints
forall a b. (a -> b) -> [a] -> [b]
map ((UnitInfo -> Bool) -> Constraint -> Constraint
balanceConEq UnitInfo -> Bool
f) Constraints
cs

negateCons :: [UnitInfo] -> [UnitInfo]
negateCons :: [UnitInfo] -> [UnitInfo]
negateCons = (UnitInfo -> UnitInfo) -> [UnitInfo] -> [UnitInfo]
forall a b. (a -> b) -> [a] -> [b]
map (\ UnitInfo
x -> case UnitInfo
x of
                     UnitPow UnitInfo
u Double
k -> UnitInfo -> Double -> UnitInfo
UnitPow UnitInfo
u (-Double
k)
                     UnitInfo
u -> UnitInfo -> Double -> UnitInfo
UnitPow UnitInfo
u (-Double
1))