{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DeriveDataTypeable  #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}

------------------------------------------------------------------------------
module Snap.Internal.Http.Server.TLS
  ( TLSException(..)
  , withTLS
  , bindHttps
  , httpsAcceptFunc
  , sendFileFunc
  ) where

------------------------------------------------------------------------------
import           Data.ByteString.Char8             (ByteString)
import qualified Data.ByteString.Char8             as S
import           Data.Typeable                     (Typeable)
import           Network.Socket                    (Socket)
#ifdef OPENSSL
import           Control.Exception                 (Exception, bracketOnError, finally, onException, throwIO)
import           Control.Monad                     (when)
import           Data.ByteString.Builder           (byteString)
import qualified Network.Socket                    as Socket
import           OpenSSL                           (withOpenSSL)
import           OpenSSL.Session                   (SSL, SSLContext)
import qualified OpenSSL.Session                   as SSL
import           Prelude                           (Bool, FilePath, IO, Int, Maybe (..), Monad (..), Show, flip, fromIntegral, fst, not, ($), ($!), (.))
import           Snap.Internal.Http.Server.Address (getAddress, getSockAddr)
import           Snap.Internal.Http.Server.Socket  (acceptAndInitialize)
import qualified System.IO.Streams                 as Streams
import qualified System.IO.Streams.SSL             as SStreams

#else
import           Control.Exception                 (Exception, throwIO)
import           Prelude                           (Bool, FilePath, IO, Int, Show, id, ($))
#endif
------------------------------------------------------------------------------
import           Snap.Internal.Http.Server.Types   (AcceptFunc (..), SendFileHandler)
------------------------------------------------------------------------------

data TLSException = TLSException S.ByteString
  deriving (Show, Typeable)
instance Exception TLSException

#ifndef OPENSSL
type SSLContext = ()
type SSL = ()

------------------------------------------------------------------------------
sslNotSupportedException :: TLSException
sslNotSupportedException = TLSException $ S.concat [
    "This version of snap-server was not built with SSL "
  , "support.\n"
  , "Please compile snap-server with -fopenssl to enable it."
  ]


------------------------------------------------------------------------------
withTLS :: IO a -> IO a
withTLS = id


------------------------------------------------------------------------------
barf :: IO a
barf = throwIO sslNotSupportedException


------------------------------------------------------------------------------
bindHttps :: ByteString -> Int -> FilePath -> Bool -> FilePath
          -> IO (Socket, SSLContext)
bindHttps _ _ _ _ _ = barf


------------------------------------------------------------------------------
httpsAcceptFunc :: Socket -> SSLContext -> AcceptFunc
httpsAcceptFunc _ _ = AcceptFunc $ \restore -> restore barf


------------------------------------------------------------------------------
sendFileFunc :: SSL -> Socket -> SendFileHandler
sendFileFunc _ _ _ _ _ _ _ = barf


#else
------------------------------------------------------------------------------
withTLS :: IO a -> IO a
withTLS = withOpenSSL


------------------------------------------------------------------------------
bindHttps :: ByteString
          -> Int
          -> FilePath
          -> Bool
          -> FilePath
          -> IO (Socket, SSLContext)
bindHttps bindAddress bindPort cert chainCert key =
    withTLS $
    bracketOnError
        (do (family, addr) <- getSockAddr bindPort bindAddress
            sock <- Socket.socket family Socket.Stream 0
            return (sock, addr)
            )
        (Socket.close . fst)
        $ \(sock, addr) -> do
             Socket.setSocketOption sock Socket.ReuseAddr 1
             Socket.bindSocket sock addr
             Socket.listen sock 150

             ctx <- SSL.context
             SSL.contextSetPrivateKeyFile ctx key
             if chainCert
               then SSL.contextSetCertificateChainFile ctx cert
               else SSL.contextSetCertificateFile ctx cert

             certOK <- SSL.contextCheckPrivateKey ctx
             when (not certOK) $ do
                 throwIO $ TLSException certificateError
             return (sock, ctx)
  where
    certificateError =
      "OpenSSL says that the certificate doesn't match the private key!"


------------------------------------------------------------------------------
httpsAcceptFunc :: Socket
                -> SSLContext
                -> AcceptFunc
httpsAcceptFunc boundSocket ctx =
    AcceptFunc $ \restore ->
    acceptAndInitialize boundSocket restore $ \(sock, remoteAddr) -> do
        localAddr                <- Socket.getSocketName sock
        (localPort, localHost)   <- getAddress localAddr
        (remotePort, remoteHost) <- getAddress remoteAddr
        ssl                      <- restore (SSL.connection ctx sock)

        restore (SSL.accept ssl) `onException` Socket.close sock

        (readEnd, writeEnd) <- SStreams.sslToStreams ssl

        let cleanup = (do Streams.write Nothing writeEnd
                          SSL.shutdown ssl $! SSL.Unidirectional)
                        `finally` Socket.close sock

        return $! ( sendFileFunc ssl
                  , localHost
                  , localPort
                  , remoteHost
                  , remotePort
                  , readEnd
                  , writeEnd
                  , cleanup
                  )


------------------------------------------------------------------------------
sendFileFunc :: SSL -> SendFileHandler
sendFileFunc ssl buffer builder fPath offset nbytes = do
    Streams.unsafeWithFileAsInputStartingAt (fromIntegral offset) fPath $ \fileInput0 -> do
        fileInput <- Streams.takeBytes (fromIntegral nbytes) fileInput0 >>=
                     Streams.map byteString
        input     <- Streams.fromList [builder] >>=
                     flip Streams.appendInputStream fileInput
        output    <- Streams.makeOutputStream sendChunk >>=
                     Streams.unsafeBuilderStream (return buffer)
        Streams.supply input output
        Streams.write Nothing output

  where
    sendChunk (Just s) = SSL.write ssl s
    sendChunk Nothing  = return $! ()
#endif