module OpenSSL.Session
(
SSLContext
, context
, contextSetPrivateKey
, contextSetCertificate
, contextSetPrivateKeyFile
, contextSetCertificateFile
, contextSetCiphers
, contextSetDefaultCiphers
, contextCheckPrivateKey
, VerificationMode(..)
, contextSetVerificationMode
, contextSetCAFile
, contextSetCADirectory
, contextGetCAStore
, SSL
, connection
, accept
, connect
, read
, write
, lazyRead
, lazyWrite
, shutdown
, ShutdownType(..)
, getPeerCertificate
, getVerifyResult
, sslSocket
) where
import Prelude hiding (read, ioError)
import Control.Concurrent (threadWaitWrite, threadWaitRead)
import Control.Concurrent.QSem
import Control.Exception (finally)
import Foreign
import Foreign.C
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
import qualified Data.ByteString.Unsafe as B
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Internal as L
import System.IO.Error (mkIOError, ioError, eofErrorType, catch, isEOFError)
import System.IO.Unsafe
import System.Posix.Types (Fd(..))
import Network.Socket (Socket(..))
import OpenSSL.EVP.PKey
import OpenSSL.Utils (failIfNull, failIf)
import OpenSSL.X509 (X509, X509_, wrapX509, withX509Ptr)
import OpenSSL.X509.Store
data SSLContext_
newtype SSLContext = SSLContext (QSem, ForeignPtr SSLContext_)
data SSLMethod_
foreign import ccall unsafe "SSL_CTX_new" _ssl_ctx_new :: Ptr SSLMethod_ -> IO (Ptr SSLContext_)
foreign import ccall unsafe "&SSL_CTX_free" _ssl_ctx_free :: FunPtr (Ptr SSLContext_ -> IO ())
foreign import ccall unsafe "SSLv23_method" _ssl_method :: IO (Ptr SSLMethod_)
context :: IO SSLContext
context = do
ctx <- _ssl_method >>= _ssl_ctx_new
context <- newForeignPtr _ssl_ctx_free ctx
sem <- newQSem 1
return $ SSLContext (sem, context)
withContext :: SSLContext -> (Ptr SSLContext_ -> IO a) -> IO a
withContext (SSLContext (sem, ctxfp)) action = do
waitQSem sem
finally (withForeignPtr ctxfp action) $ signalQSem sem
touchContext :: SSLContext -> IO ()
touchContext (SSLContext (_, fp))
= touchForeignPtr fp
contextLoadFile :: (Ptr SSLContext_ -> CString -> CInt -> IO CInt)
-> SSLContext -> String -> IO ()
contextLoadFile f context path =
withContext context $ \ctx ->
withCString path $ \cpath -> do
result <- f ctx cpath (1)
if result == 1
then return ()
else f ctx cpath (2) >>= failIf (/= 1) >> return ()
foreign import ccall unsafe "SSL_CTX_use_PrivateKey"
_ssl_ctx_use_privatekey :: Ptr SSLContext_ -> Ptr EVP_PKEY -> IO CInt
foreign import ccall unsafe "SSL_CTX_use_certificate"
_ssl_ctx_use_certificate :: Ptr SSLContext_ -> Ptr X509_ -> IO CInt
contextSetPrivateKey :: KeyPair k => SSLContext -> k -> IO ()
contextSetPrivateKey context key
= withContext context $ \ ctx ->
withPKeyPtr' key $ \ keyPtr ->
_ssl_ctx_use_privatekey ctx keyPtr
>>= failIf (/= 1)
>> return ()
contextSetCertificate :: SSLContext -> X509 -> IO ()
contextSetCertificate context cert
= withContext context $ \ ctx ->
withX509Ptr cert $ \ certPtr ->
_ssl_ctx_use_certificate ctx certPtr
>>= failIf (/= 1)
>> return ()
foreign import ccall unsafe "SSL_CTX_use_PrivateKey_file"
_ssl_ctx_use_privatekey_file :: Ptr SSLContext_ -> CString -> CInt -> IO CInt
foreign import ccall unsafe "SSL_CTX_use_certificate_file"
_ssl_ctx_use_certificate_file :: Ptr SSLContext_ -> CString -> CInt -> IO CInt
contextSetPrivateKeyFile :: SSLContext -> FilePath -> IO ()
contextSetPrivateKeyFile = contextLoadFile _ssl_ctx_use_privatekey_file
contextSetCertificateFile :: SSLContext -> FilePath -> IO ()
contextSetCertificateFile = contextLoadFile _ssl_ctx_use_certificate_file
foreign import ccall unsafe "SSL_CTX_set_cipher_list"
_ssl_ctx_set_cipher_list :: Ptr SSLContext_ -> CString -> IO CInt
contextSetCiphers :: SSLContext -> String -> IO ()
contextSetCiphers context list =
withContext context $ \ctx ->
withCString list $ \cpath ->
_ssl_ctx_set_cipher_list ctx cpath >>= failIf (/= 1) >> return ()
contextSetDefaultCiphers :: SSLContext -> IO ()
contextSetDefaultCiphers = flip contextSetCiphers "DEFAULT"
foreign import ccall unsafe "SSL_CTX_check_private_key"
_ssl_ctx_check_private_key :: Ptr SSLContext_ -> IO CInt
contextCheckPrivateKey :: SSLContext -> IO Bool
contextCheckPrivateKey context =
withContext context $ \ctx ->
_ssl_ctx_check_private_key ctx >>= return . (==) 1
data VerificationMode = VerifyNone
| VerifyPeer {
vpFailIfNoPeerCert :: Bool
, vpClientOnce :: Bool
}
foreign import ccall unsafe "SSL_CTX_set_verify"
_ssl_set_verify_mode :: Ptr SSLContext_ -> CInt -> Ptr () -> IO ()
contextSetVerificationMode :: SSLContext -> VerificationMode -> IO ()
contextSetVerificationMode context VerifyNone = do
withContext context $ \ctx ->
_ssl_set_verify_mode ctx (0) nullPtr >> return ()
contextSetVerificationMode context (VerifyPeer reqp oncep) = do
let mode = (1) .|.
(if reqp then (2) else 0) .|.
(if oncep then (4) else 0)
withContext context $ \ctx ->
_ssl_set_verify_mode ctx mode nullPtr >> return ()
foreign import ccall unsafe "SSL_CTX_load_verify_locations"
_ssl_load_verify_locations :: Ptr SSLContext_ -> Ptr CChar -> Ptr CChar -> IO CInt
contextSetCAFile :: SSLContext -> FilePath -> IO ()
contextSetCAFile context path = do
withContext context $ \ctx ->
withCString path $ \cpath -> do
_ssl_load_verify_locations ctx cpath nullPtr >>= failIf (/= 1)
return ()
contextSetCADirectory :: SSLContext -> FilePath -> IO ()
contextSetCADirectory context path = do
withContext context $ \ctx ->
withCString path $ \cpath -> do
_ssl_load_verify_locations ctx nullPtr cpath >>= failIf (/= 1)
return ()
foreign import ccall unsafe "SSL_CTX_get_cert_store"
_ssl_get_cert_store :: Ptr SSLContext_ -> IO (Ptr X509_STORE)
contextGetCAStore :: SSLContext -> IO X509Store
contextGetCAStore context
= withContext context $ \ ctx ->
_ssl_get_cert_store ctx
>>= wrapX509Store (touchContext context)
data SSL_
newtype SSL = SSL (QSem, ForeignPtr SSL_, Fd, Socket)
foreign import ccall unsafe "SSL_new" _ssl_new :: Ptr SSLContext_ -> IO (Ptr SSL_)
foreign import ccall unsafe "&SSL_free" _ssl_free :: FunPtr (Ptr SSL_ -> IO ())
foreign import ccall unsafe "SSL_set_fd" _ssl_set_fd :: Ptr SSL_ -> CInt -> IO ()
connection :: SSLContext -> Socket -> IO SSL
connection context sock@(MkSocket fd _ _ _ _) = do
sem <- newQSem 1
ssl <- withContext context (\ctx -> do
ssl <- _ssl_new ctx >>= failIfNull
_ssl_set_fd ssl fd
return ssl)
fpssl <- newForeignPtr _ssl_free ssl
return $ SSL (sem, fpssl, Fd fd, sock)
withSSL :: SSL -> (Ptr SSL_ -> IO a) -> IO a
withSSL (SSL (sem, ssl, _, _)) action = do
waitQSem sem
finally (withForeignPtr ssl action) $ signalQSem sem
foreign import ccall "SSL_accept" _ssl_accept :: Ptr SSL_ -> IO CInt
foreign import ccall "SSL_connect" _ssl_connect :: Ptr SSL_ -> IO CInt
foreign import ccall unsafe "SSL_get_error" _ssl_get_error :: Ptr SSL_ -> CInt -> IO CInt
sslErrorToString :: CInt -> String
sslErrorToString (0) = "SSL: No error"
sslErrorToString (6) = "SSL: connection cleanly closed"
sslErrorToString (7) = "SSL: want connect"
sslErrorToString (8) = "SSL: want accept"
sslErrorToString (4) = "SSL: want X509 lookup"
sslErrorToString (5) = "SSL: syscall error"
sslErrorToString (1) = "SSL: ssl protocol error"
sslErrorToString x = "SSL: unknown error " ++ show x
data SSLIOResult = Done CInt
| WantRead
| WantWrite
deriving (Eq)
sslDoHandshake :: (Ptr SSL_ -> IO CInt) -> SSL -> IO CInt
sslDoHandshake action ssl@(SSL (_, _, fd, _)) = do
let f ssl = do
n <- action ssl
case n of
n | n >= 0 -> return $ Done n
_ -> do
err <- _ssl_get_error ssl n
case err of
(2) -> return WantRead
(3) -> return WantWrite
_ -> fail $ sslErrorToString err
result <- withSSL ssl f
case result of
Done n -> return n
WantRead -> threadWaitRead fd >> sslDoHandshake action ssl
WantWrite -> threadWaitWrite fd >> sslDoHandshake action ssl
accept :: SSL -> IO ()
accept ssl = sslDoHandshake _ssl_accept ssl >>= failIf (/= 1) >> return ()
connect :: SSL -> IO ()
connect ssl = sslDoHandshake _ssl_connect ssl >>= failIf (/= 1) >> return ()
foreign import ccall "SSL_read" _ssl_read :: Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt
foreign import ccall unsafe "SSL_get_shutdown" _ssl_get_shutdown :: Ptr SSL_ -> IO CInt
sslIOInner :: (Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt)
-> Ptr CChar
-> Int
-> Ptr SSL_
-> IO SSLIOResult
sslIOInner f ptr nbytes ssl = do
n <- f ssl (castPtr ptr) $ fromIntegral nbytes
case n of
n | n > 0 -> return $ Done $ fromIntegral n
| n == 0 -> do
shutdown <- _ssl_get_shutdown ssl
if shutdown .&. (2) == 0
then fail "SSL connection abruptly terminated"
else ioError $ mkIOError eofErrorType "" Nothing Nothing
_ -> do
err <- _ssl_get_error ssl n
case err of
(2) -> return WantRead
(3) -> return WantWrite
_ -> fail $ sslErrorToString err
read :: SSL -> Int -> IO B.ByteString
read ssl@(SSL (_, _, fd, _)) nbytes = B.createAndTrim nbytes $ f ssl
where
f ssl ptr
= do result <- withSSL ssl $ sslIOInner _ssl_read (castPtr ptr) nbytes
case result of
Done n -> return $ fromIntegral n
WantRead -> threadWaitRead fd >> f ssl ptr
WantWrite -> threadWaitWrite fd >> f ssl ptr
`catch`
\ ioe ->
if isEOFError ioe then
return 0
else
ioError ioe
foreign import ccall "SSL_write" _ssl_write :: Ptr SSL_ -> Ptr Word8 -> CInt -> IO CInt
write :: SSL -> B.ByteString -> IO ()
write ssl@(SSL (_, _, fd, _)) bs = B.unsafeUseAsCStringLen bs $ f ssl where
f ssl (ptr, len) = do
result <- withSSL ssl $ sslIOInner _ssl_write ptr len
case result of
Done _ -> return ()
WantRead -> threadWaitRead fd >> f ssl (ptr, len)
WantWrite -> threadWaitWrite fd >> f ssl (ptr, len)
lazyRead :: SSL -> IO L.ByteString
lazyRead ssl = lazyRead' >>= return . L.fromChunks
where
chunkSize = L.defaultChunkSize
lazyRead' = unsafeInterleaveIO loop
loop = do bs <- read ssl chunkSize
if B.null bs then
return []
else
do bss <- lazyRead'
return (bs:bss)
lazyWrite :: SSL -> L.ByteString -> IO ()
lazyWrite ssl lbs
= mapM_ (write ssl) $ L.toChunks lbs
foreign import ccall "SSL_shutdown" _ssl_shutdown :: Ptr SSL_ -> IO CInt
data ShutdownType = Bidirectional
| Unidirectional
shutdown :: SSL -> ShutdownType -> IO ()
shutdown ssl ty = do
n <- sslDoHandshake _ssl_shutdown ssl
case ty of
Unidirectional -> return ()
Bidirectional -> do
if n == 1
then return ()
else shutdown ssl ty
foreign import ccall "SSL_get_peer_certificate" _ssl_get_peer_cert :: Ptr SSL_ -> IO (Ptr X509_)
getPeerCertificate :: SSL -> IO (Maybe X509)
getPeerCertificate ssl =
withSSL ssl $ \ssl -> do
cert <- _ssl_get_peer_cert ssl
if cert == nullPtr
then return Nothing
else wrapX509 cert >>= return . Just
foreign import ccall "SSL_get_verify_result" _ssl_get_verify_result :: Ptr SSL_ -> IO CLong
getVerifyResult :: SSL -> IO Bool
getVerifyResult ssl = do
withSSL ssl $ \ssl -> do
r <- _ssl_get_verify_result ssl
return $ r == (0)
sslSocket :: SSL -> Socket
sslSocket (SSL (_, _, _, socket)) = socket