{-# LANGUAGE RecordWildCards            #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Data.KeyStore.KS.Packet
    ( encocdeEncryptionPacket
    , decocdeEncryptionPacketE
    , encocdeSignaturePacket
    , decocdeSignaturePacketE
    -- debugging
    , testBP
    ) where

import           Data.KeyStore.KS.KS
import           Data.KeyStore.Types
import           Data.API.Types
import qualified Data.ByteString                as B
import qualified Data.ByteString.Char8          as BC
import qualified Data.ByteString.Lazy.Char8     as LBS
import           Data.ByteString.Builder
import           Data.Word
import           Data.Bits
import           Data.Char
import           Text.Printf
import           Control.Monad.RWS.Strict
import qualified Control.Monad.Except           as E


newtype MagicWord = MagicWord B.ByteString

encryption_magic_word, signature_magic_word :: MagicWord
encryption_magic_word :: MagicWord
encryption_magic_word = ByteString -> MagicWord
MagicWord forall a b. (a -> b) -> a -> b
$ [Word8] -> ByteString
B.pack [Word8
0x54,Word8
0xab,Word8
0xcd,Word8
0x00]
signature_magic_word :: MagicWord
signature_magic_word  = ByteString -> MagicWord
MagicWord forall a b. (a -> b) -> a -> b
$ [Word8] -> ByteString
B.pack [Word8
0x54,Word8
0xab,Word8
0xcd,Word8
0x80]


encocdeEncryptionPacket :: Safeguard -> RSASecretBytes -> EncryptionPacket
encocdeEncryptionPacket :: Safeguard -> RSASecretBytes -> EncryptionPacket
encocdeEncryptionPacket Safeguard
sg RSASecretBytes
rsb =
    Binary -> EncryptionPacket
EncryptionPacket forall a b. (a -> b) -> a -> b
$ ByteString -> Binary
Binary forall a b. (a -> b) -> a -> b
$
        MagicWord -> Safeguard -> ByteString -> ByteString
encodePacket MagicWord
encryption_magic_word Safeguard
sg forall a b. (a -> b) -> a -> b
$ Binary -> ByteString
_Binary forall a b. (a -> b) -> a -> b
$ RSASecretBytes -> Binary
_RSASecretBytes RSASecretBytes
rsb

decocdeEncryptionPacketE :: EncryptionPacket -> E (Safeguard,RSASecretBytes)
decocdeEncryptionPacketE :: EncryptionPacket -> E (Safeguard, RSASecretBytes)
decocdeEncryptionPacketE EncryptionPacket
ep =
 do (Safeguard
sg,ByteString
bs) <- MagicWord -> ByteString -> E (Safeguard, ByteString)
decodePacketE MagicWord
encryption_magic_word forall a b. (a -> b) -> a -> b
$ Binary -> ByteString
_Binary forall a b. (a -> b) -> a -> b
$ EncryptionPacket -> Binary
_EncryptionPacket EncryptionPacket
ep
    forall (m :: * -> *) a. Monad m => a -> m a
return (Safeguard
sg,Binary -> RSASecretBytes
RSASecretBytes forall a b. (a -> b) -> a -> b
$ ByteString -> Binary
Binary ByteString
bs)

encocdeSignaturePacket :: Safeguard -> RSASignature -> SignaturePacket
encocdeSignaturePacket :: Safeguard -> RSASignature -> SignaturePacket
encocdeSignaturePacket Safeguard
sg RSASignature
rs =
    Binary -> SignaturePacket
SignaturePacket forall a b. (a -> b) -> a -> b
$ ByteString -> Binary
Binary forall a b. (a -> b) -> a -> b
$
        MagicWord -> Safeguard -> ByteString -> ByteString
encodePacket MagicWord
signature_magic_word Safeguard
sg forall a b. (a -> b) -> a -> b
$ Binary -> ByteString
_Binary forall a b. (a -> b) -> a -> b
$ RSASignature -> Binary
_RSASignature RSASignature
rs

decocdeSignaturePacketE :: SignaturePacket -> E (Safeguard,RSASignature)
decocdeSignaturePacketE :: SignaturePacket -> E (Safeguard, RSASignature)
decocdeSignaturePacketE SignaturePacket
sp =
 do (Safeguard
sg,ByteString
bs) <- MagicWord -> ByteString -> E (Safeguard, ByteString)
decodePacketE MagicWord
signature_magic_word forall a b. (a -> b) -> a -> b
$ Binary -> ByteString
_Binary forall a b. (a -> b) -> a -> b
$ SignaturePacket -> Binary
_SignaturePacket SignaturePacket
sp
    forall (m :: * -> *) a. Monad m => a -> m a
return (Safeguard
sg,Binary -> RSASignature
RSASignature forall a b. (a -> b) -> a -> b
$ ByteString -> Binary
Binary ByteString
bs)


encodePacket :: MagicWord -> Safeguard -> B.ByteString -> B.ByteString
encodePacket :: MagicWord -> Safeguard -> ByteString -> ByteString
encodePacket (MagicWord ByteString
mw_bs) Safeguard
sg ByteString
bs =
    ByteString -> ByteString -> ByteString
B.append     ByteString
mw_bs forall a b. (a -> b) -> a -> b
$
    Safeguard -> ByteString -> ByteString
encodeSafeguard Safeguard
sg forall a b. (a -> b) -> a -> b
$
                    ByteString
bs

decodePacketE :: MagicWord -> B.ByteString -> E (Safeguard,B.ByteString)
decodePacketE :: MagicWord -> ByteString -> E (Safeguard, ByteString)
decodePacketE (MagicWord ByteString
mw_bs) ByteString
bs = forall a. ByteString -> BP a -> E a
run ByteString
bs forall a b. (a -> b) -> a -> b
$
 do ByteString
mw_bs' <- Octets -> BP ByteString
splitBP (Int -> Octets
Octets forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
mw_bs)
    case ByteString
mw_bsforall a. Eq a => a -> a -> Bool
==ByteString
mw_bs' of
      Bool
True  -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Bool
False -> forall a. [Char] -> BP a
errorBP forall a b. (a -> b) -> a -> b
$ forall r. PrintfType r => [Char] -> r
printf [Char]
"bad magic word: %s/=%s" (ByteString -> [Char]
BC.unpack forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
to_hex ByteString
mw_bs') (ByteString -> [Char]
BC.unpack forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
to_hex ByteString
mw_bs)
    Safeguard
sg   <- BP Safeguard
decodeSafeguard
    ByteString
b_bs <- BP ByteString
remainingBP
    forall (m :: * -> *) a. Monad m => a -> m a
return (Safeguard
sg,ByteString
b_bs)

encodeSafeguard :: Safeguard -> ShowB
encodeSafeguard :: Safeguard -> ByteString -> ByteString
encodeSafeguard = ByteString -> ByteString -> ByteString
encodeLengthPacket forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> ByteString
BC.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. Safeguard -> [Char]
printSafeguard

decodeSafeguard :: BP Safeguard
decodeSafeguard :: BP Safeguard
decodeSafeguard = forall a. (ByteString -> BP a) -> BP a
decodeLengthPacket forall a b. (a -> b) -> a -> b
$ forall a. E a -> BP a
e2bp forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> E Safeguard
parseSafeguard forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
BC.unpack

encodeLengthPacket :: B.ByteString -> ShowB
encodeLengthPacket :: ByteString -> ByteString -> ByteString
encodeLengthPacket ByteString
bs ByteString
t_bs = [ByteString] -> ByteString
B.concat [ByteString
ln_bs,ByteString
bs,ByteString
t_bs]
  where
    ln_bs :: ByteString
ln_bs = ByteString -> ByteString
LBS.toStrict forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
toLazyByteString forall a b. (a -> b) -> a -> b
$ Int64 -> Builder
int64LE forall a b. (a -> b) -> a -> b
$ forall a. Enum a => Int -> a
toEnum forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
bs

decodeLengthPacket :: (B.ByteString->BP a) -> BP a
decodeLengthPacket :: forall a. (ByteString -> BP a) -> BP a
decodeLengthPacket ByteString -> BP a
bp =
 do ByteString
ln_bs <- Octets -> BP ByteString
splitBP Octets
8
    let ln :: Int
ln = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall a. Bits a => a -> a -> a
(.|.) Word64
0 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (ByteString -> Int -> Word64
f ByteString
ln_bs) [Int
0..Int
7]
    [Char] -> BP ()
btwBP forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show Int
ln
    ByteString
bs <- Octets -> BP ByteString
splitBP forall a b. (a -> b) -> a -> b
$ Int -> Octets
Octets Int
ln
    ByteString -> BP a
bp ByteString
bs
  where
    f :: ByteString -> Int -> Word64
f ByteString
bs Int
i = forall a. Bits a => a -> Int -> a
rotate Word64
w64 forall a b. (a -> b) -> a -> b
$ Int
8forall a. Num a => a -> a -> a
*Int
i
      where
        w64 :: Word64
        w64 :: Word64
w64 = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ HasCallStack => ByteString -> Int -> Word8
B.index ByteString
bs Int
i

type ShowB = B.ByteString -> B.ByteString

newtype BP a = BP { forall a. BP a -> ExceptT Reason (RWS () [LogEntry] ByteString) a
_BP :: E.ExceptT Reason (RWS () [LogEntry] B.ByteString) a }
    deriving (forall a b. a -> BP b -> BP a
forall a b. (a -> b) -> BP a -> BP b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> BP b -> BP a
$c<$ :: forall a b. a -> BP b -> BP a
fmap :: forall a b. (a -> b) -> BP a -> BP b
$cfmap :: forall a b. (a -> b) -> BP a -> BP b
Functor, Functor BP
forall a. a -> BP a
forall a b. BP a -> BP b -> BP a
forall a b. BP a -> BP b -> BP b
forall a b. BP (a -> b) -> BP a -> BP b
forall a b c. (a -> b -> c) -> BP a -> BP b -> BP c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. BP a -> BP b -> BP a
$c<* :: forall a b. BP a -> BP b -> BP a
*> :: forall a b. BP a -> BP b -> BP b
$c*> :: forall a b. BP a -> BP b -> BP b
liftA2 :: forall a b c. (a -> b -> c) -> BP a -> BP b -> BP c
$cliftA2 :: forall a b c. (a -> b -> c) -> BP a -> BP b -> BP c
<*> :: forall a b. BP (a -> b) -> BP a -> BP b
$c<*> :: forall a b. BP (a -> b) -> BP a -> BP b
pure :: forall a. a -> BP a
$cpure :: forall a. a -> BP a
Applicative, Applicative BP
forall a. a -> BP a
forall a b. BP a -> BP b -> BP b
forall a b. BP a -> (a -> BP b) -> BP b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> BP a
$creturn :: forall a. a -> BP a
>> :: forall a b. BP a -> BP b -> BP b
$c>> :: forall a b. BP a -> BP b -> BP b
>>= :: forall a b. BP a -> (a -> BP b) -> BP b
$c>>= :: forall a b. BP a -> (a -> BP b) -> BP b
Monad, E.MonadError Reason)

e2bp :: E a -> BP a
e2bp :: forall a. E a -> BP a
e2bp = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. Reason -> BP a
throwBP forall (m :: * -> *) a. Monad m => a -> m a
return

run :: B.ByteString -> BP a -> E a
run :: forall a. ByteString -> BP a -> E a
run ByteString
bs BP a
bp =
    case (ByteString -> Bool
B.null ByteString
bs',E a
e) of
      (Bool
False,Right a
_) -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ [Char] -> Reason
strMsg [Char]
"bad packet format (residual bytes)"
      (Bool, E a)
_               -> E a
e
  where
    (E a
e,ByteString
bs',[LogEntry]
_) = forall a. ByteString -> BP a -> (E a, ByteString, [LogEntry])
runBP ByteString
bs BP a
bp

runBP :: B.ByteString -> BP a -> (E a,B.ByteString,[LogEntry])
runBP :: forall a. ByteString -> BP a -> (E a, ByteString, [LogEntry])
runBP ByteString
s BP a
p = forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
E.runExceptT (forall a. BP a -> ExceptT Reason (RWS () [LogEntry] ByteString) a
_BP BP a
p)) () ByteString
s

