{-# LANGUAGE OverloadedStrings #-}

-- | Facilities for comparing values for equality.  While 'Eq'
-- instances are defined, these are not useful when NaNs are involved,
-- and do not *explain* the differences.
module Futhark.Data.Compare
  ( compareValues,
    compareSeveralValues,
    Tolerance (..),
    Mismatch,
  )
where

import Data.List (intersperse)
import qualified Data.Text as T
import qualified Data.Vector.Storable as SVec
import Futhark.Data

-- | Two values differ in some way.  The 'Show' instance produces a
-- human-readable explanation.
data Mismatch
  = -- | The position the value number and a flat index
    -- into the array.
    PrimValueMismatch Int [Int] T.Text T.Text
  | ArrayShapeMismatch Int [Int] [Int]
  | TypeMismatch Int T.Text T.Text
  | ValueCountMismatch Int Int

showText :: Show a => a -> T.Text
showText :: a -> Text
showText = String -> Text
T.pack (String -> Text) -> (a -> String) -> a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> String
forall a. Show a => a -> String
show

-- | A human-readable description of how two values are not the same.
explainMismatch :: T.Text -> T.Text -> T.Text -> T.Text -> T.Text
explainMismatch :: Text -> Text -> Text -> Text -> Text
explainMismatch Text
i Text
what Text
got Text
expected =
  Text
"Value #" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
i Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
": expected " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
what Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
expected Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", got " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
got

instance Show Mismatch where
  show :: Mismatch -> String
show (PrimValueMismatch Int
vi [] Text
got Text
expected) =
    Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (Int -> Text
forall a. Show a => a -> Text
showText Int
vi) Text
"" Text
got Text
expected
  show (PrimValueMismatch Int
vi [Int]
js Text
got Text
expected) =
    Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (Int -> Text
forall a. Show a => a -> Text
showText Int
vi Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" index [" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Text -> [Text] -> [Text]
forall a. a -> [a] -> [a]
intersperse Text
"," ((Int -> Text) -> [Int] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Text
forall a. Show a => a -> Text
showText [Int]
js)) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"]") Text
"" Text
got Text
expected
  show (ArrayShapeMismatch Int
i [Int]
got [Int]
expected) =
    Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (Int -> Text
forall a. Show a => a -> Text
showText Int
i) Text
"array of shape " ([Int] -> Text
forall a. Show a => a -> Text
showText [Int]
got) ([Int] -> Text
forall a. Show a => a -> Text
showText [Int]
expected)
  show (TypeMismatch Int
i Text
got Text
expected) =
    Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text -> Text -> Text -> Text -> Text
explainMismatch (Int -> Text
forall a. Show a => a -> Text
showText Int
i) Text
"value of type " Text
got Text
expected
  show (ValueCountMismatch Int
got Int
expected) =
    Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text
"Expected " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showText Int
expected Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" values, got " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Show a => a -> Text
showText Int
got

