{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}

{- |
Module      : Test.Certs.Temp
Copyright   : (c) 2023 Tim Emiola
Maintainer  : Tim Emiola <adetokunbo@emio.la>
SPDX-License-Identifier: BSD3

Enables configuration and generation of temporary certificates
-}
module Test.Certs.Temp
  ( -- * generate certificates
    withCertPaths
  , withCertFilenames
  , withCertPathsInTmp
  , withCertPathsInTmp'
  , generateAndStore

    -- * configuration
  , Config (..)
  , defaultConfig

    -- * certificate filenames
  , CertPaths (..)
  , keyPath
  , certificatePath
  )
where

import qualified Data.ByteString as BS
import Data.Maybe (catMaybes)
import Data.Text (Text)
import qualified Data.Text as Text
import Data.Time (UTCTime, addUTCTime, getCurrentTime, nominalDay)
import Numeric.Natural (Natural)
import qualified OpenSSL.PEM as SSL
import qualified OpenSSL.RSA as SSL
import qualified OpenSSL.Random as SSL
import qualified OpenSSL.X509 as SSL
import System.FilePath ((</>))
import System.IO.Temp (getCanonicalTemporaryDirectory, withTempDirectory)


-- | Specifies the location to write the temporary certificates
data CertPaths = CertPaths
  { CertPaths -> String
cpKey :: !FilePath
  -- ^ the basename of the private key file
  , CertPaths -> String
cpCert :: !FilePath
  -- ^ the basename of the certificate file
  , CertPaths -> String
cpDir :: !FilePath
  -- ^ the directory containing the certificate files
  }
  deriving (CertPaths -> CertPaths -> Bool
(CertPaths -> CertPaths -> Bool)
-> (CertPaths -> CertPaths -> Bool) -> Eq CertPaths
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: CertPaths -> CertPaths -> Bool
== :: CertPaths -> CertPaths -> Bool
$c/= :: CertPaths -> CertPaths -> Bool
/= :: CertPaths -> CertPaths -> Bool
Eq, Int -> CertPaths -> ShowS
[CertPaths] -> ShowS
CertPaths -> String
(Int -> CertPaths -> ShowS)
-> (CertPaths -> String)
-> ([CertPaths] -> ShowS)
-> Show CertPaths
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CertPaths -> ShowS
showsPrec :: Int -> CertPaths -> ShowS
$cshow :: CertPaths -> String
show :: CertPaths -> String
$cshowList :: [CertPaths] -> ShowS
showList :: [CertPaths] -> ShowS
Show)


-- | The path of the generated key file
keyPath :: CertPaths -> FilePath
keyPath :: CertPaths -> String
keyPath CertPaths
cp = CertPaths -> String
cpDir CertPaths
cp String -> ShowS
</> CertPaths -> String
cpKey CertPaths
cp


-- | The path of the generated certificate file
certificatePath :: CertPaths -> FilePath
certificatePath :: CertPaths -> String
certificatePath CertPaths
cp = CertPaths -> String
cpDir CertPaths
cp String -> ShowS
</> CertPaths -> String
cpCert CertPaths
cp


{- | A @CertPaths using the default basenames for the certificate files
@cpKey@ is @key.pem@
@cpCert@ is @certificate.pem@
-}
defaultBasenames :: FilePath -> CertPaths
defaultBasenames :: String -> CertPaths
defaultBasenames String
cpDir =
  CertPaths
    { String
cpDir :: String
cpDir :: String
cpDir
    , cpKey :: String
cpKey = String
"key.pem"
    , cpCert :: String
cpCert = String
"certificate.pem"
    }


-- | Configure some details of the generated certificates
data Config = Config
  { Config -> Text
cCommonName :: !Text
  -- ^ the certificate common name
  , Config -> Natural
cDurationDays :: !Natural
  -- ^ the certificate's duration in days
  , Config -> Maybe Text
cProvince :: !(Maybe Text)
  , Config -> Maybe Text
cCity :: !(Maybe Text)
  , Config -> Maybe Text
cOrganization :: !(Maybe Text)
  , Config -> Maybe Text
cCountry :: !(Maybe Text)
  }
  deriving (Config -> Config -> Bool
(Config -> Config -> Bool)
-> (Config -> Config -> Bool) -> Eq Config
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Config -> Config -> Bool
== :: Config -> Config -> Bool
$c/= :: Config -> Config -> Bool
/= :: Config -> Config -> Bool
Eq, Int -> Config -> ShowS
[Config] -> ShowS
Config -> String
(Int -> Config -> ShowS)
-> (Config -> String) -> ([Config] -> ShowS) -> Show Config
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Config -> ShowS
showsPrec :: Int -> Config -> ShowS
$cshow :: Config -> String
show :: Config -> String
$cshowList :: [Config] -> ShowS
showList :: [Config] -> ShowS
Show)


