{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

{- |
Module      : Test.Certs.Temp
Copyright   : (c) 2023 Tim Emiola
Maintainer  : Tim Emiola <adetokunbo@emio.la>
SPDX-License-Identifier: BSD3

Enables configuration and generation of temporary certificates
-}
module Test.Certs.Temp
  ( -- * generate certificates
    withCertPaths
  , withCertFilenames
  , withCertPathsInTmp
  , withCertPathsInTmp'
  , generateAndStore

    -- * configuration
  , Config (..)
  , defaultConfig

    -- * certificate filenames
  , CertPaths (..)
  , keyPath
  , certificatePath
  )
where

import qualified Data.ByteString as BS
import Data.Maybe (catMaybes)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Time (UTCTime, addUTCTime, getCurrentTime, nominalDay)
import Numeric.Natural (Natural)
import qualified OpenSSL.PEM as SSL
import qualified OpenSSL.RSA as SSL
import qualified OpenSSL.Random as SSL
import qualified OpenSSL.X509 as SSL
import System.FilePath ((</>))
import System.IO.Temp (getCanonicalTemporaryDirectory, withTempDirectory)


-- | Specifies the location to write the temporary certificates
data CertPaths = CertPaths
  { CertPaths -> String
cpKey :: !FilePath
  -- ^ the basename of the private key file
  , CertPaths -> String
cpCert :: !FilePath
  -- ^ the basename of the certificate file
  , CertPaths -> String
cpDir :: !FilePath
  -- ^ the directory containing the certificate files
  }
  deriving (CertPaths -> CertPaths -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CertPaths -> CertPaths -> Bool
$c/= :: CertPaths -> CertPaths -> Bool
== :: CertPaths -> CertPaths -> Bool
$c== :: CertPaths -> CertPaths -> Bool
Eq, Int -> CertPaths -> ShowS
[CertPaths] -> ShowS
CertPaths -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CertPaths] -> ShowS
$cshowList :: [CertPaths] -> ShowS
show :: CertPaths -> String
$cshow :: CertPaths -> String
showsPrec :: Int -> CertPaths -> ShowS
$cshowsPrec :: Int -> CertPaths -> ShowS
Show)


-- | The path of the generated key file
keyPath :: CertPaths -> FilePath
keyPath :: CertPaths -> String
keyPath CertPaths
cp = CertPaths -> String
cpDir CertPaths
cp String -> ShowS
</> CertPaths -> String
cpKey CertPaths
cp


-- | The path of the generated certificate file
certificatePath :: CertPaths -> FilePath
certificatePath :: CertPaths -> String
certificatePath CertPaths
cp = CertPaths -> String
cpDir CertPaths
cp String -> ShowS
</> CertPaths -> String
cpCert CertPaths
cp


{- | A @CertPaths using the default basenames for the certificate files
@cpKey@ is @key.pem@
@cpCert@ is @certificate.pem@
-}
defaultBasenames :: FilePath -> CertPaths
defaultBasenames :: String -> CertPaths
defaultBasenames String
cpDir =
  CertPaths
    { String
cpDir :: String
cpDir :: String
cpDir
    , cpKey :: String
cpKey = String
"key.pem"
    , cpCert :: String
cpCert = String
"certificate.pem"
    }


-- | Configure some details of the generated certificates
data Config = Config
  { Config -> Text
cCommonName :: !Text
  -- ^ the certificate common name
  , Config -> Natural
cDurationDays :: !Natural
  -- ^ the certificate's duration in days
  , Config -> Maybe Text
cProvince :: !(Maybe Text)
  , Config -> Maybe Text
cCity :: !(Maybe Text)
  , Config -> Maybe Text
cOrganization :: !(Maybe Text)
  , Config -> Maybe Text
cCountry :: !(Maybe Text)
  }
  deriving (Config -> Config -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c== :: Config -> Config -> Bool
Eq, Int -> Config -> ShowS
[Config] -> ShowS
Config -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Config] -> ShowS
$cshowList :: [Config] -> ShowS
show :: Config -> String
$cshow :: Config -> String
showsPrec :: Int -> Config -> ShowS
$cshowsPrec :: Int -> Config -> ShowS
Show)


