{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

module Metro.TP.Crypto
  ( Crypto
  , crypto
  , crypto_
  , CryptoMethod (..)
  , methodEcb
  , methodCbc
  , methodCfb
  , methodCtr

  , makeCrypto
  ) where

import           Control.Monad        (when)
import           Crypto.Cipher.Types  (BlockCipher (..), Cipher (..), IV (..),
                                       KeySizeSpecifier (..), ivAdd, nullIV)
import           Crypto.Error         (CryptoFailable (..))
import           Data.Binary          (Binary (..), decode, encode)
import           Data.Binary.Get      (getByteString, getWord32be)
import           Data.Binary.Put      (putByteString, putWord32be)
import           Data.ByteString      (ByteString, empty)
import qualified Data.ByteString      as B (append, length, replicate, take)
import           Data.ByteString.Lazy (fromStrict, toStrict)
import qualified Data.ByteString.Lazy as LB (cycle, fromStrict, take, toStrict)
import qualified Data.Text            as T (pack)
import           Data.Text.Encoding   (encodeUtf8)
import           Metro.Class          (Transport (..))
import           Metro.Utils          (recvEnough)
import           UnliftIO

newtype BlockLength = BlockLength Int
  deriving (Show, Eq)

instance Binary BlockLength where
  get = BlockLength . fromIntegral <$> getWord32be
  put (BlockLength l) = putWord32be $ fromIntegral l

data Block = Block
    { msgSize :: !Int
    , encData :: !ByteString
    }
    deriving (Show, Eq)

instance Binary Block where
  get = do
    pktSize <- fromIntegral <$> getWord32be
    msgSize <- fromIntegral <$> getWord32be
    encData <- getByteString $ pktSize - 4
    return Block {..}
  put Block {..} = do
    putWord32be $ fromIntegral $ B.length encData + 4
    putWord32be $ fromIntegral msgSize
    putByteString encData

makeBlock :: Int -> ByteString -> Block
makeBlock bSize msg = Block size msg0
  where size = B.length msg
        fixedSize = ceiling (fromIntegral size / fromIntegral bSize * 1.0) * bSize
        msg0 = if size < fixedSize then msg `B.append` B.replicate (fixedSize - size) 0
                                   else msg

getMsg :: Block -> ByteString
getMsg Block {..} = B.take msgSize encData

prepareBlock
  :: BlockCipher cipher
  => (cipher -> IV cipher -> ByteString -> ByteString)
  -> cipher -> IV cipher -> Block -> Block
prepareBlock f c iv b = b { encData = f c iv (encData b) }

data CryptoMethod cipher = CryptoMethod
    { encrypt :: cipher -> IV cipher -> ByteString -> ByteString
    , decrypt :: cipher -> IV cipher -> ByteString -> ByteString
    , needIV  :: Bool
    }

data Crypto cipher tp = Crypto
    { readBuffer   :: TVar ByteString
    , cryptoMethod :: CryptoMethod cipher
    , readIV       :: TVar (IV cipher)
    , writeIV      :: TVar (IV cipher)
    , cipher       :: cipher
    , tp           :: tp
    }

instance (Transport tp, BlockCipher cipher) => Transport (Crypto cipher tp) where
  data TransportConfig (Crypto cipher tp) =
    CryptoConfig (CryptoMethod cipher) cipher (IV cipher) (TransportConfig tp)
  newTransport (CryptoConfig cryptoMethod cipher iv config) = do
    readBuffer <- newTVarIO empty
    tp <- newTransport config
    readIV  <- newTVarIO iv
    writeIV <- newTVarIO iv
    return Crypto {..}
  recvData (Crypto buf method ivr _ cipher tp) _ = do
    hbs <- recvEnough buf tp 4
    iv <- readTVarIO ivr
    case decode (fromStrict hbs) of
      BlockLength len -> do
        bs <- getMsg
          . prepareBlock (decrypt method) cipher iv
          . decode
          . fromStrict
          . (hbs <>) <$> recvEnough buf tp len

        when (needIV method) $
          atomically $ writeTVar ivr (ivAdd iv (B.length bs))

        return bs
  sendData (Crypto _ method _ ivw cipher tp) bs = do
    iv <- readTVarIO ivw

    when (needIV method) $
      atomically $ writeTVar ivw (ivAdd iv (B.length bs))

    sendData tp
      . toStrict
      . encode
      . prepareBlock (encrypt method) cipher iv
      $ makeBlock (blockSize cipher) bs
  closeTransport (Crypto _ _ _ _ _ tp) = closeTransport tp

crypto
  :: BlockCipher cipher
  => CryptoMethod cipher
  -> cipher
  -> TransportConfig tp
  -> TransportConfig (Crypto cipher tp)
crypto method cipher = crypto_ method cipher nullIV

crypto_
  :: BlockCipher cipher
  => CryptoMethod cipher
  -> cipher
  -> IV cipher
  -> TransportConfig tp
  -> TransportConfig (Crypto cipher tp)
crypto_ = CryptoConfig

methodEcb :: BlockCipher cipher => CryptoMethod cipher
methodEcb = CryptoMethod (ignoreIV ecbEncrypt) (ignoreIV ecbDecrypt) False
  where ignoreIV f c _ = f c

methodCbc :: BlockCipher cipher => CryptoMethod cipher
methodCbc = CryptoMethod cbcEncrypt cbcDecrypt True

methodCfb :: BlockCipher cipher => CryptoMethod cipher
methodCfb = CryptoMethod cfbEncrypt cfbDecrypt  True

methodCtr :: BlockCipher cipher => CryptoMethod cipher
methodCtr = CryptoMethod ctrCombine ctrCombine True

getCryptoMethod :: BlockCipher cipher => cipher -> String -> Maybe (CryptoMethod cipher)
getCryptoMethod _ "CBC" = Just methodCbc
getCryptoMethod _ "cbc" = Just methodCbc
getCryptoMethod _ "CFB" = Just methodCfb
getCryptoMethod _ "cfb" = Just methodCfb
getCryptoMethod _ "ECB" = Just methodEcb
getCryptoMethod _ "ecb" = Just methodEcb
getCryptoMethod _ "CTR" = Just methodCtr
getCryptoMethod _ "ctr" = Just methodCtr
getCryptoMethod _ _     = Nothing

makeCrypto
  :: forall cipher tp. (BlockCipher cipher, Cipher cipher)
  => cipher -> String -> String -> TransportConfig tp -> TransportConfig (Crypto cipher tp)
makeCrypto cipher method key c =
  case getCryptoMethod cipher method of
    Nothing -> error "crypto method not support"
    Just m  ->
      case cipherInit key0 of
        CryptoFailed e         -> error $ "Cipher init failed " ++ show e
        CryptoPassed (newCipher :: cipher) ->
          crypto m newCipher c

  where size = getKeySize $ cipherKeySize cipher
        key0 =
          LB.toStrict
          . LB.take (fromIntegral size)
          . LB.cycle
          . LB.fromStrict
          . encodeUtf8
          $ T.pack key


getKeySize :: KeySizeSpecifier -> Int
getKeySize (KeySizeRange _ x) = x
getKeySize (KeySizeEnum xs)   = maximum xs
getKeySize (KeySizeFixed x)   = x