{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language ExplicitNamespaces #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language MagicHash #-}
{-# language RankNTypes #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}
{-# language TypeOperators #-}
module Arithmetic.Fin
  ( -- * Modification
    incrementL
  , incrementR
  , weaken
  , weakenL
  , weakenR
    -- * Traverse
    -- | These use the terms @ascend@ and @descend@ rather than the
    -- more popular @l@ (left) and @r@ (right) that pervade the Haskell
    -- ecosystem. The general rule is that ascending functions pair
    -- the initial accumulator with zero with descending functions
    -- pair the initial accumulator with the last index.
  , ascend
  , ascend'
  , ascendFrom'
  , ascendFrom'#
  , ascendM
  , ascendM#
  , ascendM_
  , ascendM_#
  , descend
  , descend'
  , descendM
  , descendM_
  , ascending
  , descending
  , ascendingSlice
  , descendingSlice
    -- * Absurdities
  , absurd
    -- * Demote
  , demote
    -- * Deconstruct
  , with
  , with#
    -- * Construct
  , construct#
    -- * Lift and Unlift
  , lift
  , unlift
  ) where

import Prelude hiding (last)

import Arithmetic.Nat ((<?))
import Arithmetic.Types (Fin(..),Fin#,Difference(..),Nat,Nat#,type (<), type (<=), type (:=:))
import GHC.Exts (Int(I#))
import GHC.TypeNats (type (+))

import qualified Arithmetic.Lt as Lt
import qualified Arithmetic.Lte as Lte
import qualified Arithmetic.Equal as Eq
import qualified Arithmetic.Nat as Nat
import qualified Arithmetic.Plus as Plus
import qualified Arithmetic.Unsafe as Unsafe

-- | Raise the index by @m@ and weaken the bound by @m@, adding
-- @m@ to the right-hand side of @n@.
incrementR :: forall n m. Nat m -> Fin n -> Fin (n + m)
{-# inline incrementR #-}
incrementR :: forall (n :: Nat) (m :: Nat). Nat m -> Fin n -> Fin (n + m)
incrementR Nat m
m (Fin Nat m
i m < n
pf) = forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin (forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
Nat.plus Nat m
i Nat m
m) (forall (c :: Nat) (a :: Nat) (b :: Nat).
(a < b) -> (a + c) < (b + c)
Lt.incrementR @m m < n
pf)

-- | Raise the index by @m@ and weaken the bound by @m@, adding
-- @m@ to the left-hand side of @n@.
incrementL :: forall n m. Nat m -> Fin n -> Fin (m + n)
{-# inline incrementL #-}
incrementL :: forall (n :: Nat) (m :: Nat). Nat m -> Fin n -> Fin (m + n)
incrementL Nat m
m (Fin Nat m
i m < n
pf) = forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin (forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
Nat.plus Nat m
m Nat m
i) (forall (c :: Nat) (a :: Nat) (b :: Nat).
(a < b) -> (c + a) < (c + b)
Lt.incrementL @m m < n
pf)

-- | Weaken the bound by @m@, adding it to the left-hand side of
-- the existing bound. This does not change the index.
weakenL :: forall n m. Fin n -> Fin (m + n)
{-# inline weakenL #-}
weakenL :: forall (n :: Nat) (m :: Nat). Fin n -> Fin (m + n)
weakenL (Fin Nat m
i m < n
pf) = forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat m
i
  ( forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (a < b) -> a < c
Lt.substituteR
    (forall (a :: Nat) (b :: Nat). (a + b) :=: (b + a)
Plus.commutative @n @m)
    (forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a < b) -> (c <= d) -> (a + c) < (b + d)
Lt.plus m < n
pf (forall (a :: Nat). 0 <= a
Lte.zero @m))
  )

-- | Weaken the bound by @m@, adding it to the right-hand side of
-- the existing bound. This does not change the index.
weakenR :: forall n m. Fin n -> Fin (n + m)
{-# inline weakenR #-}
weakenR :: forall (n :: Nat) (m :: Nat). Fin n -> Fin (n + m)
weakenR (Fin Nat m
i m < n
pf) = forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat m
i (forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a < b) -> (c <= d) -> (a + c) < (b + d)
Lt.plus m < n
pf forall (a :: Nat). 0 <= a
Lte.zero)

-- | Weaken the bound, replacing it by another number greater than
-- or equal to itself. This does not change the index.
weaken :: forall n m. (n <= m) -> Fin n -> Fin m
{-# inline weaken #-}
weaken :: forall (n :: Nat) (m :: Nat). (n <= m) -> Fin n -> Fin m
weaken n <= m
lt (Fin Nat m
i m < n
pf) = forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat m
i (forall (a :: Nat) (b :: Nat) (c :: Nat).
(a < b) -> (b <= c) -> a < c
Lt.transitiveNonstrictR m < n
pf n <= m
lt)

-- | A finite set of no values is impossible.
absurd :: Fin 0 -> void
{-# inline absurd #-}
absurd :: forall void. Fin 0 -> void
absurd (Fin Nat m
_ m < 0
pf) = forall (n :: Nat) void. (n < 0) -> void
Lt.absurd m < 0
pf

-- | Fold over the numbers bounded by @n@ in descending
-- order. This is lazy in the accumulator. For convenince,
-- this differs from @foldr@ in the order of the parameters.
--
-- > descend 4 z f = f 0 (f 1 (f 2 (f 3 z)))
descend :: forall a n.
     Nat n -- ^ Upper bound
  -> a -- ^ Initial accumulator
  -> (Fin n -> a -> a) -- ^ Update accumulator
  -> a
{-# inline descend #-}
descend :: forall a (n :: Nat). Nat n -> a -> (Fin n -> a -> a) -> a
descend !Nat n
n a
b0 Fin n -> a -> a
f = forall (m :: Nat). Nat m -> a
go Nat 0
Nat.zero
  where
  go :: Nat m -> a
  go :: forall (m :: Nat). Nat m -> a
go !Nat m
m = case Nat m
m forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
<? Nat n
n of
    Maybe (m < n)
Nothing -> a
b0
    Just m < n
lt -> Fin n -> a -> a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat m
m m < n
lt) (forall (m :: Nat). Nat m -> a
go (forall (a :: Nat). Nat a -> Nat (a + 1)
Nat.succ Nat m
m))

-- | Fold over the numbers bounded by @n@ in descending
-- order. This is strict in the accumulator. For convenince,
-- this differs from @foldr'@ in the order of the parameters.
--
-- > descend 4 z f = f 0 (f 1 (f 2 (f 3 z)))
descend' :: forall a n.
     Nat n -- ^ Upper bound
  -> a -- ^ Initial accumulator
  -> (Fin n -> a -> a) -- ^ Update accumulator
  -> a
{-# inline descend' #-}
descend' :: forall a (n :: Nat). Nat n -> a -> (Fin n -> a -> a) -> a
descend' !Nat n
n !a
b0 Fin n -> a -> a
f = forall (p :: Nat). Nat p -> (p <= n) -> a -> a
go Nat n
n forall (a :: Nat). a <= a
Lte.reflexive a
b0
  where
    go :: Nat p -> p <= n -> a -> a
    go :: forall (p :: Nat). Nat p -> (p <= n) -> a -> a
go !Nat p
m p <= n
pLteEn !a
b = case forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
Nat.monus Nat p
m Nat 1
Nat.one of
      Maybe (Difference p 1)
Nothing -> a
b
      Just (Difference (Nat c
mpred :: Nat c) (c + 1) :=: p
cPlusOneEqP) ->
        let !cLtEn :: c < n
cLtEn = forall (a :: Nat) (b :: Nat) (c :: Nat).
((a + 1) :=: b) -> (b <= c) -> a < c
descendLemma (c + 1) :=: p
cPlusOneEqP p <= n
pLteEn
        in forall (p :: Nat). Nat p -> (p <= n) -> a -> a
go Nat c
mpred (forall (a :: Nat) (b :: Nat). (a < b) -> a <= b
Lte.fromStrict c < n
cLtEn) (Fin n -> a -> a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat c
mpred c < n
cLtEn) a
b)

-- | Fold over the numbers bounded by @n@ in ascending order. This
-- is lazy in the accumulator.
--
-- > ascend 4 z f = f 3 (f 2 (f 1 (f 0 z)))
ascend :: forall a n.
     Nat n
  -> a
  -> (Fin n -> a -> a)
  -> a
{-# inline ascend #-}
ascend :: forall a (n :: Nat). Nat n -> a -> (Fin n -> a -> a) -> a
ascend !Nat n
n !a
b0 Fin n -> a -> a
f = forall (p :: Nat). Nat p -> (p <= n) -> a
go Nat n
n forall (a :: Nat). a <= a
Lte.reflexive
  where
    go :: Nat p -> (p <= n) -> a
    go :: forall (p :: Nat). Nat p -> (p <= n) -> a
go !Nat p
m p <= n
pLteEn = case forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
Nat.monus Nat p
m Nat 1
Nat.one of
      Maybe (Difference p 1)
Nothing -> a
b0
      Just (Difference (Nat c
mpred :: Nat c) (c + 1) :=: p
cPlusOneEqP) ->
        let !cLtEn :: c < n
cLtEn = forall (a :: Nat) (b :: Nat) (c :: Nat).
((a + 1) :=: b) -> (b <= c) -> a < c
descendLemma (c + 1) :=: p
cPlusOneEqP p <= n
pLteEn
        in Fin n -> a -> a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat c
mpred c < n
cLtEn) (forall (p :: Nat). Nat p -> (p <= n) -> a
go Nat c
mpred (forall (a :: Nat) (b :: Nat). (a < b) -> a <= b
Lte.fromStrict c < n
cLtEn))

-- | Strict fold over the numbers bounded by @n@ in ascending
-- order. For convenince, this differs from @foldl'@ in the
-- order of the parameters.
--
-- > ascend' 4 z f = f 3 (f 2 (f 1 (f 0 z)))
ascend' :: forall a n.
     Nat n -- ^ Upper bound
  -> a -- ^ Initial accumulator
  -> (Fin n -> a -> a) -- ^ Update accumulator
  -> a
{-# inline ascend' #-}
ascend' :: forall a (n :: Nat). Nat n -> a -> (Fin n -> a -> a) -> a
ascend' !Nat n
n !a
b0 Fin n -> a -> a
f = forall (m :: Nat). Nat m -> a -> a
go Nat 0
Nat.zero a
b0
  where
  go :: Nat m -> a -> a
  go :: forall (m :: Nat). Nat m -> a -> a
go !Nat m
m !a
b = case Nat m
m forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
<? Nat n
n of
    Maybe (m < n)
Nothing -> a
b
    Just m < n
lt -> forall (m :: Nat). Nat m -> a -> a
go (forall (a :: Nat). Nat a -> Nat (a + 1)
Nat.succ Nat m
m) (Fin n -> a -> a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat m
m m < n
lt) a
b)

-- | Generalization of @ascend'@ that lets the caller pick the starting index:
--
-- > ascend' === ascendFrom' 0
ascendFrom' :: forall a m n.
     Nat m -- ^ Index to start at
  -> Nat n -- ^ Number of steps to take
  -> a -- ^ Initial accumulator
  -> (Fin (m + n) -> a -> a) -- ^ Update accumulator
  -> a
{-# inline ascendFrom' #-}
ascendFrom' :: forall a (m :: Nat) (n :: Nat).
Nat m -> Nat n -> a -> (Fin (m + n) -> a -> a) -> a
ascendFrom' !Nat m
m0 !Nat n
n !a
b0 Fin (m + n) -> a -> a
f = forall (k :: Nat). Nat k -> a -> a
go Nat m
m0 a
b0
  where
  end :: Nat (m + n)
end = forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
Nat.plus Nat m
m0 Nat n
n
  go :: Nat k -> a -> a
  go :: forall (k :: Nat). Nat k -> a -> a
go !Nat k
m !a
b = case Nat k
m forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
<? Nat (m + n)
end of
    Maybe (k < (m + n))
Nothing -> a
b
    Just k < (m + n)
lt -> forall (k :: Nat). Nat k -> a -> a
go (forall (a :: Nat). Nat a -> Nat (a + 1)
Nat.succ Nat k
m) (Fin (m + n) -> a -> a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat k
m k < (m + n)
lt) a
b)

-- | Variant of @ascendFrom'@ with unboxed arguments.
ascendFrom'# :: forall a m n.
     Nat# m -- ^ Index to start at
  -> Nat# n -- ^ Number of steps to take
  -> a -- ^ Initial accumulator
  -> (Fin# (m + n) -> a -> a) -- ^ Update accumulator
  -> a
{-# inline ascendFrom'# #-}
ascendFrom'# :: forall a (m :: Nat) (n :: Nat).
Nat# m -> Nat# n -> a -> (Fin# (m + n) -> a -> a) -> a
ascendFrom'# !Nat# m
m0 !Nat# n
n !a
b0 Fin# (m + n) -> a -> a
f = forall a (m :: Nat) (n :: Nat).
Nat m -> Nat n -> a -> (Fin (m + n) -> a -> a) -> a
ascendFrom' (forall (n :: Nat). Nat# n -> Nat n
Nat.lift Nat# m
m0) (forall (n :: Nat). Nat# n -> Nat n
Nat.lift Nat# n
n) a
b0 (\Fin (m + n)
ix -> Fin# (m + n) -> a -> a
f (forall (n :: Nat). Fin n -> Fin# n
unlift Fin (m + n)
ix))

-- | Strict monadic left fold over the numbers bounded by @n@
-- in ascending order. Roughly:
--
-- > ascendM 4 z0 f =
-- >   f 0 z0 >>= \z1 ->
-- >   f 1 z1 >>= \z2 ->
-- >   f 2 z2 >>= \z3 ->
-- >   f 3 z3
ascendM :: forall m a n. Monad m
  => Nat n -- ^ Upper bound
  -> a -- ^ Initial accumulator
  -> (Fin n -> a -> m a) -- ^ Update accumulator
  -> m a
{-# inline ascendM #-}
ascendM :: forall (m :: * -> *) a (n :: Nat).
Monad m =>
Nat n -> a -> (Fin n -> a -> m a) -> m a
ascendM !Nat n
n !a
b0 Fin n -> a -> m a
f = forall (p :: Nat). Nat p -> a -> m a
go Nat 0
Nat.zero a
b0
  where
  go :: Nat p -> a -> m a
  go :: forall (p :: Nat). Nat p -> a -> m a
go !Nat p
m !a
b = case Nat p
m forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
<? Nat n
n of
    Maybe (p < n)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
b
    Just p < n
lt -> forall (p :: Nat). Nat p -> a -> m a
go (forall (a :: Nat). Nat a -> Nat (a + 1)
Nat.succ Nat p
m) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Fin n -> a -> m a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat p
m p < n
lt) a
b

-- | Variant of @ascendM@ that takes an unboxed Nat and provides
-- an unboxed Fin to the callback.
ascendM# :: forall m a n. Monad m
  => Nat# n -- ^ Upper bound
  -> a -- ^ Initial accumulator
  -> (Fin# n -> a -> m a) -- ^ Update accumulator
  -> m a
{-# inline ascendM# #-}
ascendM# :: forall (m :: * -> *) a (n :: Nat).
Monad m =>
Nat# n -> a -> (Fin# n -> a -> m a) -> m a
ascendM# Nat# n
n !a
a0 Fin# n -> a -> m a
f = forall (m :: * -> *) a (n :: Nat).
Monad m =>
Nat n -> a -> (Fin n -> a -> m a) -> m a
ascendM (forall (n :: Nat). Nat# n -> Nat n
Nat.lift Nat# n
n) a
a0 (\Fin n
ix a
a -> Fin# n -> a -> m a
f (forall (n :: Nat). Fin n -> Fin# n
unlift Fin n
ix) a
a)

-- | Monadic traversal of the numbers bounded by @n@
-- in ascending order.
--
-- > ascendM_ 4 f = f 0 *> f 1 *> f 2 *> f 3
ascendM_ :: forall m a n. Applicative m
  => Nat n -- ^ Upper bound
  -> (Fin n -> m a) -- ^ Effectful interpretion
  -> m ()
{-# inline ascendM_ #-}
ascendM_ :: forall (m :: * -> *) a (n :: Nat).
Applicative m =>
Nat n -> (Fin n -> m a) -> m ()
ascendM_ !Nat n
n Fin n -> m a
f = forall (p :: Nat). Nat p -> m ()
go Nat 0
Nat.zero
  where
  go :: Nat p -> m ()
  go :: forall (p :: Nat). Nat p -> m ()
go !Nat p
m = case Nat p
m forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
<? Nat n
n of
    Maybe (p < n)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just p < n
lt -> Fin n -> m a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat p
m p < n
lt) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (p :: Nat). Nat p -> m ()
go (forall (a :: Nat). Nat a -> Nat (a + 1)
Nat.succ Nat p
m)

-- | Variant of @ascendM_@ that takes an unboxed Nat and provides
-- an unboxed Fin to the callback.
ascendM_# :: forall m a n. Monad m
  => Nat# n -- ^ Upper bound
  -> (Fin# n -> m a) -- ^ Update accumulator
  -> m ()
{-# inline ascendM_# #-}
ascendM_# :: forall (m :: * -> *) a (n :: Nat).
Monad m =>
Nat# n -> (Fin# n -> m a) -> m ()
ascendM_# Nat# n
n Fin# n -> m a
f = forall (m :: * -> *) a (n :: Nat).
Applicative m =>
Nat n -> (Fin n -> m a) -> m ()
ascendM_ (forall (n :: Nat). Nat# n -> Nat n
Nat.lift Nat# n
n) (\Fin n
ix -> Fin# n -> m a
f (forall (n :: Nat). Fin n -> Fin# n
unlift Fin n
ix))

descendLemma :: forall a b c. a + 1 :=: b -> b <= c -> a < c
{-# inline descendLemma #-}
descendLemma :: forall (a :: Nat) (b :: Nat) (c :: Nat).
((a + 1) :=: b) -> (b <= c) -> a < c
descendLemma !(a + 1) :=: b
aPlusOneEqB !b <= c
bLteC = forall a. a -> a
id
  forall a b. (a -> b) -> a -> b
$ forall (a :: Nat) (b :: Nat) (c :: Nat).
(a < b) -> (b <= c) -> a < c
Lt.transitiveNonstrictR
      (forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (a < b) -> a < c
Lt.substituteR (forall (a :: Nat) (b :: Nat). (a + b) :=: (b + a)
Plus.commutative @1 @a)
      (forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a < b) -> (c <= d) -> (a + c) < (b + d)
Lt.plus 0 < 1
Lt.zero forall (a :: Nat). a <= a
Lte.reflexive))
  forall a b. (a -> b) -> a -> b
$ forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (b <= a) -> c <= a
Lte.substituteL (forall (m :: Nat) (n :: Nat). (m :=: n) -> n :=: m
Eq.symmetric (a + 1) :=: b
aPlusOneEqB) b <= c
bLteC

-- | Strict monadic left fold over the numbers bounded by @n@
-- in descending order. Roughly:
--
-- > descendM 4 z f =
-- >   f 3 z0 >>= \z1 ->
-- >   f 2 z1 >>= \z2 ->
-- >   f 1 z2 >>= \z3 ->
-- >   f 0 z3
descendM :: forall m a n. Monad m
  => Nat n
  -> a
  -> (Fin n -> a -> m a)
  -> m a
{-# inline descendM #-}
descendM :: forall (m :: * -> *) a (n :: Nat).
Monad m =>
Nat n -> a -> (Fin n -> a -> m a) -> m a
descendM !Nat n
n !a
b0 Fin n -> a -> m a
f = forall (p :: Nat). Nat p -> (p <= n) -> a -> m a
go Nat n
n forall (a :: Nat). a <= a
Lte.reflexive a
b0
  where
    go :: Nat p -> p <= n -> a -> m a
    go :: forall (p :: Nat). Nat p -> (p <= n) -> a -> m a
go !Nat p
m p <= n
pLteEn !a
b = case forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
Nat.monus Nat p
m Nat 1
Nat.one of
      Maybe (Difference p 1)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
b
      Just (Difference (Nat c
mpred :: Nat c) (c + 1) :=: p
cPlusOneEqP) ->
        let !cLtEn :: c < n
cLtEn = forall (a :: Nat) (b :: Nat) (c :: Nat).
((a + 1) :=: b) -> (b <= c) -> a < c
descendLemma (c + 1) :=: p
cPlusOneEqP p <= n
pLteEn
        in forall (p :: Nat). Nat p -> (p <= n) -> a -> m a
go Nat c
mpred (forall (a :: Nat) (b :: Nat). (a < b) -> a <= b
Lte.fromStrict c < n
cLtEn) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Fin n -> a -> m a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat c
mpred c < n
cLtEn) a
b

-- | Monadic traversal of the numbers bounded by @n@
-- in descending order.
--
-- > descendM_ 4 f = f 3 *> f 2 *> f 1 *> f 0
descendM_ :: forall m a n. Applicative m
  => Nat n -- ^ Upper bound
  -> (Fin n -> m a) -- ^ Effectful interpretion
  -> m ()
{-# inline descendM_ #-}
descendM_ :: forall (m :: * -> *) a (n :: Nat).
Applicative m =>
Nat n -> (Fin n -> m a) -> m ()
descendM_ !Nat n
n Fin n -> m a
f = forall (p :: Nat). Nat p -> (p <= n) -> m ()
go Nat n
n forall (a :: Nat). a <= a
Lte.reflexive
  where
  go :: Nat p -> p <= n -> m ()
  go :: forall (p :: Nat). Nat p -> (p <= n) -> m ()
go !Nat p
m !p <= n
pLteEn = case forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
Nat.monus Nat p
m Nat 1
Nat.one of
    Maybe (Difference p 1)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just (Difference (Nat c
mpred :: Nat c) (c + 1) :=: p
cPlusOneEqP) ->
      let !cLtEn :: c < n
cLtEn = forall (a :: Nat) (b :: Nat) (c :: Nat).
((a + 1) :=: b) -> (b <= c) -> a < c
descendLemma (c + 1) :=: p
cPlusOneEqP p <= n
pLteEn
      in Fin n -> m a
f (forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat c
mpred c < n
cLtEn) forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> forall (p :: Nat). Nat p -> (p <= n) -> m ()
go Nat c
mpred (forall (a :: Nat) (b :: Nat). (a < b) -> a <= b
Lte.fromStrict c < n
cLtEn)

-- | Generate all values of a finite set in ascending order.
--
-- >>> ascending (Nat.constant @3)
-- [Fin 0,Fin 1,Fin 2]
ascending :: forall n. Nat n -> [Fin n]
ascending :: forall (n :: Nat). Nat n -> [Fin n]
ascending !Nat n
n = forall (m :: Nat). Nat m -> [Fin n]
go Nat 0
Nat.zero
  where
  go :: Nat m -> [Fin n]
  go :: forall (m :: Nat). Nat m -> [Fin n]
go !Nat m
m = case Nat m
m forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
<? Nat n
n of
    Maybe (m < n)
Nothing -> []
    Just m < n
lt -> forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat m
m m < n
lt forall a. a -> [a] -> [a]
: forall (m :: Nat). Nat m -> [Fin n]
go (forall (a :: Nat). Nat a -> Nat (a + 1)
Nat.succ Nat m
m)

-- | Generate all values of a finite set in descending order.
--
-- >>> descending (Nat.constant @3)
-- [Fin 2,Fin 1,Fin 0]
descending :: forall n. Nat n -> [Fin n]
descending :: forall (n :: Nat). Nat n -> [Fin n]
descending !Nat n
n = forall (p :: Nat). Nat p -> (p <= n) -> [Fin n]
go Nat n
n forall (a :: Nat). a <= a
Lte.reflexive
  where
    go :: Nat p -> (p <= n) -> [Fin n]
    go :: forall (p :: Nat). Nat p -> (p <= n) -> [Fin n]
go !Nat p
m !p <= n
pLteEn = case forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
Nat.monus Nat p
m Nat 1
Nat.one of
      Maybe (Difference p 1)
Nothing -> []
      Just (Difference (Nat c
mpred :: Nat c) (c + 1) :=: p
cPlusOneEqP) ->
        let !cLtEn :: c < n
cLtEn = forall (a :: Nat) (b :: Nat) (c :: Nat).
((a + 1) :=: b) -> (b <= c) -> a < c
descendLemma (c + 1) :=: p
cPlusOneEqP p <= n
pLteEn
        in forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin Nat c
mpred c < n
cLtEn forall a. a -> [a] -> [a]
: forall (p :: Nat). Nat p -> (p <= n) -> [Fin n]
go Nat c
mpred (forall (a :: Nat) (b :: Nat). (a < b) -> a <= b
Lte.fromStrict c < n
cLtEn)

-- | Generate 'len' values starting from 'off' in ascending order.
--
-- >>> ascendingSlice (Nat.constant @2) (Nat.constant @3) (Lte.constant @_ @6)
-- [Fin 2,Fin 3,Fin 4]
ascendingSlice
  :: forall n off len
  .  Nat off
  -> Nat len
  -> off + len <= n
  -> [Fin n]
{-# inline ascendingSlice #-}
ascendingSlice :: forall (n :: Nat) (off :: Nat) (len :: Nat).
Nat off -> Nat len -> ((off + len) <= n) -> [Fin n]
ascendingSlice Nat off
off Nat len
len !(off + len) <= n
offPlusLenLteEn = forall (m :: Nat). Nat m -> [Fin n]
go Nat 0
Nat.zero
  where
    go :: Nat m -> [Fin n]
    go :: forall (m :: Nat). Nat m -> [Fin n]
go !Nat m
m = case Nat m
m forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Maybe (a < b)
<? Nat len
len of
      Maybe (m < len)
Nothing -> []
      Just m < len
emLtLen ->
        let !offPlusEmLtOffPlusLen :: (off + m) < (off + len)
offPlusEmLtOffPlusLen = forall (c :: Nat) (a :: Nat) (b :: Nat).
(a < b) -> (c + a) < (c + b)
Lt.incrementL @off m < len
emLtLen
            !offPlusEmLtEn :: (off + m) < n
offPlusEmLtEn = forall (a :: Nat) (b :: Nat) (c :: Nat).
(a < b) -> (b <= c) -> a < c
Lt.transitiveNonstrictR (off + m) < (off + len)
offPlusEmLtOffPlusLen (off + len) <= n
offPlusLenLteEn
         in forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin (forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
Nat.plus Nat off
off Nat m
m) (off + m) < n
offPlusEmLtEn forall a. a -> [a] -> [a]
: forall (m :: Nat). Nat m -> [Fin n]
go (forall (a :: Nat). Nat a -> Nat (a + 1)
Nat.succ Nat m
m)

-- | Generate 'len' values starting from 'off + len - 1' in descending order.
--
-- >>> descendingSlice (Nat.constant @2) (Nat.constant @3) (Lt.constant @6)
-- [Fin 4,Fin 3,Fin 2]
descendingSlice
  :: forall n off len
  .  Nat off
  -> Nat len
  -> off + len <= n
  -> [Fin n]
{-# inline descendingSlice #-}
descendingSlice :: forall (n :: Nat) (off :: Nat) (len :: Nat).
Nat off -> Nat len -> ((off + len) <= n) -> [Fin n]
descendingSlice !Nat off
off !Nat len
len !(off + len) <= n
offPlusLenLteEn =
  forall (m :: Nat). Nat m -> (m <= len) -> [Fin n]
go Nat len
len forall (a :: Nat). a <= a
Lte.reflexive
  where
    go :: Nat m -> m <= len -> [Fin n]
    go :: forall (m :: Nat). Nat m -> (m <= len) -> [Fin n]
go !Nat m
m !m <= len
mLteEn = case forall (a :: Nat) (b :: Nat).
Nat a -> Nat b -> Maybe (Difference a b)
Nat.monus Nat m
m Nat 1
Nat.one of
      Maybe (Difference m 1)
Nothing -> []
      Just (Difference (Nat c
mpred :: Nat c) (c + 1) :=: m
cPlusOneEqEm) ->
        let !cLtLen :: c < len
cLtLen = forall (a :: Nat) (b :: Nat) (c :: Nat).
(a < b) -> (b <= c) -> a < c
Lt.transitiveNonstrictR
              (forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (a < b) -> a < c
Lt.substituteR (forall (a :: Nat) (b :: Nat). (a + b) :=: (b + a)
Plus.commutative @1 @c) (forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a < b) -> (c <= d) -> (a + c) < (b + d)
Lt.plus 0 < 1
Lt.zero forall (a :: Nat). a <= a
Lte.reflexive))
              -- c < c + 1
              (forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (b <= a) -> c <= a
Lte.substituteL (forall (m :: Nat) (n :: Nat). (m :=: n) -> n :=: m
Eq.symmetric (c + 1) :=: m
cPlusOneEqEm) m <= len
mLteEn)
              -- c + 1 <= len
            !cPlusOffLtEn :: (c + off) < n
cPlusOffLtEn = forall (a :: Nat) (b :: Nat) (c :: Nat).
(a < b) -> (b <= c) -> a < c
Lt.transitiveNonstrictR
              (forall (b :: Nat) (c :: Nat) (a :: Nat).
(b :=: c) -> (a < b) -> a < c
Lt.substituteR
                (forall (a :: Nat) (b :: Nat). (a + b) :=: (b + a)
Plus.commutative @len @off)
                (forall (a :: Nat) (b :: Nat) (c :: Nat) (d :: Nat).
(a < b) -> (c <= d) -> (a + c) < (b + d)
Lt.plus c < len
cLtLen (forall (a :: Nat). a <= a
Lte.reflexive @off)))
              -- c + off < off + len
              (off + len) <= n
offPlusLenLteEn
        in forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin (Nat c
mpred forall (a :: Nat) (b :: Nat). Nat a -> Nat b -> Nat (a + b)
`Nat.plus` Nat off
off) (c + off) < n
cPlusOffLtEn forall a. a -> [a] -> [a]
: forall (m :: Nat). Nat m -> (m <= len) -> [Fin n]
go Nat c
mpred (forall (a :: Nat) (b :: Nat). (a < b) -> a <= b
Lte.fromStrict c < len
cLtLen)

-- | Extract the 'Int' from a 'Fin n'. This is intended to be used
-- at a boundary where a safe interface meets the unsafe primitives
-- on top of which it is built.
demote :: Fin n -> Int
{-# inline demote #-}
demote :: forall (n :: Nat). Fin n -> Int
demote (Fin Nat m
i m < n
_) = forall (n :: Nat). Nat n -> Int
Nat.demote Nat m
i

lift :: Unsafe.Fin# n -> Fin n
{-# inline lift #-}
lift :: forall (n :: Nat). Fin# n -> Fin n
lift (Unsafe.Fin# Int#
i) = forall (m :: Nat) (n :: Nat). Nat m -> (m < n) -> Fin n
Fin (forall (n :: Nat). Int -> Nat n
Unsafe.Nat (Int# -> Int
I# Int#
i)) forall (a :: Nat) (b :: Nat). a < b
Unsafe.Lt

unlift :: Fin n -> Unsafe.Fin# n
{-# inline unlift #-}
unlift :: forall (n :: Nat). Fin n -> Fin# n
unlift (Fin (Unsafe.Nat (I# Int#
i)) m < n
_) = forall (n :: Nat). Int# -> Fin# n
Unsafe.Fin# Int#
i

-- | Consume the natural number and the proof in the Fin.
with :: Fin n -> (forall i. (i < n) -> Nat i -> a) -> a
{-# inline with #-}
with :: forall (n :: Nat) a.
Fin n -> (forall (i :: Nat). (i < n) -> Nat i -> a) -> a
with (Fin Nat m
i m < n
lt) forall (i :: Nat). (i < n) -> Nat i -> a
f = forall (i :: Nat). (i < n) -> Nat i -> a
f m < n
lt Nat m
i

-- | Variant of 'with' for unboxed argument and result types.
with# :: Fin# n -> (forall i. (i < n) -> Nat# i -> a) -> a
{-# inline with# #-}
with# :: forall (n :: Nat) a.
Fin# n -> (forall (i :: Nat). (i < n) -> Nat# i -> a) -> a
with# (Unsafe.Fin# Int#
i) forall (i :: Nat). (i < n) -> Nat# i -> a
f = forall (i :: Nat). (i < n) -> Nat# i -> a
f forall (a :: Nat) (b :: Nat). a < b
Unsafe.Lt (forall (n :: Nat). Int# -> Nat# n
Unsafe.Nat# Int#
i)

construct# :: (i < n) -> Nat# i -> Fin# n
{-# inline construct# #-}
construct# :: forall (i :: Nat) (n :: Nat). (i < n) -> Nat# i -> Fin# n
construct# i < n
_ (Unsafe.Nat# Int#
x) = forall (n :: Nat). Int# -> Fin# n
Unsafe.Fin# Int#
x