{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UnboxedTuples #-}

module Arithmetic.Nat
  ( -- * Addition
    plus
  , plus#

    -- * Subtraction
  , monus

    -- * Division
  , divide
  , divideRoundingUp

    -- * Multiplication
  , times

    -- * Successor
  , succ
  , succ#

    -- * Compare
  , testEqual
  , testEqual#
  , testLessThan
  , testLessThan#
  , testLessThanEqual
  , testZero
  , testZero#
  , (=?)
  , (<?)
  , (<?#)
  , (<=?)

    -- * Constants
  , zero
  , one
  , two
  , three
  , constant
  , constant#

    -- * Unboxed Constants
  , zero#
  , one#

    -- * Unboxed Pattern Synonyms
  , pattern N0#
  , pattern N1#
  , pattern N2#
  , pattern N3#
  , pattern N4#
  , pattern N5#
  , pattern N6#
  , pattern N7#
  , pattern N8#
  , pattern N16#
  , pattern N32#
  , pattern N64#
  , pattern N128#
  , pattern N256#
  , pattern N512#
  , pattern N1024#
  , pattern N2048#
  , pattern N4096#

    -- * Convert
  , demote
  , demote#
  , unlift
  , lift
  , with
  , with#
  ) where

import Prelude hiding (succ)

import Arithmetic.Types
import Arithmetic.Unsafe (Nat (Nat), Nat# (Nat#), (:=:) (Eq), (:=:#) (Eq#), type (<) (Lt), type (<#) (Lt#), type (<=) (Lte))
import Data.Either.Void (EitherVoid#, pattern LeftVoid#, pattern RightVoid#)
import Data.Maybe.Void (MaybeVoid#, pattern JustVoid#, pattern NothingVoid#)
import GHC.Exts (Int#, Proxy#, proxy#, (+#), (<#), (==#))
import GHC.Int (Int (I#))
import GHC.TypeNats (Div, KnownNat, natVal', type (+), type (-))

import qualified GHC.TypeNats as GHC

-- | Infix synonym of 'testLessThan'.
(<?) :: Nat a -> Nat b -> Maybe (a < b)
{-# INLINE (<?) #-}
<? :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
(<?) = Nat a -> Nat b -> Maybe (a < b)
forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
testLessThan

-- | Infix synonym of 'testLessThanEqual'.
(<=?) :: Nat a -> Nat b -> Maybe (a <= b)
{-# INLINE (<=?) #-}
<=? :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a <= b)
(<=?) = Nat a -> Nat b -> Maybe (a <= b)
forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a <= b)
testLessThanEqual

-- | Infix synonym of 'testEqual'.
(=?) :: Nat a -> Nat b -> Maybe (a :=: b)
{-# INLINE (=?) #-}
=? :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a :=: b)
(=?) = Nat a -> Nat b -> Maybe (a :=: b)
forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a :=: b)
testEqual

(<?#) :: Nat# a -> Nat# b -> MaybeVoid# (a <# b)
{-# INLINE (<?#) #-}
<?# :: forall (a :: Nat) (b :: Nat).
Nat# a -> Nat# b -> MaybeVoid# (a <# b)
(<?#) = Nat# a -> Nat# b -> MaybeVoid# (a <# b)
forall (a :: Nat) (b :: Nat).
Nat# a -> Nat# b -> MaybeVoid# (a <# b)
testLessThan#

{- | Is the first argument strictly less than the second
argument?
-}
testLessThan :: Nat a -> Nat b -> Maybe (a < b)
{-# INLINE testLessThan #-}
testLessThan :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
testLessThan (Nat Int
x) (Nat Int
y) =
  if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
y
    then (a < b) -> Maybe (a < b)
forall a. a -> Maybe a
Just a < b
forall (a :: Nat) (b :: Nat). a < b
Lt
    else Maybe (a < b)
forall a. Maybe a
Nothing

testLessThan# :: Nat# a -> Nat# b -> MaybeVoid# (a <# b)
{-# INLINE testLessThan# #-}
testLessThan# :: forall (a :: Nat) (b :: Nat).
Nat# a -> Nat# b -> MaybeVoid# (a <# b)
testLessThan# (Nat# Int#
x) (Nat# Int#
y) = case Int#
x Int# -> Int# -> Int#
<# Int#
y of
  Int#
0# -> (# #) -> forall (a :: ZeroBitType). MaybeVoid# a
forall (a :: ZeroBitType). MaybeVoid# a
NothingVoid#
  Int#
_ -> (a <# b) -> MaybeVoid# (a <# b)
forall (a :: ZeroBitType). a -> MaybeVoid# a
JustVoid# ((# #) -> a <# b
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #))

{- | Is the first argument less-than-or-equal-to the second
argument?
-}
testLessThanEqual :: Nat a -> Nat b -> Maybe (a <= b)
{-# INLINE testLessThanEqual #-}
testLessThanEqual :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a <= b)
testLessThanEqual (Nat Int
x) (Nat Int
y) =
  if Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
y
    then (a <= b) -> Maybe (a <= b)
forall a. a -> Maybe a
Just a <= b
forall (a :: Nat) (b :: Nat). a <= b
Lte
    else Maybe (a <= b)
forall a. Maybe a
Nothing

-- | Are the two arguments equal to one another?
testEqual :: Nat a -> Nat b -> Maybe (a :=: b)
{-# INLINE testEqual #-}
testEqual :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a :=: b)
testEqual (Nat Int
x) (Nat Int
y) =
  if Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y
    then (a :=: b) -> Maybe (a :=: b)
forall a. a -> Maybe a
Just a :=: b
forall (a :: Nat) (b :: Nat). a :=: b
Eq
    else Maybe (a :=: b)
forall a. Maybe a
Nothing

testEqual# :: Nat# a -> Nat# b -> MaybeVoid# (a :=:# b)
{-# INLINE testEqual# #-}
testEqual# :: forall (a :: Nat) (b :: Nat).
Nat# a -> Nat# b -> MaybeVoid# (a :=:# b)
testEqual# (Nat# Int#
x) (Nat# Int#
y) = case Int#
x Int# -> Int# -> Int#
==# Int#
y of
  Int#
0# -> (# #) -> forall (a :: ZeroBitType). MaybeVoid# a
forall (a :: ZeroBitType). MaybeVoid# a
NothingVoid#
  Int#
_ -> (a :=:# b) -> MaybeVoid# (a :=:# b)
forall (a :: ZeroBitType). a -> MaybeVoid# a
JustVoid# ((# #) -> a :=:# b
forall (a :: Nat) (b :: Nat). (# #) -> a :=:# b
Eq# (# #))

-- | Is zero equal to this number or less than it?
testZero :: Nat a -> Either (0 :=: a) (0 < a)
{-# INLINE testZero #-}
testZero :: forall (a :: Nat). Nat a -> Either (0 :=: a) (0 < a)
testZero (Nat Int
x) = case Int
x of
  Int
0 -> (0 :=: a) -> Either (0 :=: a) (0 < a)
forall a b. a -> Either a b
Left 0 :=: a
forall (a :: Nat) (b :: Nat). a :=: b
Eq
  Int
_ -> (0 < a) -> Either (0 :=: a) (0 < a)
forall a b. b -> Either a b
Right 0 < a
forall (a :: Nat) (b :: Nat). a < b
Lt

testZero# :: Nat# a -> EitherVoid# (0 :=:# a) (0 <# a)
testZero# :: forall (a :: Nat). Nat# a -> EitherVoid# (0 :=:# a) (0 <# a)
testZero# (Nat# Int#
x) = case Int#
x of
  Int#
0# -> (0 :=:# a) -> EitherVoid# (0 :=:# a) (0 <# a)
forall (a :: ZeroBitType) (b :: ZeroBitType). a -> EitherVoid# a b
LeftVoid# ((# #) -> 0 :=:# a
forall (a :: Nat) (b :: Nat). (# #) -> a :=:# b
Eq# (# #))
  Int#
_ -> (0 <# a) -> EitherVoid# (0 :=:# a) (0 <# a)
forall (b :: ZeroBitType) (a :: ZeroBitType). b -> EitherVoid# a b
RightVoid# ((# #) -> 0 <# a
forall (a :: Nat) (b :: Nat). (# #) -> a <# b
Lt# (# #))

-- | Add two numbers.
plus :: Nat a -> Nat b -> Nat (a + b)
{-# INLINE plus #-}
plus :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
plus (Nat Int
x) (Nat Int
y) = Int -> Nat (a + b)
forall (n :: Nat). Int -> Nat n
Nat (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
y)

-- | Variant of 'plus' for unboxed nats.
plus# :: Nat# a -> Nat# b -> Nat# (a + b)
{-# INLINE plus# #-}
plus# :: forall (a :: Nat) (b :: Nat). Nat# a -> Nat# b -> Nat# (a + b)
plus# (Nat# Int#
x) (Nat# Int#
y) = Int# -> Nat# (a + b)
forall (a :: Nat). Int# -> Nat# a
Nat# (Int#
x Int# -> Int# -> Int#
+# Int#
y)

-- | Divide two numbers. Rounds down (towards zero)
divide :: Nat a -> Nat b -> Nat (Div a b)
{-# INLINE divide #-}
divide :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (Div a b)
divide (Nat Int
x) (Nat Int
y) = Int -> Nat (Div a b)
forall (n :: Nat). Int -> Nat n
Nat (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
x Int
y)

-- | Divide two numbers. Rounds up (away from zero)
divideRoundingUp :: Nat a -> Nat b -> Nat (Div (a - 1) b + 1)
{-# INLINE divideRoundingUp #-}
divideRoundingUp :: forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Nat (Div (a - 1) b + 1)
divideRoundingUp (Nat Int
x) (Nat Int
y) =
  -- Implementation note. We must use div so that when x=0,
  -- the result is (-1) and not 0. Then when we add 1, we get 0.
  Int -> Nat (Div (a - 1) b + 1)
forall (n :: Nat). Int -> Nat n
Nat (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Int
y))

-- | Multiply two numbers.
times :: Nat a -> Nat b -> Nat (a GHC.* b)
{-# INLINE times #-}
times :: forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a * b)
times (Nat Int
x) (Nat Int
y) = Int -> Nat (a * b)
forall (n :: Nat). Int -> Nat n
Nat (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
y)

-- | The successor of a number.
succ :: Nat a -> Nat (a + 1)
{-# INLINE succ #-}
succ :: forall (a :: Nat). Nat a -> Nat (a + 1)
succ Nat a
n = Nat a -> Nat 1 -> Nat (a + 1)
forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
plus Nat a
n Nat 1
one

-- | Unlifted variant of 'succ'.
succ# :: Nat# a -> Nat# (a + 1)
{-# INLINE succ# #-}
succ# :: forall (a :: Nat). Nat# a -> Nat# (a + 1)
succ# Nat# a
n = Nat# a -> Nat# 1 -> Nat# (a + 1)
forall (a :: Nat) (b :: Nat). Nat# a -> Nat# b -> Nat# (a + b)
plus# Nat# a
n ((# #) -> Nat# 1
one# (# #))

-- | Subtract the second argument from the first argument.
monus :: Nat a -> Nat b -> Maybe (Difference a b)
{-# INLINE monus #-}
monus :: forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
monus (Nat Int
a) (Nat Int
b) =
  let c :: Int
c = Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
b
   in if Int
c Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0
        then Difference a b -> Maybe (Difference a b)
forall a. a -> Maybe a
Just (Nat Any -> ((Any + b) :=: a) -> Difference a b
forall (a :: Nat) (b :: Nat) (c :: Nat).
Nat c -> ((c + b) :=: a) -> Difference a b
Difference (Int -> Nat Any
forall (n :: Nat). Int -> Nat n
Nat Int
c) (Any + b) :=: a
forall (a :: Nat) (b :: Nat). a :=: b
Eq)
        else Maybe (Difference a b)
forall a. Maybe a
Nothing

-- | The number zero.
zero :: Nat 0
{-# INLINE zero #-}
zero :: Nat 0
zero = Int -> Nat 0
forall (n :: Nat). Int -> Nat n
Nat Int
0

-- | The number one.
one :: Nat 1
{-# INLINE one #-}
one :: Nat 1
one = Int -> Nat 1
forall (n :: Nat). Int -> Nat n
Nat Int
1

-- | The number two.
two :: Nat 2
{-# INLINE two #-}
two :: Nat 2
two = Int -> Nat 2
forall (n :: Nat). Int -> Nat n
Nat Int
2

-- | The number three.
three :: Nat 3
{-# INLINE three #-}
three :: Nat 3
three = Int -> Nat 3
forall (n :: Nat). Int -> Nat n
Nat Int
3

{- | Use GHC's built-in type-level arithmetic to create a witness
of a type-level number. This only reduces if the number is a
constant.
-}
constant :: forall n. (KnownNat n) => Nat n
{-# INLINE constant #-}
constant :: forall (n :: Nat). KnownNat n => Nat n
constant = Int -> Nat n
forall (n :: Nat). Int -> Nat n
Nat (Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# n -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (Proxy# n
forall {k} (a :: k). Proxy# a
proxy# :: Proxy# n)))

constant# :: forall n. (KnownNat n) => (# #) -> Nat# n
{-# INLINE constant# #-}
constant# :: forall (n :: Nat). KnownNat n => (# #) -> Nat# n
constant# (# #)
_ = case Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Proxy# n -> Nat
forall (n :: Nat). KnownNat n => Proxy# n -> Nat
natVal' (Proxy# n
forall {k} (a :: k). Proxy# a
proxy# :: Proxy# n)) of
  I# Int#
i -> Int# -> Nat# n
forall (a :: Nat). Int# -> Nat# a
Nat# Int#
i

-- | The number zero. Unboxed.
zero# :: (# #) -> Nat# 0
zero# :: (# #) -> Nat# 0
zero# (# #)
_ = Int# -> Nat# 0
forall (a :: Nat). Int# -> Nat# a
Nat# Int#
0#

-- | The number one. Unboxed.
one# :: (# #) -> Nat# 1
one# :: (# #) -> Nat# 1
one# (# #)
_ = Int# -> Nat# 1
forall (a :: Nat). Int# -> Nat# a
Nat# Int#
1#

{- | Extract the 'Int' from a 'Nat'. This is intended to be used
at a boundary where a safe interface meets the unsafe primitives
on top of which it is built.
-}
demote :: Nat n -> Int
{-# INLINE demote #-}
demote :: forall (n :: Nat). Nat n -> Int
demote (Nat Int
n) = Int
n

demote# :: Nat# n -> Int#
{-# INLINE demote# #-}
demote# :: forall (n :: Nat). Nat# n -> Int#
demote# (Nat# Int#
n) = Int#
n

{- | Run a computation on a witness of a type-level number. The
argument 'Int' must be greater than or equal to zero. This is
not checked. Failure to upload this invariant will lead to a
segfault.
-}
with :: Int -> (forall n. Nat n -> a) -> a
{-# INLINE with #-}
with :: forall a. Int -> (forall (n :: Nat). Nat n -> a) -> a
with Int
i forall (n :: Nat). Nat n -> a
f = Nat Any -> a
forall (n :: Nat). Nat n -> a
f (Int -> Nat Any
forall (n :: Nat). Int -> Nat n
Nat Int
i)

with# :: Int# -> (forall n. Nat# n -> a) -> a
{-# INLINE with# #-}
with# :: forall a. Int# -> (forall (n :: Nat). Nat# n -> a) -> a
with# Int#
i forall (n :: Nat). Nat# n -> a
f = Nat# Any -> a
forall (n :: Nat). Nat# n -> a
f (Int# -> Nat# Any
forall (a :: Nat). Int# -> Nat# a
Nat# Int#
i)

unlift :: Nat n -> Nat# n
{-# INLINE unlift #-}
unlift :: forall (n :: Nat). Nat n -> Nat# n
unlift (Nat (I# Int#
i)) = Int# -> Nat# n
forall (a :: Nat). Int# -> Nat# a
Nat# Int#
i

lift :: Nat# n -> Nat n
{-# INLINE lift #-}
lift :: forall (n :: Nat). Nat# n -> Nat n
lift (Nat# Int#
i) = Int -> Nat n
forall (n :: Nat). Int -> Nat n
Nat (Int# -> Int
I# Int#
i)

pattern N0# :: Nat# 0
pattern $mN0# :: forall {r}. Nat# 0 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN0# :: (# #) -> Nat# 0
N0# = Nat# 0#

pattern N1# :: Nat# 1
pattern $mN1# :: forall {r}. Nat# 1 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN1# :: (# #) -> Nat# 1
N1# = Nat# 1#

pattern N2# :: Nat# 2
pattern $mN2# :: forall {r}. Nat# 2 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN2# :: (# #) -> Nat# 2
N2# = Nat# 2#

pattern N3# :: Nat# 3
pattern $mN3# :: forall {r}. Nat# 3 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN3# :: (# #) -> Nat# 3
N3# = Nat# 3#

pattern N4# :: Nat# 4
pattern $mN4# :: forall {r}. Nat# 4 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN4# :: (# #) -> Nat# 4
N4# = Nat# 4#

pattern N5# :: Nat# 5
pattern $mN5# :: forall {r}. Nat# 5 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN5# :: (# #) -> Nat# 5
N5# = Nat# 5#

pattern N6# :: Nat# 6
pattern $mN6# :: forall {r}. Nat# 6 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN6# :: (# #) -> Nat# 6
N6# = Nat# 6#

pattern N7# :: Nat# 7
pattern $mN7# :: forall {r}. Nat# 7 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN7# :: (# #) -> Nat# 7
N7# = Nat# 7#

pattern N8# :: Nat# 8
pattern $mN8# :: forall {r}. Nat# 8 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN8# :: (# #) -> Nat# 8
N8# = Nat# 8#

pattern N16# :: Nat# 16
pattern $mN16# :: forall {r}. Nat# 16 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN16# :: (# #) -> Nat# 16
N16# = Nat# 16#

pattern N32# :: Nat# 32
pattern $mN32# :: forall {r}. Nat# 32 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN32# :: (# #) -> Nat# 32
N32# = Nat# 32#

pattern N64# :: Nat# 64
pattern $mN64# :: forall {r}. Nat# 64 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN64# :: (# #) -> Nat# 64
N64# = Nat# 64#

pattern N128# :: Nat# 128
pattern $mN128# :: forall {r}. Nat# 128 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN128# :: (# #) -> Nat# 128
N128# = Nat# 128#

pattern N256# :: Nat# 256
pattern $mN256# :: forall {r}. Nat# 256 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN256# :: (# #) -> Nat# 256
N256# = Nat# 256#

pattern N512# :: Nat# 512
pattern $mN512# :: forall {r}. Nat# 512 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN512# :: (# #) -> Nat# 512
N512# = Nat# 512#

pattern N1024# :: Nat# 1024
pattern $mN1024# :: forall {r}. Nat# 1024 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN1024# :: (# #) -> Nat# 1024
N1024# = Nat# 1024#

pattern N2048# :: Nat# 2048
pattern $mN2048# :: forall {r}. Nat# 2048 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN2048# :: (# #) -> Nat# 2048
N2048# = Nat# 2048#

pattern N4096# :: Nat# 4096
pattern $mN4096# :: forall {r}. Nat# 4096 -> ((# #) -> r) -> ((# #) -> r) -> r
$bN4096# :: (# #) -> Nat# 4096
N4096# = Nat# 4096#