{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Data.MerkleLog
(
MerkleTree
, merkleTree
, encodeMerkleTree
, decodeMerkleTree
, MerkleRoot
, merkleRoot
, encodeMerkleRoot
, decodeMerkleRoot
, MerkleNodeType(..)
, MerkleProof(..)
, MerkleProofSubject(..)
, MerkleProofObject
, encodeMerkleProofObject
, decodeMerkleProofObject
, merkleProof
, merkleProof_
, runMerkleProof
, Expected(..)
, Actual(..)
, MerkleTreeException(..)
, textMessage
, isEmpty
, emptyMerkleTree
, size
, leafCount
, MerkleHash
, getHash
, merkleLeaf
, merkleNode
) where
import Control.DeepSeq
import Control.Monad
import Control.Monad.Catch
import Crypto.Hash (hash)
import Crypto.Hash.Algorithms (HashAlgorithm)
import Crypto.Hash.IO
import qualified Data.ByteArray as BA
import Data.ByteArray.Encoding
import qualified Data.ByteString as B
import qualified Data.List.NonEmpty as NE
import qualified Data.Memory.Endian as BA
import Data.String
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Word
import Foreign.Ptr
import Foreign.Storable
import GHC.Generics
import GHC.Stack
import System.IO.Unsafe
newtype Expected a = Expected a
deriving (Show, Eq, Ord, Generic)
deriving anyclass (NFData)
newtype Actual a = Actual a
deriving (Show, Eq, Ord, Generic)
deriving anyclass (NFData)
expectedMessage :: Show a => Expected a -> Actual a -> T.Text
expectedMessage (Expected e) (Actual a)
= "Expected: " <> sshow e <> ", Actual: " <> sshow a
data MerkleTreeException
= EncodingSizeException T.Text (Expected Int) (Actual Int)
| EncodingSizeConstraintException T.Text (Expected T.Text) (Actual Int)
| IndexOutOfBoundsException T.Text (Expected (Int, Int)) (Actual Int)
| InputNotInTreeException T.Text Int B.ByteString
| MerkleRootNotInTreeException T.Text Int B.ByteString
| InvalidProofObjectException T.Text
deriving (Eq, Generic)
deriving anyclass (NFData)
instance Exception MerkleTreeException where
displayException = T.unpack . textMessage
instance Show MerkleTreeException where
show = T.unpack . textMessage
textMessage :: MerkleTreeException -> T.Text
textMessage (EncodingSizeException ty e a)
= "Failed to decode " <> ty <> " because the input is of wrong size"
<> ". " <> expectedMessage e a
textMessage (EncodingSizeConstraintException ty (Expected e) (Actual a))
= "Failed to decode " <> ty <> " because the input is of wrong size"
<> ". " <> "Expected: " <> e
<> ", " <> "Actual: " <> sshow a
textMessage (IndexOutOfBoundsException ty (Expected e) (Actual a))
= "Index out of bounds"
<> ". " <> ty
<> ". " <> "Expected: " <> sshow e
<> ", " <> "Actual: " <> sshow a
textMessage (InputNotInTreeException t i b)
= "Item not in tree"
<> ". " <> t
<> ". Position: " <> sshow i
<> ". Input (b64): " <> T.take 1024 (b64 b)
textMessage (MerkleRootNotInTreeException t i b)
= "Item not in tree"
<> ". " <> t
<> ". Position: " <> sshow i
<> ". Input (b64): " <> b64 b
textMessage (InvalidProofObjectException t)
= "Invalid ProofObject: " <> t
inputNotInTreeException
:: T.Text
-> Int
-> MerkleNodeType a B.ByteString
-> MerkleTreeException
inputNotInTreeException t pos (TreeNode r)
= MerkleRootNotInTreeException t pos $ encodeMerkleRoot r
inputNotInTreeException t pos (InputNode b)
= InputNotInTreeException t pos b
newtype MerkleHash a = MerkleHash BA.Bytes
deriving (Eq, Ord, Generic)
deriving newtype (NFData, BA.ByteArrayAccess)
instance Show (MerkleHash a) where
show = fmap (toEnum . fromEnum)
. BA.unpack @BA.Bytes
. convertToBase Base64URLUnpadded
{-# INLINEABLE show #-}
hashSize :: forall a c . HashAlgorithm a => Num c => c
hashSize = fromIntegral $ hashDigestSize @a undefined
{-# INLINE hashSize #-}
decodeMerkleHash
:: forall a b m
. MonadThrow m
=> HashAlgorithm a
=> BA.ByteArrayAccess b
=> b
-> m (MerkleHash a)
decodeMerkleHash b
| BA.length b /= hashSize @a = throwM e
| otherwise = return $ MerkleHash $ BA.convert b
where
e = EncodingSizeException "MerkleHash"
(Expected (hashSize @a @Int))
(Actual (BA.length b))
{-# INLINE decodeMerkleHash #-}
leafTag :: BA.ByteArray a => a
leafTag = BA.singleton 0
{-# INLINE leafTag #-}
nodeTag :: BA.ByteArray a => a
nodeTag = BA.singleton 1
{-# INLINE nodeTag #-}
merkleLeaf
:: forall a b
. HashAlgorithm a
=> BA.ByteArrayAccess b
=> b
-> MerkleHash a
merkleLeaf !bytes = MerkleHash $ BA.allocAndFreeze (hashSize @a) $ \ptr -> do
!ctx <- hashMutableInit @a
merkleLeafPtr ctx bytes ptr
merkleNode
:: forall a
. HashAlgorithm a
=> MerkleHash a
-> MerkleHash a
-> MerkleRoot a
merkleNode !a !b = MerkleRoot $ MerkleHash $ BA.allocAndFreeze (hashSize @a) $ \ptr -> do
!ctx <- hashMutableInit @a
BA.withByteArray a $ \aptr ->
BA.withByteArray b $ \bptr ->
merkleNodePtr ctx aptr bptr ptr
merkleNodePtr
:: forall a
. HashAlgorithm a
=> MutableContext a
-> Ptr (MerkleHash a)
-> Ptr (MerkleHash a)
-> Ptr (MerkleHash a)
-> IO ()
merkleNodePtr !ctx !a !b !r = do
hashMutableReset ctx
hashMutableUpdate ctx (nodeTag @BA.Bytes)
BA.withByteArray ctx $ \ctxPtr -> do
hashInternalUpdate @a ctxPtr (castPtr a) (hashSize @a)
hashInternalUpdate ctxPtr (castPtr b) (hashSize @a)
hashInternalFinalize ctxPtr (castPtr r)
merkleLeafPtr
:: forall a b
. HashAlgorithm a
=> BA.ByteArrayAccess b
=> MutableContext a
-> b
-> Ptr (MerkleHash a)
-> IO ()
merkleLeafPtr !ctx !b !r = do
hashMutableReset ctx
hashMutableUpdate ctx (leafTag @BA.Bytes)
hashMutableUpdate ctx b
BA.withByteArray ctx $ \ctxPtr ->
hashInternalFinalize @a ctxPtr (castPtr r)
data MerkleNodeType a b
= TreeNode (MerkleRoot a)
| InputNode b
deriving (Show, Eq, Ord, Generic, Functor)
deriving anyclass (NFData)
newtype MerkleTree a = MerkleTree BA.Bytes
deriving (Eq, Generic)
deriving newtype (NFData, BA.ByteArrayAccess)
instance Show (MerkleTree a) where
show = fmap (toEnum . fromEnum)
. BA.unpack @BA.Bytes
. convertToBase Base64URLUnpadded
{-# INLINEABLE show #-}
merkleTree
:: forall a b
. HasCallStack
=> HashAlgorithm a
=> BA.ByteArrayAccess b
=> [MerkleNodeType a b]
-> MerkleTree a
merkleTree [] = MerkleTree $ BA.convert $ hash @_ @a (mempty @B.ByteString)
merkleTree !items = MerkleTree $ BA.allocAndFreeze (tsize * hashSize @a) $ \ptr -> do
!ctx <- hashMutableInit @a
let
go
:: Ptr (MerkleHash a)
-> [MerkleNodeType a b]
-> [(Int, Ptr (MerkleHash a))]
-> IO ()
go !i t ((!a, !ia) : (!b, !ib) : s) | a == b = do
merkleNodePtr ctx ib ia i
go (i `plusPtr` hs) t ((succ a, i) : s)
go !i (InputNode h : t) !s = do
merkleLeafPtr ctx h i
go (i `plusPtr` hs) t ((0, i) : s)
go !i (TreeNode h : t) !s = do
BA.copyByteArrayToPtr h i
go (i `plusPtr` hs) t ((0, i) : s)
go !i [] ((!a, !ia) : (!_, !ib) : s) = do
merkleNodePtr ctx ib ia i
go (i `plusPtr` hs) [] ((succ a, i) : s)
go _ [] [_] = return ()
go _ [] [] = error "code invariant violation"
go ptr items []
where
!isize = length items
!tsize = isize + (isize - 1)
!hs = hashSize @a
isEmpty :: forall a . HashAlgorithm a => MerkleTree a -> Bool
isEmpty = BA.constEq (emptyMerkleTree @a)
{-# INLINE isEmpty #-}
emptyMerkleTree :: forall a . HashAlgorithm a => MerkleTree a
emptyMerkleTree = merkleTree @a ([] @(MerkleNodeType a B.ByteString))
{-# INLINEABLE emptyMerkleTree #-}
encodeMerkleTree :: BA.ByteArray b => MerkleTree a -> b
encodeMerkleTree = BA.convert
{-# INLINE encodeMerkleTree #-}
size :: forall a . HashAlgorithm a => MerkleTree a -> Int
size t = BA.length t `div` hashSize @a
{-# INLINE size #-}
decodeMerkleTree
:: forall a b m
. MonadThrow m
=> HashAlgorithm a
=> BA.ByteArrayAccess b
=> b
-> m (MerkleTree a)
decodeMerkleTree b
| BA.length b `mod` hashSize @a == 0 = return $ MerkleTree $ BA.convert b
| otherwise = throwM $ EncodingSizeConstraintException
"MerkleTree"
(Expected $ "multiple of " <> sshow (hashSize @a @Int))
(Actual $ BA.length b)
{-# INLINE decodeMerkleTree #-}
newtype MerkleRoot a = MerkleRoot (MerkleHash a)
deriving (Eq, Ord, Generic)
deriving newtype (Show, NFData, BA.ByteArrayAccess)
merkleRoot :: forall a . HashAlgorithm a => MerkleTree a -> MerkleRoot a
merkleRoot t = MerkleRoot $ getHash t (size t - 1)
{-# INLINE merkleRoot #-}
encodeMerkleRoot :: BA.ByteArray b => MerkleRoot a -> b
encodeMerkleRoot = BA.convert
{-# INLINE encodeMerkleRoot #-}
decodeMerkleRoot
:: MonadThrow m
=> HashAlgorithm a
=> BA.ByteArrayAccess b
=> b
-> m (MerkleRoot a)
decodeMerkleRoot = fmap MerkleRoot . decodeMerkleHash
{-# INLINE decodeMerkleRoot #-}
newtype MerkleProofObject a = MerkleProofObject BA.Bytes
deriving (Eq, Generic)
deriving anyclass (NFData)
deriving newtype (BA.ByteArrayAccess)
instance Show (MerkleProofObject a) where
show = fmap (toEnum . fromEnum)
. BA.unpack @BA.Bytes
. convertToBase @_ @BA.Bytes Base64URLUnpadded
{-# INLINEABLE show #-}
encodeMerkleProofObject :: BA.ByteArray b => MerkleProofObject a -> b
encodeMerkleProofObject = BA.convert
{-# INLINE encodeMerkleProofObject #-}
decodeMerkleProofObject
:: forall a b m
. MonadThrow m
=> HashAlgorithm a
=> BA.ByteArrayAccess b
=> b
-> m (MerkleProofObject a)
decodeMerkleProofObject bytes
| BA.length bytes < 12 = throwM
$ EncodingSizeConstraintException
"MerkleProofObject"
(Expected "larger than 12")
(Actual $ BA.length bytes)
| BA.length bytes /= proofObjectSizeInBytes @a stepCount = throwM
$ EncodingSizeException
"MerkleProofObject"
(Expected $ proofObjectSizeInBytes @a stepCount)
(Actual $ BA.length bytes)
| otherwise = return $ MerkleProofObject $ BA.convert bytes
where
stepCount = fromIntegral $ BA.fromBE $ peekBA @(BA.BE Word32) bytes
stepSize :: forall a . HashAlgorithm a => Int
stepSize = hashSize @a + 1
{-# INLINE stepSize #-}
proofObjectSizeInBytes :: forall a . HashAlgorithm a => Int -> Int
proofObjectSizeInBytes stepCount = stepSize @a * stepCount + 12
{-# INLINE proofObjectSizeInBytes #-}
newtype MerkleProofSubject a = MerkleProofSubject
{ _getMerkleProofSubject :: (MerkleNodeType a B.ByteString) }
deriving (Show, Eq, Ord, Generic)
deriving anyclass (NFData)
data MerkleProof a = MerkleProof
{ _merkleProofSubject :: !(MerkleProofSubject a)
, _merkleProofObject :: !(MerkleProofObject a)
}
deriving (Show, Eq, Generic)
deriving anyclass (NFData)
merkleProof
:: forall a m
. MonadThrow m
=> HashAlgorithm a
=> MerkleNodeType a B.ByteString
-> Int
-> MerkleTree a
-> m (MerkleProof a)
merkleProof a pos t
| pos < 0 || pos >= leafCount t = throwM $ IndexOutOfBoundsException
"merkleProof"
(Expected (0,leafCount t - 1))
(Actual pos)
| not (BA.constEq (view t tpos) (inputHash a)) = throwM
$ inputNotInTreeException "merkleProof" pos a
| otherwise = return $ MerkleProof
{ _merkleProofSubject = MerkleProofSubject a
, _merkleProofObject = MerkleProofObject go
}
where
inputHash (InputNode bytes) = merkleLeaf @a bytes
inputHash (TreeNode (MerkleRoot bytes)) = bytes
(tpos, path) = proofPath pos (leafCount t)
go = BA.allocAndFreeze (proofObjectSizeInBytes @a (length path)) $ \ptr -> do
pokeBE @Word32 ptr $ fromIntegral $ length path
pokeBE @Word64 (ptr `plusPtr` 4) (fromIntegral pos)
let pathPtr = ptr `plusPtr` 12
forM_ (path `zip` [0, fromIntegral (stepSize @a) ..]) $ \((s, i), x) -> do
poke (pathPtr `plusPtr` x) (sideWord8 s)
BA.copyByteArrayToPtr (view t i) (pathPtr `plusPtr` succ x)
merkleProof_
:: forall a m
. MonadThrow m
=> HashAlgorithm a
=> MerkleNodeType a B.ByteString
-> NE.NonEmpty (Int, MerkleTree a)
-> m (MerkleProof a)
merkleProof_ a l
= MerkleProof (MerkleProofSubject a) . MerkleProofObject . assemble <$> go a (NE.toList l)
where
go _ [] = return []
go sub ((pos, tree) : t) = do
MerkleProof (MerkleProofSubject _) (MerkleProofObject o) <- merkleProof sub pos tree
(:) (strip o) <$> go (TreeNode $ merkleRoot tree) t
strip o = (peekBeBA o :: Word32, BA.drop 12 o)
assemble ps =
let (s, os) = unzip ps
in BA.concat
$ BA.allocAndFreeze 4 (flip pokeBE $ sum s)
: BA.allocAndFreeze 8 (flip (pokeBE @Word64) $ fromIntegral $ fst $ NE.head l)
: os
proofPath
:: Int
-> Int
-> (Int, [(Side, Int)])
proofPath b c = go 0 0 b c []
where
go _ !treeOff _ 1 !acc = (treeOff, acc)
go !logOff !treeOff !m !n !acc
| m < k = go logOff treeOff m k $ (R, treeOff + 2 * n - 3) : acc
| otherwise = go (logOff + k) (treeOff + 2 * k - 1) (m - k) (n - k)
$ (L, treeOff + 2 * k - 2) : acc
where
k = k2 n
runMerkleProof :: forall a . HashAlgorithm a => MerkleProof a -> MerkleRoot a
runMerkleProof p = MerkleRoot $ MerkleHash $ runMerkleProofInternal @a subj obj
where
MerkleProofSubject subj = _merkleProofSubject p
MerkleProofObject obj = _merkleProofObject p
runMerkleProofInternal
:: forall a b c d
. HashAlgorithm a
=> BA.ByteArrayAccess b
=> BA.ByteArrayAccess c
=> BA.ByteArray d
=> MerkleNodeType a b
-> c
-> d
runMerkleProofInternal subj obj = BA.allocAndFreeze (hashSize @a) $ \ptr -> do
ctx <- hashMutableInit @a
case subj of
InputNode x -> merkleLeafPtr ctx x ptr
TreeNode x -> BA.copyByteArrayToPtr x ptr
BA.withByteArray obj $ \objPtr -> do
stepCount <- fromIntegral <$> peekBE @Word32 objPtr
forM_ [0 .. stepCount - 1] $ \(i :: Int) -> do
let off = 12 + i * stepSize @a
peekByteOff @Word8 objPtr off >>= \case
0x00 -> merkleNodePtr ctx (objPtr `plusPtr` succ off) ptr ptr
0x01 -> merkleNodePtr ctx ptr (objPtr `plusPtr` succ off) ptr
_ -> throwM $ InvalidProofObjectException "runMerkleProofInternal"
k2 :: Int -> Int
k2 i = 2 ^ floor @Double @Int (logBase 2 $ fromIntegral i - 1)
{-# INLINE k2 #-}
data Side = L | R
deriving (Show, Eq)
sideWord8 :: Side -> Word8
sideWord8 L = 0x00
sideWord8 R = 0x01
{-# INLINE sideWord8 #-}
view :: forall a . HashAlgorithm a => MerkleTree a -> Int -> BA.View BA.Bytes
view (MerkleTree v) i = BA.view v (i * hashSize @a) (hashSize @a)
{-# INLINE view #-}
getHash :: HashAlgorithm a => MerkleTree a -> Int -> MerkleHash a
getHash t = MerkleHash . BA.convert . view t
{-# INLINE getHash #-}
leafCount :: HashAlgorithm a => MerkleTree a -> Int
leafCount t
| isEmpty t = 0
| otherwise = 1 + size t `div` 2
{-# INLINE leafCount #-}
peekBE :: forall a . BA.ByteSwap a => Storable a => Ptr (BA.BE a) -> IO a
peekBE ptr = BA.fromBE <$> peek @(BA.BE a) ptr
{-# INLINE peekBE #-}
pokeBE :: forall a . BA.ByteSwap a => Storable a => Ptr (BA.BE a) -> a -> IO ()
pokeBE ptr = poke ptr . BA.toBE @a
{-# INLINE pokeBE #-}
peekBA :: forall a b . Storable a => BA.ByteArrayAccess b => b -> a
peekBA bytes = unsafePerformIO $ BA.withByteArray bytes (peek @a)
{-# INLINE peekBA #-}
peekBeBA :: forall a b . BA.ByteSwap a => Storable a => BA.ByteArrayAccess b => b -> a
peekBeBA = BA.fromBE . peekBA @(BA.BE a)
{-# INLINE peekBeBA #-}
b64 :: BA.ByteArrayAccess a => a -> T.Text
b64 = T.decodeUtf8 . convertToBase Base64URLUnpadded
{-# INLINE b64 #-}
sshow :: Show a => IsString b => a -> b
sshow = fromString . show
{-# INLINE sshow #-}