{-|
Copyright  :  (C) 2013-2016, University of Twente,
                  2016     , Myrtle Software Ltd
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TemplateHaskell #-}

{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise       #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module Clash.Promoted.Nat
  ( -- * Singleton natural numbers
    -- ** Data type
    SNat (..)
    -- ** Construction
  , snatProxy
  , withSNat
    -- ** Conversion
  , snatToInteger, snatToNatural, snatToNum
    -- ** Conversion (ambiguous types)
  , natToInteger, natToNatural, natToNum
    -- ** Arithmetic
  , addSNat, mulSNat, powSNat, minSNat, maxSNat, succSNat
    -- *** Partial
  , subSNat, divSNat, modSNat, flogBaseSNat, clogBaseSNat, logBaseSNat, predSNat
    -- *** Specialised
  , pow2SNat
    -- *** Comparison
  , SNatLE (..), compareSNat
    -- * Unary/Peano-encoded natural numbers
    -- ** Data type
  , UNat (..)
    -- ** Construction
  , toUNat
    -- ** Conversion
  , fromUNat
    -- ** Arithmetic
  , addUNat, mulUNat, powUNat
    -- *** Partial
  , predUNat, subUNat
    -- * Base-2 encoded natural numbers
    -- ** Data type
  , BNat (..)
    -- ** Construction
  , toBNat
    -- ** Conversion
  , fromBNat
    -- ** Pretty printing base-2 encoded natural numbers
  , showBNat
    -- ** Arithmetic
  , succBNat, addBNat, mulBNat, powBNat
    -- *** Partial
  , predBNat, div2BNat, div2Sub1BNat, log2BNat
    -- ** Normalisation
  , stripZeros
    -- * Constraints on natural numbers
  , leToPlus
  , leToPlusKN
  )
where

import Data.Kind          (Type)
import GHC.Show           (appPrec)
import GHC.TypeLits       (KnownNat, Nat, type (+), type (-), type (*),
                           type (^), type (<=), natVal)
import GHC.TypeLits.Extra (CLog, FLog, Div, Log, Mod, Min, Max)
import GHC.Natural        (naturalFromInteger)
import Language.Haskell.TH (appT, conT, litT, numTyLit, sigE)
import Language.Haskell.TH.Syntax (Lift (..))
#if MIN_VERSION_template_haskell(2,16,0)
import Language.Haskell.TH.Compat
#endif
import Numeric.Natural    (Natural)
import Unsafe.Coerce      (unsafeCoerce)
import Clash.XException   (ShowX (..), showsPrecXWith)

{- $setup
>>> :set -XBinaryLiterals
>>> import Clash.Promoted.Nat.Literals (d789)
-}

-- | Singleton value for a type-level natural number @n@
--
-- * "Clash.Promoted.Nat.Literals" contains a list of predefined 'SNat' literals
-- * "Clash.Promoted.Nat.TH" has functions to easily create large ranges of new
--   'SNat' literals
data SNat (n :: Nat) where
  SNat :: KnownNat n => SNat n

instance Lift (SNat n) where
  lift :: SNat n -> Q Exp
lift SNat n
s = Q Exp -> TypeQ -> Q Exp
sigE [| SNat |]
                (TypeQ -> TypeQ -> TypeQ
appT (Name -> TypeQ
conT ''SNat) (TyLitQ -> TypeQ
litT (TyLitQ -> TypeQ) -> TyLitQ -> TypeQ
forall a b. (a -> b) -> a -> b
$ Integer -> TyLitQ
numTyLit (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
s)))
#if MIN_VERSION_template_haskell(2,16,0)
  liftTyped :: SNat n -> Q (TExp (SNat n))
liftTyped = SNat n -> Q (TExp (SNat n))
forall a. Lift a => a -> Q (TExp a)
liftTypedFromUntyped
#endif

-- | Create an @`SNat` n@ from a proxy for /n/
snatProxy :: KnownNat n => proxy n -> SNat n
snatProxy :: proxy n -> SNat n
snatProxy proxy n
_ = SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat

instance Show (SNat n) where
  showsPrec :: Int -> SNat n -> ShowS
showsPrec Int
d p :: SNat n
p@SNat n
SNat | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
1024 = Char -> ShowS
showChar Char
'd' ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ShowS
forall a. Show a => a -> ShowS
shows Integer
n
                     | Bool
otherwise = Bool -> ShowS -> ShowS
showParen (Int
d Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
appPrec) (ShowS -> ShowS) -> ShowS -> ShowS
forall a b. (a -> b) -> a -> b
$
                                     String -> ShowS
showString String
"SNat @" ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> ShowS
forall a. Show a => a -> ShowS
shows Integer
n
   where
    n :: Integer
n = SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
p

instance ShowX (SNat n) where
  showsPrecX :: Int -> SNat n -> ShowS
showsPrecX = (Int -> SNat n -> ShowS) -> Int -> SNat n -> ShowS
forall a. (Int -> a -> ShowS) -> Int -> a -> ShowS
showsPrecXWith Int -> SNat n -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec

{-# INLINE withSNat #-}
-- | Supply a function with a singleton natural @n@ according to the context
withSNat :: KnownNat n => (SNat n -> a) -> a
withSNat :: (SNat n -> a) -> a
withSNat SNat n -> a
f = SNat n -> a
f SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat

-- | Same as 'snatToInteger' and 'GHC.TypeLits.natVal', but doesn't take term
-- arguments. Example usage:
--
-- >>> natToInteger @5
-- 5
natToInteger :: forall n . KnownNat n => Integer
natToInteger :: Integer
natToInteger = SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger (KnownNat n => SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat @n)
{-# INLINE natToInteger #-}

-- | Reify the type-level 'Nat' @n@ to it's term-level 'Integer' representation.
snatToInteger :: SNat n -> Integer
snatToInteger :: SNat n -> Integer
snatToInteger p :: SNat n
p@SNat n
SNat = SNat n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal SNat n
p
{-# INLINE snatToInteger #-}

-- | Same as 'snatToNatural' and 'GHC.TypeNats.natVal', but doesn't take term
-- arguments. Example usage:
--
-- >>> natToNatural @5
-- 5
natToNatural :: forall n . KnownNat n => Natural
natToNatural :: Natural
natToNatural = SNat n -> Natural
forall (n :: Nat). SNat n -> Natural
snatToNatural (KnownNat n => SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat @n)
{-# INLINE natToNatural #-}

-- | Reify the type-level 'Nat' @n@ to it's term-level 'Natural'.
snatToNatural :: SNat n -> Natural
snatToNatural :: SNat n -> Natural
snatToNatural = Integer -> Natural
naturalFromInteger (Integer -> Natural) -> (SNat n -> Integer) -> SNat n -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger
{-# INLINE snatToNatural #-}

-- | Same as 'snatToNum', but doesn't take term arguments. Example usage:
--
-- >>> natToNum @5 @Int
-- 5
natToNum :: forall n a . (Num a, KnownNat n) => a
natToNum :: a
natToNum = SNat n -> a
forall a (n :: Nat). Num a => SNat n -> a
snatToNum (KnownNat n => SNat n
forall (n :: Nat). KnownNat n => SNat n
SNat @n)
{-# INLINE natToNum #-}

-- | Reify the type-level 'Nat' @n@ to it's term-level 'Num'ber.
snatToNum :: forall a n . Num a => SNat n -> a
snatToNum :: SNat n -> a
snatToNum p :: SNat n
p@SNat n
SNat = Integer -> a
forall a. Num a => Integer -> a
fromInteger (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
p)
{-# INLINE snatToNum #-}

-- | Unary representation of a type-level natural
--
-- __NB__: Not synthesizable
data UNat :: Nat -> Type where
  UZero :: UNat 0
  USucc :: UNat n -> UNat (n + 1)

instance KnownNat n => Show (UNat n) where
  show :: UNat n -> String
show UNat n
x = Char
'u'Char -> ShowS
forall a. a -> [a] -> [a]
:Integer -> String
forall a. Show a => a -> String
show (UNat n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal UNat n
x)

instance KnownNat n => ShowX (UNat n) where
  showsPrecX :: Int -> UNat n -> ShowS
showsPrecX = (Int -> UNat n -> ShowS) -> Int -> UNat n -> ShowS
forall a. (Int -> a -> ShowS) -> Int -> a -> ShowS
showsPrecXWith Int -> UNat n -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec

-- | Convert a singleton natural number to its unary representation
--
-- __NB__: Not synthesizable
toUNat :: forall n . SNat n -> UNat n
toUNat :: SNat n -> UNat n
toUNat p :: SNat n
p@SNat n
SNat = Integer -> UNat n
forall (m :: Nat). Integer -> UNat m
fromI @n (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
p)
  where
    fromI :: forall m . Integer -> UNat m
    fromI :: Integer -> UNat m
fromI Integer
0 = UNat 0 -> UNat m
forall a b. a -> b
unsafeCoerce @(UNat 0) @(UNat m) UNat 0
UZero
    fromI Integer
n = UNat ((m - 1) + 1) -> UNat m
forall a b. a -> b
unsafeCoerce @(UNat ((m-1)+1)) @(UNat m) (UNat (m - 1) -> UNat ((m - 1) + 1)
forall (n :: Nat). UNat n -> UNat (n + 1)
USucc (Integer -> UNat (m - 1)
forall (m :: Nat). Integer -> UNat m
fromI @(m-1) (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)))

-- | Convert a unary-encoded natural number to its singleton representation
--
-- __NB__: Not synthesizable
fromUNat :: UNat n -> SNat n
fromUNat :: UNat n -> SNat n
fromUNat UNat n
UZero     = SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 0
fromUNat (USucc UNat n
x) = SNat n -> SNat 1 -> SNat (n + 1)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a + b)
addSNat (UNat n -> SNat n
forall (n :: Nat). UNat n -> SNat n
fromUNat UNat n
x) (SNat 1
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 1)

-- | Add two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat :: UNat n -> UNat m -> UNat (n + m)
addUNat UNat n
UZero     UNat m
y     = UNat m
UNat (n + m)
y
addUNat UNat n
x         UNat m
UZero = UNat n
UNat (n + m)
x
addUNat (USucc UNat n
x) UNat m
y     = UNat (n + m) -> UNat ((n + m) + 1)
forall (n :: Nat). UNat n -> UNat (n + 1)
USucc (UNat n -> UNat m -> UNat (n + m)
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n + m)
addUNat UNat n
x UNat m
y)

-- | Multiply two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat :: UNat n -> UNat m -> UNat (n * m)
mulUNat UNat n
UZero      UNat m
_     = UNat 0
UNat (n * m)
UZero
mulUNat UNat n
_          UNat m
UZero = UNat 0
UNat (n * m)
UZero
mulUNat (USucc UNat n
x) UNat m
y      = UNat m -> UNat (n * m) -> UNat (m + (n * m))
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n + m)
addUNat UNat m
y (UNat n -> UNat m -> UNat (n * m)
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n * m)
mulUNat UNat n
x UNat m
y)

-- | Power of two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat :: UNat n -> UNat m -> UNat (n ^ m)
powUNat UNat n
_ UNat m
UZero     = UNat 0 -> UNat (0 + 1)
forall (n :: Nat). UNat n -> UNat (n + 1)
USucc UNat 0
UZero
powUNat UNat n
x (USucc UNat n
y) = UNat n -> UNat (n ^ n) -> UNat (n * (n ^ n))
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n * m)
mulUNat UNat n
x (UNat n -> UNat n -> UNat (n ^ n)
forall (n :: Nat) (m :: Nat). UNat n -> UNat m -> UNat (n ^ m)
powUNat UNat n
x UNat n
y)

-- | Predecessor of a unary-encoded natural number
--
-- __NB__: Not synthesizable
predUNat :: UNat (n+1) -> UNat n
predUNat :: UNat (n + 1) -> UNat n
predUNat (USucc UNat n
x) = UNat n
UNat n
x
predUNat UNat (n + 1)
UZero     =
  String -> UNat n
forall a. HasCallStack => String -> a
error String
"predUNat: impossible: 0 minus 1, -1 is not a natural number"

-- | Subtract two unary-encoded natural numbers
--
-- __NB__: Not synthesizable
subUNat :: UNat (m+n) -> UNat n -> UNat m
subUNat :: UNat (m + n) -> UNat n -> UNat m
subUNat UNat (m + n)
x         UNat n
UZero     = UNat m
UNat (m + n)
x
subUNat (USucc UNat n
x) (USucc UNat n
y) = UNat (m + n) -> UNat n -> UNat m
forall (m :: Nat) (n :: Nat). UNat (m + n) -> UNat n -> UNat m
subUNat UNat n
UNat (m + n)
x UNat n
y
subUNat UNat (m + n)
UZero     UNat n
_         = String -> UNat m
forall a. HasCallStack => String -> a
error String
"subUNat: impossible: 0 + (n + 1) ~ 0"

-- | Predecessor of a singleton natural number
predSNat :: SNat (a+1) -> SNat (a)
predSNat :: SNat (a + 1) -> SNat a
predSNat SNat (a + 1)
SNat = SNat a
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE predSNat #-}

-- | Successor of a singleton natural number
succSNat :: SNat a -> SNat (a+1)
succSNat :: SNat a -> SNat (a + 1)
succSNat SNat a
SNat = SNat (a + 1)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE succSNat #-}

-- | Add two singleton natural numbers
addSNat :: SNat a -> SNat b -> SNat (a+b)
addSNat :: SNat a -> SNat b -> SNat (a + b)
addSNat SNat a
SNat SNat b
SNat = SNat (a + b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE addSNat #-}
infixl 6 `addSNat`

-- | Subtract two singleton natural numbers
subSNat :: SNat (a+b) -> SNat b -> SNat a
subSNat :: SNat (a + b) -> SNat b -> SNat a
subSNat SNat (a + b)
SNat SNat b
SNat = SNat a
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE subSNat #-}
infixl 6 `subSNat`

-- | Multiply two singleton natural numbers
mulSNat :: SNat a -> SNat b -> SNat (a*b)
mulSNat :: SNat a -> SNat b -> SNat (a * b)
mulSNat SNat a
SNat SNat b
SNat = SNat (a * b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE mulSNat #-}
infixl 7 `mulSNat`

-- | Power of two singleton natural numbers
powSNat :: SNat a -> SNat b -> SNat (a^b)
powSNat :: SNat a -> SNat b -> SNat (a ^ b)
powSNat SNat a
SNat SNat b
SNat = SNat (a ^ b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE powSNat #-}
infixr 8 `powSNat`

-- | Division of two singleton natural numbers
divSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Div a b)
divSNat :: SNat a -> SNat b -> SNat (Div a b)
divSNat SNat a
SNat SNat b
SNat = SNat (Div a b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE divSNat #-}
infixl 7 `divSNat`

-- | Modulo of two singleton natural numbers
modSNat :: (1 <= b) => SNat a -> SNat b -> SNat (Mod a b)
modSNat :: SNat a -> SNat b -> SNat (Mod a b)
modSNat SNat a
SNat SNat b
SNat = SNat (Mod a b)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE modSNat #-}
infixl 7 `modSNat`

minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat :: SNat a -> SNat b -> SNat (Min a b)
minSNat SNat a
SNat SNat b
SNat = SNat (Min a b)
forall (n :: Nat). KnownNat n => SNat n
SNat

maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat :: SNat a -> SNat b -> SNat (Max a b)
maxSNat SNat a
SNat SNat b
SNat = SNat (Max a b)
forall (n :: Nat). KnownNat n => SNat n
SNat

-- | Floor of the logarithm of a natural number
flogBaseSNat :: (2 <= base, 1 <= x)
             => SNat base -- ^ Base
             -> SNat x
             -> SNat (FLog base x)
flogBaseSNat :: SNat base -> SNat x -> SNat (FLog base x)
flogBaseSNat SNat base
SNat SNat x
SNat = SNat (FLog base x)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE flogBaseSNat #-}

-- | Ceiling of the logarithm of a natural number
clogBaseSNat :: (2 <= base, 1 <= x)
             => SNat base -- ^ Base
             -> SNat x
             -> SNat (CLog base x)
clogBaseSNat :: SNat base -> SNat x -> SNat (CLog base x)
clogBaseSNat SNat base
SNat SNat x
SNat = SNat (CLog base x)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE clogBaseSNat #-}

-- | Exact integer logarithm of a natural number
--
-- __NB__: Only works when the argument is a power of the base
logBaseSNat :: (FLog base x ~ CLog base x)
            => SNat base -- ^ Base
            -> SNat x
            -> SNat (Log base x)
logBaseSNat :: SNat base -> SNat x -> SNat (Log base x)
logBaseSNat SNat base
SNat SNat x
SNat = SNat (Log base x)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# NOINLINE logBaseSNat #-}

-- | Power of two of a singleton natural number
pow2SNat :: SNat a -> SNat (2^a)
pow2SNat :: SNat a -> SNat (2 ^ a)
pow2SNat SNat a
SNat = SNat (2 ^ a)
forall (n :: Nat). KnownNat n => SNat n
SNat
{-# INLINE pow2SNat #-}

-- | Ordering relation between two Nats
data SNatLE a b where
  SNatLE :: forall a b . a <= b => SNatLE a b
  SNatGT :: forall a b . (b+1) <= a => SNatLE a b

-- | Get an ordering relation between two SNats
compareSNat :: forall a b . SNat a -> SNat b -> SNatLE a b
compareSNat :: SNat a -> SNat b -> SNatLE a b
compareSNat SNat a
a SNat b
b =
  if SNat a -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat a
a Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= SNat b -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat b
b
     then SNatLE 0 0 -> SNatLE a b
forall a b. a -> b
unsafeCoerce ((0 <= 0) => SNatLE 0 0
forall (a :: Nat) (b :: Nat). (a <= b) => SNatLE a b
SNatLE @0 @0)
     else SNatLE 1 0 -> SNatLE a b
forall a b. a -> b
unsafeCoerce (((0 + 1) <= 1) => SNatLE 1 0
forall (a :: Nat) (b :: Nat). ((b + 1) <= a) => SNatLE a b
SNatGT @1 @0)

-- | Base-2 encoded natural number
--
--    * __NB__: The LSB is the left/outer-most constructor:
--    * __NB__: Not synthesizable
--
-- >>> B0 (B1 (B1 BT))
-- b6
--
-- == Constructors
--
-- * Starting/Terminating element:
--
--      @
--      __BT__ :: 'BNat' 0
--      @
--
-- * Append a zero (/0/):
--
--      @
--      __B0__ :: 'BNat' n -> 'BNat' (2 'GHC.TypeNats.*' n)
--      @
--
-- * Append a one (/1/):
--
--      @
--      __B1__ :: 'BNat' n -> 'BNat' ((2 'GHC.TypeNats.*' n) 'GHC.TypeNats.+' 1)
--      @
data BNat :: Nat -> Type where
  BT :: BNat 0
  B0 :: BNat n -> BNat (2*n)
  B1 :: BNat n -> BNat ((2*n) + 1)

instance KnownNat n => Show (BNat n) where
  show :: BNat n -> String
show BNat n
x = Char
'b'Char -> ShowS
forall a. a -> [a] -> [a]
:Integer -> String
forall a. Show a => a -> String
show (BNat n -> Integer
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> Integer
natVal BNat n
x)

instance KnownNat n => ShowX (BNat n) where
  showsPrecX :: Int -> BNat n -> ShowS
showsPrecX = (Int -> BNat n -> ShowS) -> Int -> BNat n -> ShowS
forall a. (Int -> a -> ShowS) -> Int -> a -> ShowS
showsPrecXWith Int -> BNat n -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec

-- | Show a base-2 encoded natural as a binary literal
--
-- __NB__: The LSB is shown as the right-most bit
--
-- >>> d789
-- d789
-- >>> toBNat d789
-- b789
-- >>> showBNat (toBNat d789)
-- "0b1100010101"
-- >>> 0b1100010101 :: Integer
-- 789
showBNat :: BNat n -> String
showBNat :: BNat n -> String
showBNat = String -> BNat n -> String
forall (m :: Nat). String -> BNat m -> String
go []
  where
    go :: String -> BNat m -> String
    go :: String -> BNat m -> String
go String
xs BNat m
BT  = String
"0b" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
xs
    go String
xs (B0 BNat n
x) = String -> BNat n -> String
forall (m :: Nat). String -> BNat m -> String
go (Char
'0'Char -> ShowS
forall a. a -> [a] -> [a]
:String
xs) BNat n
x
    go String
xs (B1 BNat n
x) = String -> BNat n -> String
forall (m :: Nat). String -> BNat m -> String
go (Char
'1'Char -> ShowS
forall a. a -> [a] -> [a]
:String
xs) BNat n
x

-- | Convert a singleton natural number to its base-2 representation
--
-- __NB__: Not synthesizable
toBNat :: SNat n -> BNat n
toBNat :: SNat n -> BNat n
toBNat s :: SNat n
s@SNat n
SNat = Integer -> BNat n
forall (m :: Nat). Integer -> BNat m
toBNat' (SNat n -> Integer
forall (n :: Nat). SNat n -> Integer
snatToInteger SNat n
s)
  where
    toBNat' :: forall m . Integer -> BNat m
    toBNat' :: Integer -> BNat m
toBNat' Integer
0 = BNat 0 -> BNat m
forall a b. a -> b
unsafeCoerce BNat 0
BT
    toBNat' Integer
n = case Integer
n Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` Integer
2 of
      (Integer
n',Integer
1) -> BNat ((2 * Div (m - 1) 2) + 1) -> BNat m
forall a b. a -> b
unsafeCoerce (BNat (Div (m - 1) 2) -> BNat ((2 * Div (m - 1) 2) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (Integer -> BNat (Div (m - 1) 2)
forall (m :: Nat). Integer -> BNat m
toBNat' @(Div (m-1) 2) Integer
n'))
      (Integer
n',Integer
_) -> BNat (2 * Div m 2) -> BNat m
forall a b. a -> b
unsafeCoerce (BNat (Div m 2) -> BNat (2 * Div m 2)
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (Integer -> BNat (Div m 2)
forall (m :: Nat). Integer -> BNat m
toBNat' @(Div m 2) Integer
n'))

-- | Convert a base-2 encoded natural number to its singleton representation
--
-- __NB__: Not synthesizable
fromBNat :: BNat n -> SNat n
fromBNat :: BNat n -> SNat n
fromBNat BNat n
BT     = SNat 0
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 0
fromBNat (B0 BNat n
x) = SNat 2 -> SNat n -> SNat (2 * n)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a * b)
mulSNat (SNat 2
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 2) (BNat n -> SNat n
forall (n :: Nat). BNat n -> SNat n
fromBNat BNat n
x)
fromBNat (B1 BNat n
x) = SNat (2 * n) -> SNat 1 -> SNat ((2 * n) + 1)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a + b)
addSNat (SNat 2 -> SNat n -> SNat (2 * n)
forall (a :: Nat) (b :: Nat). SNat a -> SNat b -> SNat (a * b)
mulSNat (SNat 2
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 2) (BNat n -> SNat n
forall (n :: Nat). BNat n -> SNat n
fromBNat BNat n
x))
                          (SNat 1
forall (n :: Nat). KnownNat n => SNat n
SNat :: SNat 1)

-- | Add two base-2 encoded natural numbers
--
-- __NB__: Not synthesizable
addBNat :: BNat n -> BNat m -> BNat (n+m)
addBNat :: BNat n -> BNat m -> BNat (n + m)
addBNat (B0 BNat n
a) (B0 BNat n
b) = BNat (n + n) -> BNat (2 * (n + n))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b)
addBNat (B0 BNat n
a) (B1 BNat n
b) = BNat (n + n) -> BNat ((2 * (n + n)) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b)
addBNat (B1 BNat n
a) (B0 BNat n
b) = BNat (n + n) -> BNat ((2 * (n + n)) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b)
addBNat (B1 BNat n
a) (B1 BNat n
b) = BNat ((n + n) + 1) -> BNat (2 * ((n + n) + 1))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat (n + n) -> BNat ((n + n) + 1)
forall (n :: Nat). BNat n -> BNat (n + 1)
succBNat (BNat n -> BNat n -> BNat (n + n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat BNat n
a BNat n
b))
addBNat BNat n
BT     BNat m
b      = BNat m
BNat (n + m)
b
addBNat BNat n
a      BNat m
BT     = BNat n
BNat (n + m)
a

-- | Multiply two base-2 encoded natural numbers
--
-- __NB__: Not synthesizable
mulBNat :: BNat n -> BNat m -> BNat (n*m)
mulBNat :: BNat n -> BNat m -> BNat (n * m)
mulBNat BNat n
BT      BNat m
_  = BNat 0
BNat (n * m)
BT
mulBNat BNat n
_       BNat m
BT = BNat 0
BNat (n * m)
BT
mulBNat (B0 BNat n
a)  BNat m
b  = BNat (n * m) -> BNat (2 * (n * m))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat m -> BNat (n * m)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat n
a BNat m
b)
mulBNat (B1 BNat n
a)  BNat m
b  = BNat (2 * (n * m)) -> BNat m -> BNat ((2 * (n * m)) + m)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n + m)
addBNat (BNat (n * m) -> BNat (2 * (n * m))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat m -> BNat (n * m)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat n
a BNat m
b)) BNat m
b

-- | Power of two base-2 encoded natural numbers
--
-- __NB__: Not synthesizable
powBNat :: BNat n -> BNat m -> BNat (n^m)
powBNat :: BNat n -> BNat m -> BNat (n ^ m)
powBNat BNat n
_  BNat m
BT      = BNat 0 -> BNat ((2 * 0) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 BNat 0
BT
powBNat BNat n
a  (B0 BNat n
b)  = let z :: BNat (n ^ n)
z = BNat n -> BNat n -> BNat (n ^ n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n ^ m)
powBNat BNat n
a BNat n
b
                     in  BNat (n ^ n) -> BNat (n ^ n) -> BNat ((n ^ n) * (n ^ n))
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat (n ^ n)
z BNat (n ^ n)
z
powBNat BNat n
a  (B1 BNat n
b)  = let z :: BNat (n ^ n)
z = BNat n -> BNat n -> BNat (n ^ n)
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n ^ m)
powBNat BNat n
a BNat n
b
                     in  BNat n
-> BNat ((n ^ n) * (n ^ n)) -> BNat (n * ((n ^ n) * (n ^ n)))
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat n
a (BNat (n ^ n) -> BNat (n ^ n) -> BNat ((n ^ n) * (n ^ n))
forall (n :: Nat) (m :: Nat). BNat n -> BNat m -> BNat (n * m)
mulBNat BNat (n ^ n)
z BNat (n ^ n)
z)

-- | Successor of a base-2 encoded natural number
--
-- __NB__: Not synthesizable
succBNat :: BNat n -> BNat (n+1)
succBNat :: BNat n -> BNat (n + 1)
succBNat BNat n
BT     = BNat 0 -> BNat ((2 * 0) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 BNat 0
BT
succBNat (B0 BNat n
a) = BNat n -> BNat ((2 * n) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 BNat n
a
succBNat (B1 BNat n
a) = BNat (n + 1) -> BNat (2 * (n + 1))
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 (BNat n -> BNat (n + 1)
forall (n :: Nat). BNat n -> BNat (n + 1)
succBNat BNat n
a)

-- | Predecessor of a base-2 encoded natural number
--
-- __NB__: Not synthesizable
predBNat :: (1 <= n) => BNat n -> BNat (n-1)
predBNat :: BNat n -> BNat (n - 1)
predBNat (B1 BNat n
a) = case BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
a of
  BNat n
BT -> BNat 0
BNat (n - 1)
BT
  BNat n
a' -> BNat n -> BNat (2 * n)
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 BNat n
a'
predBNat (B0 BNat n
x) = BNat (n - 1) -> BNat ((2 * (n - 1)) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat (n - 1)
forall (n :: Nat). (1 <= n) => BNat n -> BNat (n - 1)
predBNat BNat n
x)

-- | Divide a base-2 encoded natural number by 2
--
-- __NB__: Not synthesizable
div2BNat :: BNat (2*n) -> BNat n
div2BNat :: BNat (2 * n) -> BNat n
div2BNat BNat (2 * n)
BT     = BNat n
BNat 0
BT
div2BNat (B0 BNat n
x) = BNat n
BNat n
x
div2BNat (B1 BNat n
_) = String -> BNat n
forall a. HasCallStack => String -> a
error String
"div2BNat: impossible: 2*n ~ 2*n+1"

-- | Subtract 1 and divide a base-2 encoded natural number by 2
--
-- __NB__: Not synthesizable
div2Sub1BNat :: BNat (2*n+1) -> BNat n
div2Sub1BNat :: BNat ((2 * n) + 1) -> BNat n
div2Sub1BNat (B1 BNat n
x) = BNat n
BNat n
x
div2Sub1BNat BNat ((2 * n) + 1)
_      = String -> BNat n
forall a. HasCallStack => String -> a
error String
"div2Sub1BNat: impossible: 2*n+1 ~ 2*n"

-- | Get the log2 of a base-2 encoded natural number
--
-- __NB__: Not synthesizable
log2BNat :: BNat (2^n) -> BNat n
log2BNat :: BNat (2 ^ n) -> BNat n
log2BNat BNat (2 ^ n)
BT = String -> BNat n
forall a. HasCallStack => String -> a
error String
"log2BNat: log2(0) not defined"
log2BNat (B1 BNat n
x) = case BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
x of
  BNat n
BT -> BNat n
BNat 0
BT
  BNat n
_  -> String -> BNat n
forall a. HasCallStack => String -> a
error String
"log2BNat: impossible: 2^n ~ 2x+1"
log2BNat (B0 BNat n
x) = BNat (n - 1) -> BNat ((n - 1) + 1)
forall (n :: Nat). BNat n -> BNat (n + 1)
succBNat (BNat (2 ^ (n - 1)) -> BNat (n - 1)
forall (n :: Nat). BNat (2 ^ n) -> BNat n
log2BNat BNat n
BNat (2 ^ (n - 1))
x)

-- | Strip non-contributing zero's from a base-2 encoded natural number
--
-- >>> B1 (B0 (B0 (B0 BT)))
-- b1
-- >>> showBNat (B1 (B0 (B0 (B0 BT))))
-- "0b0001"
-- >>> showBNat (stripZeros (B1 (B0 (B0 (B0 BT)))))
-- "0b1"
-- >>> stripZeros (B1 (B0 (B0 (B0 BT))))
-- b1
--
-- __NB__: Not synthesizable
stripZeros :: BNat n -> BNat n
stripZeros :: BNat n -> BNat n
stripZeros BNat n
BT      = BNat n
BNat 0
BT
stripZeros (B1 BNat n
x)  = BNat n -> BNat ((2 * n) + 1)
forall (n :: Nat). BNat n -> BNat ((2 * n) + 1)
B1 (BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
x)
stripZeros (B0 BNat n
BT) = BNat n
BNat 0
BT
stripZeros (B0 BNat n
x)  = case BNat n -> BNat n
forall (n :: Nat). BNat n -> BNat n
stripZeros BNat n
x of
  BNat n
BT -> BNat n
BNat 0
BT
  BNat n
k  -> BNat n -> BNat (2 * n)
forall (n :: Nat). BNat n -> BNat (2 * n)
B0 BNat n
k

-- | Change a function that has an argument with an @(n ~ (k + m))@ constraint to a
-- function with an argument that has an @(k <= n)@ constraint.
--
-- === __Examples__
--
-- Example 1
--
-- @
-- f :: Index (n+1) -> Index (n + 1) -> Bool
--
-- g :: forall n. (1 'GHC.TypeNats.<=' n) => Index n -> Index n -> Bool
-- g a b = 'leToPlus' \@1 \@n (f a b)
-- @
--
-- Example 2
--
-- @
-- head :: Vec (n + 1) a -> a
--
-- head' :: forall n a. (1 'GHC.TypeNats.<=' n) => Vec n a -> a
-- head' = 'leToPlus' @1 @n head
-- @
leToPlus
  :: forall (k :: Nat) (n :: Nat) r
   . ( k <= n
     )
  => (forall m . (n ~ (k + m)) => r)
  -- ^ Context with the @(n ~ (k + m))@ constraint
  -> r
leToPlus :: (forall (m :: Nat). (n ~ (k + m)) => r) -> r
leToPlus forall (m :: Nat). (n ~ (k + m)) => r
r = (n ~ (k + (n - k))) => r
forall (m :: Nat). (n ~ (k + m)) => r
r @(n - k)
{-# INLINE leToPlus #-}

-- | Same as 'leToPlus' with added 'KnownNat' constraints
leToPlusKN
  :: forall (k :: Nat) (n :: Nat) r
   . ( k <= n
     , KnownNat k
     , KnownNat n
     )
  => (forall m . (n ~ (k + m), KnownNat m) => r)
  -- ^ Context with the @(n ~ (k + m))@ constraint
  -> r
leToPlusKN :: (forall (m :: Nat). (n ~ (k + m), KnownNat m) => r) -> r
leToPlusKN forall (m :: Nat). (n ~ (k + m), KnownNat m) => r
r = (n ~ (k + (n - k)), KnownNat (n - k)) => r
forall (m :: Nat). (n ~ (k + m), KnownNat m) => r
r @(n - k)
{-# INLINE leToPlusKN #-}