testBP :: B.ByteString -> BP a -> IO a
testBP :: forall a. ByteString -> BP a -> IO a
testBP ByteString
bs BP a
p =
 do forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ LogEntry -> IO ()
lg [LogEntry]
les
    case ByteString -> Bool
B.null ByteString
rbs of
      Bool
True  -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Bool
False -> [Char] -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show(ByteString -> Int
B.length ByteString
rbs) forall a. [a] -> [a] -> [a]
++ [Char]
" bytes remaining"
    case E a
e of
      Left Reason
dg -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show Reason
dg
      Right a
r -> forall (m :: * -> *) a. Monad m => a -> m a
return a
r
  where
    (E a
e,ByteString
rbs,[LogEntry]
les) = forall a. ByteString -> BP a -> (E a, ByteString, [LogEntry])
runBP ByteString
bs BP a
p

    lg :: LogEntry -> IO ()
lg LogEntry{Bool
[Char]
le_message :: LogEntry -> [Char]
le_debug :: LogEntry -> Bool
le_message :: [Char]
le_debug :: Bool
..} = [Char] -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ [Char]
"log: " forall a. [a] -> [a] -> [a]
++ [Char]
le_message

btwBP :: String -> BP ()
btwBP :: [Char] -> BP ()
btwBP [Char]
msg = forall a. ExceptT Reason (RWS () [LogEntry] ByteString) a -> BP a
BP forall a b. (a -> b) -> a -> b
$ forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [Bool -> [Char] -> LogEntry
LogEntry Bool
True [Char]
msg]

