--
-- MinIO Haskell SDK, (C) 2017-2019 MinIO, Inc.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
--     http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
{-# LANGUAGE BangPatterns #-}

module Network.Minio.Sign.V4 where

import qualified Conduit as C
import qualified Data.ByteString as B
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Char8 as B8
import qualified Data.ByteString.Lazy as LB
import Data.CaseInsensitive (mk)
import qualified Data.CaseInsensitive as CI
import qualified Data.HashMap.Strict as Map
import qualified Data.HashSet as Set
import qualified Data.Time as Time
import Lib.Prelude
import qualified Network.HTTP.Conduit as NC
import Network.HTTP.Types (Header, parseQuery)
import qualified Network.HTTP.Types as H
import Network.Minio.Data.ByteString
import Network.Minio.Data.Crypto
import Network.Minio.Data.Time
import Network.Minio.Errors
import Text.Printf (printf)

-- these headers are not included in the string to sign when signing a
-- request
ignoredHeaders :: Set.HashSet ByteString
ignoredHeaders =
  Set.fromList $
    map
      CI.foldedCase
      [ H.hAuthorization,
        H.hContentType,
        H.hUserAgent
      ]

data SignV4Data = SignV4Data
  { sv4SignTime :: UTCTime,
    sv4Scope :: ByteString,
    sv4CanonicalRequest :: ByteString,
    sv4HeadersToSign :: [(ByteString, ByteString)],
    sv4Output :: [(ByteString, ByteString)],
    sv4StringToSign :: ByteString,
    sv4SigningKey :: ByteString
  }
  deriving (Show)

data SignParams = SignParams
  { spAccessKey :: Text,
    spSecretKey :: Text,
    spTimeStamp :: UTCTime,
    spRegion :: Maybe Text,
    spExpirySecs :: Maybe Int,
    spPayloadHash :: Maybe ByteString
  }
  deriving (Show)

debugPrintSignV4Data :: SignV4Data -> IO ()
debugPrintSignV4Data (SignV4Data t s cr h2s o sts sk) = do
  B8.putStrLn "SignV4Data:"
  B8.putStr "Timestamp: " >> print t
  B8.putStr "Scope: " >> B8.putStrLn s
  B8.putStrLn "Canonical Request:"
  B8.putStrLn cr
  B8.putStr "Headers to Sign: " >> print h2s
  B8.putStr "Output: " >> print o
  B8.putStr "StringToSign: " >> B8.putStrLn sts
  B8.putStr "SigningKey: " >> printBytes sk
  B8.putStrLn "END of SignV4Data ========="
  where
    printBytes b = do
      mapM_ (\x -> B.putStr $ B.singleton x <> " ") $ B.unpack b
      B8.putStrLn ""

mkAuthHeader :: Text -> ByteString -> ByteString -> ByteString -> H.Header
mkAuthHeader accessKey scope signedHeaderKeys sign =
  let authValue =
        B.concat
          [ "AWS4-HMAC-SHA256 Credential=",
            toUtf8 accessKey,
            "/",
            scope,
            ", SignedHeaders=",
            signedHeaderKeys,
            ", Signature=",
            sign
          ]
   in (H.hAuthorization, authValue)

-- | Given SignParams and request details, including request method,
-- request path, headers, query params and payload hash, generates an
-- updated set of headers, including the x-amz-date header and the
-- Authorization header, which includes the signature.
--
-- For normal requests (i.e. without an expiry time), the output is
-- the list of headers to add to authenticate the request.
--
-- If `expiry` is not Nothing, it is assumed that a presigned request
-- is being created. The expiry is interpreted as an integer number of
-- seconds. The output will be the list of query-parameters to add to
-- the request.
signV4 :: SignParams -> NC.Request -> [(ByteString, ByteString)]
signV4 !sp !req =
  let region = fromMaybe "" $ spRegion sp
      ts = spTimeStamp sp
      scope = mkScope ts region
      accessKey = toUtf8 $ spAccessKey sp
      secretKey = toUtf8 $ spSecretKey sp
      expiry = spExpirySecs sp
      sha256Hdr =
        ( "x-amz-content-sha256",
          fromMaybe "UNSIGNED-PAYLOAD" $ spPayloadHash sp
        )
      -- headers to be added to the request
      datePair = ("X-Amz-Date", awsTimeFormatBS ts)
      computedHeaders =
        NC.requestHeaders req
          ++ if isJust $ expiry
            then []
            else map (\(x, y) -> (mk x, y)) [datePair, sha256Hdr]
      headersToSign = getHeadersToSign computedHeaders
      signedHeaderKeys = B.intercalate ";" $ sort $ map fst headersToSign
      -- query-parameters to be added before signing for presigned URLs
      -- (i.e. when `isJust expiry`)
      authQP =
        [ ("X-Amz-Algorithm", "AWS4-HMAC-SHA256"),
          ("X-Amz-Credential", B.concat [accessKey, "/", scope]),
          datePair,
          ("X-Amz-Expires", maybe "" showBS expiry),
          ("X-Amz-SignedHeaders", signedHeaderKeys)
        ]
      finalQP =
        parseQuery (NC.queryString req)
          ++ if isJust expiry
            then (fmap . fmap) Just authQP
            else []
      -- 1. compute canonical request
      canonicalRequest =
        mkCanonicalRequest
          False
          sp
          (NC.setQueryString finalQP req)
          headersToSign
      -- 2. compute string to sign
      stringToSign = mkStringToSign ts scope canonicalRequest
      -- 3.1 compute signing key
      signingKey = mkSigningKey ts region secretKey
      -- 3.2 compute signature
      signature = computeSignature stringToSign signingKey
      -- 4. compute auth header
      authHeader = mkAuthHeader (spAccessKey sp) scope signedHeaderKeys signature
      -- finally compute output pairs
      output =
        if isJust expiry
          then ("X-Amz-Signature", signature) : authQP
          else
            [ (\(x, y) -> (CI.foldedCase x, y)) authHeader,
              datePair,
              sha256Hdr
            ]
   in output

mkScope :: UTCTime -> Text -> ByteString
mkScope ts region =
  B.intercalate
    "/"
    [ toUtf8 $ Time.formatTime Time.defaultTimeLocale "%Y%m%d" ts,
      toUtf8 region,
      "s3",
      "aws4_request"
    ]

getHeadersToSign :: [Header] -> [(ByteString, ByteString)]
getHeadersToSign !h =
  filter ((\hdr -> not $ Set.member hdr ignoredHeaders) . fst) $
    map (\(x, y) -> (CI.foldedCase x, stripBS y)) h

mkCanonicalRequest ::
  Bool ->
  SignParams ->
  NC.Request ->
  [(ByteString, ByteString)] ->
  ByteString
mkCanonicalRequest !isStreaming !sp !req !headersForSign =
  let canonicalQueryString =
        B.intercalate "&"
          $ map (\(x, y) -> B.concat [x, "=", y])
          $ sort
          $ map
            ( \(x, y) ->
                (uriEncode True x, maybe "" (uriEncode True) y)
            )
          $ (parseQuery $ NC.queryString req)
      sortedHeaders = sort headersForSign
      canonicalHeaders =
        B.concat $
          map (\(x, y) -> B.concat [x, ":", y, "\n"]) sortedHeaders
      signedHeaders = B.intercalate ";" $ map fst sortedHeaders
      payloadHashStr =
        if isStreaming
          then "STREAMING-AWS4-HMAC-SHA256-PAYLOAD"
          else fromMaybe "UNSIGNED-PAYLOAD" $ spPayloadHash sp
   in B.intercalate
        "\n"
        [ NC.method req,
          uriEncode False $ NC.path req,
          canonicalQueryString,
          canonicalHeaders,
          signedHeaders,
          payloadHashStr
        ]

mkStringToSign :: UTCTime -> ByteString -> ByteString -> ByteString
mkStringToSign ts !scope !canonicalRequest =
  B.intercalate
    "\n"
    [ "AWS4-HMAC-SHA256",
      awsTimeFormatBS ts,
      scope,
      hashSHA256 canonicalRequest
    ]

mkSigningKey :: UTCTime -> Text -> ByteString -> ByteString
mkSigningKey ts region !secretKey =
  hmacSHA256RawBS "aws4_request"
    . hmacSHA256RawBS "s3"
    . hmacSHA256RawBS (toUtf8 region)
    . hmacSHA256RawBS (awsDateFormatBS ts)
    $ B.concat ["AWS4", secretKey]

computeSignature :: ByteString -> ByteString -> ByteString
computeSignature !toSign !key = digestToBase16 $ hmacSHA256 toSign key

-- | Takes a validated Post Policy JSON bytestring, the signing time,
-- and ConnInfo and returns form-data for the POST upload containing
-- just the signature and the encoded post-policy.
signV4PostPolicy ::
  ByteString ->
  SignParams ->
  Map.HashMap Text ByteString
signV4PostPolicy !postPolicyJSON !sp =
  let stringToSign = Base64.encode postPolicyJSON
      region = fromMaybe "" $ spRegion sp
      signingKey = mkSigningKey (spTimeStamp sp) region $ toUtf8 $ spSecretKey sp
      signature = computeSignature stringToSign signingKey
   in Map.fromList
        [ ("x-amz-signature", signature),
          ("policy", stringToSign)
        ]

chunkSizeConstant :: Int
chunkSizeConstant = 64 * 1024

-- base16Len computes the number of bytes required to represent @n (> 0)@ in
-- hexadecimal.
base16Len :: Integral a => a -> Int
base16Len n
  | n == 0 = 0
  | otherwise = 1 + base16Len (n `div` 16)

signedStreamLength :: Int64 -> Int64
signedStreamLength dataLen =
  let chunkSzInt = fromIntegral chunkSizeConstant
      (numChunks, lastChunkLen) = quotRem dataLen chunkSzInt
      -- Structure of a chunk:
      --   string(IntHexBase(chunk-size)) + ";chunk-signature=" + signature + \r\n + chunk-data + \r\n
      encodedChunkLen csz = fromIntegral (base16Len csz) + 17 + 64 + 2 + csz + 2
      fullChunkSize = encodedChunkLen chunkSzInt
      lastChunkSize = bool 0 (encodedChunkLen lastChunkLen) $ lastChunkLen > 0
      finalChunkSize = 1 + 17 + 64 + 2 + 2
   in numChunks * fullChunkSize + lastChunkSize + finalChunkSize

signV4Stream ::
  Int64 ->
  SignParams ->
  NC.Request ->
  (C.ConduitT () ByteString (C.ResourceT IO) () -> NC.Request)
-- -> ([Header], C.ConduitT () ByteString (C.ResourceT IO) () -> NC.RequestBody)
signV4Stream !payloadLength !sp !req =
  let ts = spTimeStamp sp
      addContentEncoding hs =
        let ceMay = headMay $ filter (\(x, _) -> x == "content-encoding") hs
         in case ceMay of
              Nothing -> ("content-encoding", "aws-chunked") : hs
              Just (_, ce) ->
                ("content-encoding", ce <> ",aws-chunked")
                  : filter (\(x, _) -> x /= "content-encoding") hs
      -- headers to be added to the request
      datePair = ("X-Amz-Date", awsTimeFormatBS ts)
      computedHeaders =
        addContentEncoding $
          datePair : NC.requestHeaders req
      -- headers specific to streaming signature
      signedContentLength = signedStreamLength payloadLength
      streamingHeaders :: [Header]
      streamingHeaders =
        [ ("x-amz-decoded-content-length", showBS payloadLength),
          ("content-length", showBS signedContentLength),
          ("x-amz-content-sha256", "STREAMING-AWS4-HMAC-SHA256-PAYLOAD")
        ]
      headersToSign = getHeadersToSign $ computedHeaders ++ streamingHeaders
      signedHeaderKeys = B.intercalate ";" $ sort $ map fst headersToSign
      finalQP = parseQuery (NC.queryString req)
      -- 1. Compute Seed Signature
      -- 1.1 Canonical Request
      canonicalReq =
        mkCanonicalRequest
          True
          sp
          (NC.setQueryString finalQP req)
          headersToSign
      region = fromMaybe "" $ spRegion sp
      scope = mkScope ts region
      accessKey = spAccessKey sp
      secretKey = spSecretKey sp
      -- 1.2 String toSign
      stringToSign = mkStringToSign ts scope canonicalReq
      -- 1.3 Compute signature
      -- 1.3.1 compute signing key
      signingKey = mkSigningKey ts region $ toUtf8 secretKey
      -- 1.3.2 Compute signature
      seedSignature = computeSignature stringToSign signingKey
      -- 1.3.3 Compute Auth Header
      authHeader = mkAuthHeader accessKey scope signedHeaderKeys seedSignature
      -- 1.4 Updated headers for the request
      finalReqHeaders = authHeader : (computedHeaders ++ streamingHeaders)
      -- headersToAdd = authHeader : datePair : streamingHeaders

      toHexStr n = B8.pack $ printf "%x" n
      (numParts, lastPSize) = payloadLength `quotRem` fromIntegral chunkSizeConstant
      -- Function to compute string to sign for each chunk.
      chunkStrToSign prevSign currChunkHash =
        B.intercalate
          "\n"
          [ "AWS4-HMAC-SHA256-PAYLOAD",
            awsTimeFormatBS ts,
            scope,
            prevSign,
            hashSHA256 "",
            currChunkHash
          ]
      -- Read n byte from upstream and return a strict bytestring.
      mustTakeN n = do
        bs <- LB.toStrict <$> (C.takeCE n C..| C.sinkLazy)
        when (B.length bs /= n) $
          throwIO MErrVStreamingBodyUnexpectedEOF
        return bs
      signerConduit n lps prevSign =
        -- First case encodes a full chunk of length
        -- 'chunkSizeConstant'.
        if
            | n > 0 -> do
              bs <- mustTakeN chunkSizeConstant
              let strToSign = chunkStrToSign prevSign (hashSHA256 bs)
                  nextSign = computeSignature strToSign signingKey
                  chunkBS =
                    toHexStr chunkSizeConstant
                      <> ";chunk-signature="
                      <> nextSign
                      <> "\r\n"
                      <> bs
                      <> "\r\n"
              C.yield chunkBS
              signerConduit (n -1) lps nextSign

            -- Second case encodes the last chunk which is smaller than
            -- 'chunkSizeConstant'
            | lps > 0 -> do
              bs <- mustTakeN $ fromIntegral lps
              let strToSign = chunkStrToSign prevSign (hashSHA256 bs)
                  nextSign = computeSignature strToSign signingKey
                  chunkBS =
                    toHexStr lps <> ";chunk-signature="
                      <> nextSign
                      <> "\r\n"
                      <> bs
                      <> "\r\n"
              C.yield chunkBS
              signerConduit 0 0 nextSign

            -- Last case encodes the final signature chunk that has no
            -- data.
            | otherwise -> do
              let strToSign = chunkStrToSign prevSign (hashSHA256 "")
                  nextSign = computeSignature strToSign signingKey
                  lastChunkBS = "0;chunk-signature=" <> nextSign <> "\r\n\r\n"
              C.yield lastChunkBS
   in \src ->
        req
          { NC.requestHeaders = finalReqHeaders,
            NC.requestBody =
              NC.requestBodySource signedContentLength $
                src C..| signerConduit numParts lastPSize seedSignature
          }