{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE OverloadedStrings #-}

module Tahoe.CHK.Merkle (
    MerkleTree (MerkleNode, MerkleLeaf),
    Direction (..),
    leaf,
    leafNumberToNodeNumber,
    breadthFirstList,
    merklePathLengthForSize,
    makeTree,
    makeTreePartial,
    merkleProof,
    neededHashes,
    firstLeafNum,
    rootHash,
    pairHash,
    emptyLeafHash,
    size,
    height,
    mapTree,
    merklePath,
    leafHashes,
    -- exported for testing in ghci
    treeFromRows,
    buildTreeOutOfAllTheNodes,
) where

import Data.Binary (Binary (get, put))
import Data.Binary.Get (getRemainingLazyByteString)
import Data.Binary.Put (putByteString)
import Data.TreeDiff.Class (ToExpr)
import GHC.Generics (Generic)

import Data.List.HT (
    padLeft,
 )
import Data.Tuple.HT (
    mapFst,
 )

import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LBS

import Data.Text (
    pack,
 )
import qualified Data.Text as T
import Data.Text.Encoding (
    encodeUtf8,
 )

import Data.ByteString.Base32 (
    encodeBase32Unpadded,
 )

import Tahoe.CHK.Crypto (
    taggedHash,
    taggedPairHash,
 )

import Crypto.Hash (HashAlgorithm (hashDigestSize))
import Crypto.Hash.Algorithms (SHA256 (SHA256))
import Tahoe.Util (
    chunkedBy,
    nextPowerOf,
    toBinary,
 )

data MerkleTree
    = MerkleLeaf B.ByteString
    | MerkleNode B.ByteString MerkleTree MerkleTree
    deriving (MerkleTree -> MerkleTree -> Bool
(MerkleTree -> MerkleTree -> Bool)
-> (MerkleTree -> MerkleTree -> Bool) -> Eq MerkleTree
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MerkleTree -> MerkleTree -> Bool
$c/= :: MerkleTree -> MerkleTree -> Bool
== :: MerkleTree -> MerkleTree -> Bool
$c== :: MerkleTree -> MerkleTree -> Bool
Eq, Eq MerkleTree
Eq MerkleTree
-> (MerkleTree -> MerkleTree -> Ordering)
-> (MerkleTree -> MerkleTree -> Bool)
-> (MerkleTree -> MerkleTree -> Bool)
-> (MerkleTree -> MerkleTree -> Bool)
-> (MerkleTree -> MerkleTree -> Bool)
-> (MerkleTree -> MerkleTree -> MerkleTree)
-> (MerkleTree -> MerkleTree -> MerkleTree)
-> Ord MerkleTree
MerkleTree -> MerkleTree -> Bool
MerkleTree -> MerkleTree -> Ordering
MerkleTree -> MerkleTree -> MerkleTree
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MerkleTree -> MerkleTree -> MerkleTree
$cmin :: MerkleTree -> MerkleTree -> MerkleTree
max :: MerkleTree -> MerkleTree -> MerkleTree
$cmax :: MerkleTree -> MerkleTree -> MerkleTree
>= :: MerkleTree -> MerkleTree -> Bool
$c>= :: MerkleTree -> MerkleTree -> Bool
> :: MerkleTree -> MerkleTree -> Bool
$c> :: MerkleTree -> MerkleTree -> Bool
<= :: MerkleTree -> MerkleTree -> Bool
$c<= :: MerkleTree -> MerkleTree -> Bool
< :: MerkleTree -> MerkleTree -> Bool
$c< :: MerkleTree -> MerkleTree -> Bool
compare :: MerkleTree -> MerkleTree -> Ordering
$ccompare :: MerkleTree -> MerkleTree -> Ordering
$cp1Ord :: Eq MerkleTree
Ord, (forall x. MerkleTree -> Rep MerkleTree x)
-> (forall x. Rep MerkleTree x -> MerkleTree) -> Generic MerkleTree
forall x. Rep MerkleTree x -> MerkleTree
forall x. MerkleTree -> Rep MerkleTree x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MerkleTree x -> MerkleTree
$cfrom :: forall x. MerkleTree -> Rep MerkleTree x
Generic, [MerkleTree] -> Expr
MerkleTree -> Expr
(MerkleTree -> Expr) -> ([MerkleTree] -> Expr) -> ToExpr MerkleTree
forall a. (a -> Expr) -> ([a] -> Expr) -> ToExpr a
listToExpr :: [MerkleTree] -> Expr
$clistToExpr :: [MerkleTree] -> Expr
toExpr :: MerkleTree -> Expr
$ctoExpr :: MerkleTree -> Expr
ToExpr)

{- | A constructor for a MerkleLeaf that enforces correct byte string length
 (error on incorrect length).
-}
leaf :: B.ByteString -> MerkleTree
leaf :: ByteString -> MerkleTree
leaf ByteString
bs
    | ByteString -> Int
B.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = ByteString -> MerkleTree
MerkleLeaf ByteString
bs
    | Bool
otherwise = [Char] -> MerkleTree
forall a. HasCallStack => [Char] -> a
error ([Char] -> MerkleTree) -> [Char] -> MerkleTree
forall a b. (a -> b) -> a -> b
$ [Char]
"Constructed MerkleLeaf with hash of length " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show (ByteString -> Int
B.length ByteString
bs)

-- | Count the number of nodes in a tree.
size :: MerkleTree -> Int
size :: MerkleTree -> Int
size = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> (MerkleTree -> [Int]) -> MerkleTree -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (MerkleTree -> Int) -> MerkleTree -> [Int]
forall a. (MerkleTree -> a) -> MerkleTree -> [a]
mapTree (Int -> MerkleTree -> Int
forall a b. a -> b -> a
const Int
1)

-- | Measure the height of a tree.
height :: MerkleTree -> Int
height :: MerkleTree -> Int
height (MerkleLeaf ByteString
_) = Int
1
height (MerkleNode ByteString
_ MerkleTree
left MerkleTree
_) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ MerkleTree -> Int
height MerkleTree
left

mapTree :: (MerkleTree -> a) -> MerkleTree -> [a]
mapTree :: (MerkleTree -> a) -> MerkleTree -> [a]
mapTree MerkleTree -> a
f l :: MerkleTree
l@(MerkleLeaf ByteString
_) = [MerkleTree -> a
f MerkleTree
l]
mapTree MerkleTree -> a
f n :: MerkleTree
n@(MerkleNode ByteString
_ MerkleTree
left MerkleTree
right) = MerkleTree -> a
f MerkleTree
n a -> [a] -> [a]
forall a. a -> [a] -> [a]
: (MerkleTree -> a) -> MerkleTree -> [a]
forall a. (MerkleTree -> a) -> MerkleTree -> [a]
mapTree MerkleTree -> a
f MerkleTree
left [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ (MerkleTree -> a) -> MerkleTree -> [a]
forall a. (MerkleTree -> a) -> MerkleTree -> [a]
mapTree MerkleTree -> a
f MerkleTree
right

instance Show MerkleTree where
    show :: MerkleTree -> [Char]
show (MerkleLeaf ByteString
value) =
        Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
T.concat [Text
"MerkleLeaf ", ByteString -> Text
encodeBase32Unpadded ByteString
value]
    show (MerkleNode ByteString
value MerkleTree
left MerkleTree
right) =
        Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$
            [Text] -> Text
T.concat
                [ Text
"MerkleNode " :: T.Text
                , ByteString -> Text
encodeBase32Unpadded ByteString
value
                , Text
" ("
                , [Char] -> Text
T.pack ([Char] -> Text) -> [Char] -> Text
forall a b. (a -> b) -> a -> b
$ MerkleTree -> [Char]
forall a. Show a => a -> [Char]
show MerkleTree
left
                , Text
")"
                , Text
" ("
                , [Char] -> Text
T.pack ([Char] -> Text) -> [Char] -> Text
forall a b. (a -> b) -> a -> b
$ MerkleTree -> [Char]
forall a. Show a => a -> [Char]
show MerkleTree
right
                , Text
")"
                ]

emptyLeafHash :: Int -> B.ByteString
emptyLeafHash :: Int -> ByteString
emptyLeafHash = Int -> ByteString -> ByteString -> ByteString
taggedHash (SHA256 -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize SHA256
SHA256) ByteString
"Merkle tree empty leaf" (ByteString -> ByteString)
-> (Int -> ByteString) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> ByteString
encodeUtf8 (Text -> ByteString) -> (Int -> Text) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Text
pack ([Char] -> Text) -> (Int -> [Char]) -> Int -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Char]
forall a. Show a => a -> [Char]
show

pairHash :: B.ByteString -> B.ByteString -> B.ByteString
pairHash :: ByteString -> ByteString -> ByteString
pairHash = Int -> ByteString -> ByteString -> ByteString -> ByteString
taggedPairHash (SHA256 -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize SHA256
SHA256) ByteString
"Merkle tree internal node"

rootHash :: MerkleTree -> B.ByteString
rootHash :: MerkleTree -> ByteString
rootHash (MerkleLeaf ByteString
value) = ByteString
value
rootHash (MerkleNode ByteString
value MerkleTree
_ MerkleTree
_) = ByteString
value

-- Like makeTree but error on empty list
makeTreePartial :: [B.ByteString] -> MerkleTree
makeTreePartial :: [ByteString] -> MerkleTree
makeTreePartial = Maybe MerkleTree -> MerkleTree
forall p. Maybe p -> p
unJust (Maybe MerkleTree -> MerkleTree)
-> ([ByteString] -> Maybe MerkleTree) -> [ByteString] -> MerkleTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> Maybe MerkleTree
makeTree
  where
    unJust :: Maybe p -> p
unJust Maybe p
Nothing = [Char] -> p
forall a. HasCallStack => [Char] -> a
error [Char]
"Merkle.makeTreePartial failed to make a tree"
    unJust (Just p
t) = p
t

-- Make a merkle tree for the given values.  Extra values are generated to
-- fill the tree if necessary.  The given values are the values of the leaf
-- nodes.
makeTree :: [B.ByteString] -> Maybe MerkleTree
makeTree :: [ByteString] -> Maybe MerkleTree
makeTree [] = Maybe MerkleTree
forall a. Maybe a
Nothing
makeTree [ByteString]
leaves =
    MerkleTree -> Maybe MerkleTree
forall a. a -> Maybe a
Just (MerkleTree -> Maybe MerkleTree) -> MerkleTree -> Maybe MerkleTree
forall a b. (a -> b) -> a -> b
$ [ByteString] -> MerkleTree
makeTree' ([ByteString] -> [ByteString]
pad [ByteString]
leaves)
  where
    -- Pad the leaves out to the next power of two so the tree is full.
    pad :: [B.ByteString] -> [B.ByteString]
    pad :: [ByteString] -> [ByteString]