errorBP :: String -> BP a
errorBP :: forall a. [Char] -> BP a
errorBP = forall a. Reason -> BP a
throwBP forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Reason
strMsg forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Char]
"packet decode error: " forall a. [a] -> [a] -> [a]
++)

throwBP :: Reason -> BP a
throwBP :: forall a. Reason -> BP a
throwBP = forall e (m :: * -> *) a. MonadError e m => e -> m a
E.throwError

splitBP :: Octets -> BP B.ByteString
splitBP :: Octets -> BP ByteString
splitBP (Octets Int
n) =
 do ByteString
bs <- BP ByteString
peek_remainingBP
    let (ByteString
bs_h,ByteString
bs_r) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
n ByteString
bs
    case Int
nforall a. Ord a => a -> a -> Bool
<=ByteString -> Int
B.length ByteString
bs of
      Bool
True  -> (ByteString -> ByteString) -> BP ()
modifyBP (forall a b. a -> b -> a
const ByteString
bs_r) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs_h
      Bool
False -> forall a. [Char] -> BP a
errorBP [Char]
"not enough bytes"

remainingBP :: BP B.ByteString
remainingBP :: BP ByteString
remainingBP =
 do ByteString
bs <- BP ByteString
peek_remainingBP
    (ByteString -> ByteString) -> BP ()