-- | The maximum relative tolerance used for comparing floating-point
-- results.  0.002 (0.2%) is a fine default if you have no particular
-- opinion.
newtype Tolerance = Tolerance Double
  deriving (Tolerance -> Tolerance -> Bool
(Tolerance -> Tolerance -> Bool)
-> (Tolerance -> Tolerance -> Bool) -> Eq Tolerance
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Tolerance -> Tolerance -> Bool
$c/= :: Tolerance -> Tolerance -> Bool
== :: Tolerance -> Tolerance -> Bool
$c== :: Tolerance -> Tolerance -> Bool
Eq, Eq Tolerance
Eq Tolerance
-> (Tolerance -> Tolerance -> Ordering)
-> (Tolerance -> Tolerance -> Bool)
-> (Tolerance -> Tolerance -> Bool)
-> (Tolerance -> Tolerance -> Bool)
-> (Tolerance -> Tolerance -> Bool)
-> (Tolerance -> Tolerance -> Tolerance)
-> (Tolerance -> Tolerance -> Tolerance)
-> Ord Tolerance
Tolerance -> Tolerance -> Bool
Tolerance -> Tolerance -> Ordering
Tolerance -> Tolerance -> Tolerance
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Tolerance -> Tolerance -> Tolerance
$cmin :: Tolerance -> Tolerance -> Tolerance
max :: Tolerance -> Tolerance -> Tolerance
$cmax :: Tolerance -> Tolerance -> Tolerance
>= :: Tolerance -> Tolerance -> Bool
$c>= :: Tolerance -> Tolerance -> Bool
> :: Tolerance -> Tolerance -> Bool
$c> :: Tolerance -> Tolerance -> Bool
<= :: Tolerance -> Tolerance -> Bool
$c<= :: Tolerance -> Tolerance -> Bool
< :: Tolerance -> Tolerance -> Bool
$c< :: Tolerance -> Tolerance -> Bool
compare :: Tolerance -> Tolerance -> Ordering
$ccompare :: Tolerance -> Tolerance -> Ordering
$cp1Ord :: Eq Tolerance
Ord, Int -> Tolerance -> ShowS
[Tolerance] -> ShowS
Tolerance -> String
(Int -> Tolerance -> ShowS)
-> (Tolerance -> String)
-> ([Tolerance] -> ShowS)
-> Show Tolerance
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Tolerance] -> ShowS
$cshowList :: [Tolerance] -> ShowS
show :: Tolerance -> String
$cshow :: Tolerance -> String
showsPrec :: Int -> Tolerance -> ShowS
$cshowsPrec :: Int -> Tolerance -> ShowS
Show)

toleranceFloat :: RealFloat a => Tolerance -> a
toleranceFloat :: Tolerance -> a
toleranceFloat (Tolerance Double
x) = Rational -> a
forall a. Fractional a => Rational -> a
fromRational (Rational -> a) -> Rational -> a
forall a b. (a -> b) -> a -> b
$ Double -> Rational
forall a. Real a => a -> Rational
toRational Double
x

-- | Compare two Futhark values for equality.
compareValues :: Tolerance -> Value -> Value -> [Mismatch]
compareValues :: Tolerance -> Value -> Value -> [Mismatch]
compareValues Tolerance
tol = Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol Int
0

-- | As 'compareValues', but compares several values.  The two lists
-- must have the same length.
compareSeveralValues :: Tolerance -> [Value] -> [Value] -> [Mismatch]
compareSeveralValues :: Tolerance -> [Value] -> [Value] -> [Mismatch]
compareSeveralValues Tolerance
tol [Value]
got [Value]
expected
  | Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
m = [Int -> Int -> Mismatch
ValueCountMismatch Int
n Int
m]
  | Bool
otherwise = [[Mismatch]] -> [Mismatch]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Mismatch]] -> [Mismatch]) -> [[Mismatch]] -> [Mismatch]
forall a b. (a -> b) -> a -> b
$ (Int -> Value -> Value -> [Mismatch])
-> [Int] -> [Value] -> [Value] -> [[Mismatch]]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 (Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol) [Int
0 ..] [Value]
got [Value]
expected
  where
    n :: Int
n = [Value] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
got
    m :: Int
m = [Value] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Value]
expected

unflattenIndex :: [Int] -> Int -> [Int]
unflattenIndex :: [Int] -> Int -> [Int]
unflattenIndex = [Int] -> Int -> [Int]
forall a. Integral a => [a] -> a -> [a]
unflattenIndexFromSlices ([Int] -> Int -> [Int])
-> ([Int] -> [Int]) -> [Int] -> Int -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
1 ([Int] -> [Int]) -> ([Int] -> [Int]) -> [Int] -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [Int]
forall a. Num a => [a] -> [a]
sliceSizes
  where
    sliceSizes :: [a] -> [a]
sliceSizes [] = [a
1]
    sliceSizes (a
n : [a]
ns) = [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (a
n a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
ns) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> [a]
sliceSizes [a]
ns
    unflattenIndexFromSlices :: [a] -> a -> [a]
unflattenIndexFromSlices [] a
_ = []
    unflattenIndexFromSlices (a
size : [a]
slices) a
i =
      (a
i a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
size) a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a] -> a -> [a]
unflattenIndexFromSlices [a]
slices (a
i a -> a -> a
forall a. Num a => a -> a -> a
- (a
i a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
size) a -> a -> a
forall a. Num a => a -> a -> a
* a
size)

