{-# LANGUAGE DataKinds #-}

{- |
    Module: EVM.CSE
    Description: Common subexpression elimination for Expr ast
-}

module EVM.CSE (BufEnv, StoreEnv, eliminateExpr, eliminateProps) where

import Control.Monad.State
import Data.Map (Map)
import Data.Map qualified as Map

import EVM.Traversals
import EVM.Types

-- maps expressions to variable names
data BuilderState = BuilderState
  { BuilderState -> Map (Expr 'Buf) Int
bufs :: Map (Expr Buf) Int
  , BuilderState -> Map (Expr 'Storage) Int
stores :: Map (Expr Storage) Int
  , BuilderState -> Int
count :: Int
  }
  deriving (Int -> BuilderState -> ShowS
[BuilderState] -> ShowS
BuilderState -> String
(Int -> BuilderState -> ShowS)
-> (BuilderState -> String)
-> ([BuilderState] -> ShowS)
-> Show BuilderState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> BuilderState -> ShowS
showsPrec :: Int -> BuilderState -> ShowS
$cshow :: BuilderState -> String
show :: BuilderState -> String
$cshowList :: [BuilderState] -> ShowS
showList :: [BuilderState] -> ShowS
Show)

type BufEnv = Map Int (Expr Buf)
type StoreEnv = Map Int (Expr Storage)

initState :: BuilderState
initState :: BuilderState
initState = BuilderState
  { $sel:bufs:BuilderState :: Map (Expr 'Buf) Int
bufs = Map (Expr 'Buf) Int
forall a. Monoid a => a
mempty
  , $sel:stores:BuilderState :: Map (Expr 'Storage) Int
stores = Map (Expr 'Storage) Int
forall a. Monoid a => a
mempty
  , $sel:count:BuilderState :: Int
count = Int
0
  }


go :: Expr a -> State BuilderState (Expr a)
go :: forall (a :: EType). Expr a -> State BuilderState (Expr a)
go = \case
  -- buffers
  e :: Expr a
e@(WriteWord {}) -> do
    BuilderState
s <- StateT BuilderState Identity BuilderState
forall s (m :: * -> *). MonadState s m => m s
get
    case Expr a -> Map (Expr a) Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Expr a
e BuilderState
s.bufs of
      Just Int
v -> Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Buf
BufVar Int
v)
      Maybe Int
Nothing -> do
        let
          next :: Int
next = BuilderState
s.count
          bs' :: Map (Expr a) Int
bs' = Expr a -> Int -> Map (Expr a) Int -> Map (Expr a) Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Expr a
e Int
next BuilderState
s.bufs
        BuilderState -> StateT BuilderState Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (BuilderState -> StateT BuilderState Identity ())
-> BuilderState -> StateT BuilderState Identity ()
forall a b. (a -> b) -> a -> b
$ BuilderState
s{$sel:bufs:BuilderState :: Map (Expr 'Buf) Int
bufs=Map (Expr a) Int
Map (Expr 'Buf) Int
bs', $sel:count:BuilderState :: Int
count=Int
nextInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1}
        Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Buf
BufVar Int
next)
  e :: Expr a
e@(WriteByte {}) -> do
    BuilderState
s <- StateT BuilderState Identity BuilderState
forall s (m :: * -> *). MonadState s m => m s
get
    case Expr a -> Map (Expr a) Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Expr a
e BuilderState
s.bufs of
      Just Int
v -> Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Buf
BufVar Int
v)
      Maybe Int
Nothing -> do
        let
          next :: Int
next = BuilderState
s.count
          bs' :: Map (Expr a) Int
bs' = Expr a -> Int -> Map (Expr a) Int -> Map (Expr a) Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Expr a
e Int
next BuilderState
s.bufs
        BuilderState -> StateT BuilderState Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (BuilderState -> StateT BuilderState Identity ())
-> BuilderState -> StateT BuilderState Identity ()
forall a b. (a -> b) -> a -> b
$ BuilderState
s{$sel:bufs:BuilderState :: Map (Expr 'Buf) Int
bufs=Map (Expr a) Int
Map (Expr 'Buf) Int
bs', $sel:count:BuilderState :: Int
count=Int
nextInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1}
        Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Buf
BufVar Int
next)
  e :: Expr a
e@(CopySlice {}) -> do
    BuilderState
s <- StateT BuilderState Identity BuilderState
forall s (m :: * -> *). MonadState s m => m s
get
    case Expr a -> Map (Expr a) Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Expr a
e BuilderState
s.bufs of
      Just Int
v -> Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Buf
BufVar Int
v)
      Maybe Int
Nothing -> do
        let
          next :: Int
next = BuilderState
s.count
          bs' :: Map (Expr a) Int
bs' = Expr a -> Int -> Map (Expr a) Int -> Map (Expr a) Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Expr a
e Int
next BuilderState
s.bufs
        BuilderState -> StateT BuilderState Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (BuilderState -> StateT BuilderState Identity ())
-> BuilderState -> StateT BuilderState Identity ()
forall a b. (a -> b) -> a -> b
$ BuilderState
s{$sel:count:BuilderState :: Int
count=Int
nextInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, $sel:bufs:BuilderState :: Map (Expr 'Buf) Int
bufs=Map (Expr a) Int
Map (Expr 'Buf) Int
bs'}
        Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Buf
BufVar Int
next)
  -- storage
  e :: Expr a
e@(SStore {}) -> do
    BuilderState
s <- StateT BuilderState Identity BuilderState
forall s (m :: * -> *). MonadState s m => m s
get
    case Expr a -> Map (Expr a) Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Expr a
e BuilderState
s.stores of
      Just Int
v -> Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Storage
StoreVar Int
v)
      Maybe Int
Nothing -> do
        let
          next :: Int
next = BuilderState
s.count
          ss' :: Map (Expr a) Int
ss' = Expr a -> Int -> Map (Expr a) Int -> Map (Expr a) Int
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Expr a
e Int
next BuilderState
s.stores
        BuilderState -> StateT BuilderState Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (BuilderState -> StateT BuilderState Identity ())
-> BuilderState -> StateT BuilderState Identity ()
forall a b. (a -> b) -> a -> b
$ BuilderState
s{$sel:count:BuilderState :: Int
count=Int
nextInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, $sel:stores:BuilderState :: Map (Expr 'Storage) Int
stores=Map (Expr a) Int
Map (Expr 'Storage) Int
ss'}
        Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr a -> State BuilderState (Expr a))
-> Expr a -> State BuilderState (Expr a)
forall a b. (a -> b) -> a -> b
$ GVar a -> Expr a
forall (a :: EType). GVar a -> Expr a
GVar (Int -> GVar 'Storage
StoreVar Int
next)
  Expr a
e -> Expr a -> State BuilderState (Expr a)
forall a. a -> StateT BuilderState Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr a
e

invertKeyVal :: forall a. Map a Int -> Map Int a
invertKeyVal :: forall a. Map a Int -> Map Int a
invertKeyVal =  [(Int, a)] -> Map Int a
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Int, a)] -> Map Int a)
-> (Map a Int -> [(Int, a)]) -> Map a Int -> Map Int a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Int) -> (Int, a)) -> [(a, Int)] -> [(Int, a)]
forall a b. (a -> b) -> [a] -> [b]
map (\(a
x, Int
y) -> (Int
y, a
x)) ([(a, Int)] -> [(Int, a)])
-> (Map a Int -> [(a, Int)]) -> Map a Int -> [(Int, a)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map a Int -> [(a, Int)]
forall k a. Map k a -> [(k, a)]
Map.toList

-- | Common subexpression elimination pass for Expr
eliminateExpr' :: Expr a -> State BuilderState (Expr a)
eliminateExpr' :: forall (a :: EType). Expr a -> State BuilderState (Expr a)
eliminateExpr' Expr a
e = (forall (a :: EType). Expr a -> State BuilderState (Expr a))
-> Expr a -> StateT BuilderState Identity (Expr a)
forall (m :: * -> *) (b :: EType).
Monad m =>
(forall (a :: EType). Expr a -> m (Expr a)) -> Expr b -> m (Expr b)
mapExprM Expr a -> State BuilderState (Expr a)
forall (a :: EType). Expr a -> State BuilderState (Expr a)
go Expr a
e

eliminateExpr :: Expr a -> (Expr a, BufEnv, StoreEnv)
eliminateExpr :: forall (a :: EType). Expr a -> (Expr a, BufEnv, StoreEnv)
eliminateExpr Expr a
e =
  let (Expr a
e', BuilderState
st) = State BuilderState (Expr a)
-> BuilderState -> (Expr a, BuilderState)
forall s a. State s a -> s -> (a, s)
runState (Expr a -> State BuilderState (Expr a)
forall (a :: EType). Expr a -> State BuilderState (Expr a)
eliminateExpr' Expr a
e) BuilderState
initState in
  (Expr a
e', Map (Expr 'Buf) Int -> BufEnv
forall a. Map a Int -> Map Int a
invertKeyVal BuilderState
st.bufs, Map (Expr 'Storage) Int -> StoreEnv
forall a. Map a Int -> Map Int a
invertKeyVal BuilderState
st.stores)

-- | Common subexpression elimination pass for Prop
eliminateProp' :: Prop -> State BuilderState Prop
eliminateProp' :: Prop -> State BuilderState Prop
eliminateProp' Prop
prop = (forall (a :: EType). Expr a -> State BuilderState (Expr a))
-> Prop -> State BuilderState Prop
forall (m :: * -> *).
Monad m =>
(forall (a :: EType). Expr a -> m (Expr a)) -> Prop -> m Prop
mapPropM Expr a -> State BuilderState (Expr a)
forall (a :: EType). Expr a -> State BuilderState (Expr a)
go Prop
prop

-- | Common subexpression elimination pass for list of Prop
eliminateProps' :: [Prop] -> State BuilderState [Prop]
eliminateProps' :: [Prop] -> State BuilderState [Prop]
eliminateProps' [Prop]
props = (Prop -> State BuilderState Prop)
-> [Prop] -> State BuilderState [Prop]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Prop -> State BuilderState Prop
eliminateProp' [Prop]
props


-- | Common subexpression elimination pass for list of Prop
eliminateProps :: [Prop] -> ([Prop], BufEnv, StoreEnv)
eliminateProps :: [Prop] -> ([Prop], BufEnv, StoreEnv)
eliminateProps [Prop]
props =
  let ([Prop]
props', BuilderState
st) = State BuilderState [Prop] -> BuilderState -> ([Prop], BuilderState)
forall s a. State s a -> s -> (a, s)
runState ([Prop] -> State BuilderState [Prop]
eliminateProps' [Prop]
props) BuilderState
initState in
  ([Prop]
props',  Map (Expr 'Buf) Int -> BufEnv
forall a. Map a Int -> Map Int a
invertKeyVal BuilderState
st.bufs,  Map (Expr 'Storage) Int -> StoreEnv
forall a. Map a Int -> Map Int a
invertKeyVal BuilderState
st.stores)