{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PostfixOperators #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ViewPatterns #-}

-- | Guess complexity from data.
module Test.Tasty.Bench.Fit.Complexity (
  Complexity (..),
  Measurement (..),
  guessComplexity,
  evalComplexity,

  -- * Predicates
  isConstant,
  isLogarithmic,
  isLinear,
  isLinearithmic,
  isQuadratic,
  isCubic,
) where

import Control.DeepSeq (NFData)
import Data.List (intercalate, minimumBy)
import Data.List.Infinite (Infinite (..), (...))
import qualified Data.List.NonEmpty as NE
import Data.Map (Map)
import qualified Data.Map as M
import Data.Ord (comparing)
import GHC.Generics (Generic)
import Math.Regression.Simple (
  Fit (..),
  V2 (..),
  levenbergMarquardt1WithYerrors,
  levenbergMarquardt2WithYerrors,
  linear,
 )
import Text.Printf (printf)
import Prelude hiding (log)
import qualified Prelude as P

#ifdef DEBUG
import Debug.Trace
#endif

log :: Word -> Double
log :: Word -> Double
log Word
x = if Word
x forall a. Ord a => a -> a -> Bool
>= Word
1 then forall a. Floating a => a -> a
P.log (Word -> Double
d Word
x) else Double
0

-- | 'Complexity' @a@ @b@ @k@ represents a time complexity
-- \( k \, x^a \log^b x \), where \( x \) is problem's size.
data Complexity = Complexity
  { Complexity -> Double
cmplVarPower :: !Double
  , Complexity -> Word
cmplLogPower :: !Word
  , Complexity -> Double
cmplMultiplier :: !Double
  }
  deriving (Complexity -> Complexity -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Complexity -> Complexity -> Bool
$c/= :: Complexity -> Complexity -> Bool
== :: Complexity -> Complexity -> Bool
$c== :: Complexity -> Complexity -> Bool
Eq, Eq Complexity
Complexity -> Complexity -> Bool
Complexity -> Complexity -> Ordering
Complexity -> Complexity -> Complexity
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Complexity -> Complexity -> Complexity
$cmin :: Complexity -> Complexity -> Complexity
max :: Complexity -> Complexity -> Complexity
$cmax :: Complexity -> Complexity -> Complexity
>= :: Complexity -> Complexity -> Bool
$c>= :: Complexity -> Complexity -> Bool
> :: Complexity -> Complexity -> Bool
$c> :: Complexity -> Complexity -> Bool
<= :: Complexity -> Complexity -> Bool
$c<= :: Complexity -> Complexity -> Bool
< :: Complexity -> Complexity -> Bool
$c< :: Complexity -> Complexity -> Bool
compare :: Complexity -> Complexity -> Ordering
$ccompare :: Complexity -> Complexity -> Ordering
Ord, forall x. Rep Complexity x -> Complexity
forall x. Complexity -> Rep Complexity x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Complexity x -> Complexity
$cfrom :: forall x. Complexity -> Rep Complexity x
Generic)

instance NFData Complexity

-- | Is the complexity \( f(x) = k \)?
isConstant :: Complexity -> Bool
isConstant :: Complexity -> Bool
isConstant = \case
  Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
0, cmplLogPower :: Complexity -> Word
cmplLogPower = Word
0} -> Bool
True
  Complexity
_ -> Bool
False

-- | Is the complexity \( f(x) = k \log x \)?
isLogarithmic :: Complexity -> Bool
isLogarithmic :: Complexity -> Bool
isLogarithmic = \case
  Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
0, cmplLogPower :: Complexity -> Word
cmplLogPower = Word
1} -> Bool
True
  Complexity
_ -> Bool
False

-- | Is the complexity \( f(x) = k \, x \)?
isLinear :: Complexity -> Bool
isLinear :: Complexity -> Bool
isLinear = \case
  Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
1, cmplLogPower :: Complexity -> Word
cmplLogPower = Word
0} -> Bool
True
  Complexity
_ -> Bool
False

-- | Is the complexity \( f(x) = k \, x \log x \)?
isLinearithmic :: Complexity -> Bool
isLinearithmic :: Complexity -> Bool
isLinearithmic = \case
  Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
1, cmplLogPower :: Complexity -> Word
cmplLogPower = Word
1} -> Bool
True
  Complexity
_ -> Bool
False

