{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.JOSE.AESKW
  (
    aesKeyWrap
  , aesKeyUnwrap
  ) where
import Control.Monad.State
import Crypto.Cipher.Types
import Data.Bits (xor)
import Data.ByteArray as BA hiding (replicate, xor)
import Data.Memory.Endian (BE(..), toBE)
import Data.Memory.PtrMethods (memCopy)
import Data.Word (Word64)
import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (peek, peekElemOff, poke, pokeElemOff)
import System.IO.Unsafe (unsafePerformIO)
iv :: Word64
iv = 0xA6A6A6A6A6A6A6A6
aesKeyWrapStep
  :: BlockCipher128 cipher
  => cipher
  -> Ptr Word64   
  -> (Int, Int)   
  -> StateT Word64 IO ()
aesKeyWrapStep cipher p (t, i) = do
  a <- get
  r_i <- lift $ peekElemOff p i
  m :: ScrubbedBytes <-
    lift $ alloc 16 $ \p' -> poke p' a >> pokeElemOff p' 1 r_i
  let b = ecbEncrypt cipher m
  b_hi <- lift $ withByteArray b peek
  b_lo <- lift $ withByteArray b (`peekElemOff` 1)
  put (b_hi `xor` unBE (toBE (fromIntegral t)))
  lift $ pokeElemOff p i b_lo
aesKeyWrap
  :: (ByteArrayAccess m, ByteArray c, BlockCipher128 cipher)
  => cipher
  -> m
  -> c
aesKeyWrap cipher m = unsafePerformIO $ do
  let n = BA.length m
  c <- withByteArray m $ \p ->
    alloc (n + 8) $ \p' ->
      memCopy (p' `plusPtr` 8) p n
  withByteArray c $ \p -> do
    let coords = zip [1..] (join (replicate 6 [1 .. n `div` 8]))
    a <- execStateT (mapM_ (aesKeyWrapStep cipher p) coords) iv
    poke p a
  return c
aesKeyUnwrapStep
  :: BlockCipher128 cipher
  => cipher
  -> Ptr Word64   
  -> (Int, Int)   
  -> StateT Word64 IO ()
aesKeyUnwrapStep cipher p (t, i) = do
  a <- get
  r_i <- lift $ peekElemOff p i
  let a_t = a `xor` unBE (toBE (fromIntegral t))
  m :: ScrubbedBytes <-
    lift $ alloc 16 $ \p' -> poke p' a_t >> pokeElemOff p' 1 r_i
  let b = ecbDecrypt cipher m
  b_hi <- lift $ withByteArray b peek
  b_lo <- lift $ withByteArray b (`peekElemOff` 1)
  put b_hi
  lift $ pokeElemOff p i b_lo
aesKeyUnwrap
  :: (ByteArrayAccess c, ByteArray m, BlockCipher128 cipher)
  => cipher
  -> c
  -> Maybe m
aesKeyUnwrap cipher c = unsafePerformIO $ do
  let n = BA.length c - 8
  m <- withByteArray c $ \p' ->
    alloc n $ \p ->
      memCopy p (p' `plusPtr` 8) n
  a <- withByteArray c $ \p' -> peek p'
  a' <- withByteArray m $ \p -> do
    let n' = n `div` 8
    let tMax = n' * 6
    let coords = zip [tMax,tMax-1..1] (cycle [n'-1,n'-2..0])
    execStateT (mapM_ (aesKeyUnwrapStep cipher p) coords) a
  return $ if a' == iv then Just m else Nothing