{-# LANGUAGE TupleSections #-}
{-# LANGUAGE LambdaCase #-}
module Kafka.Producer
( KafkaProducer
, module X
, runProducer
, newProducer
, produceMessage, produceMessageBatch
, produceMessage'
, flushProducer
, closeProducer
, RdKafkaRespErrT (..)
)
where
import Control.Arrow ((&&&))
import Control.Exception (bracket)
import Control.Monad (forM, forM_, (<=<))
import Control.Monad.IO.Class (MonadIO (liftIO))
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import Data.Function (on)
import Data.List (groupBy, sortBy)
import Data.Ord (comparing)
import qualified Data.Text as Text
import Foreign.ForeignPtr (newForeignPtr_, withForeignPtr)
import Foreign.Marshal.Array (withArrayLen)
import Foreign.Ptr (Ptr, nullPtr, plusPtr)
import Foreign.Storable (Storable (..))
import Foreign.StablePtr (newStablePtr, castStablePtrToPtr)
import Kafka.Internal.RdKafka (RdKafkaMessageT (..), RdKafkaRespErrT (..), RdKafkaTypeT (..), destroyUnmanagedRdKafkaTopic, newRdKafkaT, newUnmanagedRdKafkaTopicT, rdKafkaOutqLen, rdKafkaProduce, rdKafkaProduceBatch, rdKafkaSetLogLevel)
import Kafka.Internal.Setup (Kafka (..), KafkaConf (..), KafkaProps (..), TopicConf (..), TopicProps (..), kafkaConf, topicConf)
import Kafka.Internal.Shared (pollEvents)
import Kafka.Producer.Convert (copyMsgFlags, handleProduceErr', producePartitionCInt, producePartitionInt)
import Kafka.Producer.Types (KafkaProducer (..), ImmediateError(..))
import Kafka.Producer.ProducerProperties as X
import Kafka.Producer.Types as X hiding (KafkaProducer)
import Kafka.Types as X
{-# DEPRECATED runProducer "Use 'newProducer'/'closeProducer' instead" #-}
runProducer :: ProducerProperties
-> (KafkaProducer -> IO (Either KafkaError a))
-> IO (Either KafkaError a)
runProducer props f =
bracket mkProducer clProducer runHandler
where
mkProducer = newProducer props
clProducer (Left _) = return ()
clProducer (Right prod) = closeProducer prod
runHandler (Left err) = return $ Left err
runHandler (Right prod) = f prod
newProducer :: MonadIO m => ProducerProperties -> m (Either KafkaError KafkaProducer)
newProducer pps = liftIO $ do
kc@(KafkaConf kc' _ _) <- kafkaConf (KafkaProps $ (ppKafkaProps pps))
tc <- topicConf (TopicProps $ (ppTopicProps pps))
deliveryCallback (const mempty) kc
forM_ (ppCallbacks pps) (\setCb -> setCb kc)
mbKafka <- newRdKafkaT RdKafkaProducer kc'
case mbKafka of
Left err -> return . Left $ KafkaError err
Right kafka -> do
forM_ (ppLogLevel pps) (rdKafkaSetLogLevel kafka . fromEnum)
let prod = KafkaProducer (Kafka kafka) kc tc
return (Right prod)
produceMessage :: MonadIO m
=> KafkaProducer
-> ProducerRecord
-> m (Maybe KafkaError)
produceMessage kp m = produceMessage' kp m (pure . mempty) >>= adjustRes
where
adjustRes = \case
Right () -> pure Nothing
Left (ImmediateError err) -> pure (Just err)
produceMessage' :: MonadIO m
=> KafkaProducer
-> ProducerRecord
-> (DeliveryReport -> IO ())
-> m (Either ImmediateError ())
produceMessage' kp@(KafkaProducer (Kafka k) _ (TopicConf tc)) msg cb = liftIO $
fireCallbacks >> bracket (mkTopic . prTopic $ msg) closeTopic withTopic
where
fireCallbacks =
pollEvents kp . Just . Timeout $ 0
mkTopic (TopicName tn) =
newUnmanagedRdKafkaTopicT k (Text.unpack tn) (Just tc)
closeTopic = either mempty destroyUnmanagedRdKafkaTopic
withTopic (Left err) = return . Left . ImmediateError . KafkaError . Text.pack $ err
withTopic (Right topic) =
withBS (prValue msg) $ \payloadPtr payloadLength ->
withBS (prKey msg) $ \keyPtr keyLength -> do
callbackPtr <- newStablePtr cb
res <- handleProduceErr' =<< rdKafkaProduce
topic
(producePartitionCInt (prPartition msg))
copyMsgFlags
payloadPtr
(fromIntegral payloadLength)
keyPtr
(fromIntegral keyLength)
(castStablePtrToPtr callbackPtr)
pure $ case res of
Left err -> Left . ImmediateError $ err
Right () -> Right ()
produceMessageBatch :: MonadIO m
=> KafkaProducer
-> [ProducerRecord]
-> m [(ProducerRecord, KafkaError)]
produceMessageBatch kp@(KafkaProducer (Kafka k) _ (TopicConf tc)) messages = liftIO $ do
pollEvents kp (Just $ Timeout 0)
concat <$> forM (mkBatches messages) sendBatch
where
mkSortKey = prTopic &&& prPartition
mkBatches = groupBy ((==) `on` mkSortKey) . sortBy (comparing mkSortKey)
mkTopic (TopicName tn) = newUnmanagedRdKafkaTopicT k (Text.unpack tn) (Just tc)
clTopic = either (return . const ()) destroyUnmanagedRdKafkaTopic
sendBatch [] = return []
sendBatch batch = bracket (mkTopic $ prTopic (head batch)) clTopic (withTopic batch)
withTopic ms (Left err) = return $ (, KafkaError (Text.pack err)) <$> ms
withTopic ms (Right t) = do
let (partInt, partCInt) = (producePartitionInt &&& producePartitionCInt) $ prPartition (head ms)
withForeignPtr t $ \topicPtr -> do
nativeMs <- forM ms (toNativeMessage topicPtr partInt)
withArrayLen nativeMs $ \len batchPtr -> do
batchPtrF <- newForeignPtr_ batchPtr
numRet <- rdKafkaProduceBatch t partCInt copyMsgFlags batchPtrF len
if numRet == len then return []
else do
errs <- mapM (return . err'RdKafkaMessageT <=< peekElemOff batchPtr)
[0..(fromIntegral $ len - 1)]
return [(m, KafkaResponseError e) | (m, e) <- zip messages errs, e /= RdKafkaRespErrNoError]
toNativeMessage t p m =
withBS (prValue m) $ \payloadPtr payloadLength ->
withBS (prKey m) $ \keyPtr keyLength ->
return RdKafkaMessageT
{ err'RdKafkaMessageT = RdKafkaRespErrNoError
, topic'RdKafkaMessageT = t
, partition'RdKafkaMessageT = p
, len'RdKafkaMessageT = payloadLength
, payload'RdKafkaMessageT = payloadPtr
, offset'RdKafkaMessageT = 0
, keyLen'RdKafkaMessageT = keyLength
, key'RdKafkaMessageT = keyPtr
, opaque'RdKafkaMessageT = nullPtr
}
closeProducer :: MonadIO m => KafkaProducer -> m ()
closeProducer = flushProducer
flushProducer :: MonadIO m => KafkaProducer -> m ()
flushProducer kp = liftIO $ do
pollEvents kp (Just $ Timeout 100)
l <- outboundQueueLength (kpKafkaPtr kp)
if (l == 0)
then pollEvents kp (Just $ Timeout 0)
else flushProducer kp
withBS :: Maybe BS.ByteString -> (Ptr a -> Int -> IO b) -> IO b
withBS Nothing f = f nullPtr 0
withBS (Just bs) f =
let (d, o, l) = BSI.toForeignPtr bs
in withForeignPtr d $ \p -> f (p `plusPtr` o) l
outboundQueueLength :: Kafka -> IO Int
outboundQueueLength (Kafka k) = rdKafkaOutqLen k