-- | Is the complexity \( f(x) = k \, x^2 \)?
isQuadratic :: Complexity -> Bool
isQuadratic :: Complexity -> Bool
isQuadratic = \case
  Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
2, cmplLogPower :: Complexity -> Word
cmplLogPower = Word
0} -> Bool
True
  Complexity
_ -> Bool
False

-- | Is the complexity \( f(x) = k \, x^3 \)?
isCubic :: Complexity -> Bool
isCubic :: Complexity -> Bool
isCubic = \case
  Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
3, cmplLogPower :: Complexity -> Word
cmplLogPower = Word
0} -> Bool
True
  Complexity
_ -> Bool
False

instance Show Complexity where
  show :: Complexity -> String
show Complexity {Double
Word
cmplMultiplier :: Double
cmplLogPower :: Word
cmplVarPower :: Double
cmplMultiplier :: Complexity -> Double
cmplLogPower :: Complexity -> Word
cmplVarPower :: Complexity -> Double
..} =
    forall a. [a] -> [[a]] -> [a]
intercalate String
" * " forall a b. (a -> b) -> a -> b
$
      forall a. (a -> Bool) -> [a] -> [a]
filter
        (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Bool
null)
        [ case Double
cmplMultiplier of
            Double
1 -> String
""
            Double
_ -> forall r. PrintfType r => String -> r
printf String
"%.2g" Double
cmplMultiplier
        , case Double
cmplVarPower of
            Double
0 -> String
""
            Double
1 -> String
"x"
            Double
_ -> String
"x ^ " forall a. Semigroup a => a -> a -> a
<> Double -> String
round3 Double
cmplVarPower
        , case Word
cmplLogPower of
            Word
0 -> String
""
            Word
1 -> String
"log x"
            Word
_ -> String
"(log x) ^ " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Word
cmplLogPower
        ]
    where
      round3 :: Double -> String
      round3 :: Double -> String
round3 Double
x = if Double
x forall a. Eq a => a -> a -> Bool
== Word -> Double
d Word
x' then forall a. Show a => a -> String
show Word
x' else forall r. PrintfType r => String -> r
printf String
"%.3f" Double
x
        where
          x' :: Word
          x' :: Word
x' = forall a b. (RealFrac a, Integral b) => a -> b
round Double
x

-- | Evaluate time complexity for a given size of the problem.
evalComplexity :: Complexity -> Word -> Double
evalComplexity :: Complexity -> Word -> Double
evalComplexity Complexity {Double
Word
cmplMultiplier :: Double
cmplLogPower :: Word
cmplVarPower :: Double
cmplMultiplier :: Complexity -> Double
cmplLogPower :: Complexity -> Word
cmplVarPower :: Complexity -> Double
..} Word
x =
  Double
cmplMultiplier forall a. Num a => a -> a -> a
* Word -> Double
d Word
x forall a. Floating a => a -> a -> a
** Double
cmplVarPower forall a. Num a => a -> a -> a
* Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ Word
cmplLogPower

bestOf :: [(Complexity, Double)] -> Complexity
bestOf :: [(Complexity, Double)] -> Complexity
bestOf = forall a b. (a, b) -> a
fst forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
minimumBy (forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing (Complexity, Double) -> Double
weigh)
  where
    weigh :: (Complexity, Double) -> Double
weigh (Complexity {Double
Word
cmplMultiplier :: Double
cmplLogPower :: Word
cmplVarPower :: Double
cmplMultiplier :: Complexity -> Double
cmplLogPower :: Complexity -> Word
cmplVarPower :: Complexity -> Double
..}, Double
wssr) =
      Double
wssr
        forall a. Num a => a -> a -> a
* Double
powPenalty
        -- Penalty for high power of logarithm.
        forall a. Num a => a -> a -> a
* Word -> Double
d (forall a. Ord a => a -> a -> a
max Word
1 Word
cmplLogPower)
      where
        -- Penalty for non-integer power.
        powPenalty :: Double
        powPenalty :: Double
powPenalty = case forall a. Num a => a -> a
abs (Double
cmplVarPower forall a. Num a => a -> a -> a
- Word -> Double
d (forall a b. (RealFrac a, Integral b) => a -> b
round Double
cmplVarPower)) of
          Double
0 -> Double
1
          -- Severe penalty for almost integer powers
          Double
diff ->
            if Double
diff forall a. Ord a => a -> a -> Bool
< Double
0.05
              then Double
