{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}

module Arithmetic.Lt
  ( -- * Special Inequalities
    zero
  , zero#

    -- * Substitution
  , substituteL
  , substituteR

    -- * Increment
  , incrementL
  , incrementL#
  , incrementR
  , incrementR#

    -- * Decrement
  , decrementL
  , decrementL#
  , decrementR
  , decrementR#

    -- * Weaken
  , weakenL
  , weakenL#
  , weakenR
  , weakenR#
  , weakenLhsL#
  , weakenLhsR#

    -- * Composition
  , plus
  , plus#
  , transitive
  , transitive#
  , transitiveNonstrictL
  , transitiveNonstrictL#
  , transitiveNonstrictR
  , transitiveNonstrictR#

    -- * Multiplication and Division
  , reciprocalA
  , reciprocalB

    -- * Convert to Inequality
  , toLteL
  , toLteR

    -- * Absurdities
  , absurd

    -- * Integration with GHC solver
  , constant
  , constant#

    -- * Lift and Unlift
  , lift
  , unlift
  ) where

import Arithmetic.Unsafe (type (:=:) (Eq), type (<) (Lt), type (<#) (Lt#), type (<=) (Lte), type (<=#))
import GHC.TypeNats (CmpNat, type (+))

import qualified GHC.TypeNats as GHC

toLteR :: (a < b) -> (a + 1 <= b)
{-# INLINE toLteR #-}
toLteR :: forall (a :: Nat) (b :: Nat). (a < b) -> (a + 1) <= b
toLteR a < b
Lt = (a + 1) <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte

toLteL :: (a < b) -> (1 + a <= b)
{-# INLINE toLteL #-}
toLteL :: forall (a :: Nat) (b :: Nat). (a < b) -> (1 + a) <= b
toLteL a < b
Lt = (1 + a) <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte

{- | Replace the left-hand side of a strict inequality
with an equal number.
-}
substituteL :: (b :=: c) -> (b < a) -> (c < a)
{-# INLINE substituteL #-}
substituteL :: forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (b < a) -> c < a
substituteL b :=: c
Eq b < a
Lt = c < a
forall (a :: Nat) (b :: Nat). a < b
Lt

{- | Replace the right-hand side of a strict inequality
with an equal number.
-}
substituteR :: (b :=: c) -> (a < b) -> (a < c)
{-# INLINE substituteR #-}
substituteR :: forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (a < b) -> a < c
substituteR b :=: c
Eq a < b
Lt = a < c
forall (a :: Nat) (b :: Nat). a < b
Lt

-- | Add a strict inequality to a nonstrict inequality.
plus :: (a < b) -> (c <= d) -> (a + c < b + d)
{-# INLINE plus #-}
plus :: forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a < b) -> (c <= d) -> (a + c) < (b + d)
plus a < b
Lt c <= d
Lte = (a + c) < (b + d)
forall (a :: Nat) (b :: Nat). a < b
Lt

plus# :: (a <# b) -> (c <=# d) -> (a + c <# b + d)
{-# INLINE plus# #-}
plus# :: forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a <# b) -> (c <=# d) -> (a + c) <# (b + d)
plus# a <# b
_ c <=# d
_ = (# #) -> (a + c) <# (b + d)
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

{- | Add a constant to the left-hand side of both sides of
the strict inequality.
-}
incrementL ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a < b) ->
  (c + a < c + b)
{-# INLINE incrementL #-}
incrementL :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a < b) -> (c + a) < (c + b)
incrementL a < b
Lt = (c + a) < (c + b)
forall (a :: Nat) (b :: Nat). a < b
Lt

incrementL# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <# b) ->
  (c + a <# c + b)
{-# INLINE incrementL# #-}
incrementL# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a <# b) -> (c + a) <# (c + b)
incrementL# a <# b
_ = (# #) -> (c + a) <# (c + b)
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

{- | Add a constant to the right-hand side of both sides of
the strict inequality.
-}
incrementR ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a < b) ->
  (a + c < b + c)
{-# INLINE incrementR #-}
incrementR :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a < b) -> (a + c) < (b + c)
incrementR a < b
Lt = (a + c) < (b + c)
forall (a :: Nat) (b :: Nat). a < b
Lt

incrementR# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <# b) ->
  (a + c <# b + c)
{-# INLINE incrementR# #-}
incrementR# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
(a <# b) -> (a + c) <# (b + c)
incrementR# a <# b
_ = (# #) -> (a + c) <# (b + c)
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

{- | Subtract a constant from the left-hand side of both sides of
the inequality. This is the opposite of 'incrementL'.
-}
decrementL ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (c + a < c + b) ->
  (a < b)
{-# INLINE decrementL #-}
decrementL :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((c + a) < (c + b)) -> a < b
decrementL (c + a) < (c + b)
Lt = a < b
forall (a :: Nat) (b :: Nat). a < b
Lt

decrementL# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  ((c + a) <# (c + b)) ->
  (a <# b)
{-# INLINE decrementL# #-}
decrementL# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((c + a) <# (c + b)) -> a <# b
decrementL# (c + a) <# (c + b)
_ = (# #) -> a <# b
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

{- | Subtract a constant from the right-hand side of both sides of
the inequality. This is the opposite of 'incrementR'.
-}
decrementR ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a + c < b + c) ->
  (a < b)
{-# INLINE decrementR #-}
decrementR :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((a + c) < (b + c)) -> a < b
decrementR (a + c) < (b + c)
Lt = a < b
forall (a :: Nat) (b :: Nat). a < b
Lt

decrementR# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a + c <# b + c) ->
  (a <# b)
{-# INLINE decrementR# #-}
decrementR# :: forall (c :: Nat) (a :: Nat) (b :: Nat).
((a + c) <# (b + c)) -> a <# b
decrementR# (a + c) <# (b + c)
_ = (# #) -> a <# b
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

{- | Add a constant to the left-hand side of the right-hand side of
the strict inequality.
-}
weakenL ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a < b) ->
  (a < c + b)
{-# INLINE weakenL #-}
weakenL :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a < b) -> a < (c + b)
weakenL a < b
Lt = a < (c + b)
forall (a :: Nat) (b :: Nat). a < b
Lt

weakenL# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <# b) ->
  (a <# c + b)
{-# INLINE weakenL# #-}
weakenL# :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a <# b) -> a <# (c + b)
weakenL# a <# b
_ = (# #) -> a <# (c + b)
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

weakenLhsL# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (c + a <# b) ->
  (a <# b)
{-# INLINE weakenLhsL# #-}
weakenLhsL# :: forall (c :: Nat) (a :: Nat) (b :: Nat). ((c + a) <# b) -> a <# b
weakenLhsL# (c + a) <# b
_ = (# #) -> a <# b
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

weakenLhsR# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a + c <# b) ->
  (a <# b)
{-# INLINE weakenLhsR# #-}
weakenLhsR# :: forall (c :: Nat) (a :: Nat) (b :: Nat). ((a + c) <# b) -> a <# b
weakenLhsR# (a + c) <# b
_ = (# #) -> a <# b
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

{- | Add a constant to the right-hand side of the right-hand side of
the strict inequality.
-}
weakenR ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a < b) ->
  (a < b + c)
{-# INLINE weakenR #-}
weakenR :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a < b) -> a < (b + c)
weakenR a < b
Lt = a < (b + c)
forall (a :: Nat) (b :: Nat). a < b
Lt

weakenR# ::
  forall (c :: GHC.Nat) (a :: GHC.Nat) (b :: GHC.Nat).
  (a <# b) ->
  (a <# b + c)
{-# INLINE weakenR# #-}
weakenR# :: forall (c :: Nat) (a :: Nat) (b :: Nat). (a <# b) -> a <# (b + c)
weakenR# a <# b
_ = (# #) -> a <# (b + c)
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

-- | Compose two strict inequalities using transitivity.
transitive :: (a < b) -> (b < c) -> (a < c)
{-# INLINE transitive #-}
transitive :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a < b) -> (b < c) -> a < c
transitive a < b
Lt b < c
Lt = a < c
forall (a :: Nat) (b :: Nat). a < b
Lt

transitive# :: (a <# b) -> (b <# c) -> (a <# c)
{-# INLINE transitive# #-}
transitive# :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a <# b) -> (b <# c) -> a <# c
transitive# a <# b
_ b <# c
_ = (# #) -> a <# c
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

{- | Compose a strict inequality (the first argument) with a nonstrict
inequality (the second argument).
-}
transitiveNonstrictR :: (a < b) -> (b <= c) -> (a < c)
{-# INLINE transitiveNonstrictR #-}
transitiveNonstrictR :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a < b) -> (b <= c) -> a < c
transitiveNonstrictR a < b
Lt b <= c
Lte = a < c
forall (a :: Nat) (b :: Nat). a < b
Lt

transitiveNonstrictR# :: (a <# b) -> (b <=# c) -> (a <# c)
{-# INLINE transitiveNonstrictR# #-}
transitiveNonstrictR# :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a <# b) -> (b <=# c) -> a <# c
transitiveNonstrictR# a <# b
_ b <=# c
_ = (# #) -> a <# c
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

transitiveNonstrictL :: (a <= b) -> (b < c) -> (a < c)
{-# INLINE transitiveNonstrictL #-}
transitiveNonstrictL :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a <= b) -> (b < c) -> a < c
transitiveNonstrictL a <= b
Lte b < c
Lt = a < c
forall (a :: Nat) (b :: Nat). a < b
Lt

transitiveNonstrictL# :: (a <=# b) -> (b <# c) -> (a <# c)
{-# INLINE transitiveNonstrictL# #-}
transitiveNonstrictL# :: forall (a :: Nat) (b :: Nat) (c :: Nat).
(a <=# b) -> (b <# c) -> a <# c
transitiveNonstrictL# a <=# b
_ b <# c
_ = (# #) -> a <# c
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

-- | Zero is less than one.
zero :: 0 < 1
{-# INLINE zero #-}
zero :: 0 < 1
zero = 0 < 1
forall (a :: Nat) (b :: Nat). a < b
Lt

zero# :: (# #) -> 0 <# 1
{-# INLINE zero# #-}
zero# :: (# #) -> 0 <# 1
zero# (# #)
_ = (# #) -> 0 <# 1
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

-- | Nothing is less than zero.
absurd :: n < 0 -> void
{-# INLINE absurd #-}
absurd :: forall (n :: Nat) void. (n < 0) -> void
absurd n < 0
Lt = [Char] -> void
forall a. [Char] -> a
errorWithoutStackTrace [Char]
"Arithmetic.Nat.absurd: n < 0"

{- | Use GHC's built-in type-level arithmetic to prove
that one number is less than another. The type-checker
only reduces 'CmpNat' if both arguments are constants.
-}
constant :: forall a b. (CmpNat a b ~ 'LT) => (a < b)
{-# INLINE constant #-}
constant :: forall (a :: Nat) (b :: Nat). (CmpNat a b ~ 'LT) => a < b
constant = a < b
forall (a :: Nat) (b :: Nat). a < b
Lt

constant# :: forall a b. (CmpNat a b ~ 'LT) => (# #) -> (a <# b)
{-# INLINE constant# #-}
constant# :: forall (a :: Nat) (b :: Nat). (CmpNat a b ~ 'LT) => (# #) -> a <# b
constant# (# #)
_ = (# #) -> a <# b
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

-- | Given that @m < n/p@, we know that @p*m < n@.
reciprocalA ::
  forall (m :: GHC.Nat) (n :: GHC.Nat) (p :: GHC.Nat).
  (m < GHC.Div n p) ->
  (p GHC.* m) < n
{-# INLINE reciprocalA #-}
reciprocalA :: forall (m :: Nat) (n :: Nat) (p :: Nat).
(m < Div n p) -> (p * m) < n
reciprocalA m < Div n p
_ = (p * m) < n
forall (a :: Nat) (b :: Nat). a < b
Lt

-- | Given that @m < roundUp(n/p)@, we know that @p*m < n@.
reciprocalB ::
  forall (m :: GHC.Nat) (n :: GHC.Nat) (p :: GHC.Nat).
  (m < GHC.Div (n GHC.- 1) p + 1) ->
  (p GHC.* m) < n
{-# INLINE reciprocalB #-}
reciprocalB :: forall (m :: Nat) (n :: Nat) (p :: Nat).
(m < (Div (n - 1) p + 1)) -> (p * m) < n
reciprocalB m < (Div (n - 1) p + 1)
_ = (p * m) < n
forall (a :: Nat) (b :: Nat). a < b
Lt

unlift :: (a < b) -> (a <# b)
unlift :: forall (a :: Nat) (b :: Nat). (a < b) -> a <# b
unlift a < b
_ = (# #) -> a <# b
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #)

lift :: (a <# b) -> (a < b)
lift :: forall (a :: Nat) (b :: Nat). (a <# b) -> a < b
lift a <# b
_ = a < b
forall (a :: Nat) (b :: Nat). a < b
Lt