-- | General-purpose utility functions for typed types.

module Michelson.Typed.Util
  ( DfsSettings (..)
  , CtorEffectsApp (..)
  , dfsInstr
  , dfsFoldInstr
  , dfsModifyInstr

  -- * Changing instruction tree structure
  , linearizeLeft

  -- * Value analysis
  , isStringValue
  , isBytesValue
  , allAtomicValues
  ) where

import Prelude hiding (Ordering(..))

import Data.Default (Default(..))
import qualified Text.Show

import Michelson.Text (MText)
import Michelson.Typed.Aliases
import Michelson.Typed.Instr
import Michelson.Typed.Value

-- | Options for 'dfsInstr'.
data DfsSettings x = DfsSettings
  { dsGoToValues :: Bool
    -- ^ Whether 'dfsInstr' function should go into values which contain other
    -- instructions: lambdas and constant contracts
    -- (which can be passed to @CREATE_CONTRACT@).
  , dsCtorEffectsApp :: CtorEffectsApp x
    -- ^ How do we handle intermediate nodes in instruction tree.
  } deriving stock (Show)

-- | Describes how intermediate nodes in instruction tree are accounted.
data CtorEffectsApp x = CtorEffectsApp
  { ceaName :: Text
    -- ^ Name of this way.
  , ceaApplyEffects
      :: forall i o. Semigroup x => x -> x -> Instr i o -> (Instr i o, x)
    -- ^ This function accepts:
    -- 1. Effects gathered after applying @step@ to node's children, but
    -- before applying it to the node itself.
    -- 2. Effects gathered after applying @step@ to the given intermediate node.
    -- 3. Instruction resulting after all modifications produced by @step@.
  }

instance Show (CtorEffectsApp x) where
  show CtorEffectsApp{..} = show ceaName

-- | Gather effects first for children nodes, then for their parents.
ceaBottomToTop :: CtorEffectsApp x
ceaBottomToTop = CtorEffectsApp
  { ceaName = "Apply after"
  , ceaApplyEffects =
      \effBefore effAfter instr -> (instr, effBefore <> effAfter)
  }

instance Default (DfsSettings x) where
  def = DfsSettings
    { dsGoToValues = False
    , dsCtorEffectsApp = ceaBottomToTop
    }

-- | Traverse a typed instruction in depth-first order.
-- '<>' is used to concatenate intermediate results.
-- Each instructions can be changed using the supplied @step@ function.
-- It does not consider extra instructions (not present in Michelson).
dfsInstr ::
     forall x inp out. Semigroup x
  => DfsSettings x
  -> (forall i o. Instr i o -> (Instr i o, x))
  -> Instr inp out
  -> (Instr inp out, x)