compareValue :: Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue :: Tolerance -> Int -> Value -> Value -> [Mismatch]
compareValue Tolerance
tol Int
i Value
got_v Value
expected_v
  | Value -> [Int]
valueShape Value
got_v [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== Value -> [Int]
valueShape Value
expected_v =
    case (Value
got_v, Value
expected_v) of
      (I8Value Vector Int
_ Vector Int8
got_vs, I8Value Vector Int
_ Vector Int8
expected_vs) ->
        Vector Int8 -> Vector Int8 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int8
got_vs Vector Int8
expected_vs
      (I16Value Vector Int
_ Vector Int16
got_vs, I16Value Vector Int
_ Vector Int16
expected_vs) ->
        Vector Int16 -> Vector Int16 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int16
got_vs Vector Int16
expected_vs
      (I32Value Vector Int
_ Vector Int32
got_vs, I32Value Vector Int
_ Vector Int32
expected_vs) ->
        Vector Int32 -> Vector Int32 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int32
got_vs Vector Int32
expected_vs
      (I64Value Vector Int
_ Vector Int64
got_vs, I64Value Vector Int
_ Vector Int64
expected_vs) ->
        Vector Int64 -> Vector Int64 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Int64
got_vs Vector Int64
expected_vs
      (U8Value Vector Int
_ Vector Word8
got_vs, U8Value Vector Int
_ Vector Word8
expected_vs) ->
        Vector Word8 -> Vector Word8 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word8
got_vs Vector Word8
expected_vs
      (U16Value Vector Int
_ Vector Word16
got_vs, U16Value Vector Int
_ Vector Word16
expected_vs) ->
        Vector Word16 -> Vector Word16 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word16
got_vs Vector Word16
expected_vs
      (U32Value Vector Int
_ Vector Word32
got_vs, U32Value Vector Int
_ Vector Word32
expected_vs) ->
        Vector Word32 -> Vector Word32 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word32
got_vs Vector Word32
expected_vs
      (U64Value Vector Int
_ Vector Word64
got_vs, U64Value Vector Int
_ Vector Word64
expected_vs) ->
        Vector Word64 -> Vector Word64 -> [Mismatch]
forall a.
(Storable a, Eq a, Show a) =>
Vector a -> Vector a -> [Mismatch]
compareNum Vector Word64
got_vs Vector Word64
expected_vs
      (F16Value Vector Int
_ Vector Half
got_vs, F16Value Vector Int
_ Vector Half
expected_vs) ->
        Half -> Vector Half -> Vector Half -> [Mismatch]
forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (Half -> Vector Half -> Half
forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (Tolerance -> Half
forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Half
expected_vs) Vector Half
got_vs Vector Half
expected_vs
      (F32Value Vector Int
_ Vector Float
got_vs, F32Value Vector Int
_ Vector Float
expected_vs) ->
        Float -> Vector Float -> Vector Float -> [Mismatch]
forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (Float -> Vector Float -> Float
forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (Tolerance -> Float
forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Float
expected_vs) Vector Float
got_vs Vector Float
expected_vs
      (F64Value Vector Int
_ Vector Double
got_vs, F64Value Vector Int
_ Vector Double
expected_vs) ->
        Double -> Vector Double -> Vector Double -> [Mismatch]
forall a.
(Storable a, RealFloat a, Show a) =>
a -> Vector a -> Vector a -> [Mismatch]
compareFloat (Double -> Vector Double -> Double
forall a. (RealFloat a, Storable a) => a -> Vector a -> a
tolerance (Tolerance -> Double
forall a. RealFloat a => Tolerance -> a
toleranceFloat Tolerance
tol) Vector Double
expected_vs) Vector Double
got_vs Vector Double
expected_vs
      (BoolValue Vector Int
_ Vector Bool
got_vs, BoolValue Vector Int
_ Vector Bool
expected_vs) ->
        (Int -> Bool -> Bool -> Maybe Mismatch)
-> Vector Bool -> Vector Bool -> [Mismatch]
forall t t a.
(Storable t, Storable t) =>
(Int -> t -> t -> Maybe a) -> Vector t -> Vector t -> [a]
compareGen Int -> Bool -> Bool -> Maybe Mismatch
forall a. (Eq a, Show a) => Int -> a -> a -> Maybe Mismatch
compareBool Vector Bool
got_vs Vector Bool
expected_vs
      (Value, Value)
_ ->
        [Int -> Text -> Text -> Mismatch
TypeMismatch Int
i (PrimType -> Text
primTypeText (PrimType -> Text) -> PrimType -> Text
forall a b. (a -> b) -> a -> b
$ Value -> PrimType
valueElemType Value
got_v) (PrimType -> Text
primTypeText (PrimType -> Text) -> PrimType -> Text
forall a b. (a -> b) -> a -> b
$ Value -> PrimType
valueElemType Value
expected_v)]
  | Bool
otherwise =
    [Int -> [Int] -> [Int] -> Mismatch
ArrayShapeMismatch Int
i (Value -> [Int]
valueShape Value
got_v) (Value -> [Int]
valueShape Value
expected_v)]
  where
    unflatten :: Int -> [Int]
unflatten = [Int] -> Int -> [Int]
unflattenIndex (Value -> [Int]
valueShape Value
got_v)
    value :: Show a => a -> T.Text
    value :: a -> Text
value = String -> Text
T.pack (String -> Text) -> (a -> String) -> a -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> String
forall a. Show a => a -> String
show
    {-# INLINE compareGen #-}
    {-# INLINE compareNum #-}
    {-# INLINE compareFloat #-}
    {-# INLINE compareFloatElement #-}
    {-# INLINE compareElement #-}
    compareNum :: (SVec.Storable a, Eq a, Show a) => SVec.Vector a -> SVec.Vector a -> [Mismatch]
    compareNum :: Vector a -> Vector a -> [Mismatch]
compareNum = (Int -> a -> a -> Maybe Mismatch)
-> Vector a -> Vector a -> [Mismatch]
forall t t a.
(Storable t, Storable t) =>
(Int -> t -> t -> Maybe a) -> Vector t -> Vector t -> [a]
compareGen Int -> a -> a -> Maybe Mismatch
forall a. (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
compareElement
    compareFloat :: (SVec.Storable a, RealFloat a, Show a) => a -> SVec.Vector a -> SVec.Vector a -> [Mismatch]
    compareFloat :: a -> Vector a -> Vector a -> [Mismatch]
compareFloat = (Int -> a -> a -> Maybe Mismatch)
-> Vector a -> Vector a -> [Mismatch]
forall t t a.
(Storable t, Storable t) =>
(Int -> t -> t -> Maybe a) -> Vector t -> Vector t -> [a]
compareGen ((Int -> a -> a -> Maybe Mismatch)
 -> Vector a -> Vector a -> [Mismatch])
-> (a -> Int -> a -> a -> Maybe Mismatch)
-> a
-> Vector a
-> Vector a
-> [Mismatch]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int -> a -> a -> Maybe Mismatch
forall a.
(Show a, RealFloat a) =>
a -> Int -> a -> a -> Maybe Mismatch
compareFloatElement

    compareGen :: (Int -> t -> t -> Maybe a) -> Vector t -> Vector t -> [a]
compareGen Int -> t -> t -> Maybe a
cmp Vector t
got Vector t
expected =
      let l :: Int
l = Vector t -> Int
forall a. Storable a => Vector a -> Int
SVec.length Vector t
got
          check :: [a] -> Int -> [a]
check [a]
acc Int
j
            | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
l =
              case Int -> t -> t -> Maybe a
cmp Int
j (Vector t
got Vector t -> Int -> t
forall a. Storable a => Vector a -> Int -> a
SVec.! Int
j) (Vector t
expected Vector t -> Int -> t
forall a. Storable a => Vector a -> Int -> a
SVec.! Int
j) of
                Just a
mismatch ->
                  [a] -> Int -> [a]
check (a
mismatch a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc) (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
                Maybe a
Nothing ->
                  [a] -> Int -> [a]
check [a]
acc (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            | Bool
otherwise =
              [a]
acc
       in [a] -> [a]
forall a. [a] -> [a]
reverse ([a] -> [a]) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ [a] -> Int -> [a]
check [] Int
0

    compareElement :: (Show a, Eq a) => Int -> a -> a -> Maybe Mismatch
    compareElement :: Int -> a -> a -> Maybe Mismatch
compareElement Int
j a
got a
expected
      | a
got a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
expected = Maybe Mismatch
forall a. Maybe a
Nothing
      | Bool
otherwise = Mismatch -> Maybe Mismatch
forall a. a -> Maybe a
Just (Mismatch -> Maybe Mismatch) -> Mismatch -> Maybe Mismatch
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (a -> Text
forall a. Show a => a -> Text
value a
got) (a -> Text
forall a. Show a => a -> Text
value a
expected)

    compareFloatElement :: (Show a, RealFloat a) => a -> Int -> a -> a -> Maybe Mismatch
    compareFloatElement :: a -> Int -> a -> a -> Maybe Mismatch
compareFloatElement a
abstol Int
j a
got a
expected
      | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
got,
        a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
expected =
        Maybe Mismatch
forall a. Maybe a
Nothing
      | a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
got,
        a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
expected,
        a -> a
forall a. Num a => a -> a
signum a
got a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a -> a
forall a. Num a => a -> a
signum a
expected =
        Maybe Mismatch
forall a. Maybe a
Nothing
      | a -> a
forall a. Num a => a -> a
abs (a
got a -> a -> a
forall a. Num a => a -> a -> a
- a
expected) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
abstol = Maybe Mismatch
forall a. Maybe a
Nothing
      | Bool
otherwise = Mismatch -> Maybe Mismatch
forall a. a -> Maybe a
Just (Mismatch -> Maybe Mismatch) -> Mismatch -> Maybe Mismatch
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (a -> Text
forall a. Show a => a -> Text
value a
got) (a -> Text
forall a. Show a => a -> Text
value a
expected)

    compareBool :: Int -> a -> a -> Maybe Mismatch
compareBool Int
j a
got a
expected
      | a
got a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
expected = Maybe Mismatch
forall a. Maybe a
Nothing
      | Bool
otherwise = Mismatch -> Maybe Mismatch
forall a. a -> Maybe a
Just (Mismatch -> Maybe Mismatch) -> Mismatch -> Maybe Mismatch
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> Text -> Text -> Mismatch
PrimValueMismatch Int
i (Int -> [Int]
unflatten Int
j) (a -> Text
forall a. Show a => a -> Text
value a
got) (a -> Text
forall a. Show a => a -> Text
value a
expected)

tolerance :: (RealFloat a, SVec.Storable a) => a -> Vector a -> a
tolerance :: a -> Vector a -> a
tolerance a
tol = (a -> a -> a) -> a -> Vector a -> a
forall b a. Storable b => (a -> b -> a) -> a -> Vector b -> a
SVec.foldl a -> a -> a
tolerance' a
tol (Vector a -> a) -> (Vector a -> Vector a) -> Vector a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Bool) -> Vector a -> Vector a
forall a. Storable a => (a -> Bool) -> Vector a -> Vector a
SVec.filter (Bool -> Bool
not (Bool -> Bool) -> (a -> Bool) -> a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Bool
forall a. RealFloat a => a -> Bool
nanOrInf)
  where
    tolerance' :: a -> a -> a
tolerance' a
t a
v = a -> a -> a
forall a. Ord a => a -> a -> a
max a
t (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
tol a -> a -> a
forall a. Num a => a -> a -> a
* a
v
    nanOrInf :: a -> Bool
nanOrInf a
x = a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
x Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x