{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DeriveGeneric #-}
module Network.PushNotify.APN
( newSession
, newMessage
, newMessageWithCustomPayload
, hexEncodedToken
, rawToken
, sendMessage
, sendSilentMessage
, sendRawMessage
, alertMessage
, emptyMessage
, setAlertMessage
, setBadge
, setCategory
, setSound
, clearAlertMessage
, clearBadge
, clearCategory
, clearSound
, closeSession
, isOpen
, ApnSession
, JsonAps
, JsonApsAlert
, JsonApsMessage
, ApnMessageResult(..)
, ApnToken
) where
import Control.Concurrent
import Control.Concurrent.QSem
import Control.Exception
import Control.Monad
import Data.Aeson
import Data.Aeson.Types
import Data.ByteString (ByteString)
import Data.Char (toLower)
import Data.Default (def)
import Data.Either
import Data.Int
import Data.IORef
import Data.Map.Strict (Map)
import Data.Maybe
import Data.Text (Text)
import Data.Time.Clock.POSIX
import Data.X509
import Data.X509.CertificateStore
import GHC.Generics
import Network.HTTP2.Client
import Network.HTTP2.Client.FrameConnection
import Network.HTTP2.Client.Helpers
import Network.TLS hiding (sendData)
import Network.TLS.Extra.Cipher
import System.IO.Error
import System.Mem.Weak
import System.Random
import qualified Data.ByteString as S
import qualified Data.ByteString.Base16 as B16
import qualified Data.ByteString.Lazy as L
import qualified Data.List as DL
import qualified Data.Map.Strict as M
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Network.HTTP2 as HTTP2
import qualified Network.HPACK as HTTP2
data ApnSession = ApnSession
{ apnSessionPool :: !(IORef [ApnConnection])
, apnSessionConnectionInfo :: !ApnConnectionInfo
, apnSessionConnectionManager :: !ThreadId
, apnSessionOpen :: !(IORef Bool)}
data ApnConnectionInfo = ApnConnectionInfo
{ aciCertPath :: !FilePath
, aciCertKey :: !FilePath
, aciCaPath :: !FilePath
, aciHostname :: !Text
, aciMaxConcurrentStreams :: !Int
, aciTopic :: !ByteString }
data ApnConnection = ApnConnection
{ apnConnectionConnection :: !Http2Client
, apnConnectionInfo :: !ApnConnectionInfo
, apnConnectionWorkerPool :: !QSem
, apnConnectionLastUsed :: !Int64
, apnConnectionFlowControlWorker :: !ThreadId
, apnConnectionOpen :: !(IORef Bool)}
newtype ApnToken = ApnToken { unApnToken :: ByteString }
class SpecifyError a where
isAnError :: a
rawToken
:: ByteString
-> ApnToken
rawToken = ApnToken . B16.encode
hexEncodedToken
:: Text
-> ApnToken
hexEncodedToken = ApnToken . B16.encode . fst . B16.decode . TE.encodeUtf8
data ApnMessageResult = ApnMessageResultOk
| ApnMessageResultFatalError
| ApnMessageResultTemporaryError
| ApnMessageResultTokenNoLongerValid
deriving (Enum, Eq, Show)
instance SpecifyError ApnMessageResult where
isAnError = ApnMessageResultTemporaryError
data JsonApsAlert = JsonApsAlert
{ jaaTitle :: !Text
, jaaBody :: !Text
} deriving (Generic, Show)
instance ToJSON JsonApsAlert where
toJSON = genericToJSON defaultOptions
{ fieldLabelModifier = drop 3 . map toLower }
data JsonApsMessage
= JsonApsMessage
{ jamAlert :: !(Maybe JsonApsAlert)
, jamBadge :: !(Maybe Int)
, jamSound :: !(Maybe Text)
, jamCategory :: !(Maybe Text)
} deriving (Generic, Show)
emptyMessage :: JsonApsMessage
emptyMessage = JsonApsMessage Nothing Nothing Nothing Nothing
setSound
:: Text
-> JsonApsMessage
-> JsonApsMessage
setSound s a = a { jamSound = Just s }
clearSound
:: JsonApsMessage
-> JsonApsMessage
clearSound a = a { jamSound = Nothing }
setCategory
:: Text
-> JsonApsMessage
-> JsonApsMessage
setCategory c a = a { jamCategory = Just c }
clearCategory
:: JsonApsMessage
-> JsonApsMessage
clearCategory a = a { jamCategory = Nothing }
setBadge
:: Int
-> JsonApsMessage
-> JsonApsMessage
setBadge i a = a { jamBadge = Just i }
clearBadge
:: JsonApsMessage
-> JsonApsMessage
clearBadge a = a { jamBadge = Nothing }
alertMessage
:: Text
-> Text
-> JsonApsMessage
alertMessage title text = setAlertMessage title text emptyMessage
setAlertMessage
:: Text
-> Text
-> JsonApsMessage
-> JsonApsMessage
setAlertMessage title text a = a { jamAlert = Just jam }
where
jam = JsonApsAlert title text
clearAlertMessage
:: JsonApsMessage
-> JsonApsMessage
clearAlertMessage a = a { jamAlert = Nothing }
instance ToJSON JsonApsMessage where
toJSON = genericToJSON defaultOptions
{ fieldLabelModifier = drop 3 . map toLower }
data JsonAps
= JsonAps
{ jaAps :: !JsonApsMessage
, jaAppSpecificContent :: !(Maybe Text)
} deriving (Generic, Show)
instance ToJSON JsonAps where
toJSON = genericToJSON defaultOptions
{ fieldLabelModifier = drop 2 . map toLower }
newMessage
:: JsonApsMessage
-> JsonAps
newMessage = flip JsonAps Nothing
newMessageWithCustomPayload
:: JsonApsMessage
-> Text
-> JsonAps
newMessageWithCustomPayload message payload =
JsonAps message (Just payload)
newSession
:: FilePath
-> FilePath
-> FilePath
-> Bool
-> Int
-> ByteString
-> IO ApnSession
newSession certKey certPath caPath dev maxparallel topic = do
let hostname = if dev
then "api.development.push.apple.com"
else "api.push.apple.com"
connInfo = ApnConnectionInfo certPath certKey caPath hostname maxparallel topic
certsOk <- checkCertificates connInfo
unless certsOk $ error "Unable to load certificates and/or the private key"
connections <- newIORef []
connectionManager <- forkIO $ manage 600 connections
isOpen <- newIORef True
let session = ApnSession connections connInfo connectionManager isOpen
addFinalizer session $
closeSession session
return session
closeSession :: ApnSession -> IO ()
closeSession s = do
isOpen <- atomicModifyIORef' (apnSessionOpen s) (\a -> (False, a))
unless isOpen $ error "Session is already closed"
killThread (apnSessionConnectionManager s)
let ioref = apnSessionPool s
openConnections <- atomicModifyIORef' ioref (\conns -> ([], conns))
mapM_ closeApnConnection openConnections
isOpen :: ApnSession -> IO Bool
isOpen = readIORef . apnSessionOpen
withConnection :: ApnSession -> (ApnConnection -> IO a) -> IO a
withConnection s action = do
ensureOpen s
let pool = apnSessionPool s
connections <- readIORef pool
let len = length connections
if len == 0
then do
conn <- newConnection s
res <- action conn
atomicModifyIORef' pool (\a -> (conn:a, ()))
return res
else do
num <- randomRIO (0, len - 1)
currtime <- round <$> getPOSIXTime :: IO Int64
let conn = connections !! num
conn1 = conn { apnConnectionLastUsed=currtime }
atomicModifyIORef' pool (\a -> (removeNth num a, ()))
isOpen <- readIORef (apnConnectionOpen conn)
if isOpen
then do
res <- action conn1
atomicModifyIORef' pool (\a -> (conn1:a, ()))
return res
else withConnection s action
checkCertificates :: ApnConnectionInfo -> IO Bool
checkCertificates aci = do
castore <- readCertificateStore $ aciCaPath aci
credential <- credentialLoadX509 (aciCertPath aci) (aciCertKey aci)
return $ isJust castore && isRight credential
replaceNth n newVal (x:xs)
| n == 0 = newVal:xs
| otherwise = x:replaceNth (n-1) newVal xs
removeNth n (x:xs)
| n == 0 = xs
| otherwise = x:removeNth (n-1) xs
manage :: Int64 -> IORef [ApnConnection] -> IO ()
manage timeout ioref = forever $ do
currtime <- round <$> getPOSIXTime :: IO Int64
let minTime = currtime - timeout
expiredOnes <- atomicModifyIORef' ioref
(foldl ( \(a,b) i -> if apnConnectionLastUsed i < minTime then (a, i:b ) else ( i:a ,b)) ([],[]))
mapM_ closeApnConnection expiredOnes
threadDelay 60000000
newConnection :: ApnSession -> IO ApnConnection
newConnection apnSession = do
let aci = apnSessionConnectionInfo apnSession
Just castore <- readCertificateStore $ aciCaPath aci
Right credential <- credentialLoadX509 (aciCertPath aci) (aciCertKey aci)
let credentials = Credentials [credential]
shared = def { sharedCredentials = credentials
, sharedCAStore=castore }
maxConcurrentStreams = aciMaxConcurrentStreams aci
clip = ClientParams
{ clientUseMaxFragmentLength=Nothing
, clientServerIdentification=(T.unpack hostname, undefined)
, clientUseServerNameIndication=True
, clientWantSessionResume=Nothing
, clientShared=shared
, clientHooks=def
{ onCertificateRequest=const . return . Just $ credential }
, clientDebug=DebugParams { debugSeed=Nothing, debugPrintSeed=const $ return () }
, clientSupported=def
{ supportedVersions=[ TLS12 ]
, supportedCiphers=ciphersuite_strong }
}
conf = [ (HTTP2.SettingsMaxFrameSize, 16384)
, (HTTP2.SettingsMaxConcurrentStreams, maxConcurrentStreams)
, (HTTP2.SettingsMaxHeaderBlockSize, 4096)
, (HTTP2.SettingsInitialWindowSize, 65536)
, (HTTP2.SettingsEnablePush, 1)
]
hostname = aciHostname aci
httpFrameConnection <- newHttp2FrameConnection (T.unpack hostname) 443 (Just clip)
isOpen <- newIORef True
let handleGoAway rsgaf = do
writeIORef isOpen False
return ()
client <- newHttp2Client httpFrameConnection 4096 4096 conf handleGoAway ignoreFallbackHandler
linkAsyncs client
flowWorker <- forkIO $ forever $ do
updated <- _updateWindow $ _incomingFlowControl client
threadDelay 1000000
workersem <- newQSem maxConcurrentStreams
currtime <- round <$> getPOSIXTime :: IO Int64
return $ ApnConnection client aci workersem currtime flowWorker isOpen
closeApnConnection :: ApnConnection -> IO ()
closeApnConnection connection = do
writeIORef (apnConnectionOpen connection) False
let flowWorker = apnConnectionFlowControlWorker connection
killThread flowWorker
_gtfo (apnConnectionConnection connection) HTTP2.NoError ""
_close (apnConnectionConnection connection)
sendRawMessage
:: ApnSession
-> ApnToken
-> ByteString
-> IO ApnMessageResult
sendRawMessage s token payload = catchIOErrors $
withConnection s $ \c ->
sendApnRaw c token payload
sendMessage
:: ApnSession
-> ApnToken
-> JsonAps
-> IO ApnMessageResult
sendMessage s token payload = catchIOErrors $
withConnection s $ \c ->
sendApnRaw c token message
where message = L.toStrict $ encode payload
sendSilentMessage
:: ApnSession
-> ApnToken
-> IO ApnMessageResult
sendSilentMessage s token = catchIOErrors $
withConnection s $ \c ->
sendApnRaw c token message
where message = "{\"aps\":{\"content-available\":1}}"
ensureOpen :: ApnSession -> IO ()
ensureOpen s = do
open <- isOpen s
unless open $ error "Session is closed"
sendApnRaw
:: ApnConnection
-> ApnToken
-> ByteString
-> IO ApnMessageResult
sendApnRaw connection token message = bracket_
(waitQSem (apnConnectionWorkerPool connection))
(signalQSem (apnConnectionWorkerPool connection)) $ do
let requestHeaders = [ ( ":method", "POST" )
, ( ":scheme", "https" )
, ( ":authority", TE.encodeUtf8 hostname )
, ( ":path", "/3/device/" `S.append` token1 )
, ( "apns-topic", topic ) ]
aci = apnConnectionInfo connection
hostname = aciHostname aci
topic = aciTopic aci
client = apnConnectionConnection connection
token1 = unApnToken token
res <- _startStream client $ \stream ->
let init = headers stream requestHeaders id
handler isfc osfc = do
upload message (HTTP2.setEndHeader . HTTP2.setEndStream) client (_outgoingFlowControl client) stream osfc
let pph hStreamId hStream hHeaders hIfc hOfc =
print hHeaders
response <- waitStream stream isfc pph
let (errOrHeaders, _, _) = response
case errOrHeaders of
Left err -> return ApnMessageResultTemporaryError
Right hdrs1 -> do
let Just status = DL.lookup ":status" hdrs1
return $ case status of
"200" -> ApnMessageResultOk
"400" -> ApnMessageResultFatalError
"403" -> ApnMessageResultFatalError
"405" -> ApnMessageResultFatalError
"410" -> ApnMessageResultTokenNoLongerValid
"413" -> ApnMessageResultFatalError
"429" -> ApnMessageResultTemporaryError
"500" -> ApnMessageResultTemporaryError
"503" -> ApnMessageResultTemporaryError
in StreamDefinition init handler
case res of
Left _ -> return ApnMessageResultTemporaryError
Right res1 -> return res1
catchIOErrors :: SpecifyError a => IO a -> IO a
catchIOErrors = flip catchIOError (const $ return isAnError)