{-# LANGUAGE CPP #-}
{-# LANGUAGE NumDecimals #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- | Guess complexity of the function.
module Test.Tasty.Bench.Fit (
  -- * Fit benchmarks
  fit,
  fits,
  mkFitConfig,
  FitConfig (..),

  -- * Complexity
  Complexity (..),
  Measurement (..),
  guessComplexity,
  evalComplexity,

  -- * Predicates
  isConstant,
  isLogarithmic,
  isLinear,
  isLinearithmic,
  isQuadratic,
  isCubic,
) where

import Control.DeepSeq (NFData)
import Data.List (maximumBy)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE
import Data.Map (Map)
import qualified Data.Map as M
import Data.Ord (comparing)
import System.IO.Unsafe (unsafeInterleaveIO)
import Test.Tasty (Timeout, mkTimeout)
import Test.Tasty.Bench (Benchmarkable, RelStDev (..), measureCpuTimeAndStDev, nf)
import Test.Tasty.Bench.Fit.Complexity (
  Complexity (..),
  Measurement (..),
  evalComplexity,
  guessComplexity,
  isConstant,
  isCubic,
  isLinear,
  isLinearithmic,
  isLogarithmic,
  isQuadratic,
 )

#ifdef DEBUG
import Debug.Trace
#endif

-- | Configuration for 'fit'.
data FitConfig = FitConfig
  { FitConfig -> Word -> Benchmarkable
fitBench :: Word -> Benchmarkable
  -- ^ Which function to measure? Typically 'nf' @f@.
  , FitConfig -> Word
fitLow :: Word
  -- ^ The smallest size of the input.
  -- It should be as small as possible, but big enough for the main asymptotic
  -- term to dwarf constant overhead and other terms.
  , FitConfig -> Word
fitHigh :: Word
  -- ^ The largest size of the input.
  -- As large as practically possible, at least 100x larger than
  -- the smallest size.
  , FitConfig -> Timeout
fitTimeout :: Timeout
  -- ^ Timeout of individual measurements.
  , FitConfig -> RelStDev
fitRelStDev :: RelStDev
  -- ^ Target relative standard deviation of individual measurements.
  , FitConfig -> Map Word Measurement -> Complexity
fitOracle :: Map Word Measurement -> Complexity
  -- ^ An oracle to determine complexity from measurements.
  -- Typically 'guessComplexity'.
  }

-- | Generate a default 'fit' configuration.
mkFitConfig
  :: (NFData a)
  => (Word -> a)
  -- ^ Raw function to measure, without 'nf'.
  -> (Word, Word)
  -- ^ The smallest and the largest sizes of the input.
  -> FitConfig
mkFitConfig :: forall a. NFData a => (Word -> a) -> (Word, Word) -> FitConfig
mkFitConfig Word -> a
f (Word
low, Word
high) =
  FitConfig
    { fitBench :: Word -> Benchmarkable
fitBench = forall b a. NFData b => (a -> b) -> a -> Benchmarkable
nf Word -> a
f
    , fitLow :: Word
fitLow = Word
low
    , fitHigh :: Word
fitHigh = Word
high
    , fitTimeout :: Timeout
fitTimeout = Integer -> Timeout
mkTimeout Integer
1e8
    , fitRelStDev :: RelStDev
fitRelStDev = Double -> RelStDev
RelStDev Double
0.02
    , fitOracle :: Map Word Measurement -> Complexity
fitOracle = Map Word Measurement -> Complexity
guessComplexity
    }

-- | Determine time complexity of the function:
--
-- >>> fit $ mkFitConfig (\x -> sum [1..x]) (10, 10000)
-- 1.2153e-8 * x
-- >>> fit $ mkFitConfig (\x -> Data.List.nub [1..x]) (10, 10000)
-- 2.8369e-9 * x ^ 2
-- >>> fit $ mkFitConfig (\x -> Data.List.sort $ take (fromIntegral x) $ iterate (\n -> n * 6364136223846793005 + 1) (1 :: Int)) (10, 100000)
-- 5.2990e-8 * x * log x
--
-- One can usually get reliable results for functions, which do not
-- allocate much: like in-place vector sort or fused list operations like
-- 'sum' @[1..x]@.
--
-- Unfortunately, fitting functions, which allocate a lot,
-- is likely to be disappointing: GC kicks in irregularly depending on nursery
-- and heap sizes and often skews observations beyond any recognition.
-- Consider running such measurements with @-O0@ or in @ghci@ prompt. This is how
-- the usage example above was generated. Without optimizations your program
-- allocates much more and triggers GC regularly, somewhat evening out its effect.
fit :: FitConfig -> IO Complexity
fit :: FitConfig -> IO Complexity
fit FitConfig
cnf = NonEmpty Complexity -> Complexity
converge forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FitConfig -> IO (NonEmpty Complexity)
fits FitConfig
cnf

converge :: NonEmpty Complexity -> Complexity
converge :: NonEmpty Complexity -> Complexity
converge NonEmpty Complexity
xs = case [(Complexity, Complexity, Complexity)]
zs of
  [] -> forall a. NonEmpty a -> a
NE.last NonEmpty Complexity
xs
  (Complexity
_, Complexity
_, Complexity
z) : [(Complexity, Complexity, Complexity)]
_ -> Complexity
z
  where
    ys :: [Complexity]
ys = forall a. NonEmpty a -> [a]
NE.toList NonEmpty Complexity
xs
    zs :: [(Complexity, Complexity, Complexity)]
zs =
      forall a. (a -> Bool) -> [a] -> [a]
dropWhile (\(Complexity
x, Complexity
y, Complexity
z) -> Complexity -> Complexity -> Bool
p Complexity
x Complexity
z Bool -> Bool -> Bool
|| Complexity -> Complexity -> Bool
p Complexity
y Complexity
z) forall a b. (a -> b) -> a -> b
$
        forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Complexity]
ys (forall a. [a] -> [a]
tail [Complexity]
ys) (forall a. Int -> [a] -> [a]
drop Int
2 [Complexity]
ys)
    p :: Complexity -> Complexity -> Bool
p
      Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
varPow, cmplLogPower :: Complexity -> Word
cmplLogPower = Word
logPow, cmplMultiplier :: Complexity -> Double
cmplMultiplier = Double
mult}
      Complexity {cmplVarPower :: Complexity -> Double
cmplVarPower = Double
varPow', cmplLogPower :: Complexity -> Word
cmplLogPower = Word
logPow', cmplMultiplier :: Complexity -> Double
cmplMultiplier = Double
mult'} =
        forall a. Num a => a -> a
abs (Double
varPow forall a. Num a => a -> a -> a
- Double
varPow') forall a. Ord a => a -> a -> Bool
> Double
0.001
          Bool -> Bool -> Bool
|| Word
logPow forall a. Eq a => a -> a -> Bool
/= Word
logPow'
          Bool -> Bool -> Bool
|| forall a. Num a => a -> a
abs ((Double
mult forall a. Num a => a -> a -> a
- Double
mult') forall a. Fractional a => a -> a -> a
/ Double
mult) forall a. Ord a => a -> a -> Bool
> Double
0.01

-- | Same as 'fit', but interactively emits a list of complexities,
-- gradually converging to the final result.
--
-- If 'fit' takes too long, you might wish to implement your own criterion
-- of convergence atop of 'fits' directly.
fits :: FitConfig -> IO (NonEmpty Complexity)
fits :: FitConfig -> IO (NonEmpty Complexity)
fits FitConfig {Word
Timeout
RelStDev
Word -> Benchmarkable
Map Word Measurement -> Complexity
fitOracle :: Map Word Measurement -> Complexity
fitRelStDev :: RelStDev
fitTimeout :: Timeout
fitHigh :: Word
fitLow :: Word
fitBench :: Word -> Benchmarkable
fitOracle :: FitConfig -> Map Word Measurement -> Complexity
fitRelStDev :: FitConfig -> RelStDev
fitTimeout :: FitConfig -> Timeout
fitHigh :: FitConfig -> Word
fitLow :: FitConfig -> Word
fitBench :: FitConfig -> Word -> Benchmarkable
..} = forall a. IO a -> IO a
unsafeInterleaveIO forall a b. (a -> b) -> a -> b
$ do
  Measurement
lowTime <- Word -> IO Measurement
measure Word
fitLow
  Measurement
highTime <- Word -> IO Measurement
measure Word
fitHigh
  let mp :: Map Word Measurement
mp = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Word
fitLow, Measurement
lowTime), (Word
fitHigh, Measurement
highTime)]
      cmpl :: Complexity
cmpl = Map Word Measurement -> Complexity
fitOracle Map Word Measurement
mp
  Complexity
cmpl seq :: forall a b. a -> b -> b
`seq` (Complexity
cmpl forall a. a -> [a] -> NonEmpty a
:|) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map Word Measurement -> IO [Complexity]
go Map Word Measurement
mp
  where
    measure :: Word -> IO Measurement
    measure :: Word -> IO Measurement
measure =
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Double -> Double -> Measurement
Measurement)
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. Timeout -> RelStDev -> Benchmarkable -> IO (Double, Double)
measureCpuTimeAndStDev Timeout
fitTimeout RelStDev
fitRelStDev
        forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word -> Benchmarkable