100
              else (if Double
diff forall a. Ord a => a -> a -> Bool
< Double
0.15 then Double
32 else Double
10)

-- | Represents a time measurement for a given problem's size.
data Measurement = Measurement
  { Measurement -> Double
measTime :: !Double
  , Measurement -> Double
measStDev :: !Double
  }
  deriving (Measurement -> Measurement -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Measurement -> Measurement -> Bool
$c/= :: Measurement -> Measurement -> Bool
== :: Measurement -> Measurement -> Bool
$c== :: Measurement -> Measurement -> Bool
Eq, Eq Measurement
Measurement -> Measurement -> Bool
Measurement -> Measurement -> Ordering
Measurement -> Measurement -> Measurement
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Measurement -> Measurement -> Measurement
$cmin :: Measurement -> Measurement -> Measurement
max :: Measurement -> Measurement -> Measurement
$cmax :: Measurement -> Measurement -> Measurement
>= :: Measurement -> Measurement -> Bool
$c>= :: Measurement -> Measurement -> Bool
> :: Measurement -> Measurement -> Bool
$c> :: Measurement -> Measurement -> Bool
<= :: Measurement -> Measurement -> Bool
$c<= :: Measurement -> Measurement -> Bool
< :: Measurement -> Measurement -> Bool
$c< :: Measurement -> Measurement -> Bool
compare :: Measurement -> Measurement -> Ordering
$ccompare :: Measurement -> Measurement -> Ordering
Ord, forall x. Rep Measurement x -> Measurement
forall x. Measurement -> Rep Measurement x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Measurement x -> Measurement
$cfrom :: forall x. Measurement -> Rep Measurement x
Generic)

instance Show Measurement where
  show :: Measurement -> String
show (Measurement Double
t Double
err) = forall r. PrintfType r => String -> r
printf String
"%.3g ± %.3g" Double
t Double
err

instance NFData Measurement

-- | Guess time complexity from a map where keys
-- are problem's sizes and values are time measurements (or instruction counts).
--
-- >>> :set -XNumDecimals
-- >>> guessComplexity $ Data.Map.fromList $ map (\(x, t) -> (x, Measurement t 1)) [(2, 4), (3, 10), (4, 15), (5, 25)]
-- 0.993 * x ^ 2
-- >>> guessComplexity $ Data.Map.fromList $ map (\(x, t) -> (x, Measurement t 1)) [(1e2, 2.1), (1e3, 2.9), (1e4, 4.1), (1e5, 4.9)]
-- 0.433 * log x
--
-- This function uses following simplifying assumptions:
--
-- * All coefficients are non-negative.
-- * The power of \( \log x \) ('cmplLogPower') is unlikely to be \( > 1 \).
-- * The power of \( x \) ('cmplVarPower') is unlikely to be fractional.
--
-- This function is unsuitable to guess
-- [superpolynomial](https://en.wikipedia.org/wiki/Time_complexity#Superpolynomial_time)
-- and higher classes of complexity.
guessComplexity :: Map Word Measurement -> Complexity
guessComplexity :: Map Word Measurement -> Complexity
guessComplexity Map Word Measurement
xys =
  forall b. String -> b -> b
trace'
    (String
"guessComplexity " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show (forall k a. Map k a -> [(k, a)]
M.assocs Map Word Measurement
xys))
    [(Complexity, Double)] -> Complexity
bestOf
    (Infinite ((Complexity, Double), (Complexity, Double))
-> [(Complexity, Double)]
takeUntilLocalMin Infinite ((Complexity, Double), (Complexity, Double))
cmpls)
  where
    cmpls :: Infinite ((Complexity, Double), (Complexity, Double))
    cmpls :: Infinite ((Complexity, Double), (Complexity, Double))
cmpls = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Map Word Measurement
-> Word -> ((Complexity, Double), (Complexity, Double))
guessComplexityForFixedLog Map Word Measurement
xys) (Word
0 ...)

    takeUntilLocalMin
      :: Infinite ((Complexity, Double), (Complexity, Double))
      -> [(Complexity, Double)]
    takeUntilLocalMin :: Infinite ((Complexity, Double), (Complexity, Double))
-> [(Complexity, Double)]
takeUntilLocalMin (((Complexity, Double)
c1, (Complexity, Double)
c2) :< ((Complexity, Double)
c3, (Complexity, Double)
c4) :< Infinite ((Complexity, Double), (Complexity, Double))
cs)
      | forall a b. (a, b) -> b