-- | A default value for @'Config'@: CN=localhost, duration is 365 days.
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
  Config
    { cCountry :: Maybe Text
cCountry = forall a. Maybe a
Nothing
    , cProvince :: Maybe Text
cProvince = forall a. Maybe a
Nothing
    , cCity :: Maybe Text
cCity = forall a. Maybe a
Nothing
    , cOrganization :: Maybe Text
cOrganization = forall a. Maybe a
Nothing
    , cCommonName :: Text
cCommonName = Text
"localhost"
    , cDurationDays :: Natural
cDurationDays = Natural
365
    }


asDistinguished :: Config -> [(String, String)]
asDistinguished :: Config -> [(String, String)]
asDistinguished Config
c =
  let dnMaybe :: t -> (a -> f Text) -> a -> f (t, String)
dnMaybe t
k a -> f Text
f = (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((t
k,) forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack) forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f Text
f)
   in forall a. [Maybe a] -> [a]
catMaybes
        [ forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"C" Config -> Maybe Text
cCountry Config
c
        , forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"ST" Config -> Maybe Text
cProvince Config
c
        , forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"L" Config -> Maybe Text
cCity Config
c
        , forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"O" Config -> Maybe Text
cOrganization Config
c
        , forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"CN" (forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. Config -> Text
cCommonName) Config
c
        ]


validityNow :: Natural -> IO (UTCTime, UTCTime)
validityNow :: Natural -> IO (UTCTime, UTCTime)
validityNow Natural
ndays = do
  UTCTime
start <- IO UTCTime
getCurrentTime
  let end :: UTCTime
end = (NominalDiffTime
nominalDay forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
ndays) NominalDiffTime -> UTCTime -> UTCTime
`addUTCTime` UTCTime
start
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (UTCTime
start, UTCTime
end)


testKeySize :: Int
testKeySize :: Int
testKeySize = Int
4096


testExponent :: Integer
testExponent :: Integer
testExponent = Integer
257


genCerts :: Config -> IO (String, String)
genCerts :: Config -> IO (String, String)
genCerts Config
config = do
  -- set up values to use in the certificate fields
  let mkSerialNum :: ByteString -> Integer
mkSerialNum = forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl (\Integer
a Word8
w -> Integer
a forall a. Num a => a -> a -> a
* Integer
256 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w) Integer
0
      distinguished :: [(String, String)]
distinguished = Config -> [(String, String)]
asDistinguished Config
config
  Integer
serialNumber <- ByteString -> Integer
mkSerialNum forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
SSL.randBytes Int
8
  (UTCTime
start, UTCTime
end) <- Natural -> IO (UTCTime, UTCTime)
validityNow forall a b. (a -> b) -> a -> b
$ Config -> Natural
cDurationDays Config
config

  -- generate an RSA key pair
  RSAKeyPair
kp <- Int -> Int -> IO RSAKeyPair
SSL.generateRSAKey' Int
testKeySize forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
testExponent

  -- create and sign a certificate using the private key of the key pair
  X509
cert <- IO X509
SSL.newX509
  X509 -> Int -> IO ()
SSL.setVersion X509
cert Int
2
  X509 -> Integer -> IO ()
SSL.setSerialNumber X509
cert Integer
serialNumber
  X509 -> [(String, String)] -> IO ()
SSL.setIssuerName X509
cert [(String, String)]
distinguished
  X509 -> [(String, String)] -> IO ()
SSL.setSubjectName X509
cert [(String, String)]
distinguished
  X509 -> UTCTime -> IO ()
SSL.setNotBefore X509
cert UTCTime
start
  X509 -> UTCTime -> IO ()
SSL.setNotAfter X509
cert UTCTime
end
  forall key. PublicKey key => X509 -> key -> IO ()
