module Network.QUIC.Packet.Token (
    CryptoToken (..),
    isRetryToken,
    generateToken,
    generateRetryToken,
    encryptToken,
    decryptToken,
) where

import qualified UnliftIO.Exception as E
import qualified Crypto.Token as CT
import Data.UnixTime
import Foreign.C.Types
import Network.ByteOrder

import Network.QUIC.Imports
import Network.QUIC.Types

----------------------------------------------------------------

data CryptoToken = CryptoToken
    { CryptoToken -> Version
tokenQUICVersion :: Version
    , CryptoToken -> TimeMicrosecond
tokenCreatedTime :: TimeMicrosecond
    , CryptoToken -> Maybe (CID, CID, CID)
tokenCIDs :: Maybe (CID, CID, CID) -- local, remote, orig local
    }

isRetryToken :: CryptoToken -> Bool
isRetryToken :: CryptoToken -> Bool
isRetryToken CryptoToken
token = forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ CryptoToken -> Maybe (CID, CID, CID)
tokenCIDs CryptoToken
token

----------------------------------------------------------------

generateToken :: Version -> IO CryptoToken
generateToken :: Version -> IO CryptoToken
generateToken Version
ver = do
    TimeMicrosecond
t <- IO TimeMicrosecond
getTimeMicrosecond
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
t forall a. Maybe a
Nothing

generateRetryToken :: Version -> CID -> CID -> CID -> IO CryptoToken
generateRetryToken :: Version -> CID -> CID -> CID -> IO CryptoToken
generateRetryToken Version
ver CID
l CID
r CID
o = do
    TimeMicrosecond
t <- IO TimeMicrosecond
getTimeMicrosecond
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
t forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (CID
l, CID
r, CID
o)

----------------------------------------------------------------

encryptToken :: CT.TokenManager -> CryptoToken -> IO Token
encryptToken :: TokenManager -> CryptoToken -> IO Token
encryptToken TokenManager
mgr CryptoToken
ct = CryptoToken -> IO Token
encodeCryptoToken CryptoToken
ct forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TokenManager -> Token -> IO Token
CT.encryptToken TokenManager
mgr

decryptToken :: CT.TokenManager -> Token -> IO (Maybe CryptoToken)
decryptToken :: TokenManager -> Token -> IO (Maybe CryptoToken)
decryptToken TokenManager
mgr Token
token = do
    Maybe Token
mx <- TokenManager -> Token -> IO (Maybe Token)
CT.decryptToken TokenManager
mgr Token
token
    case Maybe Token
mx of
      Maybe Token
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
      Just Token
x -> Token -> IO (Maybe CryptoToken)
decodeCryptoToken Token
x

----------------------------------------------------------------

cryptoTokenSize :: Int
cryptoTokenSize :: Int
cryptoTokenSize = Int
76 -- 4 + 8 + 1 + (1 + 20) * 3

encodeCryptoToken :: CryptoToken -> IO Token
encodeCryptoToken :: CryptoToken -> IO Token
encodeCryptoToken (CryptoToken (Version Word32
ver) TimeMicrosecond
tim Maybe (CID, CID, CID)
mcids) =
    Int -> (WriteBuffer -> IO ()) -> IO Token
withWriteBuffer Int
cryptoTokenSize forall a b. (a -> b) -> a -> b
$ \WriteBuffer
wbuf -> do
        WriteBuffer -> Word32 -> IO ()
write32 WriteBuffer
wbuf Word32
ver
        let CTime Int64
s = TimeMicrosecond -> CTime
utSeconds TimeMicrosecond
tim
        WriteBuffer -> Word64 -> IO ()
write64 WriteBuffer
wbuf forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int64
s
        case Maybe (CID, CID, CID)
mcids of
            Maybe (CID, CID, CID)
Nothing -> WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf Word8
0
            Just (CID
l, CID
r, CID
o) -> do
                WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf Word8
1
                WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
l
                WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
r
                WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
o
      where
        bury :: WriteBuffer -> CID -> IO ()
bury WriteBuffer
wbuf CID
x = do
            let (ShortByteString
xcid, Word8
xlen) = CID -> (ShortByteString, Word8)
unpackCID CID
x
            WriteBuffer -> Word8 -> IO ()
write8 WriteBuffer
wbuf Word8
xlen
            WriteBuffer -> ShortByteString -> IO ()
copyShortByteString WriteBuffer
wbuf ShortByteString
xcid
            forall a. Readable a => a -> Int -> IO ()
ff WriteBuffer
wbuf (Int
20 forall a. Num a => a -> a -> a
- forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
xlen)

decodeCryptoToken :: Token -> IO (Maybe CryptoToken)
decodeCryptoToken :: Token -> IO (Maybe CryptoToken)
decodeCryptoToken Token
token = do
    Either SomeException CryptoToken
ex <- forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> m (Either e a)
E.try forall a b. (a -> b) -> a -> b
$ Token -> IO CryptoToken
decodeCryptoToken' Token
token
    case Either SomeException CryptoToken
ex of
      Left (E.SomeException e
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
      Right CryptoToken
x -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just CryptoToken
x

decodeCryptoToken' :: ByteString -> IO CryptoToken
decodeCryptoToken' :: Token -> IO CryptoToken
decodeCryptoToken' Token
token = forall a. Token -> (ReadBuffer -> IO a) -> IO a
withReadBuffer Token
token forall a b. (a -> b) -> a -> b
$ \ReadBuffer
rbuf -> do
    Version
ver <- Word32 -> Version
Version forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word32
read32 ReadBuffer
rbuf
    CTime
s <- Int64 -> CTime
CTime forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word64
read64 ReadBuffer
rbuf
    let tim :: TimeMicrosecond
tim = CTime -> Int32 -> TimeMicrosecond
UnixTime CTime
s Int32
0
    Word8
typ <- forall a. Readable a => a -> IO Word8
read8 ReadBuffer
rbuf
    case Word8
typ of
        Word8
0 -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
tim forall a. Maybe a
Nothing
        Word8
_ -> do
            CID
l <- forall {a}. Readable a => a -> IO CID
pick ReadBuffer
rbuf
            CID
r <- forall {a}. Readable a => a -> IO CID
pick ReadBuffer
rbuf
            CID
o <- forall {a}. Readable a => a -> IO CID
pick ReadBuffer
rbuf
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Version -> TimeMicrosecond -> Maybe (CID, CID, CID) -> CryptoToken
CryptoToken Version
ver TimeMicrosecond
tim forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (CID
l, CID
r, CID
o)
  where
    pick :: a -> IO CID
pick a
rbuf = do
        Int
xlen0 <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> IO Word8
read8 a
rbuf
        let xlen :: Int
xlen = forall a. Ord a => a -> a -> a
min Int
xlen0 Int
20
        CID
x <- ShortByteString -> CID
makeCID forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Readable a => a -> Int -> IO ShortByteString
extractShortByteString a
rbuf Int
xlen
        forall a. Readable a => a -> Int -> IO ()
ff a
rbuf (Int
20 forall a. Num a => a -> a -> a
- Int
xlen)
        forall (m :: * -> *) a. Monad m => a -> m a
return CID
x