{-# language DataKinds #-}
{-# language MagicHash #-}
{-# language ExplicitNamespaces #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language RankNTypes #-}
{-# language TypeOperators #-}

module Arithmetic.Types
  ( Nat
  , Nat#
  , WithNat(..)
  , Difference(..)
  , Fin(..)
  , Fin#
  , type (<)
  , type (<=)
  , type (:=:)
  ) where

import Arithmetic.Unsafe (Fin#,Nat#,Nat(getNat), type (<=))
import Arithmetic.Unsafe (type (<), type (:=:))
import Data.Kind (type Type)
import GHC.TypeNats (type (+))

import qualified GHC.TypeNats as GHC

data WithNat :: (GHC.Nat -> Type) -> Type where
  WithNat ::
       {-# UNPACK #-} !(Nat n)
    -> f n
    -> WithNat f

-- | A finite set of 'n' elements. 'Fin n = { 0 .. n - 1 }'
data Fin :: GHC.Nat -> Type where
  Fin :: forall m n.
    { ()
index :: !(Nat m)
    , ()
proof :: !(m < n)
    } -> Fin n

-- | Proof that the first argument can be expressed as the
-- sum of the second argument and some other natural number.
data Difference :: GHC.Nat -> GHC.Nat -> Type where
  -- It is safe for users of this library to use this data constructor
  -- freely. However, note that the interesting Difference values come
  -- from Arithmetic.Nat.monus, which is a primitive.
  Difference :: forall a b c. Nat c -> (c + b :=: a) -> Difference a b

instance Show (Fin n) where
  showsPrec :: Int -> Fin n -> ShowS
showsPrec Int
p (Fin Nat m
i m < n
_) = String -> ShowS
showString String
"Fin " forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => Int -> a -> ShowS
showsPrec Int
p (forall (n :: Nat). Nat n -> Int
getNat Nat m
i)

instance Eq (Fin n) where
  Fin Nat m
x m < n
_ == :: Fin n -> Fin n -> Bool
== Fin Nat m
y m < n
_ = forall (n :: Nat). Nat n -> Int
getNat Nat m
x forall a. Eq a => a -> a -> Bool
== forall (n :: Nat). Nat n -> Int
getNat Nat m
y

instance Ord (Fin n) where
  Fin Nat m
x m < n
_ compare :: Fin n -> Fin n -> Ordering
`compare` Fin Nat m
y m < n
_ = forall a. Ord a => a -> a -> Ordering
compare (forall (n :: Nat). Nat n -> Int
getNat Nat m
x) (forall (n :: Nat). Nat n -> Int
getNat Nat m
y)