dfsInstr settings@DfsSettings{..} step i =
  case i of
    Seq i1 i2 -> recursion2 Seq i1 i2
    InstrWithNotes notes i1 -> recursion1 (InstrWithNotes notes) i1
    FrameInstr p i1 -> recursion1 (FrameInstr p) i1
    Nested i1 -> recursion1 Nested i1
    DocGroup dg i1 -> recursion1 (DocGroup dg) i1
    IF_NONE i1 i2 -> recursion2 IF_NONE i1 i2
    IF_LEFT i1 i2 -> recursion2 IF_LEFT i1 i2
    IF_CONS i1 i2 -> recursion2 IF_CONS i1 i2
    IF i1 i2 -> recursion2 IF i1 i2
    MAP i1 -> recursion1 MAP i1
    ITER i1 -> recursion1 ITER i1
    LOOP i1 -> recursion1 LOOP i1
    LOOP_LEFT i1 -> recursion1 LOOP_LEFT i1
    DIP i1 -> recursion1 DIP i1
    DIPN s i1 -> recursion1 (DIPN s) i1
    PUSH v -> case v of
      VLam i1 | dsGoToValues ->
        -- Using 'analyzeInstrFailure' here (and in case below) is cheap
        -- (O(n) in total) because we never make it run over the same code twice
        recursion1 (PUSH . VLam . analyzeInstrFailure) (rfAnyInstr i1)
      _ -> step i
    LAMBDA (VLam i1)
      | dsGoToValues ->
          recursion1 (LAMBDA . VLam . analyzeInstrFailure) (rfAnyInstr i1)
      | otherwise -> step i
    CREATE_CONTRACT contract
      | dsGoToValues ->
        let updateContractCode code = CREATE_CONTRACT $ contract{ fcCode = code }
        in recursion1 updateContractCode $ fcCode contract
      | otherwise -> step i

    Nop{} -> step i
    Ext{} -> step i
    AnnCAR{} -> step i
    AnnCDR{} -> step i
    DROP{} -> step i
    DROPN{} -> step i
    DUP{} -> step i
    SWAP{} -> step i
    DIG{} -> step i
    DUG{} -> step i
    SOME{} -> step i
    NONE{} -> step i
    UNIT{} -> step i
    PAIR{} -> step i
    LEFT{} -> step i
    RIGHT{} -> step i
    NIL{} -> step i
    CONS{} -> step i
    SIZE{} -> step i
    EMPTY_SET{} -> step i
    EMPTY_MAP{} -> step i
    EMPTY_BIG_MAP{} -> step i
    MEM{} -> step i
    GET{} -> step i
    UPDATE{} -> step i
    EXEC{} -> step i
    APPLY{} -> step i
    FAILWITH{} -> step i
    CAST{} -> step i
    RENAME{} -> step i
    PACK{} -> step i
    UNPACK{} -> step i
    CONCAT{} -> step i
    CONCAT'{} -> step i
    SLICE{} -> step i
    ISNAT{} -> step i
    ADD{} -> step i
    SUB{} -> step i
    MUL{} -> step i
    EDIV{} -> step i
    ABS{} -> step i
    NEG{} -> step i
    LSL{} -> step i
    LSR{} -> step i
    OR{} -> step i
    AND{} -> step i
    XOR{} -> step i
    NOT{} -> step i
    COMPARE{} -> step i
    EQ{} -> step i
    NEQ{} -> step i
    LT{} -> step i
    GT{} -> step i
    LE{} -> step i
    GE{} -> step i
    INT{} -> step i
    SELF{} -> step i
    CONTRACT{} -> step i
    TRANSFER_TOKENS{} -> step i
    SET_DELEGATE{} -> step i
    IMPLICIT_ACCOUNT{} -> step i
    NOW{} -> step i
    AMOUNT{} -> step i
    BALANCE{} -> step i
    CHECK_SIGNATURE{} -> step i
    SHA256{} -> step i
    SHA512{} -> step i
    BLAKE2B{} -> step i
    HASH_KEY{} -> step i
    STEPS_TO_QUOTA{} -> step i
    SOURCE{} -> step i
    SENDER{} -> step i
    ADDRESS{} -> step i
    CHAIN_ID{} -> step i
  where
    recursion1 ::
      forall a b c d. (Instr a b -> Instr c d) -> Instr a b -> (Instr c d, x)
    recursion1 constructor i0 =
      let
        (innerI, innerX) = dfsInstr settings step i0
        (outerI, outerX) = step $ constructor innerI
      in ceaApplyEffects dsCtorEffectsApp innerX outerX outerI

    recursion2 ::
      forall i o i1 o1 i2 o2.
      (Instr i1 o1 -> Instr i2 o2 -> Instr i o) ->
      Instr i1 o1 -> Instr i2 o2 -> (Instr i o, x)
    recursion2 constructor i1 i2 =
      let
        (i1', x1) = dfsInstr settings step i1
        (i2', x2) = dfsInstr settings step i2
        (i', x) = step $ constructor i1' i2'
      in ceaApplyEffects dsCtorEffectsApp (x1 <> x2) x i'

-- | Specialization of 'dfsInstr' for case when changing the instruction is
-- not required.
dfsFoldInstr
  :: forall x inp out.
      (Semigroup x)
  => DfsSettings x
  -> (forall i o. Instr i o -> x)
  -> Instr inp out
  -> x
dfsFoldInstr settings step instr =
  snd $ dfsInstr settings (\i -> (i, step i)) instr

-- | Specialization of 'dfsInstr' which only modifies given instruction.
dfsModifyInstr
  :: DfsSettings ()
  -> (forall i o. Instr i o -> Instr i o)
  -> Instr inp out
  -> Instr inp out
dfsModifyInstr settings step instr =
  fst $ dfsInstr settings (\i -> (step i, ())) instr

-- | Check whether instruction fails at each execution path or have at least one
-- non-failing path.
--
-- This function assumes that given instruction contains no dead code
-- (contract with dead code cannot be valid Michelson contract) and may behave
-- in unexpected way if such is present. Term "dead code" includes instructions
-- which render into empty Michelson, like Morley extensions.
-- On the other hand, this function does not traverse the whole instruction tree;
-- performs fastest on left-growing combs.
--
-- Often we already have information about instruction failure, use this
-- function only in cases when this info is actually unavailable or hard
-- to use.
analyzeInstrFailure :: HasCallStack => Instr i o -> RemFail Instr i o
analyzeInstrFailure = go
  where
  go :: Instr i o -> RemFail Instr i o
  go = \case
    InstrWithNotes pn i -> case go i of
      RfNormal i0 ->
        RfNormal (InstrWithNotes pn i0)
      RfAlwaysFails i0 ->
        error $ "InstrWithNotes wraps always-failing instruction: " <> show i0
    FrameInstr s i -> case go i of
      RfNormal i0 ->
        RfNormal (FrameInstr s i0)
      RfAlwaysFails i0 ->
        error $ "FrameInstr wraps always-failing instruction: " <> show i0
    Seq a b -> Seq a `rfMapAnyInstr` go b
    Nop -> RfNormal Nop
    Ext e -> RfNormal (Ext e)
    Nested i -> Nested `rfMapAnyInstr` go i
    DocGroup g i -> DocGroup g `rfMapAnyInstr` go i

    IF_NONE l r -> rfMerge IF_NONE (go l) (go r)
    IF_LEFT l r -> rfMerge IF_LEFT (go l) (go r)
    IF_CONS l r -> rfMerge IF_CONS (go l) (go r)
    IF l r -> rfMerge IF (go l) (go r)

    i@MAP{} -> RfNormal i
    i@ITER{} -> RfNormal i
    i@LOOP{} -> RfNormal i
    i@LOOP_LEFT{} -> RfNormal i
    i@LAMBDA{} -> RfNormal i
    i@DIP{} -> RfNormal i
    i@DIPN{} -> RfNormal i

    i@AnnCAR{} -> RfNormal i
    i@AnnCDR{} -> RfNormal i
    i@DROP{} -> RfNormal i
    i@DROPN{} -> RfNormal i
    i@DUP{} -> RfNormal i
    i@SWAP{} -> RfNormal i
    i@DIG{} -> RfNormal i
    i@DUG{} -> RfNormal i
    i@PUSH{} -> RfNormal i
    i@SOME{} -> RfNormal i
    i@NONE{} -> RfNormal i
    i@UNIT{} -> RfNormal i
    i@PAIR{} -> RfNormal i
    i@LEFT{} -> RfNormal i
    i@RIGHT{} -> RfNormal i
    i@NIL{} -> RfNormal i
    i@CONS{} -> RfNormal i
    i@SIZE{} -> RfNormal i
    i@EMPTY_SET{} -> RfNormal i
    i@EMPTY_MAP{} -> RfNormal i
    i@EMPTY_BIG_MAP{} -> RfNormal i
    i@MEM{} -> RfNormal i
    i@GET{} -> RfNormal i
    i@UPDATE{} -> RfNormal i
    i@EXEC{} -> RfNormal i
    i@APPLY{} -> RfNormal i
    FAILWITH -> RfAlwaysFails FAILWITH
    i@CAST -> RfNormal i
    i@RENAME -> RfNormal i
    i@PACK -> RfNormal i
    i@UNPACK -> RfNormal i
    i@CONCAT -> RfNormal i
    i@CONCAT' -> RfNormal i
    i@SLICE -> RfNormal i
    i@ISNAT -> RfNormal i
    i@ADD -> RfNormal i
    i@SUB -> RfNormal i
    i@MUL -> RfNormal i
    i@EDIV -> RfNormal i
    i@ABS -> RfNormal i
    i@NEG -> RfNormal i
    i@LSL -> RfNormal i
    i@LSR -> RfNormal i
    i@OR -> RfNormal i
    i@AND -> RfNormal i
    i@XOR -> RfNormal i
    i@NOT -> RfNormal i
    i@COMPARE -> RfNormal i
    i@EQ -> RfNormal i
    i@NEQ -> RfNormal i
    i@LT -> RfNormal i
    i@GT -> RfNormal i
    i@LE -> RfNormal i
    i@GE -> RfNormal i
    i@INT -> RfNormal i
    i@SELF{} -> RfNormal i
    i@CONTRACT{} -> RfNormal i
    i@TRANSFER_TOKENS -> RfNormal i
    i@SET_DELEGATE -> RfNormal i
    i@CREATE_CONTRACT{} -> RfNormal i
    i@IMPLICIT_ACCOUNT -> RfNormal i
    i@NOW -> RfNormal i
    i@AMOUNT -> RfNormal i
    i@BALANCE -> RfNormal i
    i@CHECK_SIGNATURE -> RfNormal i
    i@SHA256 -> RfNormal i
    i@SHA512 -> RfNormal i
    i@BLAKE2B -> RfNormal i
    i@HASH_KEY -> RfNormal i
    i@STEPS_TO_QUOTA -> RfNormal i
    i@SOURCE -> RfNormal i
    i@SENDER -> RfNormal i
    i@ADDRESS -> RfNormal i
    i@CHAIN_ID -> RfNormal i

-- | There are many ways to represent a sequence of more than 2 instructions.
-- E. g. for `i1; i2; i3` it can be @Seq i1 $ Seq i2 i3@ or @Seq (Seq i1 i2) i3@.
-- This function enforces a particular structure. Specifically, it makes each
-- 'Seq' have a single instruction (i. e. not 'Seq') in its second argument.
-- This function also erases redundant 'Nop's.
linearizeLeft :: Instr inp out -> Instr inp out
linearizeLeft = linearizeLeftHelper False
  where
    -- In order to avoid quadratic performance we make a simple optimization.
    -- We track whether left argument of `Seq` is already linearized.
    -- If it is, we do not need to ever linearize it again.
    linearizeLeftHelper :: Bool -> Instr inp out -> Instr inp out
    linearizeLeftHelper isLeftInstrAlreadyLinear =
      \case
        Seq i1 (Seq i2 i3) ->
          linearizeLeftHelper True $
          Seq (linearizeLeftHelper isLeftInstrAlreadyLinear (Seq i1 i2)) i3
        -- `i2` is not a `Seq`, so we only need to linearize `i1`
        -- and connect it with `i2`.
        Seq i1 i2
          | isLeftInstrAlreadyLinear
          , Nop <- i2 -> i1
          | isLeftInstrAlreadyLinear -> Seq i1 i2
          | Nop <- i2 -> linearizeLeft i1
          | otherwise -> Seq (linearizeLeft i1) i2
        i -> i

----------------------------------------------------------------------------
-- Value analysis
----------------------------------------------------------------------------

-- | If value is a string, return the stored string.
isStringValue :: Value t -> Maybe MText
isStringValue =
  \case
    VC (CvString str) -> Just str
    _ -> Nothing

-- | If value is a bytestring, return the stored bytestring.
isBytesValue :: Value t -> Maybe ByteString
isBytesValue =
  \case
    VC (CvBytes bytes) -> Just bytes
    _ -> Nothing

-- | Takes a selector which checks whether an atomic value (i. e. that
-- can not contain another value) can be converted to something.
-- Recursively applies it to all atomic values in potentially
-- non-atomic value.  Collects extracted values in a list.
--
-- Perhaps one day we'll have `dfsValue`.
allAtomicValues ::
  forall t a. (forall t'. Value t' -> Maybe a) -> Value t -> [a]
allAtomicValues selector = go
  where
    go :: forall x. Value x -> [a]
    go = \case
      VList l -> foldMap go l
      VSet s -> foldMap (go . VC) s
      VPair (l, r) -> go l <> go r
      VOr e -> either go go e
      VMap m -> goMap m
      VBigMap m -> goMap m
      v -> maybeToList $ selector v

    goMap :: Map (CValue k) (Value v) -> [a]
    goMap m = foldMap (go . VC) (keys m) <> foldMap go (toList m)