pad [ByteString]
leaves' = [ByteString]
leaves' [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ Int -> [ByteString]
padding ([ByteString] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
leaves')

    -- Create the padding for the pad function.  The number of leaves in the
    -- tree must be a power of 2 (a height zero tree has 2 ^ 0 leaves, a
    -- height one tree has 2 ^ 1 leaves, etc) so compute a number of empty
    -- leaves that when added to the non-empty leaves gives us a power of 2.
    -- This could be none if we happened to already have a number of leaves
    -- that is a power of 2.
    --
    -- This function assumes that the number of non-empty leaves is at least
    -- half the number of total leaves.  If it is fewer it will create less
    -- padding than necessary.  This should be reasonable since if there fewer
    -- leaves then a smaller tree could hold them all.
    padding :: Int -> [B.ByteString]
    padding :: Int -> [ByteString]
padding Int
numLeaves = Int -> ByteString
emptyLeafHash (Int -> ByteString) -> [Int] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
numLeaves .. Int -> Int -> Int
forall p. (Ord p, Num p) => p -> p -> p
nextPowerOf Int
2 Int
numLeaves Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    -- Turn a length-of-power-of-2 list into a tree
    makeTree' :: [B.ByteString] -> MerkleTree
    makeTree' :: [ByteString] -> MerkleTree
makeTree' [ByteString
x] = ByteString -> MerkleTree
leaf ByteString
x
    makeTree' [ByteString]
