{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}

module Web.JWT.ASAP (
  module Web.JWT.ASAP.Error
, module Web.JWT.ASAP.Env
, Expiry(..)
, MaxAge(..)
, defaultTokenExpiry
, defaultTokenMaxAge
, timedClaim
, expiringClaim
, maxAgeClaimGenerator'
, maxAgeClaimGenerator
, asapReadRsaSecret
, asapAuthHeader
, asapAuthHeaderFromEnv
, asapSignerFromEnv
, laterThanMaxAge
) where

import           Control.Applicative   (liftA2)
import           Control.Lens          (view, ( # ), _tail)
import           Control.Monad.Except  (MonadError (..))
import           Data.ByteString       (ByteString)
import           Data.ByteString.Char8 as BS (unlines)
import           Data.ByteString.Lens  (packedChars)
import           Data.IORef            (newIORef, readIORef, writeIORef)
import           Data.Semigroup
import           Data.String           (fromString)
import qualified Data.Text             as T
import           Data.Time             (NominalDiffTime)
import           Data.Time.Clock.POSIX (getPOSIXTime)
import           Data.UUID             (UUID)
import qualified Data.UUID             as UUID
import qualified Data.UUID.V4          as UUID
import qualified Web.JWT               as JWT
import           Web.JWT.ASAP.Env
import           Web.JWT.ASAP.Error

newtype Expiry
  = Expiry NominalDiffTime
  deriving (Show, Eq, Ord)

newtype MaxAge
  = MaxAge NominalDiffTime
  deriving (Show, Eq, Ord)

defaultTokenExpiry ::
  Expiry
defaultTokenExpiry =
  Expiry $ 10 * 60

defaultTokenMaxAge ::
  MaxAge
defaultTokenMaxAge =
  MaxAge $ 9 * 60

timedClaim ::
  Expiry
  -> NominalDiffTime
  -> UUID
  -> JWT.JWTClaimsSet
timedClaim (Expiry expiryTime) time uuid =
  mempty
    { JWT.iat = JWT.numericDate time
    , JWT.exp = JWT.numericDate $ time + expiryTime
    , JWT.jti = JWT.stringOrURI $ UUID.toText uuid
    }

expiringClaim ::
  Expiry
  -> IO JWT.JWTClaimsSet
expiringClaim expiry =
  liftA2 (timedClaim expiry) getPOSIXTime UUID.nextRandom

maxAgeClaimGenerator' ::
  (Monad m) =>
  MaxAge
  -> m NominalDiffTime
  -> m JWT.JWTClaimsSet
  -> (JWT.JWTClaimsSet -> m ())
  -> m JWT.JWTClaimsSet
  -> m JWT.JWTClaimsSet
maxAgeClaimGenerator' maxAge time =
  regenerateWhen predicate
  where
    predicate claim =
      maybe (pure False) (\iat -> laterThanMaxAge maxAge iat <$> time) $ JWT.iat claim

maxAgeClaimGenerator ::
  MaxAge
  -> Expiry
  -> IO (IO JWT.JWTClaimsSet)
maxAgeClaimGenerator maxAge expiry = do
  initialClaim <- newClaim
  ref <- newIORef initialClaim
  pure (maxAgeClaimGenerator' maxAge getPOSIXTime newClaim (writeIORef ref) (readIORef ref))
  where
    newClaim =
      expiringClaim expiry

asapReadRsaSecret ::
  (HasAsapError e, MonadError e m) =>
  ByteString
  -> m JWT.Signer
asapReadRsaSecret =
  maybe (throwError (asapInvalidSecret # ())) (pure . JWT.RSAPrivateKey) . JWT.readRsaSecret

asapAuthHeader ::
  JWT.Signer
  -> JWT.JOSEHeader
  -> JWT.JWTClaimsSet
  -> T.Text
asapAuthHeader signer header claim =
  "Bearer " <> JWT.encodeSigned signer header claim

asapAuthHeaderFromEnv ::
  (HasAsapError e, MonadError e m, MonadEnv m) =>
  JWT.JOSEHeader
  -> JWT.JWTClaimsSet
  -> m T.Text
asapAuthHeaderFromEnv header claim = do
  issuer <- asapLookupEnv "ASAP_ISSUER"
  keyId <- asapLookupEnv "ASAP_KEY_ID"
  let header' = header { JWT.kid = Just $ fromString keyId }
      claim' = claim { JWT.iss = JWT.stringOrURI $ fromString issuer }
  signer <- asapSignerFromEnv
  pure (asapAuthHeader signer header' claim')

asapSignerFromEnv ::
  (HasAsapError e, MonadError e m, MonadEnv m) =>
  m JWT.Signer
asapSignerFromEnv = do
  pem <- toPem . view packedChars . dataUriData <$> asapLookupEnv "ASAP_PRIVATE_KEY"
  asapReadRsaSecret pem
  where
    toPem c =
      BS.unlines [ "-----BEGIN RSA PRIVATE KEY-----", c, "-----END RSA PRIVATE KEY-----" ]

dataUriData ::
  String
  -> String
dataUriData =
  view _tail . dropWhile (/= ',')

laterThanMaxAge ::
  MaxAge
  -> JWT.NumericDate
  -> NominalDiffTime
  -> Bool
laterThanMaxAge (MaxAge maxAgeTime) iat time =
  time - JWT.secondsSinceEpoch iat >= maxAgeTime

regenerateWhen ::
  Monad m =>
  (a -> m Bool)
  -> m a
  -> (a -> m ())
  -> m a
  -> m a
regenerateWhen predicate ma put get = do
  c <- get
  b <- predicate c
  if b
    then do
      c' <- ma
      put c'
      pure c'
    else pure c