module Data.KeyedPool
( KeyedPool
, createKeyedPool
, takeKeyedPool
, Managed
, managedResource
, managedReused
, managedRelease
, Reuse (..)
, dummyManaged
) where
import Control.Concurrent (forkIOWithUnmask, threadDelay)
import Control.Concurrent.STM
import Control.Exception (mask_, catch, SomeException)
import Control.Monad (join, unless, void)
import Data.Map (Map)
import Data.Maybe (isJust)
import qualified Data.Map.Strict as Map
import Data.Time (UTCTime, getCurrentTime, addUTCTime)
import Data.IORef (IORef, newIORef, mkWeakIORef)
import qualified Data.Foldable as F
import GHC.Conc (unsafeIOToSTM)
import System.IO.Unsafe (unsafePerformIO)
data KeyedPool key resource = KeyedPool
{ kpCreate :: !(key -> IO resource)
, kpDestroy :: !(resource -> IO ())
, kpMaxPerKey :: !Int
, kpMaxTotal :: !Int
, kpVar :: !(TVar (PoolMap key resource))
, kpAlive :: !(IORef ())
}
data PoolMap key resource
= PoolClosed
| PoolOpen
!Int
!(Map key (PoolList resource))
deriving F.Foldable
data PoolList a
= One a !UTCTime
| Cons
a
!Int
!UTCTime
!(PoolList a)
deriving F.Foldable
plistToList :: PoolList a -> [(UTCTime, a)]
plistToList (One a t) = [(t, a)]
plistToList (Cons a _ t plist) = (t, a) : plistToList plist
plistFromList :: [(UTCTime, a)] -> Maybe (PoolList a)
plistFromList [] = Nothing
plistFromList [(t, a)] = Just (One a t)
plistFromList xs =
Just . snd . go $ xs
where
go [] = error "plistFromList.go []"
go [(t, a)] = (2, One a t)
go ((t, a):rest) =
let (i, rest') = go rest
i' = i + 1
in i' `seq` (i', Cons a i t rest')
createKeyedPool
:: Ord key
=> (key -> IO resource)
-> (resource -> IO ())
-> Int
-> Int
-> (SomeException -> IO ())
-> IO (KeyedPool key resource)
createKeyedPool create destroy maxPerKey maxTotal onReaperException = do
var <- newTVarIO $ PoolOpen 0 Map.empty
alive <- newIORef ()
void $ mkWeakIORef alive $ destroyKeyedPool' destroy var
_ <- forkIOWithUnmask $ \restore -> keepRunning $ restore $ reap destroy var
return KeyedPool
{ kpCreate = create
, kpDestroy = destroy
, kpMaxPerKey = maxPerKey
, kpMaxTotal = maxTotal
, kpVar = var
, kpAlive = alive
}
where
keepRunning action =
loop
where
loop = action `catch` \e -> onReaperException e >> loop
destroyKeyedPool' :: (resource -> IO ())
-> TVar (PoolMap key resource)
-> IO ()
destroyKeyedPool' destroy var = do
m <- atomically $ swapTVar var PoolClosed
F.mapM_ (ignoreExceptions . destroy) m
reap :: forall key resource.
Ord key
=> (resource -> IO ())
-> TVar (PoolMap key resource)
-> IO ()
reap destroy var =
loop
where
loop = do
threadDelay (5 * 1000 * 1000)
join $ atomically $ do
m'' <- readTVar var
case m'' of
PoolClosed -> return (return ())
PoolOpen idleCount m
| Map.null m -> retry
| otherwise -> do
(m', toDestroy) <- findStale idleCount m
writeTVar var m'
return $ do
mask_ (mapM_ (ignoreExceptions . destroy) toDestroy)
loop
findStale :: Int
-> Map key (PoolList resource)
-> STM (PoolMap key resource, [resource])
findStale idleCount m = do
now <- unsafeIOToSTM getCurrentTime
let isNotStale time = 30 `addUTCTime` time >= now
let findStale' toKeep toDestroy [] =
(Map.fromList (toKeep []), toDestroy [])
findStale' toKeep toDestroy ((key, plist):rest) =
findStale' toKeep' toDestroy' rest
where
(notStale, stale) = span (isNotStale . fst) $ plistToList plist
toDestroy' = toDestroy . (map snd stale++)
toKeep' =
case plistFromList notStale of
Nothing -> toKeep
Just x -> toKeep . ((key, x):)
let (toKeep, toDestroy) = findStale' id id (Map.toList m)
let idleCount' = idleCount length toDestroy
return (PoolOpen idleCount' toKeep, toDestroy)
takeKeyedPool :: Ord key => KeyedPool key resource -> key -> IO (Managed resource)
takeKeyedPool kp key = mask_ $ join $ atomically $ do
(m, mresource) <- fmap go $ readTVar (kpVar kp)
writeTVar (kpVar kp) $! m
return $ do
resource <- maybe (kpCreate kp key) return mresource
alive <- newIORef ()
isReleasedVar <- newTVarIO False
let release action = mask_ $ do
isReleased <- atomically $ swapTVar isReleasedVar True
unless isReleased $
case action of
Reuse -> putResource kp key resource
DontReuse -> ignoreExceptions $ kpDestroy kp resource
_ <- mkWeakIORef alive $ release DontReuse
return Managed
{ _managedResource = resource
, _managedReused = isJust mresource
, _managedRelease = release
, _managedAlive = alive
}
where
go PoolClosed = (PoolClosed, Nothing)
go pcOrig@(PoolOpen idleCount m) =
case Map.lookup key m of
Nothing -> (pcOrig, Nothing)
Just (One a _) ->
(PoolOpen (idleCount 1) (Map.delete key m), Just a)
Just (Cons a _ _ rest) ->
(PoolOpen (idleCount 1) (Map.insert key rest m), Just a)
putResource :: Ord key => KeyedPool key resource -> key -> resource -> IO ()
putResource kp key resource = do
now <- getCurrentTime
join $ atomically $ do
(m, action) <- fmap (go now) (readTVar (kpVar kp))
writeTVar (kpVar kp) $! m
return action
where
go _ PoolClosed = (PoolClosed, kpDestroy kp resource)
go now pc@(PoolOpen idleCount m)
| idleCount >= kpMaxTotal kp = (pc, kpDestroy kp resource)
| otherwise = case Map.lookup key m of
Nothing ->
let cnt' = idleCount + 1
m' = PoolOpen cnt' (Map.insert key (One resource now) m)
in (m', return ())
Just l ->
let (l', mx) = addToList now (kpMaxPerKey kp) resource l
cnt' = idleCount + maybe 1 (const 0) mx
m' = PoolOpen cnt' (Map.insert key l' m)
in (m', maybe (return ()) (kpDestroy kp) mx)
addToList :: UTCTime -> Int -> a -> PoolList a -> (PoolList a, Maybe a)
addToList _ i x l | i <= 1 = (l, Just x)
addToList now _ x l@One{} = (Cons x 2 now l, Nothing)
addToList now maxCount x l@(Cons _ currCount _ _)
| maxCount > currCount = (Cons x (currCount + 1) now l, Nothing)
| otherwise = (l, Just x)
data Managed resource = Managed
{ _managedResource :: !resource
, _managedReused :: !Bool
, _managedRelease :: !(Reuse -> IO ())
, _managedAlive :: !(IORef ())
}
managedResource :: Managed resource -> resource
managedResource = _managedResource
managedReused :: Managed resource -> Bool
managedReused = _managedReused
managedRelease :: Managed resource -> Reuse -> IO ()
managedRelease = _managedRelease
data Reuse = Reuse | DontReuse
dummyManaged :: resource -> Managed resource
dummyManaged resource = Managed
{ _managedResource = resource
, _managedReused = False
, _managedRelease = const (return ())
, _managedAlive = unsafePerformIO (newIORef ())
}
ignoreExceptions :: IO () -> IO ()
ignoreExceptions f = f `catch` \(_ :: SomeException) -> return ()