snd (Complexity, Double)
c1 forall a. Ord a => a -> a -> Bool
> forall a b. (a, b) -> b
snd (Complexity, Double)
c3 Bool -> Bool -> Bool
|| forall a b. (a, b) -> b
snd (Complexity, Double)
c2 forall a. Ord a => a -> a -> Bool
> forall a b. (a, b) -> b
snd (Complexity, Double)
c4 =
          (Complexity, Double)
c1 forall a. a -> [a] -> [a]
: (Complexity, Double)
c2 forall a. a -> [a] -> [a]
: Infinite ((Complexity, Double), (Complexity, Double))
-> [(Complexity, Double)]
takeUntilLocalMin (((Complexity, Double)
c3, (Complexity, Double)
c4) forall a. a -> Infinite a -> Infinite a
:< Infinite ((Complexity, Double), (Complexity, Double))
cs)
      | Bool
otherwise =
          [(Complexity, Double)
c1, (Complexity, Double)
c2]

guessComplexityForFixedLog
  :: Map Word Measurement
  -> Word
  -> ((Complexity, Double), (Complexity, Double))
guessComplexityForFixedLog :: Map Word Measurement
-> Word -> ((Complexity, Double), (Complexity, Double))
guessComplexityForFixedLog Map Word Measurement
xys Word
logPow = forall b. String -> b -> b
trace' String
msg ((Complexity, Double), (Complexity, Double))
res
  where
    -- varPow might be negative here, so always pass it through mkCmpl
    V2 Double
_ Double
varPow = Map Word Measurement -> Word -> V2
guessComplexityWithoutLog Map Word Measurement
xys Word
logPow
    mkCmpl :: Double -> (Complexity, Double)
mkCmpl Double
varPow' = Map Word Measurement -> Double -> Word -> (Complexity, Double)
guessComplexityForFixedPowAndLog Map Word Measurement
xys Double
varPow' Word
logPow
    res :: ((Complexity, Double), (Complexity, Double))
res@((Complexity
res1, Double
wssr1), (Complexity
res2, Double
wssr2)) =
      (Double -> (Complexity, Double)
mkCmpl (forall a. Ord a => a -> a -> a
max Double
0 Double
varPow), Double -> (Complexity, Double)
mkCmpl (Word -> Double
d (forall a b. (RealFrac a, Integral b) => a -> b
round Double
varPow)))

    msg :: String
msg =
      forall r. PrintfType r => String -> r
printf
        String
"forFixedLog:\n\t%s, RSS %.4g\n\t%s, RSS %.4g"
        (forall a. Show a => a -> String
show Complexity
res1)
        Double
wssr1
        (forall a. Show a => a -> String
show Complexity
res2)
        Double
wssr2

guessComplexityWithoutLog :: Map Word Measurement -> Word -> V2
guessComplexityWithoutLog :: Map Word Measurement -> Word -> V2
guessComplexityWithoutLog (forall k a. Map k a -> [(k, a)]
M.assocs -> [(Word, Measurement)]
xys) Word
logPow = V2
finish
  where
    -- Fit y_i ~ a x_i^b, which is equivalent to log y_i ~ log a + b log x_i.
    -- This is not ideal, because minimizing the sum of (log y_i - log a - b log x_i) ^ 2
    -- is not equivalent to minimizing the sum of (y_i - a * x_i^b) ^ 2, but close enough,
    -- so we are going to use it as a starting point for Levenberg-Marquardt.
    V2 Double
b0 Double
la0 =
      forall (f :: * -> *) a.
Foldable f =>
(a -> (Double, Double)) -> f a -> V2
linear (\(Word
x, Measurement Double
y Double
_) -> (Word -> Double
log Word
x, forall a. Floating a => a -> a
P.log (Double
y forall a. Fractional a => a -> a -> a
/ Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ Word
logPow))) [(Word, Measurement)]
xys
    start :: V2
start = Double -> Double -> V2
V2 (forall a. Floating a => a -> a
exp Double
la0) (forall a. Ord a => a -> a -> a
max Double
0 Double
b0)

    Fit {fitParams :: forall v. Fit v -> v
fitParams = V2
finish} =
      forall a. NonEmpty a -> a
