module Control.Concurrent.ParallelTasks.Cache (parMapCache) where
import Control.Applicative ((<$>), (<*>))
import Control.Concurrent.STM (atomically)
import Control.Concurrent.STM.TMVar (newTMVar, putTMVar, takeTMVar)
import Control.DeepSeq (NFData, force)
import Control.Exception as E(catch, evaluate, IOException)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.ST (ST, runST)
import qualified Data.ByteString as BS
import Data.Int (Int64)
import Data.Serialize
import Data.String.Here.Interpolated (i)
import Data.Time.Clock (getCurrentTime)
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector as V
import Data.Vector.Algorithms.Intro as MU
import System.IO (Handle, IOMode(..), SeekMode(..), hClose, hSeek, hTell, openFile, withFile)
import System.IO (hFlush, hPutStrLn)
import Control.Concurrent.ParallelTasks.Base (ExtendedParTaskOpts(..), ParTaskOpts(..), TaskOutcome(..), parallelTasks)
type Location = (Int64, Int64)
type CacheStem = String
data CacheOutcome key = CacheHit | CacheMissSuccess | CacheMissTookTooLong key
isCacheHit :: CacheOutcome a -> Bool
isCacheHit CacheHit = True
isCacheHit _ = False
isCacheMissSuccess :: CacheOutcome a -> Bool
isCacheMissSuccess CacheMissSuccess = True
isCacheMissSuccess _ = False
indexFile :: CacheStem -> FilePath
indexFile = (++ "-index")
payloadFile :: CacheStem -> FilePath
payloadFile = (++ "-payload")
readKeysFromCache :: (U.Unbox key, Serialize key) => CacheStem -> IO (U.Vector (key, Location))
readKeysFromCache cacheStem = (readKeysFromCache' <$> BS.readFile (indexFile cacheStem)) `E.catch` (\(_e :: IOException) -> return $ U.fromList [])
readKeysFromCache' :: (U.Unbox key, Serialize key) => BS.ByteString -> U.Vector (key, Location)
readKeysFromCache' origFull =
let (count, table) = BS.splitAt 8 origFull
keysAmount :: Int64
keysAmount = either error id $ runGet get count
in either (const U.empty) id $ runGet (U.replicateM (fromIntegral keysAmount) getKeyLocation) table
where
getKeyLocation = (,) <$> get <*> ((,) <$> get <*> get)
withCache :: forall key value a. (Ord key, MU.Unbox key, Serialize key, NFData value, Serialize value) =>
CacheStem -> Handle -> Int -> (U.Vector (key, Location) -> (Location -> IO value) -> (key -> value -> IO ()) -> IO a) -> IO a
withCache cacheStem logHandle maxNewKeys inner = withFile (payloadFile cacheStem) ReadWriteMode $ \payloadHandle -> do
prevKeys <- readKeysFromCache cacheStem
newKeys <- MU.new maxNewKeys
(newKeysVar, mutex) <- atomically $ (,) <$> newTMVar 0 <*> newTMVar ()
let readValue :: Location -> IO value
readValue (start, len) = do
atomically $ takeTMVar mutex
hSeek payloadHandle AbsoluteSeek (toInteger start)
val <- BS.hGet payloadHandle (fromIntegral len)
atomically $ putTMVar mutex ()
evaluate $ either error force $ runGet get val
writeValue :: key -> value -> IO ()
writeValue k v = do
newPayload <- evaluate $ runPut (put v)
n <- atomically $ takeTMVar mutex >> takeTMVar newKeysVar
hSeek payloadHandle SeekFromEnd 0
start <- hTell payloadHandle
BS.hPut payloadHandle newPayload
MU.write newKeys n (k, (fromInteger start, fromIntegral $ BS.length newPayload))
atomically $ putTMVar mutex () >> putTMVar newKeysVar (succ n)
result <- inner prevKeys readValue writeValue
printTime logHandle "Combining keys "
numNewKeys <- atomically $ takeTMVar newKeysVar
let endPrevKeys = U.length prevKeys
joinedKeys <- flip MU.unsafeGrow numNewKeys =<< U.unsafeThaw prevKeys
mapM_ (\n -> MU.read newKeys n >>= MU.write joinedKeys (n + endPrevKeys)) [0 .. numNewKeys 1]
printTime logHandle "Sorting keys "
MU.sort joinedKeys
frozenJoinedKeys <- U.unsafeFreeze joinedKeys
printTime logHandle "Writing index "
withFile (indexFile cacheStem) WriteMode $ \indexHandle -> do
BS.hPut indexHandle $ runPut $ put (fromIntegral (U.length frozenJoinedKeys) :: Int64)
U.mapM_ (BS.hPut indexHandle . runPut . (\(k, l) -> put k >> put (fst l) >> put (snd l))) frozenJoinedKeys
return result
printTime :: Handle -> String -> IO ()
printTime h msg = getCurrentTime >>= hPutStrLnFlush h . (msg ++) . show
binarySearch :: (Ord key, MU.Unbox key, MU.Unbox v) => key -> U.Vector (key, v) -> Maybe v
binarySearch tgt v = go 0 (U.length v 1)
where
go imin imax
| imax < imin = Nothing
| otherwise = let imid = (imin + imax) `div` 2
(k, x) = v U.! imid
in case compare k tgt of
GT -> go imin (imid 1)
LT -> go (imid + 1) imax
EQ -> Just x
parMapCache :: forall input output key m. (MonadIO m, Ord key, Show key, MU.Unbox key, NFData output, Serialize key, Serialize output) =>
ParTaskOpts m output
-> FilePath
-> (input -> key)
-> (input -> m output)
-> [input]
-> m (MV.IOVector output)
parMapCache opts dir getKey process inputs
= do vOutcome <- liftIO $ MV.new (length inputs)
logFile <- liftIO $ openFile (dir ++ "/parmap-log") WriteMode
let fullOpts = (ExtendedParTaskOpts opts
($ logFile)
(\t n outcome -> case outcome of
Success -> do (_, x) <- MV.read vOutcome n
MV.write vOutcome n (t, x)
return Nothing
TookTooLong -> do let key = getKey $ inputs !! n
MV.write vOutcome n (t, CacheMissTookTooLong key)
return $ Just [i|*** Killed task with key ${show key} for taking too long|]
))
run <- wrapWorker opts
results <- liftIO $ withCache (dir ++ "/cache") logFile (length inputs) $
\cachedKeys readValue saveResult -> run $
parallelTasks fullOpts (zipWith (processWithCache vOutcome cachedKeys readValue saveResult) [0..] inputs)
liftIO $ do
outcomes <- V.unsafeFreeze vOutcome
let hits = fstFilter isCacheHit outcomes
missSuccesses = fstFilter isCacheMissSuccess outcomes
hPutStrLn logFile [i|Complete; hits: ${V.length hits}, misses: ${V.length missSuccesses}, timed out: ${V.length outcomes V.length hits V.length missSuccesses}|]
hPutStrLn logFile [i|Average cache hit time: ${average hits}|]
hPutStrLn logFile [i|Average successful task (cache miss) time: ${average missSuccesses}|]
hPutStrLn logFile [i|Median successful task (cache miss) time: ${median missSuccesses}|]
hPutStrLn logFile [i|Longest successful task (cache miss) time: ${maximumV missSuccesses}|]
hPutStrLn logFile "Details of killed tasks:"
sequence_ [hPutStrLn logFile [i| Killed task, key: ${show k}|] | (_, CacheMissTookTooLong k) <- V.toList outcomes]
hClose logFile
return results
where
processWithCache :: MV.IOVector (Double, CacheOutcome key) -> U.Vector (key, Location) -> ((Int64, Int64) -> IO output) -> (key -> output -> IO ()) -> Int -> input -> m output
processWithCache vOutcome cachedKeys readValue saveResult n x = case binarySearch theKey cachedKeys of
Just resultLoc -> liftIO $ do MV.write vOutcome n (0, CacheHit)
readValue resultLoc
Nothing -> do result <- process x
liftIO $ MV.write vOutcome n (0, CacheMissSuccess)
liftIO $ saveResult theKey result
return result
where
theKey = getKey x
average :: V.Vector Double -> Double
average xs = V.foldr (+) 0 xs / fromIntegral (V.length xs)
maximumV :: V.Vector Double -> Double
maximumV = V.foldr max 0
median :: V.Vector Double -> Double
median = median' . (\v -> runST (stSort v))
where
stSort :: V.Vector Double -> (forall s. ST s (V.Vector Double))
stSort orig = do
copy <- V.thaw orig
MU.sort copy
V.unsafeFreeze copy
median' :: V.Vector Double -> Double
median' v
| V.null v = 1 / 0
| V.length v `mod` 2 == 1 = v V.! (V.length v `div` 2)
| otherwise = ((v V.! (V.length v `div` 2)) + (v V.! ((V.length v `div` 2) + 1))) / 2
fstFilter :: (b -> Bool) -> V.Vector (a, b) -> V.Vector a
fstFilter f v = V.unfoldr build 0
where
build n
| n >= V.length v = Nothing
| f (snd x) = Just (fst x, succ n)
| otherwise = build (succ n)
where x = v V.! n
hPutStrLnFlush :: MonadIO m => Handle -> String -> m ()
hPutStrLnFlush h s = liftIO $ hPutStrLn h s >> hFlush h