fitBench

    processGap
      :: forall t
       . (Ord t)
      => [(Word, t)]
      -> Map Word Measurement
      -> IO (Map Word Measurement)
    processGap :: forall t.
Ord t =>
[(Word, t)] -> Map Word Measurement -> IO (Map Word Measurement)
processGap [(Word, t)]
gaps Map Word Measurement
mp
      | forall k a. Map k a -> Bool
M.null Map Word t
gaps' = forall (f :: * -> *) a. Applicative f => a -> f a
pure Map Word Measurement
mp
      | Bool
otherwise = (\Measurement
m -> forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Word
maxGap Measurement
m Map Word Measurement
mp) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Word -> IO Measurement
measure Word
maxGap
      where
        gaps' :: Map Word t
gaps' = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Word, t)]
gaps forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.difference` Map Word Measurement
mp
        maxGap :: Word
maxGap = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a.
Foldable t =>
(a -> a -> Ordering) -> t a -> a
maximumBy (forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing forall a b. (a, b) -> b
snd) forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map Word t
gaps'

    go :: Map Word Measurement -> IO [Complexity]
    go :: Map Word Measurement -> IO [Complexity]
go Map Word Measurement
mp = forall a. IO a -> IO a
unsafeInterleaveIO forall a b. (a -> b) -> a -> b
$ do
      let xys :: [(Word, Double)]
xys = forall k a. Map k a -> [(k, a)]
M.toAscList forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Measurement -> Double
measTime Map Word Measurement
mp
          paired :: [((Word, Double), (Word, Double))]
paired = forall a b. [a] -> [b] -> [(a, b)]
zip [(Word, Double)]
xys (forall a. Int -> [a] -> [a]
drop Int
1 [(Word, Double)]
xys)

          arithGaps :: [(Word, Double)]
          arithGaps :: [(Word, Double)]
arithGaps =
            forall a b. (a -> b) -> [a] -> [b]
map
              (\((Word
x, Double
tx), (Word
y, Double
ty)) -> (forall a b. (RealFrac a, Integral b) => a -> b
round ((Word -> Double
d Word
x forall a. Num a => a -> a -> a
+ Word -> Double
d Word
y) forall a. Fractional a => a -> a -> a
/ Double
2), Double
ty forall a. Num a => a -> a -> a
- Double
tx))
              [((Word, Double), (Word, Double))]
paired

          geomGaps :: [(Word, Double)]
          geomGaps :: [(Word, Double)]
geomGaps =
            forall a b. (a -> b) -> [a] -> [b]
map
              (\((Word
x, Double
tx), (Word
y, Double
ty)) -> (forall a b. (RealFrac a, Integral b) => a -> b
round (forall a. Floating a => a -> a
sqrt (Word -> Double
d Word
x forall a. Num a => a -> a -> a
* Word -> Double
d Word
y)), Double
ty forall a. Fractional a => a -> a -> a
/ Double
tx))
              [((Word, Double), (Word, Double))]
paired

      Map Word Measurement
mp' <- forall t.
Ord t =>
[(Word, t)] -> Map Word Measurement -> IO (Map Word Measurement)
processGap [(Word, Double)]
arithGaps Map Word Measurement
mp
      Map Word Measurement
mp'' <- forall t.
Ord t =>
[(Word, t)] -> Map Word Measurement -> IO (Map Word Measurement)
processGap [(Word, Double)]
geomGaps Map Word Measurement
mp'
      forall (m :: * -> *) a. (Applicative m, Show a) => a -> m ()
traceShowM' (forall k a. Map k a -> [k]
M.keys Map Word Measurement
mp'')
      let cmpl :: Complexity
cmpl = Map Word Measurement -> Complexity
fitOracle Map Word Measurement
mp''
      forall (m :: * -> *) a. (Applicative m, Show a) => a -> m ()
traceShowM' Complexity
cmpl
      (Complexity
cmpl forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (if Map Word Measurement
mp forall a. Eq a => a -> a -> Bool
== Map Word Measurement
mp'' then forall (f :: * -> *) a. Applicative f => a -> f a
pure [] else Map Word Measurement -> IO [Complexity]
go Map Word Measurement
mp'')

d :: Word -> Double
d :: Word -> Double
d = forall a b. (Integral a, Num b) => a -> b
fromIntegral

traceShowM' :: (Applicative m, Show a) => a -> m ()
#ifdef DEBUG
traceShowM' = traceShowM
#else
traceShowM' :: forall (m :: * -> *) a. (Applicative m, Show a) => a -> m ()
traceShowM' = forall a b. a -> b -> a
const (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
#endif