{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: EVM.Opcode.Labelled
-- Copyright: 2018 Simon Shine
-- Maintainer: Simon Shine <shreddedglory@gmail.com>
-- License: MIT
--
-- This module exposes the 'LabelledOpcode' type for expressing Ethereum VM
-- opcodes with labelled jumps. Plain Ethereum VM Opcodes are not so ergonomic
-- because one has to know the exact byte offset of the target 'JUMPDEST'.
--
-- With 'Opcode' the byte offset is pushed to the stack via 'PUSH', but the
-- offset to the 'JUMPDEST' depends on all occurrences of 'PUSH' prior to
-- the label, including the 'PUSH' to the label itself.

module EVM.Opcode.Labelled
  ( Label
  , LabelledOpcode
  , TranslateError(..)
  , translate
  , labelPositions
  ) where

import           Data.Function (fix)
import           Data.List (group, sort, foldl')
import qualified Data.Map as Map
import           Data.Map (Map)
import           Data.Maybe (mapMaybe)
import qualified Data.Set as Set
import           Data.Text (Text)

import           EVM.Opcode (Opcode'(..), opcodeSize, jumpdest, concrete, jumpAnnot, jumpdestAnnot)
import           EVM.Opcode.Positional (Position, PositionalOpcode, jumpSize)
import           EVM.Opcode.Traversal (OpcodeMapper(..), mapOpcodeM)

-- | For now, all labels are 'Text'.
type Label = Text

-- | 'LabelledOpcode's use 'Label' to represent jumps.
--
-- In particular, @'JUMP' "name"@, @'JUMPI' "name"@ and @'JUMPDEST' "name"@.
--
-- All other opcodes remain the same.
type LabelledOpcode = Opcode' Label

-- | Translation of 'LabelledOpcode's into 'PositionalOpcode's may fail if
-- a jump is made to a non-occurring 'JUMPDEST' or a 'JUMPDEST' occurs twice.

data TranslateError = TranslateError
  { TranslateError -> [Label]
translateErrorMissingJumpdests   :: [Label]
  , TranslateError -> [Label]
translateErrorDuplicateJumpdests :: [Label]
  } deriving (TranslateError -> TranslateError -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TranslateError -> TranslateError -> Bool
$c/= :: TranslateError -> TranslateError -> Bool
== :: TranslateError -> TranslateError -> Bool
$c== :: TranslateError -> TranslateError -> Bool
Eq, Int -> TranslateError -> ShowS
[TranslateError] -> ShowS
TranslateError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TranslateError] -> ShowS
$cshowList :: [TranslateError] -> ShowS
show :: TranslateError -> String
$cshow :: TranslateError -> String
showsPrec :: Int -> TranslateError -> ShowS
$cshowsPrec :: Int -> TranslateError -> ShowS
Show)

-- | Replace all labels with absolute positions.
--
-- Positions are calculated by fixed-point iteration to account for variable
-- sizes of jumps. Labelled jumps don't have a size defined, the size of a
-- positional jump depends on the address being jumped to.
--
-- For example, if jumping to the 'JUMPDEST' on the 256th position in a
-- @['LabelledOpcode']@, this requires a 'PUSH2' instruction which uses an
-- additional byte, which pushes the 'JUMPDEST' one byte ahead.

translate :: [LabelledOpcode] -> Either TranslateError [PositionalOpcode]
translate :: [LabelledOpcode] -> Either TranslateError [PositionalOpcode]
translate [LabelledOpcode]
opcodes = do
  Map Label Position
labelMap <- [LabelledOpcode] -> Either TranslateError (Map Label Position)
labelPositions [LabelledOpcode]
opcodes
  forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (forall {b}.
Map Label b -> LabelledOpcode -> Either TranslateError (Opcode' b)
replaceLabel Map Label Position
labelMap) [LabelledOpcode]
opcodes
  where
    replaceLabel :: Map Label b -> LabelledOpcode -> Either TranslateError (Opcode' b)
replaceLabel = forall (m :: * -> *) a b.
Monad m =>
OpcodeMapper m a b -> Opcode' a -> m (Opcode' b)
mapOpcodeM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {m :: * -> *} {a} {b}.
Applicative m =>
(a -> m b) -> OpcodeMapper m a b
jumpMapper forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {b}. Map Label b -> Label -> Either TranslateError b
lookupLabel

    -- Apply @f@ to the parameter of 'JUMP's, 'JUMPI's and 'JUMPDEST's
    jumpMapper :: (a -> m b) -> OpcodeMapper m a b
jumpMapper a -> m b
f = OpcodeMapper
      { mapOnJump :: a -> m (Opcode' b)
mapOnJump = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall j. j -> Opcode' j
JUMP forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
f
      , mapOnJumpi :: a -> m (Opcode' b)
mapOnJumpi = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall j. j -> Opcode' j
JUMPI forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
f
      , mapOnJumpdest :: a -> m (Opcode' b)
mapOnJumpdest = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall j. j -> Opcode' j
JUMPDEST forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m b
f
      , mapOnOther :: Opcode' a -> m (Maybe (Opcode' b))
mapOnOther = forall a b. a -> b -> a
const (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing)
      }

    -- Let @f@ be @'lookupLabel' labelMap@.
    lookupLabel :: Map Label b -> Label -> Either TranslateError b
lookupLabel Map Label b
labelMap Label
label =
      case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Label
label Map Label b
labelMap of
        Just b
pos -> forall a b. b -> Either a b
Right b
pos
        Maybe b
Nothing -> forall a b. a -> Either a b
Left ([Label] -> [Label] -> TranslateError
TranslateError [Label
label] [])


-- | Extract a @'Map' 'Label' 'Position'@ that describes where each 'JUMPDEST'
-- is located, taking into account the sizes of all prior opcodes.

labelPositions :: [LabelledOpcode] -> Either TranslateError (Map Label Position)
labelPositions :: [LabelledOpcode] -> Either TranslateError (Map Label Position)
labelPositions [LabelledOpcode]
opcodes
  | forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Label]
wildJumps Bool -> Bool -> Bool
&& forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Label]
duplicateDests = forall a b. b -> Either a b
Right ([LabelledOpcode] -> Map Label Position
fixpoint [LabelledOpcode]
opcodes)
  | Bool
otherwise = forall a b. a -> Either a b
Left ([Label] -> [Label] -> TranslateError
TranslateError [Label]
wildJumps [Label]
duplicateDests)
  where
    wildJumps :: [Label]
    wildJumps :: [Label]
wildJumps = [Label]
jumps forall a. Ord a => [a] -> [a] -> [a]
`missing` [Label]
dests

    duplicateDests :: [Label]
    duplicateDests :: [Label]
duplicateDests = forall a. Ord a => [a] -> [a]
duplicate [Label]
dests

    jumps :: [Label]
    jumps :: [Label]
jumps = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall a. Opcode' a -> Maybe a
jumpAnnot [LabelledOpcode]
opcodes

    dests :: [Label]
    dests :: [Label]
dests = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall a. Opcode' a -> Maybe a
jumpdestAnnot [LabelledOpcode]
opcodes

    missing :: Ord a => [a] -> [a] -> [a]
    missing :: forall a. Ord a => [a] -> [a] -> [a]
missing [a]
xs [a]
ys = forall a. Set a -> [a]
Set.toList (forall a. Ord a => Set a -> Set a -> Set a
Set.difference (forall a. Ord a => [a] -> Set a
Set.fromList [a]
xs) (forall a. Ord a => [a] -> Set a
Set.fromList [a]
ys))

    duplicate :: Ord a => [a] -> [a]
    duplicate :: forall a. Ord a => [a] -> [a]
duplicate = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall a. Int -> [a] -> [a]
take Int
1) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter ((forall a. Ord a => a -> a -> Bool
> Int
1) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => [a] -> [[a]]
group forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Ord a => [a] -> [a]
sort

-- | Extract a 'Map' the position of every 'JUMPDEST'.
--
-- Do this by keeping track of the current position.
--
-- This function may not terminate for all inputs!

fixpoint :: [LabelledOpcode] -> Map Label Position
fixpoint :: [LabelledOpcode] -> Map Label Position
fixpoint [LabelledOpcode]
opcodes = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. (a -> a) -> a
fix forall k a. Map k a
Map.empty forall a b. (a -> b) -> a -> b
$ \Map Label Position -> Map Label Position
go Map Label Position
labelMap ->
  case Map Label Position
-> [LabelledOpcode] -> (Bool, Position, Map Label Position)
step Map Label Position
labelMap [LabelledOpcode]
opcodes of
    (Bool
True, Position
_, Map Label Position
labelMap') -> Map Label Position
labelMap'
    (Bool
False, Position
_, Map Label Position
labelMap') -> Map Label Position -> Map Label Position
go Map Label Position
labelMap'

-- | A single step in the fixpoint function is going over every opcode and
-- checking if its position is already aligned and updating the map of
-- positions otherwise.
step :: Map Label Position
     -> [LabelledOpcode]
     -> (Bool, Position, Map Label Position)
step :: Map Label Position
-> [LabelledOpcode] -> (Bool, Position, Map Label Position)
step Map Label Position
labelMap = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Bool, Position, Map Label Position)
-> LabelledOpcode -> (Bool, Position, Map Label Position)
align (Bool
True, Position
0, Map Label Position
labelMap)

align :: (Bool, Position, Map Label Position)
      -> LabelledOpcode
      -> (Bool, Position, Map Label Position)

-- When encountering a 'JUMPDEST' that was either not seen before, or was seen
-- at another offset, it hasn't been aligned. In that case, update the 'Map'
-- and signal that another iteration of 'fixpoint' is necessary.
align :: (Bool, Position, Map Label Position)
-> LabelledOpcode -> (Bool, Position, Map Label Position)
align (Bool
done, Position
currentBytePos, Map Label Position
labelMap) (JUMPDEST Label
label) =
  let aligned :: Bool
aligned = forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Label
label Map Label Position
labelMap forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Position
currentBytePos
  in ( Bool
done Bool -> Bool -> Bool
&& Bool
aligned
     , Position
currentBytePos forall a. Num a => a -> a -> a
+ forall i. Num i => Opcode -> i
opcodeSize Opcode
jumpdest
     , forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Label
label Position
currentBytePos Map Label Position
labelMap
     )

-- When encountering a 'JUMP' or a 'JUMPI', check if the destination 'JUMPDEST'
-- was seen before. If so, increment the running offset with the size of a jump
-- to that 'JUMPDEST'. If not, the offset is still approximate.
align (Bool
done, Position
currentBytePos, Map Label Position
labelMap) (forall a. Opcode' a -> Maybe a
jumpAnnot -> Just Label
label) =
  case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Label
label Map Label Position
labelMap of
    Just Position
bytePos -> ( Bool
done, Position
currentBytePos forall a. Num a => a -> a -> a
+ forall i. Num i => Position -> i
jumpSize Position
bytePos, Map Label Position
labelMap )
    Maybe Position
Nothing      -> ( Bool
False, Position
currentBytePos forall a. Num a => a -> a -> a
+ forall i. Num i => Position -> i
jumpSize Position
0,      Map Label Position
labelMap )

-- For any straight-line opcode, just increment the offset with its size.
align (Bool
done, Position
currentBytePos, Map Label Position
labelMap) LabelledOpcode
opcode =
  ( Bool
done, Position
currentBytePos forall a. Num a => a -> a -> a
+ forall i. Num i => Opcode -> i
opcodeSize (forall a. Opcode' a -> Opcode
concrete LabelledOpcode
opcode), Map Label Position
labelMap )