xs =
        MerkleTree -> MerkleTree -> MerkleTree
makeNode ([ByteString] -> MerkleTree
makeTree' [ByteString]
left) ([ByteString] -> MerkleTree
makeTree' [ByteString]
right)
      where
        ([ByteString]
left, [ByteString]
right) = Int -> [ByteString] -> ([ByteString], [ByteString])
forall a. Int -> [a] -> ([a], [a])
splitAt ([ByteString] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
xs Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [ByteString]
xs

    -- Make a parent node referencing two given child nodes, calculating the
    -- parent node's hash in the process.
    makeNode :: MerkleTree -> MerkleTree -> MerkleTree
    makeNode :: MerkleTree -> MerkleTree -> MerkleTree
makeNode MerkleTree
left MerkleTree
right =
        ByteString -> MerkleTree -> MerkleTree -> MerkleTree
MerkleNode (ByteString -> ByteString -> ByteString
pairHash (MerkleTree -> ByteString
rootHash MerkleTree
left) (MerkleTree -> ByteString
rootHash MerkleTree
right)) MerkleTree
left MerkleTree
right

-- | Represent a direction to take when walking down a binary tree.
data Direction = TurnLeft | TurnRight deriving (Int -> Direction -> [Char] -> [Char]
[Direction] -> [Char] -> [Char]
Direction -> [Char]
(Int -> Direction -> [Char] -> [Char])
-> (Direction -> [Char])
-> ([Direction] -> [Char] -> [Char])
-> Show Direction
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Direction] -> [Char] -> [Char]
$cshowList :: [Direction] -> [Char] -> [Char]
show :: Direction -> [Char]
$cshow :: Direction -> [Char]
showsPrec :: Int -> Direction -> [Char] -> [Char]
$cshowsPrec :: Int -> Direction -> [Char] -> [Char]
Show, Eq Direction
Eq Direction
-> (Direction -> Direction -> Ordering)
-> (Direction -> Direction -> Bool)
-> (Direction -> Direction -> Bool)
-> (Direction -> Direction -> Bool)
-> (Direction -> Direction -> Bool)
-> (Direction -> Direction -> Direction)
-> (Direction -> Direction -> Direction)
-> Ord Direction
Direction -> Direction -> Bool
Direction -> Direction -> Ordering
Direction -> Direction -> Direction
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Direction -> Direction -> Direction
$cmin :: Direction -> Direction -> Direction
max :: Direction -> Direction -> Direction
$cmax :: Direction -> Direction -> Direction
>= :: Direction -> Direction -> Bool
$c>= :: Direction -> Direction -> Bool
> :: Direction -> Direction -> Bool
$c> :: Direction -> Direction -> Bool
<= :: Direction -> Direction -> Bool
$c<= :: Direction -> Direction -> Bool
< :: Direction -> Direction -> Bool
$c< :: Direction -> Direction -> Bool
compare :: Direction -> Direction -> Ordering
$ccompare :: Direction -> Direction -> Ordering
$cp1Ord :: Eq Direction
Ord, Direction -> Direction -> Bool
(Direction -> Direction -> Bool)
-> (Direction -> Direction -> Bool) -> Eq Direction
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Direction -> Direction -> Bool
$c/= :: Direction -> Direction -> Bool
== :: Direction -> Direction -> Bool
$c== :: Direction -> Direction -> Bool
Eq)

