{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: EVM.Opcode.Labelled
-- Copyright: 2018 Simon Shine
-- Maintainer: Simon Shine <shreddedglory@gmail.com>
-- License: MIT
--
-- This module exposes the 'LabelledOpcode' type for expressing Ethereum VM
-- opcodes with labelled jumps. Plain Ethereum VM Opcodes are not so ergonomic
-- because one has to know the exact byte offset of the target 'JUMPDEST'.
--
-- With 'Opcode' the byte offset is pushed to the stack via 'PUSH', but the
-- offset to the 'JUMPDEST' depends on all occurrences of 'PUSH' prior to
-- the label, including the 'PUSH' to the label itself.

module EVM.Opcode.Labelled
  ( Label
  , LabelledOpcode
  , TranslateError(..)
  , translate
  , labelPositions
  ) where

import           Data.Function (fix)
import           Data.List (group, sort, foldl')
import qualified Data.Map as Map
import           Data.Map (Map)
import           Data.Maybe (mapMaybe)
import qualified Data.Set as Set
import           Data.Text (Text)

import           EVM.Opcode (Opcode'(..), opcodeSize, jumpdest, concrete, jumpAnnot, jumpdestAnnot)
import           EVM.Opcode.Positional (Position, PositionalOpcode, jumpSize)
import           EVM.Opcode.Traversal (OpcodeMapper(..), mapOpcodeM)

-- | For now, all labels are 'Text'.
type Label = Text

-- | 'LabelledOpcode's use 'Label' to represent jumps.
--
-- In particular, @'JUMP' "name"@, @'JUMPI' "name"@ and @'JUMPDEST' "name"@.
--
-- All other opcodes remain the same.
type LabelledOpcode = Opcode' Label

-- | Translation of 'LabelledOpcode's into 'PositionalOpcode's may fail if
-- a jump is made to a non-occurring 'JUMPDEST' or a 'JUMPDEST' occurs twice.

data TranslateError = TranslateError
  { TranslateError -> [Label]
translateErrorMissingJumpdests   :: [Label]
  , TranslateError -> [Label]
translateErrorDuplicateJumpdests :: [Label]
  } deriving (TranslateError -> TranslateError -> Bool
(TranslateError -> TranslateError -> Bool)
-> (TranslateError -> TranslateError -> Bool) -> Eq TranslateError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TranslateError -> TranslateError -> Bool
$c/= :: TranslateError -> TranslateError -> Bool
== :: TranslateError -> TranslateError -> Bool
$c== :: TranslateError -> TranslateError -> Bool
Eq, Int -> TranslateError -> ShowS
[TranslateError] -> ShowS
TranslateError -> String
(Int -> TranslateError -> ShowS)
-> (TranslateError -> String)
-> ([TranslateError] -> ShowS)
-> Show TranslateError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TranslateError] -> ShowS
$cshowList :: [TranslateError] -> ShowS
show :: TranslateError -> String
$cshow :: TranslateError -> String
showsPrec :: Int -> TranslateError -> ShowS
$cshowsPrec :: Int -> TranslateError -> ShowS
Show)

-- | Replace all labels with absolute positions.
--
-- Positions are calculated by fixed-point iteration to account for variable
-- sizes of jumps. Labelled jumps don't have a size defined, the size of a
-- positional jump depends on the address being jumped to.
--
-- For example, if jumping to the 'JUMPDEST' on the 256th position in a
-- @['LabelledOpcode']@, this requires a 'PUSH2' instruction which uses an
-- additional byte, which pushes the 'JUMPDEST' one byte ahead.

translate :: [LabelledOpcode] -> Either TranslateError [PositionalOpcode]
translate :: [LabelledOpcode] -> Either TranslateError [PositionalOpcode]
translate [LabelledOpcode]
opcodes = do
  Map Label Position
labelMap <- [LabelledOpcode] -> Either TranslateError (Map Label Position)
labelPositions [LabelledOpcode]
opcodes
  (LabelledOpcode -> Either TranslateError PositionalOpcode)
-> [LabelledOpcode] -> Either TranslateError [PositionalOpcode]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Map Label Position
-> LabelledOpcode -> Either TranslateError PositionalOpcode
forall b.
Map Label b -> LabelledOpcode -> Either TranslateError (Opcode' b)
replaceLabel Map Label Position
labelMap) [LabelledOpcode]
opcodes
  where
    replaceLabel :: Map Label b -> LabelledOpcode -> Either TranslateError (Opcode' b)
replaceLabel = OpcodeMapper (Either TranslateError) Label b
-> LabelledOpcode -> Either TranslateError (Opcode' b)
forall (m :: * -> *) a b.
Monad m =>
OpcodeMapper m a b -> Opcode' a -> m (Opcode' b)
mapOpcodeM (OpcodeMapper (Either TranslateError) Label b
 -> LabelledOpcode -> Either TranslateError (Opcode' b))
-> (Map Label b -> OpcodeMapper (Either TranslateError) Label b)
-> Map Label b
-> LabelledOpcode
-> Either TranslateError (Opcode' b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Label -> Either TranslateError b)
-> OpcodeMapper (Either TranslateError) Label b
forall (f :: * -> *) a j.
Applicative f =>
(a -> f j) -> OpcodeMapper f a j
jumpMapper ((Label -> Either TranslateError b)
 -> OpcodeMapper (Either TranslateError) Label b)
-> (Map Label b -> Label -> Either TranslateError b)
-> Map Label b
-> OpcodeMapper (Either TranslateError) Label b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Label b -> Label -> Either TranslateError b
forall b. Map Label b -> Label -> Either TranslateError b
lookupLabel

    -- Apply @f@ to the parameter of 'JUMP's, 'JUMPI's and 'JUMPDEST's
    jumpMapper :: (a -> f j) -> OpcodeMapper f a j
jumpMapper a -> f j
f = OpcodeMapper :: forall (m :: * -> *) a b.
(a -> m (Opcode' b))
-> (a -> m (Opcode' b))
-> (a -> m (Opcode' b))
-> (Opcode' a -> m (Maybe (Opcode' b)))
-> OpcodeMapper m a b
OpcodeMapper
      { mapOnJump :: a -> f (Opcode' j)
mapOnJump = (j -> Opcode' j) -> f j -> f (Opcode' j)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap j -> Opcode' j
forall j. j -> Opcode' j
JUMP (f j -> f (Opcode' j)) -> (a -> f j) -> a -> f (Opcode' j)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f j
f
      , mapOnJumpi :: a -> f (Opcode' j)
mapOnJumpi = (j -> Opcode' j) -> f j -> f (Opcode' j)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap j -> Opcode' j
forall j. j -> Opcode' j
JUMPI (f j -> f (Opcode' j)) -> (a -> f j) -> a -> f (Opcode' j)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f j
f
      , mapOnJumpdest :: a -> f (Opcode' j)
mapOnJumpdest = (j -> Opcode' j) -> f j -> f (Opcode' j)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap j -> Opcode' j
forall j. j -> Opcode' j
JUMPDEST (f j -> f (Opcode' j)) -> (a -> f j) -> a -> f (Opcode' j)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> f j
f
      , mapOnOther :: Opcode' a -> f (Maybe (Opcode' j))
mapOnOther = f (Maybe (Opcode' j)) -> Opcode' a -> f (Maybe (Opcode' j))
forall a b. a -> b -> a
const (Maybe (Opcode' j) -> f (Maybe (Opcode' j))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Opcode' j)
forall a. Maybe a
Nothing)
      }

    -- Let @f@ be @'lookupLabel' labelMap@.
    lookupLabel :: Map Label b -> Label -> Either TranslateError b
lookupLabel Map Label b
labelMap Label
label =
      case Label -> Map Label b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Label
label Map Label b
labelMap of
        Just b
pos -> b -> Either TranslateError b
forall a b. b -> Either a b
Right b
pos
        Maybe b
Nothing -> TranslateError -> Either TranslateError b
forall a b. a -> Either a b
Left ([Label] -> [Label] -> TranslateError
TranslateError [Label
label] [])


-- | Extract a @'Map' 'Label' 'Position'@ that describes where each 'JUMPDEST'
-- is located, taking into account the sizes of all prior opcodes.

labelPositions :: [LabelledOpcode] -> Either TranslateError (Map Label Position)
labelPositions :: [LabelledOpcode] -> Either TranslateError (Map Label Position)
labelPositions [LabelledOpcode]
opcodes
  | [Label] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Label]
wildJumps Bool -> Bool -> Bool
&& [Label] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Label]
duplicateDests = Map Label Position -> Either TranslateError (Map Label Position)
forall a b. b -> Either a b
Right ([LabelledOpcode] -> Map Label Position
fixpoint [LabelledOpcode]
opcodes)
  | Bool
otherwise = TranslateError -> Either TranslateError (Map Label Position)
forall a b. a -> Either a b
Left ([Label] -> [Label] -> TranslateError
TranslateError [Label]
wildJumps [Label]
duplicateDests)
  where
    wildJumps :: [Label]
    wildJumps :: [Label]
wildJumps = [Label]
jumps [Label] -> [Label] -> [Label]
forall a. Ord a => [a] -> [a] -> [a]
`missing` [Label]
dests

    duplicateDests :: [Label]
    duplicateDests :: [Label]
duplicateDests = [Label] -> [Label]
forall a. Ord a => [a] -> [a]
duplicate [Label]
dests

    jumps :: [Label]
    jumps :: [Label]
jumps = (LabelledOpcode -> Maybe Label) -> [LabelledOpcode] -> [Label]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LabelledOpcode -> Maybe Label
forall a. Opcode' a -> Maybe a
jumpAnnot [LabelledOpcode]
opcodes

    dests :: [Label]
    dests :: [Label]
dests = (LabelledOpcode -> Maybe Label) -> [LabelledOpcode] -> [Label]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe LabelledOpcode -> Maybe Label
forall a. Opcode' a -> Maybe a
jumpdestAnnot [LabelledOpcode]
opcodes

    missing :: Ord a => [a] -> [a] -> [a]
    missing :: [a] -> [a] -> [a]
missing [a]
xs [a]
ys = Set a -> [a]
forall a. Set a -> [a]
Set.toList (Set a -> Set a -> Set a
forall a. Ord a => Set a -> Set a -> Set a
Set.difference ([a] -> Set a
forall a. Ord a => [a] -> Set a
Set.fromList [a]
xs) ([a] -> Set a
forall a. Ord a => [a] -> Set a
Set.fromList [a]
ys))

    duplicate :: Ord a => [a] -> [a]
    duplicate :: [a] -> [a]
duplicate = ([a] -> [a]) -> [[a]] -> [a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
1) ([[a]] -> [a]) -> ([a] -> [[a]]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([a] -> Bool) -> [[a]] -> [[a]]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1) (Int -> Bool) -> ([a] -> Int) -> [a] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) ([[a]] -> [[a]]) -> ([a] -> [[a]]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [[a]]
forall a. Eq a => [a] -> [[a]]
group ([a] -> [[a]]) -> ([a] -> [a]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [a]
forall a. Ord a => [a] -> [a]
sort

-- | Extract a 'Map' the position of every 'JUMPDEST'.
--
-- Do this by keeping track of the current position.
--
-- This function may not terminate for all inputs!

fixpoint :: [LabelledOpcode] -> Map Label Position
fixpoint :: [LabelledOpcode] -> Map Label Position
fixpoint [LabelledOpcode]
opcodes = (((Map Label Position -> Map Label Position)
  -> Map Label Position -> Map Label Position)
 -> Map Label Position -> Map Label Position)
-> Map Label Position
-> ((Map Label Position -> Map Label Position)
    -> Map Label Position -> Map Label Position)
-> Map Label Position
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Map Label Position -> Map Label Position)
 -> Map Label Position -> Map Label Position)
-> Map Label Position -> Map Label Position
forall a. (a -> a) -> a
fix Map Label Position
forall k a. Map k a
Map.empty (((Map Label Position -> Map Label Position)
  -> Map Label Position -> Map Label Position)
 -> Map Label Position)
-> ((Map Label Position -> Map Label Position)
    -> Map Label Position -> Map Label Position)
-> Map Label Position
forall a b. (a -> b) -> a -> b
$ \Map Label Position -> Map Label Position
go Map Label Position
labelMap ->
  case Map Label Position
-> [LabelledOpcode] -> (Bool, Position, Map Label Position)
step Map Label Position
labelMap [LabelledOpcode]
opcodes of
    (Bool
True, Position
_, Map Label Position
labelMap') -> Map Label Position
labelMap'
    (Bool
False, Position
_, Map Label Position
labelMap') -> Map Label Position -> Map Label Position
go Map Label Position
labelMap'

-- | A single step in the fixpoint function is going over every opcode and
-- checking if its position is already aligned and updating the map of
-- positions otherwise.
step :: Map Label Position
     -> [LabelledOpcode]
     -> (Bool, Position, Map Label Position)
step :: Map Label Position
-> [LabelledOpcode] -> (Bool, Position, Map Label Position)
step Map Label Position
labelMap = ((Bool, Position, Map Label Position)
 -> LabelledOpcode -> (Bool, Position, Map Label Position))
-> (Bool, Position, Map Label Position)
-> [LabelledOpcode]
-> (Bool, Position, Map Label Position)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Bool, Position, Map Label Position)
-> LabelledOpcode -> (Bool, Position, Map Label Position)
align (Bool
True, Position
0, Map Label Position
labelMap)

align :: (Bool, Position, Map Label Position)
      -> LabelledOpcode
      -> (Bool, Position, Map Label Position)

-- When encountering a 'JUMPDEST' that was either not seen before, or was seen
-- at another offset, it hasn't been aligned. In that case, update the 'Map'
-- and signal that another iteration of 'fixpoint' is necessary.
align :: (Bool, Position, Map Label Position)
-> LabelledOpcode -> (Bool, Position, Map Label Position)
align (Bool
done, Position
currentBytePos, Map Label Position
labelMap) (JUMPDEST Label
label) =
  let aligned :: Bool
aligned = Label -> Map Label Position -> Maybe Position
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Label
label Map Label Position
labelMap Maybe Position -> Maybe Position -> Bool
forall a. Eq a => a -> a -> Bool
== Position -> Maybe Position
forall a. a -> Maybe a
Just Position
currentBytePos
  in ( Bool
done Bool -> Bool -> Bool
&& Bool
aligned
     , Position
currentBytePos Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Opcode -> Position
forall i. Num i => Opcode -> i
opcodeSize Opcode
jumpdest
     , Label -> Position -> Map Label Position -> Map Label Position
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Label
label Position
currentBytePos Map Label Position
labelMap
     )

-- When encountering a 'JUMP' or a 'JUMPI', check if the destination 'JUMPDEST'
-- was seen before. If so, increment the running offset with the size of a jump
-- to that 'JUMPDEST'. If not, the offset is still approximate.
align (Bool
done, Position
currentBytePos, Map Label Position
labelMap) (LabelledOpcode -> Maybe Label
forall a. Opcode' a -> Maybe a
jumpAnnot -> Just Label
label) =
  case Label -> Map Label Position -> Maybe Position
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Label
label Map Label Position
labelMap of
    Just Position
bytePos -> ( Bool
done, Position
currentBytePos Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position -> Position
forall i. Num i => Position -> i
jumpSize Position
bytePos, Map Label Position
labelMap )
    Maybe Position
Nothing      -> ( Bool
False, Position
currentBytePos Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Position -> Position
forall i. Num i => Position -> i
jumpSize Position
0,      Map Label Position
labelMap )

-- For any straight-line opcode, just increment the offset with its size.
align (Bool
done, Position
currentBytePos, Map Label Position
labelMap) LabelledOpcode
opcode =
  ( Bool
done, Position
currentBytePos Position -> Position -> Position
forall a. Num a => a -> a -> a
+ Opcode -> Position
forall i. Num i => Opcode -> i
opcodeSize (LabelledOpcode -> Opcode
forall a. Opcode' a -> Opcode
concrete LabelledOpcode
opcode), Map Label Position
labelMap )