{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Database.CQRS.ReadModel.AggregateStore
( AggregateStore
, makeAggregateStore
, Response(..)
) where
import Control.Monad.Trans (MonadIO(..))
import Data.Hashable (Hashable)
import qualified Control.Monad.Except as Exc
import qualified Control.Concurrent.STM as STM
import qualified Data.HashPSQ as HashPSQ
import qualified Data.Time as T
import qualified Database.CQRS as CQRS
data Response eventId aggregate = Response
{ lastEventId :: Maybe eventId
, aggregate :: aggregate
, eventCount :: Int
, totalEventCount :: Int
}
data AggregateStore streamFamily aggregate = AggregateStore
{ streamFamily :: streamFamily
, aggregator :: CQRS.Aggregator
(CQRS.EventWithContext'
(CQRS.StreamType streamFamily))
aggregate
, initialAggregate :: CQRS.StreamIdentifier streamFamily -> aggregate
, cache :: Cache
(CQRS.StreamIdentifier streamFamily)
(CQRS.EventIdentifier
(CQRS.StreamType streamFamily))
aggregate
, lagTolerance :: T.NominalDiffTime
}
makeAggregateStore
:: MonadIO m
=> streamFamily
-> CQRS.Aggregator
(CQRS.EventWithContext' (CQRS.StreamType streamFamily))
aggregate
-> (CQRS.StreamIdentifier streamFamily -> aggregate)
-> T.NominalDiffTime
-> Int
-> m (AggregateStore streamFamily aggregate)
makeAggregateStore streamFamily aggregator initialAggregate lagTolerance
maxSize = do
cache <- liftIO . STM.atomically $ do
cachedValues <- STM.newTVar HashPSQ.empty
size <- STM.newTVar 0
pure Cache{..}
pure AggregateStore{..}
instance
( CQRS.StreamFamily m streamFamily
, CQRS.Stream m (CQRS.StreamType streamFamily)
, Exc.MonadError CQRS.Error m
, Hashable (CQRS.StreamIdentifier streamFamily)
, MonadIO m
, Ord (CQRS.EventIdentifier (CQRS.StreamType streamFamily))
, Ord (CQRS.StreamIdentifier streamFamily)
, Show (CQRS.EventIdentifier (CQRS.StreamType streamFamily))
)
=> CQRS.ReadModel m (AggregateStore streamFamily aggregate) where
type ReadModelQuery (AggregateStore streamFamily aggregate) =
CQRS.StreamIdentifier streamFamily
type ReadModelResponse (AggregateStore streamFamily aggregate) =
Response (CQRS.EventIdentifier (CQRS.StreamType streamFamily)) aggregate
query = aggregateStoreQuery
aggregateStoreQuery
:: forall m streamFamily aggregate.
( CQRS.StreamFamily m streamFamily
, CQRS.Stream m (CQRS.StreamType streamFamily)
, Exc.MonadError CQRS.Error m
, Hashable (CQRS.StreamIdentifier streamFamily)
, MonadIO m
, Ord (CQRS.EventIdentifier (CQRS.StreamType streamFamily))
, Ord (CQRS.StreamIdentifier streamFamily)
, Show (CQRS.EventIdentifier (CQRS.StreamType streamFamily))
)
=> AggregateStore streamFamily aggregate
-> CQRS.StreamIdentifier streamFamily
-> m (Response
(CQRS.EventIdentifier (CQRS.StreamType streamFamily)) aggregate)
aggregateStoreQuery AggregateStore{..} streamId = do
hpsq <- liftIO . STM.atomically . STM.readTVar . cachedValues $ cache
now <- liftIO T.getCurrentTime
case HashPSQ.lookup streamId hpsq of
Just (lastUpToDateTime, item@CacheItem{..}) -> do
if now < T.addUTCTime lagTolerance lastUpToDateTime
then
pure Response
{ lastEventId = Just cachedLastEventId
, aggregate = cachedAggregate
, eventCount = 0
, totalEventCount = cachedEventCount
}
else getAggregate now (Just item)
Nothing -> getAggregate now Nothing
where
getAggregate
:: T.UTCTime
-> Maybe (CacheItem
(CQRS.EventIdentifier (CQRS.StreamType streamFamily))
aggregate)
-> m (Response
(CQRS.EventIdentifier (CQRS.StreamType streamFamily))
aggregate)
getAggregate now mPrevious = do
let (initAggregate, bounds, eventCount) = case mPrevious of
Just CacheItem{..} ->
( cachedAggregate
, CQRS.afterEvent cachedLastEventId
, cachedEventCount
)
Nothing -> (initialAggregate streamId, mempty, 0)
stream <- CQRS.getStream streamFamily streamId
(aggregate, mEventId, processedEventCount) <-
CQRS.runAggregator aggregator stream bounds initAggregate
let totalEventCount = eventCount + processedEventCount
mkCacheItem lastEventId = CacheItem
{ cachedAggregate = aggregate
, cachedLastEventId = lastEventId
, cachedEventCount = totalEventCount
}
lastEventId <- liftIO $ case (mEventId, mPrevious) of
(Nothing, Nothing) -> pure Nothing
(Nothing, Just CacheItem{..}) -> do
addValueToCache cache streamId now (mkCacheItem cachedLastEventId)
pure $ Just cachedLastEventId
(Just lastEventId, _) -> do
addValueToCache cache streamId now (mkCacheItem lastEventId)
pure $ Just lastEventId
pure Response
{ lastEventId
, aggregate
, eventCount = processedEventCount
, totalEventCount
}
data CacheItem eventId aggregate = CacheItem
{ cachedAggregate :: aggregate
, cachedLastEventId :: eventId
, cachedEventCount :: Int
}
data Cache streamId eventId aggregate = Cache
{ cachedValues
:: STM.TVar (HashPSQ.HashPSQ
streamId T.UTCTime (CacheItem eventId aggregate))
, size :: STM.TVar Int
, maxSize :: Int
}
addValueToCache
:: ( Hashable streamId
, Ord eventId
, Ord streamId
)
=> Cache streamId eventId aggregate
-> streamId
-> T.UTCTime
-> CacheItem eventId aggregate
-> IO ()
addValueToCache Cache{..} streamId now item =
STM.atomically $ do
hpsq <- STM.readTVar cachedValues
currentSize <- STM.readTVar size
let (newSize, hpsq') = (\f -> HashPSQ.alter f streamId hpsq) $ \case
Nothing -> (currentSize + 1, Just (now, item))
Just current@(_, currentItem)
| cachedLastEventId currentItem > cachedLastEventId item ->
(currentSize, Just current)
| otherwise -> (currentSize, Just (now, item))
(newSize', hpsq'')
| newSize > maxSize = (newSize - 1, HashPSQ.deleteMin hpsq')
| otherwise = (newSize, hpsq')
STM.writeTVar size newSize'
STM.writeTVar cachedValues hpsq''