SSL.setPublicKey X509
cert RSAKeyPair
kp
  forall key. KeyPair key => X509 -> key -> Maybe Digest -> IO ()
SSL.signX509 X509
cert RSAKeyPair
kp forall a. Maybe a
Nothing

  -- the PEM representation of the private key
  String
privString <- forall key.
KeyPair key =>
key -> Maybe (Cipher, PemPasswordSupply) -> IO String
SSL.writePKCS8PrivateKey RSAKeyPair
kp forall a. Maybe a
Nothing

  -- the PEM representation of the certificate
  String
certString <- X509 -> IO String
SSL.writeX509 X509
cert

  forall (f :: * -> *) a. Applicative f => a -> f a
pure (String
certString, String
privString)


storeCerts :: CertPaths -> String -> String -> IO ()
storeCerts :: CertPaths -> String -> String -> IO ()
storeCerts CertPaths
cp String
rsaKey String
signedCert = do
  String -> String -> IO ()
writeFile (CertPaths -> String
keyPath CertPaths
cp) String
rsaKey
  String -> String -> IO ()
writeFile (CertPaths -> String
certificatePath CertPaths
cp) String
signedCert


-- | Generate and store certificate files as specified as @'CertPaths'@
generateAndStore :: CertPaths -> Config -> IO ()
generateAndStore :: CertPaths -> Config -> IO ()
generateAndStore CertPaths
cp Config
config = do
  (String
certificate, String
privKey) <- Config -> IO (String, String)
genCerts Config
config
  CertPaths -> String -> String -> IO ()
storeCerts CertPaths
cp String
privKey String
certificate


-- | Like 'withCertPaths', but allows the @CertPath@ filenames to be specified
withCertFilenames
  :: (FilePath -> CertPaths)
  -> FilePath
  -> Config
  -> (CertPaths -> IO a)
  -> IO a
withCertFilenames :: forall a.
(String -> CertPaths)
-> String -> Config -> (CertPaths -> IO a) -> IO a
withCertFilenames String -> CertPaths
mkCertPath String
parentDir Config
config CertPaths -> IO a
useCerts =
  forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
String -> String -> (String -> m a) -> m a
withTempDirectory String
parentDir String
"temp-certs" forall a b. (a -> b) -> a -> b
$ \String
tmpDir -> do
    let certPaths :: CertPaths
certPaths = String -> CertPaths
mkCertPath String
tmpDir
    CertPaths -> Config -> IO ()
generateAndStore CertPaths
certPaths Config
config
    CertPaths -> IO a
useCerts CertPaths
certPaths


{- | Create certificates in a temporary directory below @parentDir@, specify the
locations using @CertPaths@, use them, then delete them
-}
withCertPaths :: FilePath -> Config -> (CertPaths -> IO a) -> IO a
withCertPaths :: forall a. String -> Config -> (CertPaths -> IO a) -> IO a
withCertPaths = forall a.
(String -> CertPaths)
-> String -> Config -> (CertPaths -> IO a) -> IO a
withCertFilenames String -> CertPaths
defaultBasenames


-- | Like 'withCertPaths' with the system @TEMP@ dir as the @parentDir@
withCertPathsInTmp :: Config -> (CertPaths -> IO a) -> IO a
withCertPathsInTmp :: forall a. Config -> (CertPaths -> IO a) -> IO a
withCertPathsInTmp Config
config CertPaths -> IO a
action = do
  String
parentDir <- IO String
getCanonicalTemporaryDirectory
  forall a. String -> Config -> (CertPaths -> IO a) -> IO a
withCertPaths String
parentDir Config
config CertPaths -> IO a
action


-- | Like 'withCertPathsInTmp' using a default @'Config'@
withCertPathsInTmp' :: (CertPaths -> IO a) -> IO a
withCertPathsInTmp' :: forall a. (CertPaths -> IO a) -> IO a
withCertPathsInTmp' = forall a. Config -> (CertPaths -> IO a) -> IO a
withCertPathsInTmp Config
defaultConfig