{- | Return a list of tuples of node numbers and corresponding merkle hashes.
 The node numbers correspond to a numbering of the nodes in the tree where the
 root node is numbered 1, each node's left child is the node's number times
 two, and the node's right child is the node's number times two plus one.
-}
merkleProof :: MerkleTree -> Int -> Maybe [(Int, B.ByteString)]
merkleProof :: MerkleTree -> Int -> Maybe [(Int, ByteString)]
merkleProof MerkleTree
tree Int
targetLeaf = Int -> MerkleTree -> [Direction] -> Maybe [(Int, ByteString)]
merkleProof' Int
1 MerkleTree
tree ([Direction] -> Maybe [(Int, ByteString)])
-> [Direction] -> Maybe [(Int, ByteString)]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> [Direction]
merklePath (MerkleTree -> Int
height MerkleTree
tree) Int
targetLeaf

{- | Compute the path to a leaf from the root of a merkle tree of a certain
 height.
-}
merklePath :: Int -> Int -> [Direction]
merklePath :: Int -> Int -> [Direction]
merklePath Int
height' Int
leafNum = Direction -> Int -> [Direction] -> [Direction]
forall a. a -> Int -> [a] -> [a]
padLeft Direction
TurnLeft (Int
height' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Direction -> Direction -> Int -> [Direction]
forall a. a -> a -> Int -> [a]
toBinary Direction
TurnLeft Direction
TurnRight Int
leafNum)