-- | A default value for @'Config'@: CN=localhost, duration is 365 days.
defaultConfig :: Config
defaultConfig :: Config
defaultConfig =
  Config
    { cCountry :: Maybe Text
cCountry = Maybe Text
forall a. Maybe a
Nothing
    , cProvince :: Maybe Text
cProvince = Maybe Text
forall a. Maybe a
Nothing
    , cCity :: Maybe Text
cCity = Maybe Text
forall a. Maybe a
Nothing
    , cOrganization :: Maybe Text
cOrganization = Maybe Text
forall a. Maybe a
Nothing
    , cCommonName :: Text
cCommonName = Text
"localhost"
    , cDurationDays :: Natural
cDurationDays = Natural
365
    }


asDistinguished :: Config -> [(String, String)]
asDistinguished :: Config -> [(String, String)]
asDistinguished Config
c =
  let dnMaybe :: t -> (a -> f Text) -> a -> f (t, String)
dnMaybe t
k a -> f Text
f = ((Text -> (t, String)) -> f Text -> f (t, String)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((t
k,) (String -> (t, String)) -> (Text -> String) -> Text -> (t, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
Text.unpack) (f Text -> f (t, String)) -> (a -> f Text) -> a -> f (t, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f Text
f)
   in [Maybe (String, String)] -> [(String, String)]
forall a. [Maybe a] -> [a]
catMaybes
        [ String
-> (Config -> Maybe Text) -> Config -> Maybe (String, String)
forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"C" Config -> Maybe Text
cCountry Config
c
        , String
-> (Config -> Maybe Text) -> Config -> Maybe (String, String)
forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"ST" Config -> Maybe Text
cProvince Config
c
        , String
-> (Config -> Maybe Text) -> Config -> Maybe (String, String)
forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"L" Config -> Maybe Text
cCity Config
c
        , String
-> (Config -> Maybe Text) -> Config -> Maybe (String, String)
forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"O" Config -> Maybe Text
cOrganization Config
c
        , String
-> (Config -> Maybe Text) -> Config -> Maybe (String, String)
forall {f :: * -> *} {t} {a}.
Functor f =>
t -> (a -> f Text) -> a -> f (t, String)
dnMaybe String
"CN" (Text -> Maybe Text
forall a. a -> Maybe a
Just (Text -> Maybe Text) -> (Config -> Text) -> Config -> Maybe Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Config -> Text
cCommonName) Config
c
        ]


validityNow :: Natural -> IO (UTCTime, UTCTime)
validityNow :: Natural -> IO (UTCTime, UTCTime)
validityNow Natural
ndays = do
  UTCTime
start <- IO UTCTime
getCurrentTime
  let end :: UTCTime
end = (NominalDiffTime
nominalDay NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
* Natural -> NominalDiffTime
forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
ndays) NominalDiffTime -> UTCTime -> UTCTime
`addUTCTime` UTCTime
start
  (UTCTime, UTCTime) -> IO (UTCTime, UTCTime)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UTCTime
start, UTCTime
end)


testKeySize :: Int
testKeySize :: Int
testKeySize = Int
4096


testExponent :: Integer
testExponent :: Integer
testExponent = Integer
257


genCerts :: Config -> IO (String, String)
genCerts :: Config -> IO (String, String)
genCerts Config
config = do
  -- set up values to use in the certificate fields
  let mkSerialNum :: ByteString -> Integer
mkSerialNum = (Integer -> Word8 -> Integer) -> Integer -> ByteString -> Integer
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl (\Integer
a Word8
w -> Integer
a Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
256 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Word8 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
w) Integer
0
      distinguished :: [(String, String)]
distinguished = Config -> [(String, String)]
asDistinguished Config
config
  Integer
serialNumber <- ByteString -> Integer
mkSerialNum (ByteString -> Integer) -> IO ByteString -> IO Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO ByteString
SSL.randBytes Int
8
  (UTCTime
start, UTCTime
end) <- Natural -> IO (UTCTime, UTCTime)
validityNow (Natural -> IO (UTCTime, UTCTime))
-> Natural -> IO (UTCTime, UTCTime)
forall a b. (a -> b) -> a -> b
$ Config -> Natural
cDurationDays Config
config

  -- generate an RSA key pair
  RSAKeyPair
kp <- Int -> Int -> IO RSAKeyPair
SSL.generateRSAKey' Int
testKeySize (Int -> IO RSAKeyPair) -> Int -> IO RSAKeyPair
forall a b. (a -> b) -> a -> b
$ Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
testExponent

  -- create and sign a certificate using the private key of the key pair
  X509
cert <- IO X509
SSL.newX509
  X509 -> Int -> IO ()
SSL.setVersion X509
cert Int
2
  X509 -> Integer -> IO ()
SSL.setSerialNumber X509
cert Integer
serialNumber
  X509 -> [(String, String)] -> IO ()
SSL.setIssuerName X509
cert [(String, String)]
distinguished
  X509 -> [(String, String)] -> IO ()
SSL.setSubjectName X509
cert [(String, String)]
distinguished
  X509 -> UTCTime -> IO ()
SSL.setNotBefore X509
cert UTCTime
start
  X509 -> UTCTime -> IO ()
SSL.setNotAfter X509
cert UTCTime
end
  X509 -> RSAKeyPair -> IO ()
forall key. PublicKey key => X509 -> key -> IO ()
SSL.setPublicKey X509
cert RSAKeyPair
kp
  X509 -> RSAKeyPair -> Maybe Digest -> IO ()
forall key. KeyPair key => X509 -> key -> Maybe Digest -> IO ()
SSL.signX509 X509
cert RSAKeyPair
kp Maybe Digest
forall a. Maybe a
Nothing

  -- the PEM representation of the private key
  String
privString <- RSAKeyPair -> Maybe (Cipher, PemPasswordSupply) -> IO String
forall key.
KeyPair key =>
key -> Maybe (Cipher, PemPasswordSupply) -> IO String
SSL.writePKCS8PrivateKey RSAKeyPair
kp Maybe (Cipher, PemPasswordSupply)
forall a. Maybe a
Nothing

  -- the PEM representation of the certificate
  String
certString <- X509 -> IO String
SSL.writeX509 X509
cert

  (String, String) -> IO (String, String)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String
certString, String
privString)