modifyBP forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const ByteString
B.empty
    forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs

peek_remainingBP :: BP B.ByteString
peek_remainingBP :: BP ByteString
peek_remainingBP = forall a. ExceptT Reason (RWS () [LogEntry] ByteString) a -> BP a
BP forall s (m :: * -> *). MonadState s m => m s
get

modifyBP :: (B.ByteString->B.ByteString) -> BP ()
modifyBP :: (ByteString -> ByteString) -> BP ()
modifyBP ByteString -> ByteString
upd = forall a. ExceptT Reason (RWS () [LogEntry] ByteString) a -> BP a
BP forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ByteString -> ByteString
upd

-- hexify a ByteString

to_hex :: B.ByteString -> B.ByteString
to_hex :: ByteString -> ByteString
to_hex = [Char] -> ByteString
BC.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Char -> [Char] -> [Char]
f [Char]
"" forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
BC.unpack
  where
    f :: Char -> [Char] -> [Char]
f Char
c [Char]
t = Int -> Char
intToDigit (Int
n forall a. Integral a => a -> a -> a
`div` Int
16) forall a. a -> [a] -> [a]
: Int -> Char
intToDigit (Int
n forall a. Integral a => a -> a -> a
`mod` Int
16) forall a. a -> [a] -> [a]
: [Char]
t
          where
            n :: Int
n = Char -> Int
ord Char
c