-- | Compute the length of a merkle path through a tree of the given height.
merklePathLengthForSize :: Int -> Int
merklePathLengthForSize :: Int -> Int
merklePathLengthForSize Int
size' = Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Double -> Int) -> (Int -> Double) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> Double -> Double
forall a. Floating a => a -> a -> a
logBase (Double
2 :: Double) (Double -> Double) -> (Int -> Double) -> Int -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int
forall p. (Ord p, Num p) => p -> p -> p
nextPowerOf Int
2 Int
size'

-- Convert a tree to a breadth-first list of its hash values.
breadthFirstList :: MerkleTree -> [B.ByteString]
breadthFirstList :: MerkleTree -> [ByteString]
breadthFirstList MerkleTree
tree = [MerkleTree] -> [ByteString]
traverse' [MerkleTree
tree]
  where
    traverse' :: [MerkleTree] -> [B.ByteString]
    traverse' :: [MerkleTree] -> [ByteString]
traverse' [] = []
    traverse' [MerkleTree]
trees =
        [MerkleTree -> ByteString
rootHash MerkleTree
tree' | MerkleTree
tree' <- [MerkleTree]
trees] [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [MerkleTree] -> [ByteString]
traverse' ([[MerkleTree]] -> [MerkleTree]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [MerkleTree -> [MerkleTree]
children MerkleTree
tree'' | MerkleTree
tree'' <- [MerkleTree]
trees])

    children :: MerkleTree -> [MerkleTree]
children (MerkleLeaf ByteString
_) = []
    children (MerkleNode ByteString
_ MerkleTree
left MerkleTree
right) = [MerkleTree
left, MerkleTree
right]

{- | Construct Just a merkle proof along the pre-computed path or Nothing if
 the path runs past the leaves of the tree.
-}
merkleProof' :: Int -> MerkleTree -> [Direction] -> Maybe [(Int, B.ByteString)]
merkleProof' :: Int -> MerkleTree -> [Direction] -> Maybe [(Int, ByteString)]
merkleProof' Int
_ MerkleTree
_ [] = [(Int, ByteString)] -> Maybe [(Int, ByteString)]
forall a. a -> Maybe a
Just []
merkleProof' Int
thisNodeNum (MerkleNode ByteString
_ MerkleTree
left MerkleTree
right) (Direction
d : [Direction]
ds) =
    case Direction
d of
        Direction
