-- SPDX-FileCopyrightText: 2020 Tocqueville Group
--
-- SPDX-License-Identifier: LicenseRef-MIT-TQ

{- | Reimplementation of some syntax sugar.

You need the following module pragmas to make it work smoothly:

{-# LANGUAGE NoApplicativeDo, RebindableSyntax #-}
{-# OPTIONS_GHC -Wno-unused-do-bind #-}

-}
module Lorentz.Rebinded
  ( (>>)
  , pure
  , return
  , ifThenElse
  , Condition (..)
  , (<.)
  , (>.)
  , (<=.)
  , (>=.)
  , (==.)
  , (/=.)
  , keepIfArgs

    -- * Re-exports required for RebindableSyntax
  , fromInteger
  , fromString
  , fromLabel
  ) where


import Prelude hiding (drop, swap, (>>), (>>=))

import Named ((:!))

import Lorentz.Arith
import Lorentz.Base
import Lorentz.Coercions
import Lorentz.Instr
import Lorentz.Macro
import Michelson.Typed.Arith
import Util.Label (Label)

-- | Aliases for '(#)' used by do-blocks.
(>>) :: (a :-> b) -> (b :-> c) -> (a :-> c)
>> :: (a :-> b) -> (b :-> c) -> a :-> c
(>>) = (a :-> b) -> (b :-> c) -> a :-> c
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
(#)

-- | Predicate for @if ... then .. else ...@ construction,
-- defines a kind of operation applied to the top elements of the current stack.
--
-- Type arguments mean:
-- 1. Input of @if@
-- 2. Left branch input
-- 3. Right branch input
-- 4. Output of branches
-- 5. Output of @if@
data Condition arg argl argr outb out where
  Holds :: Condition (Bool ': s) s s o o
  IsSome :: Condition (Maybe a ': s) (a ': s) s o o
  IsNone :: Condition (Maybe a ': s) s (a ': s) o o
  IsLeft :: Condition (Either l r ': s) (l ': s) (r ': s) o o
  IsRight :: Condition (Either l r ': s) (r ': s) (l ': s) o o
  IsCons :: Condition ([a] ': s) (a ': [a] ': s) s o o
  IsNil :: Condition ([a] ': s) s (a ': [a] ': s) o o

  IsZero :: (UnaryArithOpHs Eq' a, UnaryArithResHs Eq' a ~ Bool)
         => Condition (a ': s) s s o o
  IsNotZero :: (UnaryArithOpHs Eq' a, UnaryArithResHs Eq' a ~ Bool)
         => Condition (a ': s) s s o o

  IsEq :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsNeq :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsLt :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsGt :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsLe :: NiceComparable a => Condition (a ': a ': s) s s o o
  IsGe :: NiceComparable a => Condition (a ': a ': s) s s o o

  -- | Explicitly named binary condition, to ensure proper order of
  -- stack arguments.
  NamedBinCondition ::
    Condition (a ': a ': s) s s o o ->
    Label n1 -> Label n2 ->
    Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o

  -- | Provide the compared arguments to @if@ branches.
  PreserveArgsBinCondition ::
    (forall st o. Condition (a ': b ': st) st st o o) ->
    Condition (a ': b ': s) (a ': b ': s) (a ': b ': s) (a ': b ': s) s

-- | Defines semantics of @if ... then ... else ...@ construction.
ifThenElse
  :: Condition arg argl argr outb out
  -> (argl :-> outb) -> (argr :-> outb) -> (arg :-> out)
ifThenElse :: Condition arg argl argr outb out
-> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
ifThenElse = \case
  Holds -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall (s :: [*]) (s' :: [*]).
(s :-> s') -> (s :-> s') -> (Bool & s) :-> s'
if_
  IsSome -> ((argr :-> out)
 -> ((a & argr) :-> out) -> (Maybe a & argr) :-> out)
-> ((a & argr) :-> out)
-> (argr :-> out)
-> (Maybe a & argr) :-> out
forall a b c. (a -> b -> c) -> b -> a -> c
flip (argr :-> out) -> ((a & argr) :-> out) -> (Maybe a & argr) :-> out
forall (s :: [*]) (s' :: [*]) a.
(s :-> s') -> ((a & s) :-> s') -> (Maybe a & s) :-> s'
ifNone
  IsNone -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall (s :: [*]) (s' :: [*]) a.
(s :-> s') -> ((a & s) :-> s') -> (Maybe a & s) :-> s'
ifNone
  IsLeft -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]) b.
((a & s) :-> s') -> ((b & s) :-> s') -> (Either a b & s) :-> s'
ifLeft
  IsRight -> (((l & s) :-> out)
 -> ((r & s) :-> out) -> (Either l r & s) :-> out)
-> ((r & s) :-> out)
-> ((l & s) :-> out)
-> (Either l r & s) :-> out
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((l & s) :-> out) -> ((r & s) :-> out) -> (Either l r & s) :-> out
forall a (s :: [*]) (s' :: [*]) b.
((a & s) :-> s') -> ((b & s) :-> s') -> (Either a b & s) :-> s'
ifLeft
  IsCons -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]).
((a & (List a & s)) :-> s') -> (s :-> s') -> (List a & s) :-> s'
ifCons
  IsNil -> (((a & (List a & argl)) :-> out)
 -> (argl :-> out) -> (List a & argl) :-> out)
-> (argl :-> out)
-> ((a & (List a & argl)) :-> out)
-> (List a & argl) :-> out
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((a & (List a & argl)) :-> out)
-> (argl :-> out) -> (List a & argl) :-> out
forall a (s :: [*]) (s' :: [*]).
((a & (List a & s)) :-> s') -> (s :-> s') -> (List a & s) :-> s'
ifCons

  IsZero -> \l :: argl :-> outb
l r :: argr :-> outb
r -> (a & argl) :-> (Bool : argl)
forall n (s :: [*]).
UnaryArithOpHs Eq' n =>
(n & s) :-> (UnaryArithResHs Eq' n & s)
eq0 ((a & argl) :-> (Bool : argl))
-> ((Bool : argl) :-> outb) -> (a & argl) :-> outb
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (argl :-> outb) -> (argl :-> outb) -> (Bool : argl) :-> outb
forall (s :: [*]) (s' :: [*]).
(s :-> s') -> (s :-> s') -> (Bool & s) :-> s'
if_ argl :-> outb
l argl :-> outb
argr :-> outb
r
  IsNotZero -> \l :: argl :-> outb
l r :: argr :-> outb
r -> (a & argl) :-> (Bool : argr)
forall n (s :: [*]).
UnaryArithOpHs Eq' n =>
(n & s) :-> (UnaryArithResHs Eq' n & s)
eq0 ((a & argl) :-> (Bool : argr))
-> ((Bool : argr) :-> outb) -> (a & argl) :-> outb
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (argr :-> outb) -> (argr :-> outb) -> (Bool : argr) :-> outb
forall (s :: [*]) (s' :: [*]).
(s :-> s') -> (s :-> s') -> (Bool & s) :-> s'
if_ argr :-> outb
r argl :-> outb
argr :-> outb
l

  IsEq -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a & (a & s)) :-> s'
ifEq
  IsNeq -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a & (a & s)) :-> s'
ifNeq
  IsLt -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a & (a & s)) :-> s'
ifLt
  IsGt -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a & (a & s)) :-> s'
ifGt
  IsLe -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a & (a & s)) :-> s'
ifLe
  IsGe -> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
forall a (s :: [*]) (s' :: [*]).
NiceComparable a =>
(s :-> s') -> (s :-> s') -> (a & (a & s)) :-> s'
ifGe

  NamedBinCondition condition :: Condition (a : a : argl) argl argl outb outb
condition l1 :: Label n1
l1 l2 :: Label n2
l2 -> \l :: argl :-> outb
l r :: argr :-> outb
r ->
    Label n1
-> (NamedF Identity a n1 : NamedF Identity a n2 : argl)
   :-> (a : NamedF Identity a n2 : argl)
forall (name :: Symbol) a (s :: [*]).
Label name -> (NamedF Identity a name : s) :-> (a : s)
fromNamed Label n1
l1 ((NamedF Identity a n1 : NamedF Identity a n2 : argl)
 :-> (a : NamedF Identity a n2 : argl))
-> ((a : NamedF Identity a n2 : argl) :-> (a : a : argl))
-> (NamedF Identity a n1 : NamedF Identity a n2 : argl)
   :-> (a : a : argl)
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# ((NamedF Identity a n2 : argl) :-> (a : argl))
-> (a : NamedF Identity a n2 : argl) :-> (a : a : argl)
forall a (s :: [*]) (s' :: [*]).
HasCallStack =>
(s :-> s') -> (a & s) :-> (a & s')
dip (Label n2 -> (NamedF Identity a n2 : argl) :-> (a : argl)
forall (name :: Symbol) a (s :: [*]).
Label name -> (NamedF Identity a name : s) :-> (a : s)
fromNamed Label n2
l2) ((NamedF Identity a n1 : NamedF Identity a n2 : argl)
 :-> (a : a : argl))
-> ((a : a : argl) :-> outb)
-> (NamedF Identity a n1 : NamedF Identity a n2 : argl) :-> outb
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# Condition (a : a : argl) argl argl outb outb
-> (argl :-> outb) -> (argl :-> outb) -> (a : a : argl) :-> outb
forall (arg :: [*]) (argl :: [*]) (argr :: [*]) (outb :: [*])
       (out :: [*]).
Condition arg argl argr outb out
-> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
ifThenElse Condition (a : a : argl) argl argl outb outb
condition argl :-> outb
l argl :-> outb
argr :-> outb
r

  PreserveArgsBinCondition condition :: forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o
condition -> \l :: argl :-> outb
l r :: argr :-> outb
r ->
    ((b & out) :-> (b & (b & out)))
-> (a & (b & out)) :-> (a & (b & (b & out)))
forall a (s :: [*]) (s' :: [*]).
HasCallStack =>
(s :-> s') -> (a & s) :-> (a & s')
dip (b & out) :-> (b & (b & out))
forall a (s :: [*]). (a & s) :-> (a & (a & s))
dup ((a & (b & out)) :-> (a & (b & (b & out))))
-> ((a & (b & (b & out))) :-> (b & (a & (b & out))))
-> (a & (b & out)) :-> (b & (a & (b & out)))
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (a & (b & (b & out))) :-> (b & (a & (b & out)))
forall a b (s :: [*]). (a & (b & s)) :-> (b & (a & s))
swap ((a & (b & out)) :-> (b & (a & (b & out))))
-> ((b & (a & (b & out))) :-> (b & (a & (a & (b & out)))))
-> (a & (b & out)) :-> (b & (a & (a & (b & out))))
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# ((a & (b & out)) :-> (a & (a & (b & out))))
-> (b & (a & (b & out))) :-> (b & (a & (a & (b & out))))
forall a (s :: [*]) (s' :: [*]).
HasCallStack =>
(s :-> s') -> (a & s) :-> (a & s')
dip (a & (b & out)) :-> (a & (a & (b & out)))
forall a (s :: [*]). (a & s) :-> (a & (a & s))
dup ((a & (b & out)) :-> (b & (a & (a & (b & out)))))
-> ((b & (a & (a & (b & out)))) :-> (a & (b & (a & (b & out)))))
-> (a & (b & out)) :-> (a & (b & (a & (b & out))))
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (b & (a & (a & (b & out)))) :-> (a & (b & (a & (b & out))))
forall a b (s :: [*]). (a & (b & s)) :-> (b & (a & s))
swap ((a & (b & out)) :-> (a & (b & (a & (b & out)))))
-> ((a & (b & (a & (b & out)))) :-> out) -> (a & (b & out)) :-> out
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
#
    Condition (a : b : argl) argl argl out out
-> (argl :-> out) -> (argl :-> out) -> (a : b : argl) :-> out
forall (arg :: [*]) (argl :: [*]) (argr :: [*]) (outb :: [*])
       (out :: [*]).
Condition arg argl argr outb out
-> (argl :-> outb) -> (argr :-> outb) -> arg :-> out
ifThenElse Condition (a : b : argl) argl argl out out
forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o
condition
      -- since this pattern is commonly used when one of the branches fails,
      -- it's essential to @drop@ within branches, not after @if@ - @drop@s
      -- appearing to be dead code will be cut off
      (argl :-> outb
l (argl :-> outb) -> (outb :-> (b & out)) -> argl :-> (b & out)
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# outb :-> (b & out)
forall a (s :: [*]). (a & s) :-> s
drop (argl :-> (b & out)) -> ((b & out) :-> out) -> argl :-> out
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (b & out) :-> out
forall a (s :: [*]). (a & s) :-> s
drop)
      (argr :-> outb
r (argr :-> outb) -> (outb :-> (b & out)) -> argr :-> (b & out)
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# outb :-> (b & out)
forall a (s :: [*]). (a & s) :-> s
drop (argr :-> (b & out)) -> ((b & out) :-> out) -> argr :-> out
forall (a :: [*]) (b :: [*]) (c :: [*]).
(a :-> b) -> (b :-> c) -> a :-> c
# (b & out) :-> out
forall a (s :: [*]). (a & s) :-> s
drop)

-- | Named version of 'IsLt'.
--
-- In this and similar operators you provide names of accepted stack operands as
-- a safety measure of that they go in the expected order.
infix 4 <.
(<.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
<. :: Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(<.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsLt

-- | Named version of 'IsGt'.
infix 4 >.
(>.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
>. :: Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(>.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsGt

-- | Named version of 'IsLe'.
infix 4 <=.
(<=.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
<=. :: Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(<=.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsLe

-- | Named version of 'IsGe'.
infix 4 >=.
(>=.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
>=. :: Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(>=.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsGe

-- | Named version of 'IsEq'.
infix 4 ==.
(==.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
==. :: Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(==.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsEq

-- | Named version of 'IsNeq'.
infix 4 /=.
(/=.)
  :: NiceComparable a
  => Label n1 -> Label n2
  -> Condition ((n1 :! a) ': (n2 :! a) ': s) s s o o
/=. :: Label n1
-> Label n2 -> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
(/=.) = Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
forall a (s :: [*]) (o :: [*]) (n1 :: Symbol) (n2 :: Symbol).
Condition (a : a : s) s s o o
-> Label n1
-> Label n2
-> Condition ((n1 :! a) : (n2 :! a) : s) s s o o
NamedBinCondition Condition (a : a : s) s s o o
forall a (s :: [*]) (o :: [*]).
NiceComparable a =>
Condition (a : a : s) s s o o
IsNeq

-- | Condition modifier, makes stack operands of binary comparison to be
-- available within @if@ branches.
keepIfArgs
  :: (forall st o. Condition (a ': b ': st) st st o o)
  -> Condition (a ': b ': s) (a ': b ': s) (a ': b ': s) (a ': b ': s) s
keepIfArgs :: (forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o)
-> Condition (a : b : s) (a : b : s) (a : b : s) (a : b : s) s
keepIfArgs = (forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o)
-> Condition (a : b : s) (a : b : s) (a : b : s) (a : b : s) s
forall a b (s :: [*]).
(forall (st :: [*]) (o :: [*]). Condition (a : b : st) st st o o)
-> Condition (a : b : s) (a : b : s) (a : b : s) (a : b : s) s
PreserveArgsBinCondition