module OpenSSL.Session
(
SSLContext
, context
, contextAddOption
, contextRemoveOption
, contextSetPrivateKey
, contextSetCertificate
, contextSetPrivateKeyFile
, contextSetCertificateFile
, contextSetCertificateChainFile
, contextSetCiphers
, contextSetDefaultCiphers
, contextCheckPrivateKey
, VerificationMode(..)
, contextSetVerificationMode
, contextSetCAFile
, contextSetCADirectory
, contextGetCAStore
, SSL
, SSLResult(..)
, connection
, fdConnection
, addOption
, removeOption
, setTlsextHostName
, accept
, tryAccept
, connect
, tryConnect
, read
, tryRead
, readPtr
, tryReadPtr
, write
, tryWrite
, writePtr
, tryWritePtr
, lazyRead
, lazyWrite
, shutdown
, tryShutdown
, ShutdownType(..)
, getPeerCertificate
, getVerifyResult
, sslSocket
, sslFd
, SSLOption(..)
, SomeSSLException
, ConnectionAbruptlyTerminated
, ProtocolError(..)
, SSLContext_
, withContext
, SSL_
, withSSL
) where
import Prelude hiding (
read, ioError, mapM, mapM_)
import Control.Concurrent (threadWaitWrite, threadWaitRead, runInBoundThread)
import Control.Concurrent.MVar
import Control.Exception
import Control.Monad (unless)
import Data.Foldable (mapM_, forM_)
import Data.Traversable (mapM)
import Data.Typeable
import Data.Maybe (fromMaybe)
import Data.IORef
import Foreign
import Foreign.C
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Internal as L
import System.IO.Unsafe
import System.Posix.Types (Fd(..))
import Network.Socket (Socket(..))
import OpenSSL.ERR
import OpenSSL.EVP.PKey
import OpenSSL.EVP.Internal
import OpenSSL.SSL.Option
import OpenSSL.Utils
import OpenSSL.X509 (X509, X509_, wrapX509, withX509Ptr)
import OpenSSL.X509.Store
type VerifyCb = Bool -> Ptr X509_STORE_CTX -> IO Bool
foreign import ccall "wrapper" mkVerifyCb :: VerifyCb -> IO (FunPtr VerifyCb)
data SSLContext_
data SSLContext = SSLContext { ctxMVar :: MVar (Ptr SSLContext_)
, ctxVfCb :: IORef (Maybe (FunPtr VerifyCb))
}
deriving Typeable
data SSLMethod_
foreign import ccall unsafe "SSL_CTX_new" _ssl_ctx_new :: Ptr SSLMethod_ -> IO (Ptr SSLContext_)
foreign import ccall unsafe "SSL_CTX_free" _ssl_ctx_free :: Ptr SSLContext_ -> IO ()
foreign import ccall unsafe "SSLv23_method" _ssl_method :: IO (Ptr SSLMethod_)
context :: IO SSLContext
context = mask_ $ do
ctx <- _ssl_method >>= _ssl_ctx_new >>= failIfNull
cbRef <- newIORef Nothing
mvar <- newMVar ctx
_ <- mkWeakMVar mvar
$ do _ssl_ctx_free ctx
readIORef cbRef >>= mapM_ freeHaskellFunPtr
return $ SSLContext { ctxMVar = mvar, ctxVfCb = cbRef }
withContext :: SSLContext -> (Ptr SSLContext_ -> IO a) -> IO a
withContext = withMVar . ctxMVar
touchContext :: SSLContext -> IO ()
touchContext = (>> return ()) . isEmptyMVar . ctxMVar
foreign import ccall unsafe "HsOpenSSL_SSL_CTX_set_options"
_SSL_CTX_set_options :: Ptr SSLContext_ -> CLong -> IO CLong
foreign import ccall unsafe "HsOpenSSL_SSL_CTX_clear_options"
_SSL_CTX_clear_options :: Ptr SSLContext_ -> CLong -> IO CLong
contextAddOption :: SSLContext -> SSLOption -> IO ()
contextAddOption ctx opt =
withContext ctx $ \ctxPtr ->
_SSL_CTX_set_options ctxPtr (optionToIntegral opt) >> return ()
contextRemoveOption :: SSLContext -> SSLOption -> IO ()
contextRemoveOption ctx opt =
withContext ctx $ \ctxPtr ->
_SSL_CTX_clear_options ctxPtr (optionToIntegral opt) >> return ()
contextLoadFile :: (Ptr SSLContext_ -> CString -> CInt -> IO CInt)
-> SSLContext -> String -> IO ()
contextLoadFile f context path =
withContext context $ \ctx ->
withCString path $ \cpath -> do
result <- f ctx cpath (1)
unless (result == 1)
$ f ctx cpath (2) >>= failIf_ (/= 1)
foreign import ccall unsafe "SSL_CTX_use_PrivateKey"
_ssl_ctx_use_privatekey :: Ptr SSLContext_ -> Ptr EVP_PKEY -> IO CInt
foreign import ccall unsafe "SSL_CTX_use_certificate"
_ssl_ctx_use_certificate :: Ptr SSLContext_ -> Ptr X509_ -> IO CInt
contextSetPrivateKey :: KeyPair k => SSLContext -> k -> IO ()
contextSetPrivateKey context key
= withContext context $ \ ctx ->
withPKeyPtr' key $ \ keyPtr ->
_ssl_ctx_use_privatekey ctx keyPtr
>>= failIf_ (/= 1)
contextSetCertificate :: SSLContext -> X509 -> IO ()
contextSetCertificate context cert
= withContext context $ \ ctx ->
withX509Ptr cert $ \ certPtr ->
_ssl_ctx_use_certificate ctx certPtr
>>= failIf_ (/= 1)
foreign import ccall unsafe "SSL_CTX_use_PrivateKey_file"
_ssl_ctx_use_privatekey_file :: Ptr SSLContext_ -> CString -> CInt -> IO CInt
foreign import ccall unsafe "SSL_CTX_use_certificate_file"
_ssl_ctx_use_certificate_file :: Ptr SSLContext_ -> CString -> CInt -> IO CInt
contextSetPrivateKeyFile :: SSLContext -> FilePath -> IO ()
contextSetPrivateKeyFile = contextLoadFile _ssl_ctx_use_privatekey_file
contextSetCertificateFile :: SSLContext -> FilePath -> IO ()
contextSetCertificateFile = contextLoadFile _ssl_ctx_use_certificate_file
foreign import ccall unsafe "SSL_CTX_use_certificate_chain_file"
_ssl_ctx_use_certificate_chain_file :: Ptr SSLContext_ -> CString -> IO CInt
contextSetCertificateChainFile :: SSLContext -> FilePath -> IO ()
contextSetCertificateChainFile context path =
withContext context $ \ctx ->
withCString path $ \cpath ->
_ssl_ctx_use_certificate_chain_file ctx cpath >>= failIf_ (/= 1)
foreign import ccall unsafe "SSL_CTX_set_cipher_list"
_ssl_ctx_set_cipher_list :: Ptr SSLContext_ -> CString -> IO CInt
contextSetCiphers :: SSLContext -> String -> IO ()
contextSetCiphers context list =
withContext context $ \ctx ->
withCString list $ \cpath ->
_ssl_ctx_set_cipher_list ctx cpath >>= failIf_ (/= 1)
contextSetDefaultCiphers :: SSLContext -> IO ()
contextSetDefaultCiphers = flip contextSetCiphers "DEFAULT"
foreign import ccall unsafe "SSL_CTX_check_private_key"
_ssl_ctx_check_private_key :: Ptr SSLContext_ -> IO CInt
contextCheckPrivateKey :: SSLContext -> IO Bool
contextCheckPrivateKey context =
withContext context $ \ctx ->
fmap (== 1) (_ssl_ctx_check_private_key ctx)
data VerificationMode = VerifyNone
| VerifyPeer {
vpFailIfNoPeerCert :: Bool
, vpClientOnce :: Bool
, vpCallback :: Maybe (Bool -> X509StoreCtx -> IO Bool)
}
deriving Typeable
foreign import ccall unsafe "SSL_CTX_set_verify"
_ssl_set_verify_mode :: Ptr SSLContext_ -> CInt -> FunPtr VerifyCb -> IO ()
contextSetVerificationMode :: SSLContext -> VerificationMode -> IO ()
contextSetVerificationMode context VerifyNone =
withContext context $ \ctx ->
_ssl_set_verify_mode ctx (0) nullFunPtr >> return ()
contextSetVerificationMode context (VerifyPeer reqp oncep cbp) = do
let mode = (1) .|.
(if reqp then (2) else 0) .|.
(if oncep then (4) else 0)
withContext context $ \ctx -> mask_ $ do
let cbRef = ctxVfCb context
newCb <- mapM mkVerifyCb $ (<$> cbp) $ \cb pvf pStoreCtx ->
cb pvf =<< wrapX509StoreCtx (return ()) pStoreCtx
oldCb <- readIORef cbRef
writeIORef cbRef newCb
forM_ oldCb freeHaskellFunPtr
_ssl_set_verify_mode ctx mode $ fromMaybe nullFunPtr newCb
return ()
foreign import ccall unsafe "SSL_CTX_load_verify_locations"
_ssl_load_verify_locations :: Ptr SSLContext_ -> Ptr CChar -> Ptr CChar -> IO CInt
contextSetCAFile :: SSLContext -> FilePath -> IO ()
contextSetCAFile context path =
withContext context $ \ctx ->
withCString path $ \cpath ->
_ssl_load_verify_locations ctx cpath nullPtr >>= failIf_ (/= 1)
contextSetCADirectory :: SSLContext -> FilePath -> IO ()
contextSetCADirectory context path =
withContext context $ \ctx ->
withCString path $ \cpath ->
_ssl_load_verify_locations ctx nullPtr cpath >>= failIf_ (/= 1)
foreign import ccall unsafe "SSL_CTX_get_cert_store"
_ssl_get_cert_store :: Ptr SSLContext_ -> IO (Ptr X509_STORE)
contextGetCAStore :: SSLContext -> IO X509Store
contextGetCAStore context
= withContext context $ \ ctx ->
_ssl_get_cert_store ctx
>>= wrapX509Store (touchContext context)
data SSL_
data SSL = SSL { sslCtx :: SSLContext
, sslMVar :: MVar (Ptr SSL_)
, sslFd :: Fd
, sslSocket :: Maybe Socket
}
deriving Typeable
foreign import ccall unsafe "SSL_new" _ssl_new :: Ptr SSLContext_ -> IO (Ptr SSL_)
foreign import ccall unsafe "SSL_free" _ssl_free :: Ptr SSL_ -> IO ()
foreign import ccall unsafe "SSL_set_fd" _ssl_set_fd :: Ptr SSL_ -> CInt -> IO ()
connection' :: SSLContext -> Fd -> Maybe Socket -> IO SSL
connection' context fd@(Fd fdInt) sock = do
mvar <- mask_ $ do
ssl <- withContext context $ \ctx -> do
ssl <- _ssl_new ctx >>= failIfNull
_ssl_set_fd ssl fdInt
return ssl
mvar <- newMVar ssl
_ <- mkWeakMVar mvar $ _ssl_free ssl
return mvar
return $ SSL { sslCtx = context
, sslMVar = mvar
, sslFd = fd
, sslSocket = sock
}
connection :: SSLContext -> Socket -> IO SSL
connection context sock@(MkSocket fd _ _ _ _) =
connection' context (Fd fd) (Just sock)
fdConnection :: SSLContext -> Fd -> IO SSL
fdConnection context fd = connection' context fd Nothing
withSSL :: SSL -> (Ptr SSL_ -> IO a) -> IO a
withSSL = withMVar . sslMVar
foreign import ccall unsafe "HsOpenSSL_SSL_set_options"
_SSL_set_options :: Ptr SSL_ -> CLong -> IO CLong
foreign import ccall unsafe "HsOpenSSL_SSL_clear_options"
_SSL_clear_options :: Ptr SSL_ -> CLong -> IO CLong
foreign import ccall unsafe "HsOpenSSL_SSL_set_tlsext_host_name"
_SSL_set_tlsext_host_name :: Ptr SSL_ -> CString -> IO CLong
addOption :: SSL -> SSLOption -> IO ()
addOption ssl opt =
withSSL ssl $ \sslPtr ->
_SSL_set_options sslPtr (optionToIntegral opt) >> return ()
removeOption :: SSL -> SSLOption -> IO ()
removeOption ssl opt =
withSSL ssl $ \sslPtr ->
_SSL_clear_options sslPtr (optionToIntegral opt) >> return ()
setTlsextHostName :: SSL -> String -> IO ()
setTlsextHostName ssl h =
withSSL ssl $ \sslPtr ->
withCString h $ \ hPtr ->
_SSL_set_tlsext_host_name sslPtr hPtr >> return ()
foreign import ccall "SSL_accept" _ssl_accept :: Ptr SSL_ -> IO CInt
foreign import ccall "SSL_connect" _ssl_connect :: Ptr SSL_ -> IO CInt
foreign import ccall unsafe "SSL_get_error" _ssl_get_error :: Ptr SSL_ -> CInt -> IO CInt
throwSSLException :: String -> CInt -> IO a
throwSSLException loc ret
= do e <- getError
if e == 0 then
case ret of
0 -> throwIO ConnectionAbruptlyTerminated
_ -> throwErrno loc
else
errorString e >>= throwIO . ProtocolError
data SSLResult a = SSLDone a
| WantRead
| WantWrite
deriving (Eq, Show, Functor, Foldable, Traversable, Typeable)
sslBlock :: (SSL -> IO (SSLResult a)) -> SSL -> IO a
sslBlock action ssl
= do result <- action ssl
case result of
SSLDone r -> return r
WantRead -> threadWaitRead (sslFd ssl) >> sslBlock action ssl
WantWrite -> threadWaitWrite (sslFd ssl) >> sslBlock action ssl
sslTryHandshake :: String
-> (Ptr SSL_ -> IO CInt)
-> SSL
-> IO (SSLResult CInt)
sslTryHandshake loc action ssl
= runInBoundThread $
withSSL ssl $ \sslPtr ->
do n <- action sslPtr
if n == 1 then
return $ SSLDone n
else
do err <- _ssl_get_error sslPtr n
case err of
(2) -> return WantRead
(3) -> return WantWrite
_ -> throwSSLException loc n
accept :: SSL -> IO ()
accept = sslBlock tryAccept
tryAccept :: SSL -> IO (SSLResult ())
tryAccept ssl
= (() <$) <$> sslTryHandshake "SSL_accept" _ssl_accept ssl
connect :: SSL -> IO ()
connect = sslBlock tryConnect
tryConnect :: SSL -> IO (SSLResult ())
tryConnect ssl
= (() <$) <$> sslTryHandshake "SSL_connect" _ssl_connect ssl
foreign import ccall "SSL_read" _ssl_read :: Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt
foreign import ccall unsafe "SSL_get_shutdown" _ssl_get_shutdown :: Ptr SSL_ -> IO CInt
sslIOInner :: String
-> (Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt)
-> Ptr CChar
-> Int
-> SSL
-> IO (SSLResult CInt)
sslIOInner loc f ptr nbytes ssl
= runInBoundThread $
withSSL ssl $ \sslPtr ->
do n <- f sslPtr (castPtr ptr) $ fromIntegral nbytes
if n > 0 then
return $ SSLDone $ fromIntegral n
else
do err <- _ssl_get_error sslPtr n
case err of
(6) -> return $ SSLDone $ 0
(2) -> return WantRead
(3) -> return WantWrite
_ -> throwSSLException loc n
read :: SSL -> Int -> IO B.ByteString
read ssl nBytes = sslBlock (`tryRead` nBytes) ssl
tryRead :: SSL -> Int -> IO (SSLResult B.ByteString)
tryRead ssl nBytes
= do (bs, result) <- B.createAndTrim' nBytes $ \bufPtr ->
do result <- sslIOInner "SSL_read" _ssl_read (castPtr bufPtr) nBytes ssl
case result of
SSLDone n -> return (0, fromIntegral n, SSLDone ())
WantRead -> return (0, 0, WantRead )
WantWrite -> return (0, 0, WantWrite )
return $ bs <$ result
readPtr :: SSL -> Ptr a -> Int -> IO Int
readPtr ssl ptr len = sslBlock (\h -> tryReadPtr h ptr len) ssl
tryReadPtr :: SSL -> Ptr a -> Int -> IO (SSLResult Int)
tryReadPtr ssl bufPtr nBytes =
fmap (fmap fromIntegral) (sslIOInner "SSL_read" _ssl_read (castPtr bufPtr) nBytes ssl)
foreign import ccall "SSL_write" _ssl_write :: Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt
write :: SSL -> B.ByteString -> IO ()
write ssl bs = sslBlock (`tryWrite` bs) ssl >> return ()
tryWrite :: SSL -> B.ByteString -> IO (SSLResult ())
tryWrite ssl bs
| B.null bs = return $ SSLDone ()
| otherwise
= B.unsafeUseAsCStringLen bs $ \(ptr, len) -> tryWritePtr ssl ptr len
writePtr :: SSL -> Ptr a -> Int -> IO ()
writePtr ssl ptr len = sslBlock (\h -> tryWritePtr h ptr len) ssl >> return ()
tryWritePtr :: SSL -> Ptr a -> Int -> IO (SSLResult ())
tryWritePtr ssl ptr len =
do result <- sslIOInner "SSL_write" _ssl_write (castPtr ptr) len ssl
case result of
SSLDone 0 -> ioError $ errnoToIOError "SSL_write" ePIPE Nothing Nothing
SSLDone _ -> return $ SSLDone ()
WantRead -> return WantRead
WantWrite -> return WantWrite
lazyRead :: SSL -> IO L.ByteString
lazyRead ssl = fmap L.fromChunks lazyRead'
where
chunkSize = L.defaultChunkSize
lazyRead' = unsafeInterleaveIO loop
loop = do bs <- read ssl chunkSize
if B.null bs then
return []
else
do bss <- lazyRead'
return (bs:bss)
lazyWrite :: SSL -> L.ByteString -> IO ()
lazyWrite ssl lbs
= mapM_ (write ssl) $ L.toChunks lbs
foreign import ccall "SSL_shutdown" _ssl_shutdown :: Ptr SSL_ -> IO CInt
data ShutdownType = Bidirectional
| Unidirectional
deriving (Eq, Show, Typeable)
shutdown :: SSL -> ShutdownType -> IO ()
shutdown ssl ty = sslBlock (`tryShutdown` ty) ssl
tryShutdown :: SSL -> ShutdownType -> IO (SSLResult ())
tryShutdown ssl ty = runInBoundThread $ withSSL ssl loop
where
loop :: Ptr SSL_ -> IO (SSLResult ())
loop sslPtr
= do n <- _ssl_shutdown sslPtr
case n of
0 | ty == Bidirectional ->
loop sslPtr
| otherwise ->
return $ SSLDone ()
1 ->
return $ SSLDone ()
2 ->
loop sslPtr
_ -> do err <- _ssl_get_error sslPtr n
case err of
(2) -> return WantRead
(3) -> return WantWrite
(5)
-> do sd <- _ssl_get_shutdown sslPtr
if sd .&. (2) == 0 then
throwSSLException "SSL_shutdown" n
else
return $ SSLDone ()
_ -> throwSSLException "SSL_shutdown" n
foreign import ccall "SSL_get_peer_certificate" _ssl_get_peer_cert :: Ptr SSL_ -> IO (Ptr X509_)
getPeerCertificate :: SSL -> IO (Maybe X509)
getPeerCertificate ssl =
withSSL ssl $ \ssl -> do
cert <- _ssl_get_peer_cert ssl
if cert == nullPtr
then return Nothing
else fmap Just (wrapX509 cert)
foreign import ccall "SSL_get_verify_result" _ssl_get_verify_result :: Ptr SSL_ -> IO CLong
getVerifyResult :: SSL -> IO Bool
getVerifyResult ssl =
withSSL ssl $ \ssl -> do
r <- _ssl_get_verify_result ssl
return $ r == (0)
data SomeSSLException
= forall e. Exception e => SomeSSLException e
deriving Typeable
instance Show SomeSSLException where
show (SomeSSLException e) = show e
instance Exception SomeSSLException
sslExceptionToException :: Exception e => e -> SomeException
sslExceptionToException = toException . SomeSSLException
sslExceptionFromException :: Exception e => SomeException -> Maybe e
sslExceptionFromException x
= do SomeSSLException a <- fromException x
cast a
data ConnectionAbruptlyTerminated
= ConnectionAbruptlyTerminated
deriving (Typeable, Show, Eq)
instance Exception ConnectionAbruptlyTerminated where
toException = sslExceptionToException
fromException = sslExceptionFromException
data ProtocolError
= ProtocolError !String
deriving (Typeable, Show, Eq)
instance Exception ProtocolError where
toException = sslExceptionToException
fromException = sslExceptionFromException