-- SPDX-FileCopyrightText: 2023 Oxhead Alpha
-- SPDX-License-Identifier: LicenseRef-MIT-OA

-- | Utilities for writing optimizer rules.
module Morley.Michelson.Optimizer.Utils
  ( pattern (:#)
  , orRule
  , orSimpleRule
  , fixpoint
  , applyOnce
  , whileApplies
  , linearizeAndReapply
  ) where

import Morley.Michelson.Optimizer.Internal.Ruleset (Rule(..))
import Morley.Michelson.Typed.Instr hiding ((:#))

{- | This is a redefinition of @(:#)@ from "Morley.Michelson.Typed.Instr" that
is particularly useful for writing optimizer rules. When matching on an
instruction @x@ that isn't v'Seq', it behaves as if it matched on @x :# Nop@.

When constructing instructions using this pattern, @Nop@ is automatically
removed.

To understand why this is useful, consider that a given instruction sequence can
appear in the middle of a sequence, and then @a :# b :# tail@ will match, or at
the end of the sequence, and then @a :# b@ will match. Thus, to cover all cases
one would have to duplicate most rules.

This definition of @(:#)@ makes it so we can always assume there's a @tail@.
However, we don't need it when matching on single instructions.

Thus, the rule of thumb is this: if you're matching on a single instruction,
everything is fine. If you're matching on a sequence, i.e. using (:#), then
always match on tail, e.g.

@
dupSwap2dup :: Rule
dupSwap2dup = Rule $ \case
  DUP :# SWAP :# c -> Just $ DUP :# c
  _                -> Nothing
@

But this works, too:

@
ifNopNop2Drop :: Rule
ifNopNop2Drop = Rule $ \case
  IF Nop Nop -> Just DROP
  _          -> Nothing
@

-}
pattern (:#) :: Instr inp b -> Instr b out -> Instr inp out
pattern l $m:# :: forall {r} {inp :: [T]} {out :: [T]}.
Instr inp out
-> (forall {b :: [T]}. Instr inp b -> Instr b out -> r)
-> ((# #) -> r)
-> r
$b:# :: forall (inp :: [T]) (out :: [T]) (b :: [T]).
Instr inp b -> Instr b out -> Instr inp out
:# r <- (\case { x :: Instr inp out
x@Seq{} -> Instr inp out
x; Instr inp out
x -> Instr inp out -> Instr out out -> Instr inp out
forall (inp :: [T]) (b :: [T]) (out :: [T]).
Instr inp b -> Instr b out -> Instr inp out
Seq Instr inp out
x Instr out out
forall (inp :: [T]). Instr inp inp
Nop } -> Seq l r)
  where Instr inp b
l :# Instr b out
Nop = Instr inp out
Instr inp b
l
        Instr inp b
Nop :# Instr b out
r = Instr inp out
Instr b out
r
        Instr inp b
l :# Instr b out
r = Instr inp b -> Instr b out -> Instr inp out
forall (inp :: [T]) (b :: [T]) (out :: [T]).
Instr inp b -> Instr b out -> Instr inp out
Seq Instr inp b
l Instr b out
r
infixr 8 :#

-- | Combine two rule fixpoints.
orRule :: (Rule -> Rule) -> (Rule -> Rule) -> (Rule -> Rule)
orRule :: (Rule -> Rule) -> (Rule -> Rule) -> Rule -> Rule
orRule Rule -> Rule
l Rule -> Rule
r Rule
topl = (forall (inp :: [T]) (out :: [T]).
 Instr inp out -> Maybe (Instr inp out))
-> Rule
Rule ((forall (inp :: [T]) (out :: [T]).
  Instr inp out -> Maybe (Instr inp out))
 -> Rule)
-> (forall (inp :: [T]) (out :: [T]).
    Instr inp out -> Maybe (Instr inp out))
-> Rule
forall a b. (a -> b) -> a -> b
$ \Instr inp out
instr ->
  (Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule (Rule -> Rule
l Rule
topl) (Instr inp out -> Maybe (Instr inp out))
-> Instr inp out -> Maybe (Instr inp out)
forall a b. (a -> b) -> a -> b
$ Instr inp out
instr) Maybe (Instr inp out)
-> Maybe (Instr inp out) -> Maybe (Instr inp out)
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule (Rule -> Rule
r Rule
topl) (Instr inp out -> Maybe (Instr inp out))
-> Instr inp out -> Maybe (Instr inp out)
forall a b. (a -> b) -> a -> b
$ Instr inp out
instr)

-- | Combine a rule fixpoint and a simple rule.
orSimpleRule :: (Rule -> Rule) -> Rule -> (Rule -> Rule)
orSimpleRule :: (Rule -> Rule) -> Rule -> Rule -> Rule
orSimpleRule Rule -> Rule
l Rule
r Rule
topl = (forall (inp :: [T]) (out :: [T]).
 Instr inp out -> Maybe (Instr inp out))
-> Rule
Rule ((forall (inp :: [T]) (out :: [T]).
  Instr inp out -> Maybe (Instr inp out))
 -> Rule)
-> (forall (inp :: [T]) (out :: [T]).
    Instr inp out -> Maybe (Instr inp out))
-> Rule
forall a b. (a -> b) -> a -> b
$ \Instr inp out
instr ->
  (Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule (Rule -> Rule
l Rule
topl) (Instr inp out -> Maybe (Instr inp out))
-> Instr inp out -> Maybe (Instr inp out)
forall a b. (a -> b) -> a -> b
$ Instr inp out
instr) Maybe (Instr inp out)
-> Maybe (Instr inp out) -> Maybe (Instr inp out)
forall a. Maybe a -> Maybe a -> Maybe a
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule Rule
r (Instr inp out -> Maybe (Instr inp out))
-> Instr inp out -> Maybe (Instr inp out)
forall a b. (a -> b) -> a -> b
$ Instr inp out
instr)

-- | Turn rule fixpoint into rule.
fixpoint :: (Rule -> Rule) -> Rule
fixpoint :: (Rule -> Rule) -> Rule
fixpoint Rule -> Rule
r = Rule
go
  where
    go :: Rule
    go :: Rule
go = Rule -> Rule
whileApplies (Rule -> Rule
r Rule
go)

-- | Apply the rule once, if it fails, return the instruction unmodified.
--
-- Also returns a flag showing whether the rule succeeded or not.
applyOnce :: Rule -> Instr inp out -> (Any Bool, Instr inp out)
applyOnce :: forall (inp :: [T]) (out :: [T]).
Rule -> Instr inp out -> (Any Bool, Instr inp out)
applyOnce Rule
r Instr inp out
i = (Any Bool, Instr inp out)
-> (Instr inp out -> (Any Bool, Instr inp out))
-> Maybe (Instr inp out)
-> (Any Bool, Instr inp out)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Instr inp out -> (Any Bool, Instr inp out)
forall a. a -> (Any Bool, a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Instr inp out
i) (Bool -> Any Bool
forall a. a -> Any a
Any Bool
True,) (Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule Rule
r (Instr inp out -> Maybe (Instr inp out))
-> Instr inp out -> Maybe (Instr inp out)
forall a b. (a -> b) -> a -> b
$ Instr inp out
i)

-- | Apply a rule to the same code, until it fails.
whileApplies :: Rule -> Rule
whileApplies :: Rule -> Rule
whileApplies Rule
r = (forall (inp :: [T]) (out :: [T]).
 Instr inp out -> Maybe (Instr inp out))
-> Rule
Rule ((forall (inp :: [T]) (out :: [T]).
  Instr inp out -> Maybe (Instr inp out))
 -> Rule)
-> (forall (inp :: [T]) (out :: [T]).
    Instr inp out -> Maybe (Instr inp out))
-> Rule
forall a b. (a -> b) -> a -> b
$ Instr inp out -> Maybe (Instr inp out)
forall (inp :: [T]) (out :: [T]).
Instr inp out -> Maybe (Instr inp out)
go (Instr inp out -> Maybe (Instr inp out))
-> (Instr inp out -> Maybe (Instr inp out))
-> Instr inp out
-> Maybe (Instr inp out)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule Rule
r
  -- NB: if the rule doesn't apply even once, we want to return Nothing here,
  -- hence it's first applied above and only if successful goes into recursion.
  where
    go :: Instr inp out -> Maybe (Instr inp out)
    go :: forall (inp :: [T]) (out :: [T]).
Instr inp out -> Maybe (Instr inp out)
go Instr inp out
i = Maybe (Instr inp out)
-> (Instr inp out -> Maybe (Instr inp out))
-> Maybe (Instr inp out)
-> Maybe (Instr inp out)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Instr inp out -> Maybe (Instr inp out)
forall a. a -> Maybe a
Just Instr inp out
i) Instr inp out -> Maybe (Instr inp out)
forall (inp :: [T]) (out :: [T]).
Instr inp out -> Maybe (Instr inp out)
go (Rule
-> forall (inp :: [T]) (out :: [T]).
   Instr inp out -> Maybe (Instr inp out)
unRule Rule
r Instr inp out
i)

-- | Append LHS of v'Seq' to RHS and re-run pointwise ocRuleset at each point.
--   That might cause reinvocation of this function (see @defaultRule@'),
--   but effectively this ensures it will flatten any v'Seq'-tree right-to-left,
--   while evaling no more than once on each node.
--
--   The reason this function invokes ocRuleset is when you append an instr
--   to already-optimised RHS of v'Seq', you might get an optimisable tree.
--
--   The argument is a local, non-structurally-recursive ocRuleset.
linearizeAndReapply :: Rule -> Instr inp out -> Instr inp out
linearizeAndReapply :: forall (inp :: [T]) (out :: [T]).
Rule -> Instr inp out -> Instr inp out
linearizeAndReapply Rule
restart = (Any Bool, Instr inp out) -> Instr inp out
forall a b. (a, b) -> b
snd ((Any Bool, Instr inp out) -> Instr inp out)
-> (Instr inp out -> (Any Bool, Instr inp out))
-> Instr inp out
-> Instr inp out
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
  Seq (Seq Instr inp b
a Instr b b
b) Instr b out
c ->
    Rule -> Instr inp out -> (Any Bool, Instr inp out)
forall (inp :: [T]) (out :: [T]).
Rule -> Instr inp out -> (Any Bool, Instr inp out)
applyOnce Rule
restart (Instr inp out -> (Any Bool, Instr inp out))
-> Instr inp out -> (Any Bool, Instr inp out)
forall a b. (a -> b) -> a -> b
$ Instr inp b -> Instr b out -> Instr inp out
forall (inp :: [T]) (b :: [T]) (out :: [T]).
Instr inp b -> Instr b out -> Instr inp out
Seq Instr inp b
a (Rule -> Instr b out -> Instr b out
forall (inp :: [T]) (out :: [T]).
Rule -> Instr inp out -> Instr inp out
linearizeAndReapply Rule
restart (Instr b b -> Instr b out -> Instr b out
forall (inp :: [T]) (b :: [T]) (out :: [T]).
Instr inp b -> Instr b out -> Instr inp out
Seq Instr b b
b Instr b out
c))

  Instr inp out
other -> Rule -> Instr inp out -> (Any Bool, Instr inp out)
forall (inp :: [T]) (out :: [T]).
Rule -> Instr inp out -> (Any Bool, Instr inp out)
applyOnce Rule
restart Instr inp out
other