{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

-- |
-- Module      : GHC.TypeLits.Compare
-- Copyright   : (c) Justin Le 2024
-- License     : MIT
-- Maintainer  : justin@jle.im
-- Stability   : unstable
-- Portability : non-portable
--
--
-- This module provides the ability to refine given 'KnownNat' instances
-- using "GHC.TypeLits"'s comparison API, and also the ability to prove
-- inequalities and upper/lower limits.
--
-- If a library function requires @1 '<=' n@ constraint, but only
-- @'KnownNat' n@ is available:
--
-- @
-- foo :: (KnownNat n, 1 '<=' n) => 'Data.Proxy.Proxy' n -> Int
--
-- bar :: KnownNat n => Proxy n -> Int
-- bar n = case (Proxy :: Proxy 1) '%<=?' n of
--           'LE'  'Refl' -> foo n
--           'NLE' _    -> 0
-- @
--
-- @foo@ requires that @1 <= n@, but @bar@ has to handle all cases of @n@.
-- @%<=?@ lets you compare the 'KnownNat's in two 'Data.Proxy.Proxy's and returns
-- a @:<=?@, which has two constructors, 'LE' and 'NLE'.
--
-- If you pattern match on the result, in the 'LE' branch, the constraint
-- @1 <= n@ will be satisfied according to GHC, so @bar@ can safely call
-- @foo@, and GHC will recognize that @1 <= n@.
--
-- In the 'NLE' branch, the constraint that @1 > n@ is satisfied, so any
-- functions that require that constraint would be callable.
--
-- For convenience, 'isLE' and 'isNLE' are also offered:
--
-- @
-- bar :: KnownNat n => Proxy n -> Int
-- bar n = case 'isLE' (Proxy :: Proxy 1) n of
--           'Just' Refl -> foo n
--           'Nothing'   -> 0
-- @
--
-- Similarly, if a library function requires something involving 'CmpNat',
-- you can use 'cmpNat' and the 'SCmpNat' type:
--
-- @
-- foo1 :: (KnownNat n, 'CmpNat' 5 n ~ LT) => Proxy n -> Int
-- foo2 :: (KnownNat n, CmpNat 5 n ~ GT) => Proxy n -> Int
--
-- bar :: KnownNat n => Proxy n -> Int
-- bar n = case 'cmpNat' (Proxy :: Proxy 5) n of
--           'CLT' Refl -> foo1 n
--           'CEQ' Refl -> 0
--           'CGT' Refl -> foo2 n
-- @
--
-- You can use the 'Refl' that 'cmpNat' gives you with 'flipCmpNat' and
-- 'cmpNatLE' to "flip" the inequality or turn it into something compatible
-- with '<=?' (useful for when you have to work with libraries that mix the
-- two methods) or 'cmpNatEq' and 'eqCmpNat' to get to/from witnesses for
-- equality of the two 'Nat's.
--
-- This module is useful for helping bridge between libraries that use
-- different 'Nat'-based comparison systems in their type constraints.
module GHC.TypeLits.Compare (
  -- * '<=' and '<=?'
  (:<=?) (..),
  (%<=?),

  -- ** Convenience functions
  isLE,
  isNLE,

  -- * 'CmpNat'
  SCmpNat (..),
  GHC.TypeLits.Compare.cmpNat,

  -- ** Manipulating witnesses
  flipCmpNat,
  cmpNatEq,
  eqCmpNat,
  reflCmpNat,
  cmpNatLE,
  cmpNatGOrdering,
)
where

import Data.GADT.Compare
import Data.Kind
import Data.Type.Equality
import GHC.TypeLits (
  CmpNat,
  KnownNat,
  Nat,
  natVal,
  type (<=?),
 )
import Unsafe.Coerce

-- | Simplified version of '%<=?': check if @m@ is less than or equal to to
-- @n@.  If it is, match on @'Just' 'Refl'@ to get GHC to believe it,
-- within the body of the pattern match.
isLE ::
  (KnownNat m, KnownNat n) =>
  p m ->
  q n ->
  Maybe ((m <=? n) :~: 'True)
isLE :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> Maybe ((m <=? n) :~: 'True)
isLE p m
m q n
n = case p m
m p m -> q n -> m :<=? n
forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> m :<=? n
%<=? q n
n of
  LE (m <=? n) :~: 'True
Refl -> ('True :~: 'True) -> Maybe ('True :~: 'True)
forall a. a -> Maybe a
Just 'True :~: 'True
forall {k} (a :: k). a :~: a
Refl
  NLE (m <=? n) :~: 'False
_ (n <=? m) :~: 'True
_ -> Maybe (OrdCond (CmpNat m n) 'True 'True 'False :~: 'True)
Maybe ((m <=? n) :~: 'True)
forall a. Maybe a
Nothing

-- | Simplified version of '%<=?': check if @m@ is not less than or equal
-- to to @n@.  If it is, match on @'Just' 'Refl'@ to get GHC to believe it,
-- within the body of the pattern match.
isNLE ::
  (KnownNat m, KnownNat n) =>
  p m ->
  q n ->
  Maybe ((m <=? n) :~: 'False)
isNLE :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> Maybe ((m <=? n) :~: 'False)
isNLE p m
m q n
n = case p m
m p m -> q n -> m :<=? n
forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> m :<=? n
%<=? q n
n of
  NLE (m <=? n) :~: 'False
Refl (n <=? m) :~: 'True
Refl -> ('False :~: 'False) -> Maybe ('False :~: 'False)
forall a. a -> Maybe a
Just 'False :~: 'False
forall {k} (a :: k). a :~: a
Refl
  LE (m <=? n) :~: 'True
_ -> Maybe (OrdCond (CmpNat m n) 'True 'True 'False :~: 'False)
Maybe ((m <=? n) :~: 'False)
forall a. Maybe a
Nothing

-- | Two possible ordered relationships between two natural numbers.
data (:<=?) :: Nat -> Nat -> Type where
  LE :: ((m <=? n) :~: 'True) -> m :<=? n
  NLE :: ((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> m :<=? n

-- | Compare @m@ and @n@, classifying their relationship into some
-- constructor of ':<=?'.
(%<=?) ::
  (KnownNat m, KnownNat n) =>
  p m ->
  q n ->
  (m :<=? n)
p m
m %<=? :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> m :<=? n
%<=? q n
n
  | p m -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal p m
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= q n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal q n
n = ((m <=? n) :~: 'True) -> m :<=? n
forall (m :: Nat) (n :: Nat). ((m <=? n) :~: 'True) -> m :<=? n
LE ((Any :~: Any) -> OrdCond (CmpNat m n) 'True 'True 'False :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
  | Bool
otherwise = ((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> m :<=? n
forall (m :: Nat) (n :: Nat).
((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> m :<=? n
NLE ((Any :~: Any) -> OrdCond (CmpNat m n) 'True 'True 'False :~: 'False
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl) ((Any :~: Any) -> OrdCond (CmpNat n m) 'True 'True 'False :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)

-- | Three possible ordered relationships between two natural numbers.
data SCmpNat :: Nat -> Nat -> Type where
  CLT :: (CmpNat m n :~: 'LT) -> SCmpNat m n
  CEQ :: (CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
  CGT :: (CmpNat m n :~: 'GT) -> SCmpNat m n

-- | Compare @m@ and @n@, classifying their relationship into some
-- constructor of 'SCmpNat'.
cmpNat ::
  (KnownNat m, KnownNat n) =>
  p m ->
  q n ->
  SCmpNat m n
cmpNat :: forall (m :: Nat) (n :: Nat) (p :: Nat -> *) (q :: Nat -> *).
(KnownNat m, KnownNat n) =>
p m -> q n -> SCmpNat m n
cmpNat p m
m q n
n = case Integer -> Integer -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (p m -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal p m
m) (q n -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal q n
n) of
  Ordering
LT -> (CmpNat m n :~: 'LT) -> SCmpNat m n
forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'LT) -> SCmpNat m n
CLT ((Any :~: Any) -> CmpNat m n :~: 'LT
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
  Ordering
EQ -> (CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
forall (m :: Nat) (n :: Nat).
(CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
CEQ ((Any :~: Any) -> CmpNat m n :~: 'EQ
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl) ((Any :~: Any) -> m :~: n
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
  Ordering
GT -> (CmpNat m n :~: 'GT) -> SCmpNat m n
forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'GT) -> SCmpNat m n
CGT ((Any :~: Any) -> CmpNat m n :~: 'GT
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)

-- | Flip an inequality.
flipCmpNat :: SCmpNat m n -> SCmpNat n m
flipCmpNat :: forall (m :: Nat) (n :: Nat). SCmpNat m n -> SCmpNat n m
flipCmpNat = \case
  CLT CmpNat m n :~: 'LT
Refl -> (CmpNat n m :~: 'GT) -> SCmpNat n m
forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'GT) -> SCmpNat m n
CGT ((Any :~: Any) -> CmpNat n m :~: 'GT
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
  CEQ CmpNat m n :~: 'EQ
Refl m :~: n
Refl -> (CmpNat n m :~: 'EQ) -> (n :~: m) -> SCmpNat n m
forall (m :: Nat) (n :: Nat).
(CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
CEQ ((Any :~: Any) -> 'EQ :~: 'EQ
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl) n :~: m
n :~: n
forall {k} (a :: k). a :~: a
Refl
  CGT CmpNat m n :~: 'GT
Refl -> (CmpNat n m :~: 'LT) -> SCmpNat n m
forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'LT) -> SCmpNat m n
CLT ((Any :~: Any) -> CmpNat n m :~: 'LT
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)

-- | @'CmpNat' m n@ being 'EQ' implies that @m@ is equal to @n@.
cmpNatEq :: (CmpNat m n :~: 'EQ) -> (m :~: n)
cmpNatEq :: forall (m :: Nat) (n :: Nat). (CmpNat m n :~: 'EQ) -> m :~: n
cmpNatEq = \case CmpNat m n :~: 'EQ
Refl -> (Any :~: Any) -> m :~: n
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

-- | A witness of equality implies that @'CmpNat' m n@ is 'Eq'.
eqCmpNat :: (m :~: n) -> (CmpNat m n :~: 'EQ)
eqCmpNat :: forall (m :: Nat) (n :: Nat). (m :~: n) -> CmpNat m n :~: 'EQ
eqCmpNat = \case m :~: n
Refl -> (Any :~: Any) -> 'EQ :~: 'EQ
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl

-- | Inject a witness of equality into an 'SCmpNat' at 'CEQ'.
reflCmpNat :: (m :~: n) -> SCmpNat m n
reflCmpNat :: forall (m :: Nat) (n :: Nat). (m :~: n) -> SCmpNat m n
reflCmpNat m :~: n
r = (CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
forall (m :: Nat) (n :: Nat).
(CmpNat m n :~: 'EQ) -> (m :~: n) -> SCmpNat m n
CEQ ((m :~: n) -> CmpNat m n :~: 'EQ
forall (m :: Nat) (n :: Nat). (m :~: n) -> CmpNat m n :~: 'EQ
eqCmpNat m :~: n
r) m :~: n
r

-- | Convert to ':<=?'
cmpNatLE :: SCmpNat m n -> (m :<=? n)
cmpNatLE :: forall (m :: Nat) (n :: Nat). SCmpNat m n -> m :<=? n
cmpNatLE = \case
  CLT CmpNat m n :~: 'LT
Refl -> ((m <=? n) :~: 'True) -> m :<=? n
forall (m :: Nat) (n :: Nat). ((m <=? n) :~: 'True) -> m :<=? n
LE ((Any :~: Any) -> 'True :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
  CEQ CmpNat m n :~: 'EQ
Refl m :~: n
Refl -> ((m <=? n) :~: 'True) -> m :<=? n
forall (m :: Nat) (n :: Nat). ((m <=? n) :~: 'True) -> m :<=? n
LE ((Any :~: Any) -> 'True :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)
  CGT CmpNat m n :~: 'GT
Refl -> ((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> m :<=? n
forall (m :: Nat) (n :: Nat).
((m <=? n) :~: 'False) -> ((n <=? m) :~: 'True) -> m :<=? n
NLE ((Any :~: Any) -> 'False :~: 'False
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl) ((Any :~: Any) -> OrdCond (CmpNat n m) 'True 'True 'False :~: 'True
forall a b. a -> b
unsafeCoerce Any :~: Any
forall {k} (a :: k). a :~: a
Refl)

-- | Convert to 'GOrdering'
--
-- @since 0.4.0.0
cmpNatGOrdering :: SCmpNat n m -> GOrdering n m
cmpNatGOrdering :: forall (n :: Nat) (m :: Nat). SCmpNat n m -> GOrdering n m
cmpNatGOrdering = \case
  CLT CmpNat n m :~: 'LT
Refl -> GOrdering n m
forall {k} (a :: k) (b :: k). GOrdering a b
GLT
  CEQ CmpNat n m :~: 'EQ
Refl n :~: m
Refl -> GOrdering n n
GOrdering n m
forall {k} (a :: k). GOrdering a a
GEQ
  CGT CmpNat n m :~: 'GT
Refl -> GOrdering n m
forall {k} (a :: k) (b :: k). GOrdering a b
GGT