{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
module Network.Haskoin.Block.Merkle
( MerkleBlock(..)
, MerkleRoot
, FlagBits
, PartialMerkleTree
, buildMerkleRoot
, merkleBlockTxs
, testMerkleRoot
, buildPartialMerkle
, decodeMerkleFlags
, encodeMerkleFlags
, calcTreeHeight
, calcTreeWidth
, hash2
, calcHash
, traverseAndBuild
, traverseAndExtract
, extractMatches
, splitIn
, boolsToWord8
) where
import Control.Monad (forM_, replicateM, when)
import Data.Bits
import qualified Data.ByteString as BS
import Data.Either (isRight)
import Data.Hashable
import Data.Maybe
import Data.Serialize (Serialize, encode, get,
put)
import Data.Serialize.Get (getWord32le, getWord8)
import Data.Serialize.Put (putWord32le, putWord8)
import Data.Word (Word32, Word8)
import GHC.Generics
import Network.Haskoin.Block.Common
import Network.Haskoin.Constants
import Network.Haskoin.Crypto.Hash
import Network.Haskoin.Network.Common
import Network.Haskoin.Transaction.Common
type MerkleRoot = Hash256
type FlagBits = [Bool]
type PartialMerkleTree = [Hash256]
data MerkleBlock =
MerkleBlock {
merkleHeader :: !BlockHeader
, merkleTotalTxns :: !Word32
, mHashes :: !PartialMerkleTree
, mFlags :: !FlagBits
} deriving (Eq, Show, Read, Generic, Hashable)
instance Serialize MerkleBlock where
get = do
header <- get
ntx <- getWord32le
(VarInt matchLen) <- get
hashes <- replicateM (fromIntegral matchLen) get
(VarInt flagLen) <- get
ws <- replicateM (fromIntegral flagLen) getWord8
return $ MerkleBlock header ntx hashes (decodeMerkleFlags ws)
put (MerkleBlock h ntx hashes flags) = do
put h
putWord32le ntx
putVarInt $ length hashes
forM_ hashes put
let ws = encodeMerkleFlags flags
putVarInt $ length ws
forM_ ws putWord8
decodeMerkleFlags :: [Word8] -> FlagBits
decodeMerkleFlags ws =
[ b | p <- [ 0 .. length ws * 8 - 1 ]
, b <- [ testBit (ws !! (p `div` 8)) (p `mod` 8) ]
]
encodeMerkleFlags :: FlagBits -> [Word8]
encodeMerkleFlags bs = map boolsToWord8 $ splitIn 8 bs
calcTreeHeight :: Int
-> Int
calcTreeHeight ntx | ntx < 2 = 0
| even ntx = 1 + calcTreeHeight (ntx `div` 2)
| otherwise = calcTreeHeight $ ntx + 1
calcTreeWidth :: Int
-> Int
-> Int
calcTreeWidth ntx h = (ntx + (1 `shiftL` h) - 1) `shiftR` h
buildMerkleRoot :: [TxHash]
-> MerkleRoot
buildMerkleRoot txs = calcHash (calcTreeHeight $ length txs) 0 txs
hash2 :: Hash256 -> Hash256 -> Hash256
hash2 a b = doubleSHA256 $ encode a `BS.append` encode b
calcHash :: Int
-> Int
-> [TxHash]
-> Hash256
calcHash height pos txs
| height < 0 || pos < 0 = error "calcHash: Invalid parameters"
| height == 0 = getTxHash $ txs !! pos
| otherwise = hash2 left right
where
left = calcHash (height-1) (pos*2) txs
right | pos*2+1 < calcTreeWidth (length txs) (height-1) =
calcHash (height-1) (pos*2+1) txs
| otherwise = left
buildPartialMerkle ::
[(TxHash, Bool)]
-> (FlagBits, PartialMerkleTree)
buildPartialMerkle hs = traverseAndBuild (calcTreeHeight $ length hs) 0 hs
traverseAndBuild ::
Int -> Int -> [(TxHash, Bool)] -> (FlagBits, PartialMerkleTree)
traverseAndBuild height pos txs
| height < 0 || pos < 0 = error "traverseAndBuild: Invalid parameters"
| height == 0 || not match = ([match], [calcHash height pos t])
| otherwise = (match : lb ++ rb, lh ++ rh)
where
t = map fst txs
s = pos `shiftL` height
e = min (length txs) $ (pos + 1) `shiftL` height
match = any snd $ take (e - s) $ drop s txs
(lb, lh) = traverseAndBuild (height - 1) (pos * 2) txs
(rb, rh)
| (pos * 2 + 1) < calcTreeWidth (length txs) (height - 1) =
traverseAndBuild (height - 1) (pos * 2 + 1) txs
| otherwise = ([], [])
traverseAndExtract :: Int -> Int -> Int -> FlagBits -> PartialMerkleTree
-> Maybe (MerkleRoot, [TxHash], Int, Int)
traverseAndExtract height pos ntx flags hashes
| null flags = Nothing
| height == 0 || not match = leafResult
| isNothing leftM = Nothing
| (pos*2+1) >= calcTreeWidth ntx (height-1) =
Just (hash2 lh lh, lm, lcf+1, lch)
| isNothing rightM = Nothing
| otherwise =
Just (hash2 lh rh, lm ++ rm, lcf+rcf+1, lch+rch)
where
leafResult
| null hashes = Nothing
| otherwise = Just (h, [ TxHash h | height == 0 && match ], 1, 1)
(match:fs) = flags
(h:_) = hashes
leftM = traverseAndExtract (height-1) (pos*2) ntx fs hashes
(lh,lm,lcf,lch) = fromMaybe e leftM
rightM = traverseAndExtract (height-1) (pos*2+1) ntx
(drop lcf fs) (drop lch hashes)
(rh,rm,rcf,rch) = fromMaybe e rightM
e = error "traverseAndExtract: unexpected error extracting a Maybe value"
extractMatches :: Network
-> FlagBits
-> PartialMerkleTree
-> Int
-> Either String (MerkleRoot, [TxHash])
extractMatches net flags hashes ntx
| ntx == 0 = Left
"extractMatches: number of transactions can not be 0"
| ntx > getMaxBlockSize net `div` 60 = Left
"extractMatches: number of transactions excessively high"
| length hashes > ntx = Left
"extractMatches: More hashes provided than the number of transactions"
| length flags < length hashes = Left
"extractMatches: At least one bit per node and one bit per hash"
| isNothing resM = Left
"extractMatches: traverseAndExtract failed"
| (nBitsUsed+7) `div` 8 /= (length flags+7) `div` 8 = Left
"extractMatches: All bits were not consumed"
| nHashUsed /= length hashes = Left $
"extractMatches: All hashes were not consumed: " ++ show nHashUsed
| otherwise = return (merkRoot, matches)
where
resM = traverseAndExtract (calcTreeHeight ntx) 0 ntx flags hashes
(merkRoot, matches, nBitsUsed, nHashUsed) = fromMaybe e resM
e = error "extractMatches: unexpected error extracting a Maybe value"
splitIn :: Int -> [a] -> [[a]]
splitIn _ [] = []
splitIn c xs = xs1 : splitIn c xs2
where
(xs1, xs2) = splitAt c xs
boolsToWord8 :: [Bool] -> Word8
boolsToWord8 [] = 0
boolsToWord8 xs = foldl setBit 0 (map snd $ filter fst $ zip xs [0..7])
merkleBlockTxs :: Network -> MerkleBlock -> Either String [TxHash]
merkleBlockTxs net b =
let flags = mFlags b
hs = mHashes b
n = fromIntegral $ merkleTotalTxns b
merkle = merkleRoot $ merkleHeader b
in do (root, ths) <- extractMatches net flags hs n
when (root /= merkle) $ Left "merkleBlockTxs: Merkle root incorrect"
return ths
testMerkleRoot :: Network -> MerkleBlock -> Bool
testMerkleRoot net = isRight . merkleBlockTxs net