-- |
-- Module      : Crypto.Number.Serialize
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good
--
-- Fast serialization primitives for integer
{-# LANGUAGE BangPatterns #-}
module Crypto.Number.Serialize
    ( i2osp
    , os2ip
    , i2ospOf
    , i2ospOf_
    ) where

import           Crypto.Number.Basic
import           Crypto.Internal.Compat (unsafeDoIO)
import qualified Crypto.Internal.ByteArray as B
import qualified Crypto.Number.Serialize.Internal as Internal

-- | @os2ip@ converts a byte string into a positive integer.
os2ip :: B.ByteArrayAccess ba => ba -> Integer
os2ip :: forall ba. ByteArrayAccess ba => ba -> Integer
os2ip ba
bs = IO Integer -> Integer
forall a. IO a -> a
unsafeDoIO (IO Integer -> Integer) -> IO Integer -> Integer
forall a b. (a -> b) -> a -> b
$ ba -> (Ptr Word8 -> IO Integer) -> IO Integer
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
forall p a. ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
bs (\Ptr Word8
p -> Ptr Word8 -> Int -> IO Integer
Internal.os2ip Ptr Word8
p (ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs))

-- | @i2osp@ converts a positive integer into a byte string.
--
-- The first byte is MSB (most significant byte); the last byte is the LSB (least significant byte)
i2osp :: B.ByteArray ba => Integer -> ba
i2osp :: forall ba. ByteArray ba => Integer -> ba
i2osp Integer
0 = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
1  (\Ptr Word8
p -> Integer -> Ptr Word8 -> Int -> IO Int
Internal.i2osp Integer
0 Ptr Word8
p Int
1 IO Int -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
i2osp Integer
m = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
sz (\Ptr Word8
p -> Integer -> Ptr Word8 -> Int -> IO Int
Internal.i2osp Integer
m Ptr Word8
p Int
sz IO Int -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
  where
        !sz :: Int
sz = Integer -> Int
numBytes Integer
m

-- | Just like 'i2osp', but takes an extra parameter for size.
-- If the number is too big to fit in @len@ bytes, 'Nothing' is returned
-- otherwise the number is padded with 0 to fit the @len@ required.
{-# INLINABLE i2ospOf #-}
i2ospOf :: B.ByteArray ba => Int -> Integer -> Maybe ba
i2ospOf :: forall ba. ByteArray ba => Int -> Integer -> Maybe ba
i2ospOf Int
len Integer
m
    | Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0  = Maybe ba
forall a. Maybe a
Nothing
    | Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0     = Maybe ba
forall a. Maybe a
Nothing
    | Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
len  = Maybe ba
forall a. Maybe a
Nothing
    | Bool
otherwise = ba -> Maybe ba
forall a. a -> Maybe a
Just (ba -> Maybe ba) -> ba -> Maybe ba
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate Int
len (\Ptr Word8
p -> Integer -> Ptr Word8 -> Int -> IO Int
Internal.i2ospOf Integer
m Ptr Word8
p Int
len IO Int -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
  where
        !sz :: Int
sz = Integer -> Int
numBytes Integer
m

-- | Just like 'i2ospOf' except that it doesn't expect a failure: i.e.
-- an integer larger than the number of output bytes requested.
--
-- For example if you just took a modulo of the number that represent
-- the size (example the RSA modulo n).
i2ospOf_ :: B.ByteArray ba => Int -> Integer -> ba
i2ospOf_ :: forall ba. ByteArray ba => Int -> Integer -> ba
i2ospOf_ Int
len = ba -> (ba -> ba) -> Maybe ba -> ba
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> ba
forall a. HasCallStack => [Char] -> a
error [Char]
"i2ospOf_: integer is larger than expected") ba -> ba
forall a. a -> a
id (Maybe ba -> ba) -> (Integer -> Maybe ba) -> Integer -> ba
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Integer -> Maybe ba
forall ba. ByteArray ba => Int -> Integer -> Maybe ba
i2ospOf Int
len