storeCerts :: CertPaths -> String -> String -> IO ()
storeCerts :: CertPaths -> String -> String -> IO ()
storeCerts CertPaths
cp String
rsaKey String
signedCert = do
  String -> String -> IO ()
writeFile (CertPaths -> String
keyPath CertPaths
cp) String
rsaKey
  String -> String -> IO ()
writeFile (CertPaths -> String
certificatePath CertPaths
cp) String
signedCert


-- | Generate and store certificate files as specified as @'CertPaths'@
generateAndStore :: CertPaths -> Config -> IO ()
generateAndStore :: CertPaths -> Config -> IO ()
generateAndStore CertPaths
cp Config
config = do
  (String
certificate, String
privKey) <- Config -> IO (String, String)
genCerts Config
config
  CertPaths -> String -> String -> IO ()
storeCerts CertPaths
cp String
privKey String
certificate


-- | Like 'withCertPaths', but allows the @CertPath@ filenames to be specified
withCertFilenames
  :: (FilePath -> CertPaths)
  -> FilePath
  -> Config
  -> (CertPaths -> IO a)
  -> IO a
withCertFilenames :: forall a.
(String -> CertPaths)
-> String -> Config -> (CertPaths -> IO a) -> IO a
withCertFilenames String -> CertPaths
mkCertPath String
parentDir Config
config CertPaths -> IO a
useCerts =
  String -> String -> (String -> IO a) -> IO a
forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
String -> String -> (String -> m a) -> m a
withTempDirectory String
parentDir String
"temp-certs" ((String -> IO a) -> IO a) -> (String -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \String
tmpDir -> do
    let certPaths :: CertPaths
certPaths = String -> CertPaths
mkCertPath String
tmpDir
    CertPaths -> Config -> IO ()
generateAndStore CertPaths
certPaths Config
config
    CertPaths -> IO a
useCerts CertPaths
certPaths


{- | Create certificates in a temporary directory below @parentDir@, specify the
locations using @CertPaths@, use them, then delete them
-}
withCertPaths :: FilePath -> Config -> (CertPaths -> IO a) -> IO a
withCertPaths :: forall a. String -> Config -> (CertPaths -> IO a) -> IO a
withCertPaths = (String -> CertPaths)
-> String -> Config -> (CertPaths -> IO a) -> IO a
forall a.
(String -> CertPaths)
-> String -> Config -> (CertPaths -> IO a) -> IO a
withCertFilenames String -> CertPaths
defaultBasenames


-- | Like 'withCertPaths' with the system @TEMP@ dir as the @parentDir@
withCertPathsInTmp :: Config -> (CertPaths -> IO a) -> IO a
withCertPathsInTmp :: forall a. Config -> (CertPaths -> IO a) -> IO a
withCertPathsInTmp Config
config CertPaths -> IO a
action = do
  String
parentDir <- IO String
getCanonicalTemporaryDirectory
  String -> Config -> (CertPaths -> IO a) -> IO a
forall a. String -> Config -> (CertPaths -> IO a) -> IO a
withCertPaths String
parentDir Config
config CertPaths -> IO a
action


-- | Like 'withCertPathsInTmp' using a default @'Config'@
withCertPathsInTmp' :: (CertPaths -> IO a) -> IO a
withCertPathsInTmp' :: forall a. (CertPaths -> IO a) -> IO a
withCertPathsInTmp' = Config -> (CertPaths -> IO a) -> IO a
forall a. Config -> (CertPaths -> IO a) -> IO a
withCertPathsInTmp Config
defaultConfig