{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Crypto.ECC.Edwards25519
( Scalar
, Point
, scalarGenerate
, scalarDecodeLong
, scalarEncode
, pointDecode
, pointEncode
, pointHasPrimeOrder
, toPoint
, scalarAdd
, scalarMul
, pointNegate
, pointAdd
, pointDouble
, pointMul
, pointMulByCofactor
, pointsMulVarTime
) where
import Data.Word
import Foreign.C.Types
import Foreign.Ptr
import Crypto.Error
import Crypto.Internal.ByteArray (Bytes, ScrubbedBytes, withByteArray)
import qualified Crypto.Internal.ByteArray as B
import Crypto.Internal.Compat
import Crypto.Internal.Imports
import Crypto.Random
scalarArraySize :: Int
scalarArraySize = 40
newtype Scalar = Scalar ScrubbedBytes
deriving (Show,NFData)
instance Eq Scalar where
(Scalar s1) == (Scalar s2) = unsafeDoIO $
withByteArray s1 $ \ps1 ->
withByteArray s2 $ \ps2 ->
fmap (/= 0) (ed25519_scalar_eq ps1 ps2)
{-# NOINLINE (==) #-}
pointArraySize :: Int
pointArraySize = 160
newtype Point = Point Bytes
deriving NFData
instance Show Point where
showsPrec d p =
let bs = pointEncode p :: Bytes
in showParen (d > 10) $ showString "Point "
. shows (B.convertToBase B.Base16 bs :: Bytes)
instance Eq Point where
(Point p1) == (Point p2) = unsafeDoIO $
withByteArray p1 $ \pp1 ->
withByteArray p2 $ \pp2 ->
fmap (/= 0) (ed25519_point_eq pp1 pp2)
{-# NOINLINE (==) #-}
scalarGenerate :: MonadRandom randomly => randomly Scalar
scalarGenerate = throwCryptoError . scalarDecodeLong <$> generate
where
generate :: MonadRandom randomly => randomly ScrubbedBytes
generate = getRandomBytes 64
scalarEncode :: B.ByteArray bs => Scalar -> bs
scalarEncode (Scalar s) =
B.allocAndFreeze 32 $ \out ->
withByteArray s $ \ps -> ed25519_scalar_encode out ps
scalarDecodeLong :: B.ByteArrayAccess bs => bs -> CryptoFailable Scalar
scalarDecodeLong bs
| B.length bs > 64 = CryptoFailed CryptoError_EcScalarOutOfBounds
| otherwise = unsafeDoIO $ withByteArray bs initialize
where
len = fromIntegral $ B.length bs
initialize inp = do
s <- B.alloc scalarArraySize $ \ps ->
ed25519_scalar_decode_long ps inp len
return $ CryptoPassed (Scalar s)
{-# NOINLINE scalarDecodeLong #-}
scalarAdd :: Scalar -> Scalar -> Scalar
scalarAdd (Scalar a) (Scalar b) =
Scalar $ B.allocAndFreeze scalarArraySize $ \out ->
withByteArray a $ \pa ->
withByteArray b $ \pb ->
ed25519_scalar_add out pa pb
scalarMul :: Scalar -> Scalar -> Scalar
scalarMul (Scalar a) (Scalar b) =
Scalar $ B.allocAndFreeze scalarArraySize $ \out ->
withByteArray a $ \pa ->
withByteArray b $ \pb ->
ed25519_scalar_mul out pa pb
toPoint :: Scalar -> Point
toPoint (Scalar scalar) =
Point $ B.allocAndFreeze pointArraySize $ \out ->
withByteArray scalar $ \pscalar ->
ed25519_point_base_scalarmul out pscalar
pointEncode :: B.ByteArray bs => Point -> bs
pointEncode (Point p) =
B.allocAndFreeze 32 $ \out ->
withByteArray p $ \pp ->
ed25519_point_encode out pp
pointDecode :: B.ByteArrayAccess bs => bs -> CryptoFailable Point
pointDecode bs
| B.length bs == 32 = unsafeDoIO $ withByteArray bs initialize
| otherwise = CryptoFailed CryptoError_PointSizeInvalid
where
initialize inp = do
(res, p) <- B.allocRet pointArraySize $ \pp ->
ed25519_point_decode_vartime pp inp
if res == 0 then return $ CryptoFailed CryptoError_PointCoordinatesInvalid
else return $ CryptoPassed (Point p)
{-# NOINLINE pointDecode #-}
pointHasPrimeOrder :: Point -> Bool
pointHasPrimeOrder (Point p) = unsafeDoIO $
withByteArray p $ \pp ->
fmap (/= 0) (ed25519_point_has_prime_order pp)
{-# NOINLINE pointHasPrimeOrder #-}
pointNegate :: Point -> Point
pointNegate (Point a) =
Point $ B.allocAndFreeze pointArraySize $ \out ->
withByteArray a $ \pa ->
ed25519_point_negate out pa
pointAdd :: Point -> Point -> Point
pointAdd (Point a) (Point b) =
Point $ B.allocAndFreeze pointArraySize $ \out ->
withByteArray a $ \pa ->
withByteArray b $ \pb ->
ed25519_point_add out pa pb
pointDouble :: Point -> Point
pointDouble (Point a) =
Point $ B.allocAndFreeze pointArraySize $ \out ->
withByteArray a $ \pa ->
ed25519_point_double out pa
pointMulByCofactor :: Point -> Point
pointMulByCofactor (Point a) =
Point $ B.allocAndFreeze pointArraySize $ \out ->
withByteArray a $ \pa ->
ed25519_point_mul_by_cofactor out pa
pointMul :: Scalar -> Point -> Point
pointMul (Scalar scalar) (Point base) =
Point $ B.allocAndFreeze pointArraySize $ \out ->
withByteArray scalar $ \pscalar ->
withByteArray base $ \pbase ->
ed25519_point_scalarmul out pbase pscalar
pointsMulVarTime :: Scalar -> Scalar -> Point -> Point
pointsMulVarTime (Scalar s1) (Scalar s2) (Point p) =
Point $ B.allocAndFreeze pointArraySize $ \out ->
withByteArray s1 $ \ps1 ->
withByteArray s2 $ \ps2 ->
withByteArray p $ \pp ->
ed25519_base_double_scalarmul_vartime out ps1 pp ps2
foreign import ccall "cryptonite_ed25519_scalar_eq"
ed25519_scalar_eq :: Ptr Scalar
-> Ptr Scalar
-> IO CInt
foreign import ccall "cryptonite_ed25519_scalar_encode"
ed25519_scalar_encode :: Ptr Word8
-> Ptr Scalar
-> IO ()
foreign import ccall "cryptonite_ed25519_scalar_decode_long"
ed25519_scalar_decode_long :: Ptr Scalar
-> Ptr Word8
-> CSize
-> IO ()
foreign import ccall "cryptonite_ed25519_scalar_add"
ed25519_scalar_add :: Ptr Scalar
-> Ptr Scalar
-> Ptr Scalar
-> IO ()
foreign import ccall "cryptonite_ed25519_scalar_mul"
ed25519_scalar_mul :: Ptr Scalar
-> Ptr Scalar
-> Ptr Scalar
-> IO ()
foreign import ccall "cryptonite_ed25519_point_encode"
ed25519_point_encode :: Ptr Word8
-> Ptr Point
-> IO ()
foreign import ccall "cryptonite_ed25519_point_decode_vartime"
ed25519_point_decode_vartime :: Ptr Point
-> Ptr Word8
-> IO CInt
foreign import ccall "cryptonite_ed25519_point_eq"
ed25519_point_eq :: Ptr Point
-> Ptr Point
-> IO CInt
foreign import ccall "cryptonite_ed25519_point_has_prime_order"
ed25519_point_has_prime_order :: Ptr Point
-> IO CInt
foreign import ccall "cryptonite_ed25519_point_negate"
ed25519_point_negate :: Ptr Point
-> Ptr Point
-> IO ()
foreign import ccall "cryptonite_ed25519_point_add"
ed25519_point_add :: Ptr Point
-> Ptr Point
-> Ptr Point
-> IO ()
foreign import ccall "cryptonite_ed25519_point_double"
ed25519_point_double :: Ptr Point
-> Ptr Point
-> IO ()
foreign import ccall "cryptonite_ed25519_point_mul_by_cofactor"
ed25519_point_mul_by_cofactor :: Ptr Point
-> Ptr Point
-> IO ()
foreign import ccall "cryptonite_ed25519_point_base_scalarmul"
ed25519_point_base_scalarmul :: Ptr Point
-> Ptr Scalar
-> IO ()
foreign import ccall "cryptonite_ed25519_point_scalarmul"
ed25519_point_scalarmul :: Ptr Point
-> Ptr Point
-> Ptr Scalar
-> IO ()
foreign import ccall "cryptonite_ed25519_base_double_scalarmul_vartime"
ed25519_base_double_scalarmul_vartime :: Ptr Point
-> Ptr Scalar
-> Ptr Point
-> Ptr Scalar
-> IO ()