NE.last forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a.
Foldable f =>
(V2 -> a -> (Double, Double, V2, Double))
-> V2 -> f a -> NonEmpty (Fit V2)
levenbergMarquardt2WithYerrors
          ( \(V2 Double
mult Double
varPow) (Word
x, Measurement Double
y Double
err) ->
              ( Double
y
              , Double
mult forall a. Num a => a -> a -> a
* Word -> Double
d Word
x forall a. Floating a => a -> a -> a
** Double
varPow forall a. Num a => a -> a -> a
* Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ Word
logPow
              , Double -> Double -> V2
V2
                  (Word -> Double
d Word
x forall a. Floating a => a -> a -> a
** Double
varPow forall a. Num a => a -> a -> a
* Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ Word
logPow)
                  (Double
mult forall a. Num a => a -> a -> a
* Word -> Double
d Word
x forall a. Floating a => a -> a -> a
** Double
varPow forall a. Num a => a -> a -> a
* Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ (Word
logPow forall a. Num a => a -> a -> a
+ Word
1))
              , Double
err
              )
          )
          V2
start
          [(Word, Measurement)]
xys

guessComplexityForFixedPowAndLog
  :: Map Word Measurement
  -> Double
  -> Word
  -> (Complexity, Double)
guessComplexityForFixedPowAndLog :: Map Word Measurement -> Double -> Word -> (Complexity, Double)
guessComplexityForFixedPowAndLog (forall k a. Map k a -> [(k, a)]
M.assocs -> [(Word, Measurement)]
xys) Double
varPow Word
logPow = (Complexity
res, Double
wssr)
  where
    -- We want to find a which minimizes \sum_i (y_i - a f(x_i))^2 for f(x) = x^b * log^c x.
    -- Then d/da = 0 means that \sum_i (2 a f(x_i)^2 - 2 f(x_i) y_i) = 0
    -- or equivalently a = \sum_i f(x_i) y_i / \sum_i x_i^2.
    eval :: Word -> Double
eval Word
x = Word -> Double
d Word
x forall a. Floating a => a -> a -> a
** Double
varPow forall a. Num a => a -> a -> a
* Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ Word
logPow
    sumXY :: Double
sumXY = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Word
x, Measurement Double
y Double
_) -> Word -> Double
eval Word
x forall a. Num a => a -> a -> a
* Double
y) [(Word, Measurement)]
xys
    sumX2 :: Double
sumX2 = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\(Word
x, Measurement
_) -> Word -> Double
eval Word
x forall a. Floating a => a -> a -> a
** Double
2) [(Word, Measurement)]
xys
    start :: Double
start = Double
sumXY forall a. Fractional a => a -> a -> a
/ Double
sumX2

    ft :: Fit Double
ft =
      forall a. NonEmpty a -> a
NE.last forall a b. (a -> b) -> a -> b
$
        forall (f :: * -> *) a.
Foldable f =>
(Double -> a -> (Double, Double, Double, Double))
-> Double -> f a -> NonEmpty (Fit Double)
levenbergMarquardt1WithYerrors
          ( \Double
mult (Word
x, Measurement Double
y Double
err) ->
              ( Double
y
              , Double
mult forall a. Num a => a -> a -> a
* Word -> Double
d Word
x forall a. Floating a => a -> a -> a
** Double
varPow forall a. Num a => a -> a -> a
* Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ Word
logPow
              , Word -> Double
d Word
x forall a. Floating a => a -> a -> a
** Double
varPow forall a. Num a => a -> a -> a
* Word -> Double
log Word
x forall a b. (Num a, Integral b) => a -> b -> a
^ Word
logPow
              , Double
err
              )
          )
          Double
start
          [(Word, Measurement)]
xys
    res :: Complexity
res =
      Complexity
        { cmplMultiplier :: Double
cmplMultiplier = forall v. Fit v -> v
fitParams Fit Double
ft
        , cmplVarPower :: Double
cmplVarPower = Double
varPow
        , cmplLogPower :: Word
cmplLogPower = Word
logPow
        }
    wssr :: Double
wssr = forall v. Fit v -> Double
fitWSSR Fit Double
ft

d :: Word -> Double
d :: Word -> Double
d = forall a b. (Integral a, Num b) => a -> b
fromIntegral

trace' :: String -> b -> b
#ifdef DEBUG
trace' = trace
#else
trace' :: forall b. String -> b -> b
trace' = forall a b. a -> b -> a
const forall a. a -> a
id
#endif