{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}

module Arithmetic.Lte
  ( -- * Special Inequalities
    zero
  , reflexive
  , reflexive#

    -- * Substitution
  , substituteL
  , substituteR

    -- * Increment
  , incrementL
  , incrementL#
  , incrementR
  , incrementR#

    -- * Decrement
  , decrementL
  , decrementL#
  , decrementR
  , decrementR#

    -- * Weaken
  , weakenL
  , weakenL#
  , weakenR
  , weakenR#

    -- * Composition
  , transitive
  , transitive#
  , plus
  , plus#

    -- * Convert Strict Inequality
  , fromStrict
  , fromStrict#
  , fromStrictSucc
  , fromStrictSucc#

    -- * Integration with GHC solver
  , constant

    -- * Lift and Unlift
  , lift
  , unlift
  ) where

import Arithmetic.Unsafe (type (:=:) (Eq), type (<) (Lt), type (<#), type (<=) (Lte), type (<=#) (Lte#))
import GHC.TypeNats (CmpNat, type (+))

import qualified GHC.TypeNats as GHC

{- | Replace the left-hand side of a strict inequality
with an equal number.
-}
substituteL :: (b :=: c) -> (b <= a) -> (c <= a)
{-# INLINE substituteL #-}
substituteL :: forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (b <= a) -> c <= a
substituteL b :=: c
Eq b <= a
Lte = c <= a
forall (a :: Nat) (b :: Nat). a <= b
Lte

{- | Replace the right-hand side of a strict inequality
with an equal number.
-}
substituteR :: (b :=: c) -> (a <= b) -> (a <= c)
{-# INLINE substituteR #-}
substituteR :: forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (a <= b) -> a <= c
substituteR b :=: c
Eq a <= b
Lte = a <= c
forall (a :: Nat) (b :: Nat). a <= b
Lte

-- | Add two inequalities.
plus :: (a <= b) -> (c <= d) -> (a + c <= b + d)
{-# INLINE plus #-}
plus :: forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a <= b) -> (c <= d) -> (a + c) <= (b + d)
plus a <= b
Lte c <= d
Lte = (a + c) <= (b + d)
forall (a :: Nat) (b :: Nat). a <= b
Lte

plus# :: (a <=# b) -> (c <=# d) -> (a + c <=# b + d)
{-# INLINE plus# #-}
plus# :: forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a <=# b) -> (c <=# d) -> (a + c) <=# (b + d)
plus# a <=# b
_ c <=# d
_ = (# #) -> (a + c) <=# (b + d)
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

-- | Compose two inequalities using transitivity.
transitive :: (a <= b) -> (b <= c) -> (a <= c)
{-# INLINE transitive #-}
transitive :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a <= b) -> (b <= c) -> a <= c
transitive a <= b
Lte b <= c
Lte = a <= c
forall (a :: Nat) (b :: Nat). a <= b
Lte

transitive# :: (a <=# b) -> (b <=# c) -> (a <=# c)
{-# INLINE transitive# #-}
transitive# :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a <=# b) -> (b <=# c) -> a <=# c
transitive# a <=# b
_ b <=# c
_ = (# #) -> a <=# c
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

-- | Any number is less-than-or-equal-to itself.
reflexive :: a <= a
{-# INLINE reflexive #-}
reflexive :: forall (a :: Nat). a <= a
reflexive = a <= a
forall (a :: Nat) (b :: Nat). a <= b
Lte

reflexive# :: (# #) -> a <=# a
{-# INLINE reflexive# #-}
reflexive# :: forall (a :: Nat). (# #) -> a <=# a
reflexive# (# #)
_ = (# #) -> a <=# a
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

{- | Add a constant to the left-hand side of both sides of
the inequality.
-}
incrementL ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <= b) ->
  (c + a <= c + b)
{-# INLINE incrementL #-}
incrementL :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a <= b) -> (c + a) <= (c + b)
incrementL a <= b
Lte = (c + a) <= (c + b)
forall (a :: Nat) (b :: Nat). a <= b
Lte

incrementL# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <=# b) ->
  (c + a <=# c + b)
{-# INLINE incrementL# #-}
incrementL# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a <=# b) -> (c + a) <=# (c + b)
incrementL# a <=# b
_ = (# #) -> (c + a) <=# (c + b)
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

{- | Add a constant to the right-hand side of both sides of
the inequality.
-}
incrementR ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <= b) ->
  (a + c <= b + c)
{-# INLINE incrementR #-}
incrementR :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a <= b) -> (a + c) <= (b + c)
incrementR a <= b
Lte = (a + c) <= (b + c)
forall (a :: Nat) (b :: Nat). a <= b
Lte

incrementR# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <=# b) ->
  (a + c <=# b + c)
{-# INLINE incrementR# #-}
incrementR# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a <=# b) -> (a + c) <=# (b + c)
incrementR# a <=# b
_ = (# #) -> (a + c) <=# (b + c)
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

{- | Add a constant to the left-hand side of the right-hand side of
the inequality.
-}
weakenL ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <= b) ->
  (a <= c + b)
{-# INLINE weakenL #-}
weakenL :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a <= b) -> a <= (c + b)
weakenL a <= b
Lte = a <= (c + b)
forall (a :: Nat) (b :: Nat). a <= b
Lte

weakenL# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <=# b) ->
  (a <=# c + b)
{-# INLINE weakenL# #-}
weakenL# :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a <=# b) -> a <=# (c + b)
weakenL# a <=# b
_ = (# #) -> a <=# (c + b)
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

{- | Add a constant to the right-hand side of the right-hand side of
the inequality.
-}
weakenR ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <= b) ->
  (a <= b + c)
{-# INLINE weakenR #-}
weakenR :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a <= b) -> a <= (b + c)
weakenR a <= b
Lte = a <= (b + c)
forall (a :: Nat) (b :: Nat). a <= b
Lte

weakenR# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <=# b) ->
  (a <=# b + c)
{-# INLINE weakenR# #-}
weakenR# :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a <=# b) -> a <=# (b + c)
weakenR# a <=# b
_ = (# #) -> a <=# (b + c)
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

{- | Subtract a constant from the left-hand side of both sides of
the inequality. This is the opposite of 'incrementL'.
-}
decrementL ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (c + a <= c + b) ->
  (a <= b)
{-# INLINE decrementL #-}
decrementL :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((c + a) <= (c + b)) -> a <= b
decrementL (c + a) <= (c + b)
Lte = a <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte

decrementL# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (c + a <=# c + b) ->
  (a <=# b)
{-# INLINE decrementL# #-}
decrementL# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((c + a) <=# (c + b)) -> a <=# b
decrementL# (c + a) <=# (c + b)
_ = (# #) -> a <=# b
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

{- | Subtract a constant from the right-hand side of both sides of
the inequality. This is the opposite of 'incrementR'.
-}
decrementR ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a + c <= b + c) ->
  (a <= b)
{-# INLINE decrementR #-}
decrementR :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((a + c) <= (b + c)) -> a <= b
decrementR (a + c) <= (b + c)
Lte = a <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte

decrementR# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a + c <=# b + c) ->
  (a <=# b)
{-# INLINE decrementR# #-}
decrementR# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((a + c) <=# (b + c)) -> a <=# b
decrementR# (a + c) <=# (b + c)
_ = (# #) -> a <=# b
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

-- | Weaken a strict inequality to a non-strict inequality.
fromStrict :: (a < b) -> (a <= b)
{-# INLINE fromStrict #-}
fromStrict :: forall (a :: Nat) (b :: Nat). (a < b) -> a <= b
fromStrict a < b
Lt = a <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte

fromStrict# :: (a <# b) -> (a <=# b)
{-# INLINE fromStrict# #-}
fromStrict# :: forall (a :: Nat) (b :: Nat). (a <# b) -> a <=# b
fromStrict# a <# b
_ = (# #) -> a <=# b
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

{- | Weaken a strict inequality to a non-strict inequality, incrementing
the right-hand argument by one.
-}
fromStrictSucc :: (a < b) -> (a + 1 <= b)
{-# INLINE fromStrictSucc #-}
fromStrictSucc :: forall (a :: Nat) (b :: Nat). (a < b) -> (a + 1) <= b
fromStrictSucc a < b
Lt = (a + 1) <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte

fromStrictSucc# :: (a <# b) -> (a + 1 <=# b)
{-# INLINE fromStrictSucc# #-}
fromStrictSucc# :: forall (a :: Nat) (b :: Nat). (a <# b) -> (a + 1) <=# b
fromStrictSucc# a <# b
_ = (# #) -> (a + 1) <=# b
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

-- | Zero is less-than-or-equal-to any number.
zero :: 0 <= a
{-# INLINE zero #-}
zero :: forall (a :: Nat). 0 <= a
zero = 0 <= a
forall (a :: Nat) (b :: Nat). a <= b
Lte

{- | Use GHC's built-in type-level arithmetic to prove
that one number is less-than-or-equal-to another. The type-checker
only reduces 'CmpNat' if both arguments are constants.
-}
constant :: forall a b. (IsLte (CmpNat a b) ~ 'True) => (a <= b)
{-# INLINE constant #-}
constant :: forall (a :: Nat) (b :: Nat).
(IsLte (CmpNat a b) ~ 'True) =>
a <= b
constant = a <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte

type family IsLte (o :: Ordering) :: Bool where
  IsLte 'GT = 'False
  IsLte 'LT = 'True
  IsLte 'EQ = 'True

unlift :: (a <= b) -> (a <=# b)
unlift :: forall (a :: Nat) (b :: Nat). (a <= b) -> a <=# b
unlift a <= b
_ = (# #) -> a <=# b
forall (a :: Nat) (b :: Nat). (# #) -> a <=# b
Lte# (# #)

lift :: (a <=# b) -> (a <= b)
lift :: forall (a :: Nat) (b :: Nat). (a <=# b) -> a <= b
lift a <=# b
_ = a <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte