module Test.Maybench where

import System.Time
import System.Cmd (system) -- ideally this should use System.Process in the future, but for the sake of a first version this will do.
import Data.Maybe (maybe, isJust, fromJust)
import Control.Monad (when,replicateM)
import Control.Monad.State (MonadIO, liftIO)
import System.Directory (findExecutable)
import System.IO (putStr,hPutStr,hClose,hGetContents)
import System.Process (waitForProcess, runInteractiveProcess)

import Test.Maybench.Command (CommandModifier, Command(Cmd), modifyCmd)

data Benchmark = Benchmark {benchIters :: Int, benchTimes :: [TimeDiff]}

run :: MonadIO m => CommandModifier m -> m (String, String)
run cmd = modifyCmd cmd >>= (\m -> runC $ m (Cmd "" [] ""))

runC :: MonadIO m => Command -> m (String, String)
runC (Cmd exe' args input) = liftIO $ do
  exe <- findExecutable exe' >>= maybe (fail $ "cannot find " ++ exe') return
  putStr "Running... "
  let cmd_str = unwords $ map showSh (exe:args)
  putStrLn cmd_str
  (output, err) <- runProcessWithInput exe args input
  return (output, err)
  where showSh x | ' ' `elem` x = show x
                 | otherwise = x

runProcessWithInput :: FilePath -> [String] -> String -> IO (String, String)
runProcessWithInput cmd args input = do
    (pin, pout, perr, ph) <- runInteractiveProcess cmd args Nothing Nothing
    hPutStr pin input
    hClose pin
    output <- hGetContents pout
    when (output==output) $ return ()
    err <- hGetContents perr
    when (err==err) $ return ()
    hClose pout
    hClose perr
    waitForProcess ph -- should check exit code here...
    return (output, err)

bench :: Maybe (IO a) -- ^ setup
      -> IO b         -- ^ action
      -> Maybe (IO c) -- ^ cleanup
      -> Int          -- ^ iterations
      -> IO Benchmark
bench setup action cleanup reps = do times <- replicateM reps core
                                     return $ Benchmark reps times
    where core = do maybe (return ()) (>> return ()) setup
                    start <- getClockTime
                    action
                    end <- getClockTime
                    maybe (return ()) (>> return ()) cleanup
                    return $ end `diffClockTimes` start

benchSimple :: IO a -> Int -> IO Benchmark
benchSimple f = bench Nothing f Nothing

timeProgram :: String -> String -> String -> IO (String, TimeDiff)
timeProgram cmd setup cleanup = do time <- bench (Just $ system setup) (system cmd) (Just $ system cleanup) 1
                                   return $ (cmd,averageTimeDiffs $ benchTimes time)

averageTimeDiffs :: [TimeDiff] -> TimeDiff
averageTimeDiffs = secondsToTimeDiff . mean . map timeDiffToSeconds
    where mean xs = sum xs `div` length xs

averageTime :: String -> String -> String -> Int -> IO (String, TimeDiff)
averageTime cmd setup cleanup n = do times <- replicateM n (timeProgram cmd setup cleanup)
                                     return (cmd,averageTimeDiffs (map snd times))

showTimeDiff :: (String, TimeDiff) -> String
showTimeDiff (cmd,td) = case filter isJust [helper tdYear "years",
                                            helper tdMonth "months",
                                            helper tdDay "days",
                                            helper tdHour "hours",
                                            helper tdMin "minutes",
                                            helper tdSec "seconds"]
                        of [] -> (show cmd) ++ " took less than a second."
                           xs -> (((show cmd) ++ " took ") ++) . intercalate ", " . map fromJust $ xs
    where helper accessor string = if accessor td > 0
                                   then (Just (show (accessor td) ++ " " ++ string))
                                   else Nothing
          intercalate _ [] = []
          intercalate x (y:ys) = y++x++intercalate x ys

printTimeDiff :: (String, TimeDiff) -> IO ()
printTimeDiff = putStrLn . showTimeDiff

minute, hour, day, month, year :: Int
minute = 60
hour = minute * 60
day = hour * 24
month = day * 30
year = day * 365

timeDiffToSeconds :: TimeDiff -> Int
timeDiffToSeconds td = tdSec td + (tdMin td) * minute + (tdHour td) * hour + (tdDay td) * day + (tdMonth td) * month + (tdYear td) * year

secondsToTimeDiff :: Int -> TimeDiff
secondsToTimeDiff sec = normalizeTimeDiff $ TimeDiff 0 0 0 0 0 sec 0

compareTimes :: (String, TimeDiff) -> (String, TimeDiff) -> Maybe String
compareTimes (cmd1,td1) (cmd2,td2) = case (td1,td2) of
                                       (TimeDiff 0 0 0 0 0 0 _,
                                        TimeDiff 0 0 0 0 0 0 _) -> Nothing
                                       _ -> Just $ show cmd2 ++ " took " ++
                                            (show ((fromIntegral $ timeDiffToSeconds td2) `percentage` (fromIntegral $ timeDiffToSeconds td1) :: Double))
                                            ++ "% of the time " ++ show cmd1 ++ " took."
    where percentage x y = (fromIntegral $ (truncate $ (x / y * 10000 :: Double) :: Int)) / 100