TurnLeft ->
            ((Int
rightChildNum, MerkleTree -> ByteString
rootHash MerkleTree
right) (Int, ByteString) -> [(Int, ByteString)] -> [(Int, ByteString)]
forall a. a -> [a] -> [a]
:) ([(Int, ByteString)] -> [(Int, ByteString)])
-> Maybe [(Int, ByteString)] -> Maybe [(Int, ByteString)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> MerkleTree -> [Direction] -> Maybe [(Int, ByteString)]
merkleProof' Int
leftChildNum MerkleTree
left [Direction]
ds
        Direction
TurnRight ->
            ((Int
leftChildNum, MerkleTree -> ByteString
rootHash MerkleTree
left) (Int, ByteString) -> [(Int, ByteString)] -> [(Int, ByteString)]
forall a. a -> [a] -> [a]
:) ([(Int, ByteString)] -> [(Int, ByteString)])
-> Maybe [(Int, ByteString)] -> Maybe [(Int, ByteString)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> MerkleTree -> [Direction] -> Maybe [(Int, ByteString)]
merkleProof' Int
rightChildNum MerkleTree
right [Direction]
ds
  where
    leftChildNum :: Int
leftChildNum = Int
thisNodeNum Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2
    rightChildNum :: Int
rightChildNum = Int
thisNodeNum Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
merkleProof' Int
_ (MerkleLeaf ByteString
_) [Direction]
ds = [Char] -> Maybe [(Int, ByteString)]
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe [(Int, ByteString)])
-> [Char] -> Maybe [(Int, ByteString)]
forall a b. (a -> b) -> a -> b
$ [Direction] -> [Char]
forall a. Show a => a -> [Char]
show [Direction]
ds

{- | Translate a leaf number to a node number.  Leaf numbers are zero indexed
 and identify leaves of a tree from left to right.  Node numbers are one
 indexed and identify nodes of a tree from top to bottom, left to right.
-}
leafNumberToNodeNumber :: MerkleTree -> Int -> Int
leafNumberToNodeNumber :: MerkleTree -> Int -> Int
leafNumberToNodeNumber MerkleTree
tree Int
leafNum = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
leafNum Int -> Int -> Int
forall a. Num a => a -> a -> a
+ MerkleTree -> Int
firstLeafNum MerkleTree
tree

{- | Get a merkle proof but re-number the node numbers to be zero-indexed
 instead of one-indexed.
-}
neededHashes :: MerkleTree -> Int -> Maybe [(Int, B.ByteString)]
neededHashes :: MerkleTree -> Int -> Maybe [(Int, ByteString)]
neededHashes MerkleTree
tree = ([(Int, ByteString)] -> [(Int, ByteString)])
-> Maybe [(Int, ByteString)] -> Maybe [(Int, ByteString)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((Int, ByteString) -> (Int, ByteString))
-> [(Int, ByteString)] -> [(Int, ByteString)]
forall a b. (a -> b) -> [a] -> [b]
map (((Int, ByteString) -> (Int, ByteString))
 -> [(Int, ByteString)] -> [(Int, ByteString)])
-> ((Int, ByteString) -> (Int, ByteString))
-> [(Int, ByteString)]
-> [(Int, ByteString)]
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> (Int, ByteString) -> (Int, ByteString)
forall a c b. (a -> c) -> (a, b) -> (c, b)
mapFst (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1)) (Maybe [(Int, ByteString)] -> Maybe [(Int, ByteString)])
-> (Int -> Maybe [(Int, ByteString)])
-> Int
-> Maybe [(Int, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MerkleTree -> Int -> Maybe [(Int, ByteString)]
merkleProof MerkleTree
tree

{- | Determine the smallest index into the breadth first list for the given
 tree where a leaf may be found.
-}
firstLeafNum :: MerkleTree -> Int
firstLeafNum :: MerkleTree -> Int
firstLeafNum MerkleTree
tree = MerkleTree -> Int
size MerkleTree
tree Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

{- | Serialize a MerkleTree to bytes by concatenating all of the leaf hashes
 left to right.

 This serialization includes no framing so the only thing we can do is
 consume all available input.  Use this instance with `isolate` and bring
 your own framing mechanism to determine how many bytes to process.
-}
instance Binary MerkleTree where
    put :: MerkleTree -> Put
put = ByteString -> Put
putByteString (ByteString -> Put)
-> (MerkleTree -> ByteString) -> MerkleTree -> Put
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString)
-> (MerkleTree -> [ByteString]) -> MerkleTree -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MerkleTree -> [ByteString]
breadthFirstList
    get :: Get MerkleTree
get =
        Get ByteString
getRemainingLazyByteString
            Get ByteString -> (ByteString -> Get MerkleTree) -> Get MerkleTree
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Get MerkleTree
-> (MerkleTree -> Get MerkleTree)
-> Maybe MerkleTree
-> Get MerkleTree
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> Get MerkleTree
forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"could not construct MerkleTree") MerkleTree -> Get MerkleTree
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                (Maybe MerkleTree -> Get MerkleTree)
-> (ByteString -> Maybe MerkleTree) -> ByteString -> Get MerkleTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> Maybe MerkleTree
buildTreeOutOfAllTheNodes
                ([ByteString] -> Maybe MerkleTree)
-> (ByteString -> [ByteString]) -> ByteString -> Maybe MerkleTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ByteString -> [ByteString]
chunkedBy (SHA256 -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize SHA256
SHA256)
                (ByteString -> [ByteString])
-> (ByteString -> ByteString) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.toStrict

-- | Get a list of all of the leaf hashes of a tree from left to right.
leafHashes :: MerkleTree -> [B.ByteString]
leafHashes :: MerkleTree -> [ByteString]
leafHashes (MerkleLeaf ByteString
h) = [ByteString
h]
leafHashes (MerkleNode ByteString
_ MerkleTree
l MerkleTree
r) = MerkleTree -> [ByteString]
leafHashes MerkleTree
l [ByteString] -> [ByteString] -> [ByteString]
forall a. Semigroup a => a -> a -> a
<> MerkleTree -> [ByteString]
leafHashes MerkleTree
r

{- | Make a merkle tree out of a flat list of all nodes (start from
 root, then first two children, etc .. [0, 1, 2] is a two-layer
 tree, [0, 1, 2, 3, 4, 5, 6] is three-layer, etc
-}
buildTreeOutOfAllTheNodes :: [B.ByteString] -> Maybe MerkleTree
buildTreeOutOfAllTheNodes :: [ByteString] -> Maybe MerkleTree
buildTreeOutOfAllTheNodes [ByteString]
nodes
    | [ByteString] -> Bool
forall a. [a] -> Bool
validMerkleSize [ByteString]
nodes = MerkleTree -> Maybe MerkleTree
forall a. a -> Maybe a
Just ([MerkleTree] -> MerkleTree
forall a. [a] -> a
head ([MerkleTree] -> [[ByteString]] -> [MerkleTree]
treeFromRows [] ([Int] -> [ByteString] -> [[ByteString]]
clumpRows [Int]
powersOfTwo [ByteString]
nodes)))
    | Bool
otherwise = Maybe MerkleTree
forall a. Maybe a
Nothing

{- | Increasing consecutive powers of 2 from 2 ^ 0 to the maximum value
 representable in `Int`.
-}
powersOfTwo :: [Int]
powersOfTwo :: [Int]
powersOfTwo = (Int
2 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^) (Int -> Int) -> [Int] -> [Int]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
0 :: Int .. Int
62]

{- | Determine whether a list of nodes is a possible representation of a
 merkle tree.

 It is possible if the number of elements in the list is one less than a
 positive power of 2.
-}
validMerkleSize :: [a] -> Bool
validMerkleSize :: [a] -> Bool
validMerkleSize [a]
nodes =
    [Int] -> Int
forall a. [a] -> a
head ((Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
size') ([Int] -> [Int]
forall a. [a] -> [a]
tail [Int]
powersOfTwo)) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
size'
  where
    size' :: Int
size' = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
nodes Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

{- | Reorganize a flat list of merkle tree node values into a list of lists of
 merkle tree node values.  Each inner list gives the values from left to right
 at a particular height in the tree.  The head of the outer list gives the
 leaves.
-}
clumpRows ::
    -- | The numbers of elements of the flat list to take to make this (the
    -- head) and subsequent (the tail) clumps.
    [Int] ->
    -- | The values of the nodes themselves.
    [B.ByteString] ->
    [[B.ByteString]]
clumpRows :: [Int] -> [ByteString] -> [[ByteString]]
clumpRows [Int]
_ [] = []
clumpRows [] [ByteString]
_ = [Char] -> [[ByteString]]
forall a. HasCallStack => [Char] -> a
error [Char]
"Ran out of clump lengths (too many nodes!)"
clumpRows (Int
p : [Int]
ps) [ByteString]
rows = [Int] -> [ByteString] -> [[ByteString]]
clumpRows [Int]
ps (Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
drop Int
p [ByteString]
rows) [[ByteString]] -> [[ByteString]] -> [[ByteString]]
forall a. [a] -> [a] -> [a]
++ [Int -> [ByteString] -> [ByteString]
forall a. Int -> [a] -> [a]
take Int
p [ByteString]
rows]

-- | Given some children
treeFromRows ::
    -- | Some children to attach to a list of nodes representing the next
    -- shallowest level of the tree.
    [MerkleTree] ->
    -- | The values of the nodes to create at the next shallowest level of the
    -- tree.
    [[B.ByteString]] ->
    -- | The nodes forming the shallowest level of the tree.  If we built a
    -- full tree, there will be exactly one node here.
    [MerkleTree]
-- if we've processed nothing yet, we're on the "all leafs" children row
treeFromRows :: [MerkleTree] -> [[ByteString]] -> [MerkleTree]
treeFromRows [] ([ByteString]
children : [[ByteString]]
rest) = [MerkleTree] -> [[ByteString]] -> [MerkleTree]
treeFromRows (ByteString -> MerkleTree
MerkleLeaf (ByteString -> MerkleTree) -> [ByteString] -> [MerkleTree]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ByteString]
children) [[ByteString]]
rest
-- if we're out of other stuff then we're done
treeFromRows [MerkleTree]
children [] = [MerkleTree]
children
-- with only a single thing in the "rest", we're at the root
treeFromRows [MerkleTree
left, MerkleTree
right] [[ByteString
root]] = [ByteString -> MerkleTree -> MerkleTree -> MerkleTree
MerkleNode ByteString
root MerkleTree
left MerkleTree
right]
-- this recursion is harder to think about: we want to "collect" done
-- stuff from the first argument and build it up into a tree. kind of.
treeFromRows (MerkleTree
left : MerkleTree
right : [MerkleTree]
children) ([ByteString]
row : [[ByteString]]
rest) = [MerkleTree] -> [[ByteString]] -> [MerkleTree]
treeFromRows ([MerkleTree] -> [ByteString] -> [MerkleTree]
mTree (MerkleTree
left MerkleTree -> [MerkleTree] -> [MerkleTree]
forall a. a -> [a] -> [a]
: MerkleTree
right MerkleTree -> [MerkleTree] -> [MerkleTree]
forall a. a -> [a] -> [a]
: [MerkleTree]
children) [ByteString]
row) [[ByteString]]
rest
treeFromRows [MerkleTree]
x [[ByteString]]
y = [Char] -> [MerkleTree]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [MerkleTree]) -> [Char] -> [MerkleTree]
forall a b. (a -> b) -> a -> b
$ [Char]
"treeFromRows not sure what to do with " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [MerkleTree] -> [Char]
forall a. Show a => a -> [Char]
show [MerkleTree]
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [[ByteString]] -> [Char]
forall a. Show a => a -> [Char]
show [[ByteString]]
y

