module Main where

import Control.Monad
import Control.Concurrent
import Foreign.C.Types
import Foreign.CInvoke
import Numeric
import CPUTime
import Time
import Ratio
import System.Environment
import System.Exit
import System.Mem

readMiB s = ceiling $ (read s * 2^20) / 2

main = do
    args <- getArgs
    (sz, cnt) <- case args of
                [n] -> return (readMiB n, 10)
                [n,c] -> return (readMiB n, read c)
                []  -> putStrLn "usage: MemSpeed megabytes-to-use [count]" >> exitWith (ExitFailure 1)

    context <- newContext
    libc <- loadLibrary context "libc.so.6"
    memset <- loadSymbol libc "memset"
    memcpy <- loadSymbol libc "memcpy"
    malloc <- loadSymbol libc "malloc"
    free <- loadSymbol libc "free"

    s <- cinvoke malloc (retPtr retVoid) [argCSize sz]
    d <- cinvoke malloc (retPtr retVoid) [argCSize sz]
    cinvoke memcpy retVoid [argPtr d, argPtr s, argCSize sz]
    check (cnt*sz) "memcpy" $ replicateM_ (fromIntegral cnt) $ cinvoke memcpy retVoid [argPtr d, argPtr s, argCSize sz]
    cinvoke free retVoid [argPtr s]
    cinvoke free retVoid [argPtr d]

    p <- cinvoke malloc (retPtr retVoid) [argCSize (2 * sz)]
    cinvoke memset retVoid [argPtr p, argCInt 97, argCSize (2 * sz)]
    check (2*cnt*sz) "memset" $ replicateM_ (fromIntegral cnt) $ cinvoke memset retVoid [argPtr p, argCInt 97, argCSize (2 * sz)]
    cinvoke free retVoid [argPtr p]

    performGC
    threadDelay (10^5)

check sz s a = do
    (r, cpu, clock) <- timeIt a
    putStrLn $ s ++ ": "
                        ++ showf 2 ((fromIntegral sz / cpu) / (2 ^ 20)) ++ " mb/cpu sec  "
                        ++ showf 2 ((fromIntegral sz / clock) / (2 ^ 20)) ++ " mb/clock sec  "
    return r

type TimeIt     = (Integer, ClockTime)

timeItStart     :: IO TimeIt
timeItStart     = liftM2 (,) getCPUTime getClockTime

timeItEnd       :: TimeIt -> IO (Double, Double)
timeItEnd (startCPU, startClock) = do
    stopCPU <- getCPUTime
    stopClock <- getClockTime
    let
        cpuTime     = (fromIntegral (stopCPU - startCPU) / 10^12)
        clockTime   = (timeDiffToSec $ diffClockTimes stopClock startClock)
    return (cpuTime, clockTime)
    where
        timeDiffToSec td
            = fromIntegral (tdSec td) + fromIntegral (tdPicosec td) / 10^12

{- | @timeIt action@ executes @action@, then returns
   a tuple of its result, CPU- and wallclock-time elapsed. -}
timeIt :: IO a -> IO (a, Double, Double)
timeIt a = do
    t <- timeItStart
    r <- a
    (cpuTime, clockTime) <- timeItEnd t
    return (r, cpuTime, clockTime)

showf           :: RealFloat a => Int -> a -> String
showf n x
    | x >= 0    = ' ':s
    | otherwise = s
    where
        s       = showFFloat (Just n) x ""