{-# LANGUAGE OverloadedStrings, BangPatterns #-} module Network.Wai.Handler.WarpTLS.UserId ( CertFile, KeyFile, GroupName, UserName, runTlsWithGroupUserName, runTlsWithGroupUserNameClientCert ) where import Control.Arrow ((***), (&&&)) import Control.Exception (bracket) import Data.Semigroup ((<>)) import Data.List (unfoldr) import Data.Default import Data.Streaming.Network (bindPortTCP) import Data.X509 import System.Posix ( groupID, getGroupEntryForName, setGroupID, userID, getUserEntryForName, setUserID ) import Network.Socket (Socket, withSocketsDo, close) import Network.TLS import Network.Wai (Application) import Network.Wai.Handler.Warp ( Settings, HostPreference, getPort, getHost ) import Network.Wai.Handler.WarpTLS ( runTLSSocket, tlsSettingsChainMemory, TLSSettings(..)) import qualified Data.ByteString as BS type CertFile = FilePath type KeyFile = FilePath type GroupName = String type UserName = String runTlsWithGroupUserName :: (CertFile, KeyFile) -> (GroupName, UserName) -> Settings -> Application -> IO () runTlsWithGroupUserName (crt, key) (g, u) set app = do (!c, !cs) <- separateChain <$> BS.readFile crt !k <- BS.readFile key let tset = tlsSettingsChainMemory c cs k withSocketsDo $ bracket (bindPortTCPWithName (g, u) (getPort set) (getHost set)) close (\sock -> runTLSSocket tset set sock app) type OnClientCertificate = CertificateChain -> IO CertificateUsage runTlsWithGroupUserNameClientCert :: (CertFile, KeyFile) -> OnClientCertificate -> (GroupName, UserName) -> Settings -> Application -> IO () runTlsWithGroupUserNameClientCert (crt, key) occ (g, u) set app = do (!c, !cs) <- separateChain <$> BS.readFile crt !k <- BS.readFile key let tset = tlsSettingsChainMemory c cs k withSocketsDo $ bracket (bindPortTCPWithName (g, u) (getPort set) (getHost set)) close (\sock -> runTLSSocket tset { tlsWantClientCert = True, tlsServerHooks = def { onClientCertificate = occ } } set sock app) bindPortTCPWithName :: (GroupName, UserName) -> Int -> HostPreference -> IO Socket bindPortTCPWithName (g, u) p h = (bindPortTCP p h <*) $ do getGroupEntryForName g >>= setGroupID . groupID getUserEntryForName u >>= setUserID . userID separateChain :: BS.ByteString -> (BS.ByteString, [BS.ByteString]) separateChain = (head &&& tail) . separate endCertificate :: BS.ByteString endCertificate = "-----END CERTIFICATE-----" separate :: BS.ByteString -> [BS.ByteString] separate = unfoldr separateOne separateOne :: BS.ByteString -> Maybe (BS.ByteString, BS.ByteString) separateOne "" = Nothing separateOne "\n" = Nothing separateOne ccs = Just (c <> ec, cs) where (c, (ec, cs)) = id *** BS.splitAt (BS.length endCertificate) $ BS.breakSubstring endCertificate ccs