module Network.SSL (
SSLHandle
, sslInit
, randSeed
, sslConnect
, sslRead
, sslReadWhile
, sslWrite
) where
import Control.Monad
import Data.List
import Data.Maybe
import Data.Word
import Foreign.C
import Foreign.ForeignPtr
import Foreign.Marshal
import Foreign.Ptr
import Foreign.Storable
import Network.Socket
import Network.Stream
wrap m = Right `fmap` m `catch` handler
where handler = return . Left . ErrorMisc . show
instance Stream SSLHandle where
readLine sh = wrap (upd `fmap` sslReadWhile (/= c) sh)
where
c = toEnum (fromEnum '\n')
upd bs = map (toEnum . fromEnum) bs ++ "\n"
readBlock sh n = wrap (map (toEnum . fromEnum) `fmap` sslRead sh n)
writeBlock sh bs = wrap $ sslWrite sh $ map (toEnum . fromEnum) bs
close (SH (_,_,sock)) = sClose sock
newtype SSLHandle = SH (ForeignPtr SSL,ForeignPtr SSL_CTX,Socket)
sslInit :: IO ()
sslInit = do
c_SSL_library_init
c_SSL_load_error_strings
randSeed :: [Word8] -> IO ()
randSeed bs = withArray bs (\buf -> c_RAND_seed buf len)
where len = genericLength bs
sslConnect :: Socket -> IO (Maybe SSLHandle)
sslConnect sock =
notNull (c_SSL_CTX_new =<< c_SSLv23_client_method) $ \ctx ->
newForeignPtr p_SSL_CTX_free ctx >>= \pctx ->
notNull (c_SSL_new ctx) $ \ssl ->
newForeignPtr p_SSL_shutdown ssl >>= \pssl ->
setSocket ssl sock >>= \b ->
if not b
then return Nothing
else
let loop = c_SSL_connect ssl >>= \ret -> case ret of
1 -> return $ Just $ SH (pssl,pctx,sock)
0 -> return Nothing
_ -> c_SSL_get_error ssl ret >>= \code -> case code of
2 -> loop
_ -> return Nothing
in loop
sslRead :: SSLHandle -> Int -> IO [Word8]
sslRead (SH (pssl,_,_)) n = withForeignPtr pssl (loop n)
where
loop n ssl | n < 1024 = aux ssl n
| otherwise = do xs <- aux ssl 1024
ys <- loop (n 1024) ssl
return (xs ++ ys)
aux ssl n = allocaArray n aux'
where
aux' buf = do
ret <- c_SSL_read ssl buf (toEnum n)
bs <- peekArray (toEnum n) buf
if ret >= 0
then return bs
else c_SSL_get_error ssl ret >>= \r -> case r of
2 -> aux' buf
0 -> return []
_ -> getError >>= error . ("sslRead: " ++)
sslReadWhile :: (Word8 -> Bool) -> SSLHandle -> IO [Word8]
sslReadWhile p (SH (pssl,_,_)) = withForeignPtr pssl f
where
f ssl = allocaArray 1 loop
where
loop buf = do
ret <- c_SSL_read ssl buf 1
b <- peek buf
if ret == 1 && p b
then do
bs <- loop buf
return (b:bs)
else c_SSL_get_error ssl ret >>= \r -> case r of
2 -> loop buf
0 -> return []
_ -> getError >>= error . ("sslReadWhile: " ++)
sslWrite :: SSLHandle -> [Word8] -> IO ()
sslWrite _ [] = return ()
sslWrite (SH (pssl,_,_)) bs = withForeignPtr pssl $ \ssl -> withArrayLen bs (write ssl)
where
write ssl len buf = do
ret <- c_SSL_write ssl buf (toEnum len)
if ret > 0
then return ()
else c_SSL_get_error ssl ret >>= \r -> case r of
2 -> write ssl len buf
0 -> return ()
_ -> getError >>= error . ("sslWrite:" ++)
notNull :: IO (Ptr a) -> (Ptr a -> IO (Maybe b)) -> IO (Maybe b)
notNull m f = do
ptr <- m
if ptr == nullPtr
then return Nothing
else f ptr
setSocket :: Ptr SSL -> Socket -> IO Bool
setSocket ssl sock = c_SSL_set_fd ssl (fdSocket sock) >>= \ret ->
case ret of
1 -> return True
_ -> return False
getError :: IO String
getError = allocaArray 120 $ \array -> do
c_ERR_error_string 120 array
peekCString array
data SSL_CTX
data SSL
foreign import ccall "openssl/ssl.h SSL_CTX_new"
c_SSL_CTX_new :: Ptr () -> IO (Ptr SSL_CTX)
foreign import ccall "openssl/ssl.h SSLv23_client_method"
c_SSLv23_client_method :: IO (Ptr ())
foreign import ccall "openssl/ssl.h SSL_library_init"
c_SSL_library_init :: IO ()
foreign import ccall "openssl/ssl.h SSL_load_error_strings"
c_SSL_load_error_strings :: IO ()
foreign import ccall "openssl/rand.h RAND_seed"
c_RAND_seed :: Ptr Word8 -> CInt -> IO ()
foreign import ccall "openssl/ssl.h SSL_new"
c_SSL_new :: Ptr SSL_CTX -> IO (Ptr SSL)
foreign import ccall "openssl/ssl.h &SSL_shutdown"
p_SSL_shutdown :: FunPtr (Ptr SSL -> IO ())
foreign import ccall "openssl/ssl.h &SSL_CTX_free"
p_SSL_CTX_free :: FunPtr (Ptr SSL_CTX -> IO ())
foreign import ccall "openssl/ssl.h SSL_set_fd"
c_SSL_set_fd :: Ptr SSL -> CInt -> IO CInt
foreign import ccall "openssl/ssl.h SSL_connect"
c_SSL_connect :: Ptr SSL -> IO CInt
foreign import ccall "openssl/ssl.h SSL_get_error"
c_SSL_get_error :: Ptr SSL -> CInt -> IO CInt
foreign import ccall "openssl/err.h ERR_error_string"
c_ERR_error_string :: CULong -> CString -> IO CString
foreign import ccall "openssl/ssl.h SSL_read"
c_SSL_read :: Ptr SSL -> Ptr Word8 -> CInt -> IO CInt
foreign import ccall "openssl/ssl.h SSL_write"
c_SSL_write :: Ptr SSL -> Ptr Word8 -> CInt -> IO CInt