{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
module Database.CQL.IO.Pool
( Pool
, create
, destroy
, purge
, with
, PoolSettings
, defSettings
, idleTimeout
, maxConnections
, maxTimeouts
, poolStripes
) where
import Control.AutoUpdate
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.STM
import Control.Exception
import Control.Lens ((^.), makeLenses, view)
import Control.Monad.IO.Class
import Control.Monad
import Data.Foldable (forM_, mapM_, find)
import Data.Function (on)
import Data.Hashable
import Data.IORef
import Data.Sequence (Seq, ViewL (..), (|>), (><))
import Data.Semigroup ((<>))
import Data.Time.Clock (UTCTime, NominalDiffTime, getCurrentTime, diffUTCTime)
import Data.Vector (Vector, (!))
import Database.CQL.IO.Connection (Connection)
import Database.CQL.IO.Exception (ConnectionError (..), ignore)
import Database.CQL.IO.Log
import qualified Data.Sequence as Seq
import qualified Data.Vector as Vec
data PoolSettings = PoolSettings
{ _idleTimeout :: !NominalDiffTime
, _maxConnections :: !Int
, _maxTimeouts :: !Int
, _poolStripes :: !Int
}
data Pool = Pool
{ _createFn :: !(IO Connection)
, _destroyFn :: !(Connection -> IO ())
, _logger :: !Logger
, _settings :: !PoolSettings
, _maxRefs :: !Int
, _currentTime :: !(IO UTCTime)
, _stripes :: !(Vector Stripe)
, _finaliser :: !(IORef ())
}
data Resource = Resource
{ tstamp :: !UTCTime
, refcnt :: !Int
, timeouts :: !Int
, value :: !Connection
} deriving Show
data Box
= New !(IO Resource)
| Used !Resource
| Empty
data Stripe = Stripe
{ conns :: !(TVar (Seq Resource))
, inUse :: !(TVar Int)
}
makeLenses ''PoolSettings
makeLenses ''Pool
defSettings :: PoolSettings
defSettings = PoolSettings
60
2
16
4
create :: IO Connection -> (Connection -> IO ()) -> Logger -> PoolSettings -> Int -> IO Pool
create mk del g s k = do
p <- Pool mk del g s k
<$> mkAutoUpdate defaultUpdateSettings { updateAction = getCurrentTime }
<*> Vec.replicateM (s^.poolStripes) (Stripe <$> newTVarIO Seq.empty <*> newTVarIO 0)
<*> newIORef ()
r <- async $ reaper p
void $ mkWeakIORef (p^.finaliser) (cancel r >> destroy p)
return p
destroy :: Pool -> IO ()
destroy = purge
with :: MonadIO m => Pool -> (Connection -> IO a) -> m (Maybe a)
with p f = liftIO $ do
s <- stripe p
mask $ \restore -> do
r <- take1 p s
case r of
Just v -> do
x <- restore (f (value v)) `catch` cleanup p s v
put p s v id
return (Just x)
Nothing -> return Nothing
purge :: Pool -> IO ()
purge p = Vec.forM_ (p^.stripes) $ \s -> do
cs <- atomically (swapTVar (conns s) Seq.empty)
mapM_ (ignore . view destroyFn p . value) cs
cleanup :: Pool -> Stripe -> Resource -> SomeException -> IO a
cleanup p s r x = do
case fromException x of
Just (ResponseTimeout {}) -> onTimeout
_ -> destroyR p s r
throwIO x
where
onTimeout =
if timeouts r > p^.settings.maxTimeouts
then do
logInfo (p^.logger) $ string8 (show (value r)) <> ": Too many timeouts."
destroyR p s r
else put p s r incrTimeouts
take1 :: Pool -> Stripe -> IO (Maybe Resource)
take1 p s = do
r <- atomically $ do
c <- readTVar (conns s)
u <- readTVar (inUse s)
let n = Seq.length c
check (u == n)
let r :< rr = Seq.viewl $ Seq.unstableSortBy (compare `on` refcnt) c
if | u < p^.settings.maxConnections -> do
writeTVar (inUse s) $! u + 1
mkNew p
| n > 0 && refcnt r < p^.maxRefs -> use s r rr
| otherwise -> return Empty
case r of
New io -> do
x <- io `onException` atomically (modifyTVar' (inUse s) (subtract 1))
atomically (modifyTVar' (conns s) (|> x))
return (Just x)
Used x -> return (Just x)
Empty -> return Nothing
use :: Stripe -> Resource -> Seq Resource -> STM Box
use s r rr = do
writeTVar (conns s) $! rr |> r { refcnt = refcnt r + 1 }
return (Used r)
{-# INLINE use #-}
mkNew :: Pool -> STM Box
mkNew p = return (New $ Resource <$> p^.currentTime <*> pure 1 <*> pure 0 <*> p^.createFn)
{-# INLINE mkNew #-}
put :: Pool -> Stripe -> Resource -> (Resource -> Resource) -> IO ()
put p s r f = do
now <- p^.currentTime
let updated x = f x { tstamp = now, refcnt = refcnt x - 1 }
atomically $ do
rs <- readTVar (conns s)
let (xs, rr) = Seq.breakl ((value r ==) . value) rs
case Seq.viewl rr of
EmptyL -> writeTVar (conns s) $! xs |> updated r
y :< ys -> writeTVar (conns s) $! (xs >< ys) |> updated y
destroyR :: Pool -> Stripe -> Resource -> IO ()
destroyR p s r = do
atomically $ do
rs <- readTVar (conns s)
case find ((value r ==) . value) rs of
Nothing -> return ()
Just _ -> do
modifyTVar' (inUse s) (subtract 1)
writeTVar (conns s) $! Seq.filter ((value r /=) . value) rs
ignore $ p^.destroyFn $ value r
reaper :: Pool -> IO ()
reaper p = forever $ do
threadDelay 1000000
now <- p^.currentTime
let isStale r = refcnt r == 0 && now `diffUTCTime` tstamp r > p^.settings.idleTimeout
Vec.forM_ (p^.stripes) $ \s -> do
x <- atomically $ do
(stale, okay) <- Seq.partition isStale <$> readTVar (conns s)
unless (Seq.null stale) $ do
writeTVar (conns s) okay
modifyTVar' (inUse s) (subtract (Seq.length stale))
return stale
forM_ x $ \v -> ignore $ do
logDebug (p^.logger) $ "Reaping idle connection: " <> string8 (show (value v))
p^.destroyFn $ (value v)
stripe :: Pool -> IO Stripe
stripe p = ((p^.stripes) !) <$> ((`mod` (p^.settings.poolStripes)) . hash) <$> myThreadId
{-# INLINE stripe #-}
incrTimeouts :: Resource -> Resource
incrTimeouts r = r { timeouts = timeouts r + 1 }
{-# INLINE incrTimeouts #-}