-- this does the "second recursion"; see above -- building out a row
-- of parents from children + parent node content
mTree :: [MerkleTree] -> [B.ByteString] -> [MerkleTree]
mTree :: [MerkleTree] -> [ByteString] -> [MerkleTree]
mTree [MerkleTree
left, MerkleTree
right] [ByteString
head'] = [ByteString -> MerkleTree -> MerkleTree -> MerkleTree
MerkleNode ByteString
head' MerkleTree
left MerkleTree
right]
mTree (MerkleTree
left : MerkleTree
right : [MerkleTree]
more) [ByteString]
row = ByteString -> MerkleTree -> MerkleTree -> MerkleTree
MerkleNode ([ByteString] -> ByteString
forall a. [a] -> a
head [ByteString]
row) MerkleTree
left MerkleTree
right MerkleTree -> [MerkleTree] -> [MerkleTree]
forall a. a -> [a] -> [a]
: [MerkleTree] -> [ByteString] -> [MerkleTree]
mTree [MerkleTree]
more ([ByteString] -> [ByteString]
forall a. [a] -> [a]
tail [ByteString]
row)
mTree [MerkleTree]
x [ByteString]
y = [Char] -> [MerkleTree]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [MerkleTree]) -> [Char] -> [MerkleTree]
forall a b. (a -> b) -> a -> b
$ [Char]
"mTree not sure what to do with " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [MerkleTree] -> [Char]
forall a. Show a => a -> [Char]
show [MerkleTree]
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [ByteString] -> [Char]
forall a. Show a => a -> [Char]
show [ByteString]
y