{-# language BangPatterns #-} {-# language DataKinds #-} {-# language ExplicitNamespaces #-} {-# language GADTs #-} {-# language KindSignatures #-} {-# language MagicHash #-} {-# language RankNTypes #-} {-# language ScopedTypeVariables #-} {-# language TypeApplications #-} {-# language TypeOperators #-} module Arithmetic.Fin ( -- * Modification incrementL , incrementR , weaken , weakenL , weakenR -- * Traverse -- | These use the terms @ascend@ and @descend@ rather than the -- more popular @l@ (left) and @r@ (right) that pervade the Haskell -- ecosystem. The general rule is that ascending functions pair -- the initial accumulator with zero with descending functions -- pair the initial accumulator with the last index. , ascend , ascend' , ascendFrom' , ascendFrom'# , ascendM , ascendM# , ascendM_ , ascendM_# , descend , descend' , descendM , descendM_ , ascending , descending , ascendingSlice , descendingSlice -- * Absurdities , absurd -- * Demote , demote -- * Deconstruct , with , with# -- * Construct , construct# -- * Lift and Unlift , lift , unlift ) where import Prelude hiding (last) import Arithmetic.Nat (( Fin n -> Fin (n + m) {-# inline incrementR #-} incrementR m (Fin i pf) = Fin (Nat.plus i m) (Lt.incrementR @m pf) -- | Raise the index by @m@ and weaken the bound by @m@, adding -- @m@ to the left-hand side of @n@. incrementL :: forall n m. Nat m -> Fin n -> Fin (m + n) {-# inline incrementL #-} incrementL m (Fin i pf) = Fin (Nat.plus m i) (Lt.incrementL @m pf) -- | Weaken the bound by @m@, adding it to the left-hand side of -- the existing bound. This does not change the index. weakenL :: forall n m. Fin n -> Fin (m + n) {-# inline weakenL #-} weakenL (Fin i pf) = Fin i ( Lt.substituteR (Plus.commutative @n @m) (Lt.plus pf (Lte.zero @m)) ) -- | Weaken the bound by @m@, adding it to the right-hand side of -- the existing bound. This does not change the index. weakenR :: forall n m. Fin n -> Fin (n + m) {-# inline weakenR #-} weakenR (Fin i pf) = Fin i (Lt.plus pf Lte.zero) -- | Weaken the bound, replacing it by another number greater than -- or equal to itself. This does not change the index. weaken :: forall n m. (n <= m) -> Fin n -> Fin m {-# inline weaken #-} weaken lt (Fin i pf) = Fin i (Lt.transitiveNonstrictR pf lt) -- | A finite set of no values is impossible. absurd :: Fin 0 -> void {-# inline absurd #-} absurd (Fin _ pf) = Lt.absurd pf -- | Fold over the numbers bounded by @n@ in descending -- order. This is lazy in the accumulator. For convenince, -- this differs from @foldr@ in the order of the parameters. -- -- > descend 4 z f = f 0 (f 1 (f 2 (f 3 z))) descend :: forall a n. Nat n -- ^ Upper bound -> a -- ^ Initial accumulator -> (Fin n -> a -> a) -- ^ Update accumulator -> a {-# inline descend #-} descend !n b0 f = go Nat.zero where go :: Nat m -> a go !m = case m b0 Just lt -> f (Fin m lt) (go (Nat.succ m)) -- | Fold over the numbers bounded by @n@ in descending -- order. This is strict in the accumulator. For convenince, -- this differs from @foldr'@ in the order of the parameters. -- -- > descend 4 z f = f 0 (f 1 (f 2 (f 3 z))) descend' :: forall a n. Nat n -- ^ Upper bound -> a -- ^ Initial accumulator -> (Fin n -> a -> a) -- ^ Update accumulator -> a {-# inline descend' #-} descend' !n !b0 f = go n Lte.reflexive b0 where go :: Nat p -> p <= n -> a -> a go !m pLteEn !b = case Nat.monus m Nat.one of Nothing -> b Just (Difference (mpred :: Nat c) cPlusOneEqP) -> let !cLtEn = descendLemma cPlusOneEqP pLteEn in go mpred (Lte.fromStrict cLtEn) (f (Fin mpred cLtEn) b) -- | Fold over the numbers bounded by @n@ in ascending order. This -- is lazy in the accumulator. -- -- > ascend 4 z f = f 3 (f 2 (f 1 (f 0 z))) ascend :: forall a n. Nat n -> a -> (Fin n -> a -> a) -> a {-# inline ascend #-} ascend !n !b0 f = go n Lte.reflexive where go :: Nat p -> (p <= n) -> a go !m pLteEn = case Nat.monus m Nat.one of Nothing -> b0 Just (Difference (mpred :: Nat c) cPlusOneEqP) -> let !cLtEn = descendLemma cPlusOneEqP pLteEn in f (Fin mpred cLtEn) (go mpred (Lte.fromStrict cLtEn)) -- | Strict fold over the numbers bounded by @n@ in ascending -- order. For convenince, this differs from @foldl'@ in the -- order of the parameters. -- -- > ascend' 4 z f = f 3 (f 2 (f 1 (f 0 z))) ascend' :: forall a n. Nat n -- ^ Upper bound -> a -- ^ Initial accumulator -> (Fin n -> a -> a) -- ^ Update accumulator -> a {-# inline ascend' #-} ascend' !n !b0 f = go Nat.zero b0 where go :: Nat m -> a -> a go !m !b = case m b Just lt -> go (Nat.succ m) (f (Fin m lt) b) -- | Generalization of @ascend'@ that lets the caller pick the starting index: -- -- > ascend' === ascendFrom' 0 ascendFrom' :: forall a m n. Nat m -- ^ Index to start at -> Nat n -- ^ Number of steps to take -> a -- ^ Initial accumulator -> (Fin (m + n) -> a -> a) -- ^ Update accumulator -> a {-# inline ascendFrom' #-} ascendFrom' !m0 !n !b0 f = go m0 b0 where end = Nat.plus m0 n go :: Nat k -> a -> a go !m !b = case m b Just lt -> go (Nat.succ m) (f (Fin m lt) b) -- | Variant of @ascendFrom'@ with unboxed arguments. ascendFrom'# :: forall a m n. Nat# m -- ^ Index to start at -> Nat# n -- ^ Number of steps to take -> a -- ^ Initial accumulator -> (Fin# (m + n) -> a -> a) -- ^ Update accumulator -> a {-# inline ascendFrom'# #-} ascendFrom'# !m0 !n !b0 f = ascendFrom' (Nat.lift m0) (Nat.lift n) b0 (\ix -> f (unlift ix)) -- | Strict monadic left fold over the numbers bounded by @n@ -- in ascending order. Roughly: -- -- > ascendM 4 z0 f = -- > f 0 z0 >>= \z1 -> -- > f 1 z1 >>= \z2 -> -- > f 2 z2 >>= \z3 -> -- > f 3 z3 ascendM :: forall m a n. Monad m => Nat n -- ^ Upper bound -> a -- ^ Initial accumulator -> (Fin n -> a -> m a) -- ^ Update accumulator -> m a {-# inline ascendM #-} ascendM !n !b0 f = go Nat.zero b0 where go :: Nat p -> a -> m a go !m !b = case m pure b Just lt -> go (Nat.succ m) =<< f (Fin m lt) b -- | Variant of @ascendM@ that takes an unboxed Nat and provides -- an unboxed Fin to the callback. ascendM# :: forall m a n. Monad m => Nat# n -- ^ Upper bound -> a -- ^ Initial accumulator -> (Fin# n -> a -> m a) -- ^ Update accumulator -> m a {-# inline ascendM# #-} ascendM# n !a0 f = ascendM (Nat.lift n) a0 (\ix a -> f (unlift ix) a) -- | Monadic traversal of the numbers bounded by @n@ -- in ascending order. -- -- > ascendM_ 4 f = f 0 *> f 1 *> f 2 *> f 3 ascendM_ :: forall m a n. Applicative m => Nat n -- ^ Upper bound -> (Fin n -> m a) -- ^ Effectful interpretion -> m () {-# inline ascendM_ #-} ascendM_ !n f = go Nat.zero where go :: Nat p -> m () go !m = case m pure () Just lt -> f (Fin m lt) *> go (Nat.succ m) -- | Variant of @ascendM_@ that takes an unboxed Nat and provides -- an unboxed Fin to the callback. ascendM_# :: forall m a n. Monad m => Nat# n -- ^ Upper bound -> (Fin# n -> m a) -- ^ Update accumulator -> m () {-# inline ascendM_# #-} ascendM_# n f = ascendM_ (Nat.lift n) (\ix -> f (unlift ix)) descendLemma :: forall a b c. a + 1 :=: b -> b <= c -> a < c {-# inline descendLemma #-} descendLemma !aPlusOneEqB !bLteC = id $ Lt.transitiveNonstrictR (Lt.substituteR (Plus.commutative @1 @a) (Lt.plus Lt.zero Lte.reflexive)) $ Lte.substituteL (Eq.symmetric aPlusOneEqB) bLteC -- | Strict monadic left fold over the numbers bounded by @n@ -- in descending order. Roughly: -- -- > descendM 4 z f = -- > f 3 z0 >>= \z1 -> -- > f 2 z1 >>= \z2 -> -- > f 1 z2 >>= \z3 -> -- > f 0 z3 descendM :: forall m a n. Monad m => Nat n -> a -> (Fin n -> a -> m a) -> m a {-# inline descendM #-} descendM !n !b0 f = go n Lte.reflexive b0 where go :: Nat p -> p <= n -> a -> m a go !m pLteEn !b = case Nat.monus m Nat.one of Nothing -> pure b Just (Difference (mpred :: Nat c) cPlusOneEqP) -> let !cLtEn = descendLemma cPlusOneEqP pLteEn in go mpred (Lte.fromStrict cLtEn) =<< f (Fin mpred cLtEn) b -- | Monadic traversal of the numbers bounded by @n@ -- in descending order. -- -- > descendM_ 4 f = f 3 *> f 2 *> f 1 *> f 0 descendM_ :: forall m a n. Applicative m => Nat n -- ^ Upper bound -> (Fin n -> m a) -- ^ Effectful interpretion -> m () {-# inline descendM_ #-} descendM_ !n f = go n Lte.reflexive where go :: Nat p -> p <= n -> m () go !m !pLteEn = case Nat.monus m Nat.one of Nothing -> pure () Just (Difference (mpred :: Nat c) cPlusOneEqP) -> let !cLtEn = descendLemma cPlusOneEqP pLteEn in f (Fin mpred cLtEn) *> go mpred (Lte.fromStrict cLtEn) -- | Generate all values of a finite set in ascending order. -- -- >>> ascending (Nat.constant @3) -- [Fin 0,Fin 1,Fin 2] ascending :: forall n. Nat n -> [Fin n] ascending !n = go Nat.zero where go :: Nat m -> [Fin n] go !m = case m [] Just lt -> Fin m lt : go (Nat.succ m) -- | Generate all values of a finite set in descending order. -- -- >>> descending (Nat.constant @3) -- [Fin 2,Fin 1,Fin 0] descending :: forall n. Nat n -> [Fin n] descending !n = go n Lte.reflexive where go :: Nat p -> (p <= n) -> [Fin n] go !m !pLteEn = case Nat.monus m Nat.one of Nothing -> [] Just (Difference (mpred :: Nat c) cPlusOneEqP) -> let !cLtEn = descendLemma cPlusOneEqP pLteEn in Fin mpred cLtEn : go mpred (Lte.fromStrict cLtEn) -- | Generate 'len' values starting from 'off' in ascending order. -- -- >>> ascendingSlice (Nat.constant @2) (Nat.constant @3) (Lte.constant @_ @6) -- [Fin 2,Fin 3,Fin 4] ascendingSlice :: forall n off len . Nat off -> Nat len -> off + len <= n -> [Fin n] {-# inline ascendingSlice #-} ascendingSlice off len !offPlusLenLteEn = go Nat.zero where go :: Nat m -> [Fin n] go !m = case m [] Just emLtLen -> let !offPlusEmLtOffPlusLen = Lt.incrementL @off emLtLen !offPlusEmLtEn = Lt.transitiveNonstrictR offPlusEmLtOffPlusLen offPlusLenLteEn in Fin (Nat.plus off m) offPlusEmLtEn : go (Nat.succ m) -- | Generate 'len' values starting from 'off + len - 1' in descending order. -- -- >>> descendingSlice (Nat.constant @2) (Nat.constant @3) (Lt.constant @6) -- [Fin 4,Fin 3,Fin 2] descendingSlice :: forall n off len . Nat off -> Nat len -> off + len <= n -> [Fin n] {-# inline descendingSlice #-} descendingSlice !off !len !offPlusLenLteEn = go len Lte.reflexive where go :: Nat m -> m <= len -> [Fin n] go !m !mLteEn = case Nat.monus m Nat.one of Nothing -> [] Just (Difference (mpred :: Nat c) cPlusOneEqEm) -> let !cLtLen = Lt.transitiveNonstrictR (Lt.substituteR (Plus.commutative @1 @c) (Lt.plus Lt.zero Lte.reflexive)) -- c < c + 1 (Lte.substituteL (Eq.symmetric cPlusOneEqEm) mLteEn) -- c + 1 <= len !cPlusOffLtEn = Lt.transitiveNonstrictR (Lt.substituteR (Plus.commutative @len @off) (Lt.plus cLtLen (Lte.reflexive @off))) -- c + off < off + len offPlusLenLteEn in Fin (mpred `Nat.plus` off) cPlusOffLtEn : go mpred (Lte.fromStrict cLtLen) -- | Extract the 'Int' from a 'Fin n'. This is intended to be used -- at a boundary where a safe interface meets the unsafe primitives -- on top of which it is built. demote :: Fin n -> Int {-# inline demote #-} demote (Fin i _) = Nat.demote i lift :: Unsafe.Fin# n -> Fin n {-# inline lift #-} lift (Unsafe.Fin# i) = Fin (Unsafe.Nat (I# i)) Unsafe.Lt unlift :: Fin n -> Unsafe.Fin# n {-# inline unlift #-} unlift (Fin (Unsafe.Nat (I# i)) _) = Unsafe.Fin# i -- | Consume the natural number and the proof in the Fin. with :: Fin n -> (forall i. (i < n) -> Nat i -> a) -> a {-# inline with #-} with (Fin i lt) f = f lt i -- | Variant of 'with' for unboxed argument and result types. with# :: Fin# n -> (forall i. (i < n) -> Nat# i -> a) -> a {-# inline with# #-} with# (Unsafe.Fin# i) f = f Unsafe.Lt (Unsafe.Nat# i) construct# :: (i < n) -> Nat# i -> Fin# n {-# inline construct# #-} construct# _ (Unsafe.Nat# x) = Unsafe.Fin# x