{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings     #-}

module Web.JWT.ASAP (
  module Web.JWT.ASAP.Error
, module Web.JWT.ASAP.Env
, Expiry(..)
, MaxAge(..)
, defaultTokenExpiry
, defaultTokenMaxAge
, timedClaim
, expiringClaim
, maxAgeClaimGenerator'
, maxAgeClaimGenerator
, asapReadRsaSecret
, asapAuthHeader
, asapAuthHeaderFromEnv
, laterThanMaxAge
) where

import           Control.Applicative    (liftA2)
import           Control.Lens           (view, ( # ))
import           Control.Monad.Except   (MonadError (..))
import           Data.ByteString        (ByteString)
import           Data.ByteString.Base64 (decodeLenient)
import           Data.ByteString.Lens   (packedChars)
import           Data.IORef             (newIORef, readIORef, writeIORef)
import           Data.String            (fromString)
import qualified Data.Text              as T
import           Data.Time              (NominalDiffTime)
import           Data.Time.Clock.POSIX  (getPOSIXTime)
import           Data.UUID              (UUID)
import qualified Data.UUID              as UUID
import qualified Data.UUID.V4           as UUID
import qualified Web.JWT                as JWT
import           Web.JWT.ASAP.Env       (MonadEnv (..), asapLookupEnv)
import           Web.JWT.ASAP.Error     (HasAsapError (..))

newtype Expiry
  = Expiry NominalDiffTime
  deriving (Show, Eq, Ord)

newtype MaxAge
  = MaxAge NominalDiffTime
  deriving (Show, Eq, Ord)

defaultTokenExpiry ::
  Expiry
defaultTokenExpiry =
  Expiry $ 10 * 60

defaultTokenMaxAge ::
  MaxAge
defaultTokenMaxAge =
  MaxAge $ 9 * 60

timedClaim ::
  Expiry
  -> NominalDiffTime
  -> UUID
  -> JWT.JWTClaimsSet
timedClaim (Expiry expiryTime) time uuid =
  mempty
    { JWT.iat = JWT.numericDate time
    , JWT.exp = JWT.numericDate $ time + expiryTime
    , JWT.jti = JWT.stringOrURI $ UUID.toText uuid
    }

expiringClaim ::
  Expiry
  -> IO JWT.JWTClaimsSet
expiringClaim expiry = do
  liftA2 (timedClaim expiry) getPOSIXTime UUID.nextRandom

maxAgeClaimGenerator' ::
  (Monad m) =>
  MaxAge
  -> m NominalDiffTime
  -> m JWT.JWTClaimsSet
  -> (JWT.JWTClaimsSet -> m ())
  -> m JWT.JWTClaimsSet
  -> m JWT.JWTClaimsSet
maxAgeClaimGenerator' maxAge time newClaim =
  regenerateWhen predicate newClaim
  where
    predicate claim =
      maybe (pure False) (\iat -> laterThanMaxAge maxAge iat <$> time) $ JWT.iat claim

maxAgeClaimGenerator ::
  MaxAge
  -> Expiry
  -> IO (IO JWT.JWTClaimsSet)
maxAgeClaimGenerator maxAge expiry = do
  initialClaim <- newClaim
  ref <- newIORef initialClaim
  pure (maxAgeClaimGenerator' maxAge getPOSIXTime newClaim (writeIORef ref) (readIORef ref))
  where
    newClaim =
      expiringClaim expiry

asapReadRsaSecret ::
  (HasAsapError e, MonadError e m) =>
  ByteString
  -> m JWT.Signer
asapReadRsaSecret =
  maybe (throwError (asapInvalidSecret # ())) (pure . JWT.RSAPrivateKey) . JWT.readRsaSecret

asapAuthHeader ::
  JWT.Signer
  -> JWT.JOSEHeader
  -> JWT.JWTClaimsSet
  -> T.Text
asapAuthHeader signer header claim =
  "Bearer " <> JWT.encodeSigned signer header claim

asapAuthHeaderFromEnv ::
  (HasAsapError e, MonadError e m, MonadEnv m) =>
  JWT.JOSEHeader
  -> JWT.JWTClaimsSet
  -> m T.Text
asapAuthHeaderFromEnv header claim = do
  issuer <- asapLookupEnv "ASAP_ISSUER"
  keyId <- asapLookupEnv "ASAP_KEY_ID"
  dataUri <- asapLookupEnv "ASAP_PRIVATE_KEY"
  let pem = decodeLenient . view packedChars $ dataUriData dataUri
      header' = header { JWT.kid = Just $ fromString keyId }
      claim' = claim { JWT.iss = JWT.stringOrURI $ fromString issuer }
  signer <- asapReadRsaSecret pem
  pure (asapAuthHeader signer header' claim')

dataUriData ::
  String
  -> String
dataUriData =
  snd . break (== ',')

laterThanMaxAge ::
  MaxAge
  -> JWT.NumericDate
  -> NominalDiffTime
  -> Bool
laterThanMaxAge (MaxAge maxAgeTime) iat time = do
  time - JWT.secondsSinceEpoch iat >= maxAgeTime

regenerateWhen ::
  Monad m =>
  (a -> m Bool)
  -> m a
  -> (a -> m ())
  -> m a
  -> m a
regenerateWhen predicate ma put get = do
  c <- get
  b <- predicate c
  if b
    then do
      c' <- ma
      put c'
      pure c'
    else pure c