{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeApplications #-}
#if MIN_VERSION_base(4,9,0)
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
#endif
module Data.Parameterized.NatRepr
( NatRepr
, natValue
, knownNat
, withKnownNat
, IsZeroNat(..)
, isZeroNat
, NatComparison(..)
, compareNat
, decNat
, predNat
, incNat
, addNat
, subNat
, divNat
, halfNat
, withDivModNat
, natMultiply
, someNat
, maxNat
, natRec
, natForEach
, NatCases(..)
, testNatCases
, widthVal
, minUnsigned
, maxUnsigned
, minSigned
, maxSigned
, toUnsigned
, toSigned
, unsignedClamp
, signedClamp
, LeqProof(..)
, testLeq
, testStrictLeq
, leqRefl
, leqTrans
, leqAdd2
, leqSub2
, leqMulCongr
, leqProof
, withLeqProof
, isPosNat
, leqAdd
, leqSub
, leqMulPos
, leqAddPos
, addIsLeq
, withAddLeq
, addPrefixIsLeq
, withAddPrefixLeq
, addIsLeqLeft1
, dblPosIsPos
, leqMulMono
, plusComm
, mulComm
, plusMinusCancel
, minusPlusCancel
, addMulDistribRight
, withAddMulDistribRight
, withSubMulDistribRight
, mulCancelR
, mul2Plus
, type (+)
, type (-)
, type (*)
, type (<=)
, Equality.TestEquality(..)
, (Equality.:~:)(..)
, Data.Parameterized.Some.Some
) where
import Data.Bits ((.&.))
import Data.Hashable
import Data.Proxy as Proxy
import Data.Type.Equality as Equality
import GHC.TypeLits as TypeLits
import Unsafe.Coerce
import Data.Parameterized.Classes
import Data.Parameterized.Some
maxInt :: Integer
maxInt = toInteger (maxBound :: Int)
newtype NatRepr (n::Nat) = NatRepr { natValue :: Integer
}
deriving (Hashable)
widthVal :: NatRepr n -> Int
widthVal (NatRepr i) | i < maxInt = fromInteger i
| otherwise = error "Width is too large."
instance Eq (NatRepr m) where
_ == _ = True
instance TestEquality NatRepr where
testEquality (NatRepr m) (NatRepr n)
| m == n = Just (unsafeCoerce Refl)
| otherwise = Nothing
data NatComparison m n where
NatLT :: x+1 <= x+(y+1) => !(NatRepr y) -> NatComparison x (x+(y+1))
NatEQ :: NatComparison x x
NatGT :: x+1 <= x+(y+1) => !(NatRepr y) -> NatComparison (x+(y+1)) x
compareNat :: NatRepr m -> NatRepr n -> NatComparison m n
compareNat m n =
case compare (natValue m) (natValue n) of
LT -> unsafeCoerce (NatLT @0 @0) (NatRepr (natValue n - natValue m - 1))
EQ -> unsafeCoerce NatEQ
GT -> unsafeCoerce (NatGT @0 @0) (NatRepr (natValue m - natValue n - 1))
instance OrdF NatRepr where
compareF x y =
case compareNat x y of
NatLT _ -> LTF
NatEQ -> EQF
NatGT _ -> GTF
instance PolyEq (NatRepr m) (NatRepr n) where
polyEqF x y = fmap (\Refl -> Refl) $ testEquality x y
instance Show (NatRepr n) where
show (NatRepr n) = show n
instance ShowF NatRepr
instance HashableF NatRepr where
hashWithSaltF = hashWithSalt
knownNat :: forall n . KnownNat n => NatRepr n
knownNat = NatRepr (natVal (Proxy :: Proxy n))
instance (KnownNat n) => KnownRepr NatRepr n where
knownRepr = knownNat
{-# DEPRECATED withKnownNat "This function is potentially unsafe and is schedueled to be removed." #-}
withKnownNat :: forall n r. NatRepr n -> (KnownNat n => r) -> r
withKnownNat (NatRepr nVal) v =
case someNatVal nVal of
Just (SomeNat (Proxy :: Proxy n')) ->
case unsafeCoerce (Refl :: 0 :~: 0) :: n :~: n' of
Refl -> v
Nothing -> error "withKnownNat: inner value in NatRepr is not a natural"
data IsZeroNat n where
ZeroNat :: IsZeroNat 0
NonZeroNat :: IsZeroNat (n+1)
isZeroNat :: NatRepr n -> IsZeroNat n
isZeroNat (NatRepr 0) = unsafeCoerce ZeroNat
isZeroNat (NatRepr _) = unsafeCoerce NonZeroNat
decNat :: (1 <= n) => NatRepr n -> NatRepr (n-1)
decNat (NatRepr i) = NatRepr (i-1)
predNat :: NatRepr (n+1) -> NatRepr n
predNat (NatRepr i) = NatRepr (i-1)
incNat :: NatRepr n -> NatRepr (n+1)
incNat (NatRepr x) = NatRepr (x+1)
halfNat :: NatRepr (n+n) -> NatRepr n
halfNat (NatRepr x) = NatRepr (x `div` 2)
addNat :: NatRepr m -> NatRepr n -> NatRepr (m+n)
addNat (NatRepr m) (NatRepr n) = NatRepr (m+n)
subNat :: (n <= m) => NatRepr m -> NatRepr n -> NatRepr (m-n)
subNat (NatRepr m) (NatRepr n) = NatRepr (m-n)
divNat :: (1 <= n) => NatRepr (m * n) -> NatRepr n -> NatRepr m
divNat (NatRepr x) (NatRepr y) = NatRepr (div x y)
withDivModNat :: forall n m a.
NatRepr n
-> NatRepr m
-> (forall div mod. (n ~ ((div * m) + mod)) =>
NatRepr div -> NatRepr mod -> a)
-> a
withDivModNat n m f =
case ( Some (NatRepr divPart), Some (NatRepr modPart)) of
( Some (divn :: NatRepr div), Some (modn :: NatRepr mod) )
-> case unsafeCoerce (Refl :: 0 :~: 0) of
(Refl :: (n :~: ((div * m) + mod))) -> f divn modn
where
(divPart, modPart) = divMod (natValue n) (natValue m)
natMultiply :: NatRepr n -> NatRepr m -> NatRepr (n * m)
natMultiply (NatRepr n) (NatRepr m) = NatRepr (n * m)
minUnsigned :: NatRepr w -> Integer
minUnsigned _ = 0
maxUnsigned :: NatRepr w -> Integer
maxUnsigned w = 2^(natValue w) - 1
minSigned :: (1 <= w) => NatRepr w -> Integer
minSigned w = negate (2^(natValue w - 1))
maxSigned :: (1 <= w) => NatRepr w -> Integer
maxSigned w = 2^(natValue w - 1) - 1
toUnsigned :: NatRepr w -> Integer -> Integer
toUnsigned w i = maxUnsigned w .&. i
toSigned :: (1 <= w) => NatRepr w -> Integer -> Integer
toSigned w i0
| i > maxSigned w = i - 2^(natValue w)
| otherwise = i
where i = i0 .&. maxUnsigned w
unsignedClamp :: NatRepr w -> Integer -> Integer
unsignedClamp w i
| i < minUnsigned w = minUnsigned w
| i > maxUnsigned w = maxUnsigned w
| otherwise = i
signedClamp :: (1 <= w) => NatRepr w -> Integer -> Integer
signedClamp w i
| i < minSigned w = minSigned w
| i > maxSigned w = maxSigned w
| otherwise = i
someNat :: Integer -> Maybe (Some NatRepr)
someNat n | 0 <= n && n <= toInteger maxInt = Just (Some (NatRepr (fromInteger n)))
| otherwise = Nothing
maxNat :: NatRepr m -> NatRepr n -> Some NatRepr
maxNat x y
| natValue x >= natValue y = Some x
| otherwise = Some y
plusComm :: forall f m g n . f m -> g n -> m+n :~: n+m
plusComm _ _ = unsafeCoerce (Refl :: m+n :~: m+n)
mulComm :: forall f m g n. f m -> g n -> (m * n) :~: (n * m)
mulComm _ _ = unsafeCoerce Refl
mul2Plus :: forall f n. f n -> (n + n) :~: (2 * n)
mul2Plus n = case addMulDistribRight (Proxy @1) (Proxy @1) n of
Refl -> Refl
plusMinusCancel :: forall f m g n . f m -> g n -> (m + n) - n :~: m
plusMinusCancel _ _ = unsafeCoerce (Refl :: m :~: m)
minusPlusCancel :: forall f m g n . (n <= m) => f m -> g n -> (m - n) + n :~: m
minusPlusCancel _ _ = unsafeCoerce (Refl :: m :~: m)
addMulDistribRight :: forall n m p f g h. f n -> g m -> h p
-> ((n * p) + (m * p)) :~: ((n + m) * p)
addMulDistribRight _n _m _p = unsafeCoerce Refl
withAddMulDistribRight :: forall n m p f g h a. f n -> g m -> h p
-> ( (((n * p) + (m * p)) ~ ((n + m) * p)) => a) -> a
withAddMulDistribRight n m p f =
case addMulDistribRight n m p of
Refl -> f
withSubMulDistribRight :: forall n m p f g h a. (m <= n) => f n -> g m -> h p
-> ( (((n * p) - (m * p)) ~ ((n - m) * p)) => a) -> a
withSubMulDistribRight _n _m _p f =
case unsafeCoerce (Refl :: 0 :~: 0) of
(Refl :: (((n * p) - (m * p)) :~: ((n - m) * p)) ) -> f
data LeqProof m n where
LeqProof :: (m <= n) => LeqProof m n
testStrictLeq :: forall m n
. (m <= n)
=> NatRepr m
-> NatRepr n
-> Either (LeqProof (m+1) n) (m :~: n)
testStrictLeq (NatRepr m) (NatRepr n)
| m < n = Left (unsafeCoerce (LeqProof :: LeqProof 0 0))
| otherwise = Right (unsafeCoerce (Refl :: m :~: m))
{-# NOINLINE testStrictLeq #-}
data NatCases m n where
NatCaseLT :: LeqProof (m+1) n -> NatCases m n
NatCaseEQ :: NatCases m m
NatCaseGT :: LeqProof (n+1) m -> NatCases m n
testNatCases :: forall m n
. NatRepr m
-> NatRepr n
-> NatCases m n
testNatCases m n =
case compare (natValue m) (natValue n) of
LT -> NatCaseLT (unsafeCoerce (LeqProof :: LeqProof 0 0))
EQ -> unsafeCoerce $ (NatCaseEQ :: NatCases m m)
GT -> NatCaseGT (unsafeCoerce (LeqProof :: LeqProof 0 0))
{-# NOINLINE testNatCases #-}
testLeq :: forall m n . NatRepr m -> NatRepr n -> Maybe (LeqProof m n)
testLeq (NatRepr m) (NatRepr n)
| m <= n = Just (unsafeCoerce (LeqProof :: LeqProof 0 0))
| otherwise = Nothing
{-# NOINLINE testLeq #-}
leqRefl :: forall f n . f n -> LeqProof n n
leqRefl _ = LeqProof
leqTrans :: LeqProof m n -> LeqProof n p -> LeqProof m p
leqTrans LeqProof LeqProof = unsafeCoerce (LeqProof :: LeqProof 0 0)
{-# NOINLINE leqTrans #-}
leqAdd2 :: LeqProof x_l x_h -> LeqProof y_l y_h -> LeqProof (x_l + y_l) (x_h + y_h)
leqAdd2 x y = seq x $ seq y $ unsafeCoerce (LeqProof :: LeqProof 0 0)
{-# NOINLINE leqAdd2 #-}
leqSub2 :: LeqProof x_l x_h
-> LeqProof y_l y_h
-> LeqProof (x_l-y_h) (x_h-y_l)
leqSub2 LeqProof LeqProof = unsafeCoerce (LeqProof :: LeqProof 0 0)
{-# NOINLINE leqSub2 #-}
leqProof :: (m <= n) => f m -> g n -> LeqProof m n
leqProof _ _ = LeqProof
withLeqProof :: LeqProof m n -> ((m <= n) => a) -> a
withLeqProof p a =
case p of
LeqProof -> a
isPosNat :: NatRepr n -> Maybe (LeqProof 1 n)
isPosNat = testLeq (knownNat :: NatRepr 1)
leqMulCongr :: LeqProof a x
-> LeqProof b y
-> LeqProof (a*b) (x*y)
leqMulCongr LeqProof LeqProof = unsafeCoerce (LeqProof :: LeqProof 1 1)
{-# NOINLINE leqMulCongr #-}
leqMulPos :: forall p q x y
. (1 <= x, 1 <= y)
=> p x
-> q y
-> LeqProof 1 (x*y)
leqMulPos _ _ = leqMulCongr (LeqProof :: LeqProof 1 x) (LeqProof :: LeqProof 1 y)
leqMulMono :: (1 <= x) => p x -> q y -> LeqProof y (x * y)
leqMulMono x y = leqMulCongr (leqProof (Proxy :: Proxy 1) x) (leqRefl y)
leqAdd :: forall f m n p . LeqProof m n -> f p -> LeqProof m (n+p)
leqAdd x _ = leqAdd2 x (LeqProof :: LeqProof 0 p)
leqAddPos :: (1 <= m, 1 <= n) => p m -> q n -> LeqProof 1 (m + n)
leqAddPos m n = leqAdd (leqProof (Proxy :: Proxy 1) m) n
leqSub :: forall m n p . LeqProof m n -> LeqProof p m -> LeqProof (m-p) n
leqSub x _ = leqSub2 x (LeqProof :: LeqProof 0 p)
addIsLeq :: f n -> g m -> LeqProof n (n + m)
addIsLeq n m = leqAdd (leqRefl n) m
addPrefixIsLeq :: f m -> g n -> LeqProof n (m + n)
addPrefixIsLeq m n =
case plusComm n m of
Refl -> addIsLeq n m
dblPosIsPos :: forall n . LeqProof 1 n -> LeqProof 1 (n+n)
dblPosIsPos x = leqAdd x Proxy
addIsLeqLeft1 :: forall n n' m . LeqProof (n + n') m -> LeqProof n m
addIsLeqLeft1 p =
case plusMinusCancel n n' of
Refl -> leqSub p le
where n :: Proxy n
n = Proxy
n' :: Proxy n'
n' = Proxy
le :: LeqProof n' (n + n')
le = addPrefixIsLeq n n'
{-# INLINE withAddPrefixLeq #-}
withAddPrefixLeq :: NatRepr n -> NatRepr m -> ((m <= n + m) => a) -> a
withAddPrefixLeq n m = withLeqProof (addPrefixIsLeq n m)
withAddLeq :: forall n m a. NatRepr n -> NatRepr m -> ((n <= n + m) => NatRepr (n + m) -> a) -> a
withAddLeq n m f = withLeqProof (addIsLeq n m) (f (addNat n m))
natForEach' :: forall l h a
. NatRepr l
-> NatRepr h
-> (forall n. LeqProof l n -> LeqProof n h -> NatRepr n -> a)
-> [a]
natForEach' l h f
| Just LeqProof <- testLeq l h =
let f' :: forall n. LeqProof (l + 1) n -> LeqProof n h -> NatRepr n -> a
f' = \lp hp -> f (addIsLeqLeft1 lp) hp
in f LeqProof LeqProof l : natForEach' (incNat l) h f'
| otherwise = []
natForEach :: forall l h a
. NatRepr l
-> NatRepr h
-> (forall n. (l <= n, n <= h) => NatRepr n -> a)
-> [a]
natForEach l h f = natForEach' l h (\LeqProof LeqProof -> f)
natRec :: forall m f
. NatRepr m
-> f 0
-> (forall n. NatRepr n -> f n -> f (n + 1))
-> f m
natRec n f0 ih = go n
where
go :: forall n'. NatRepr n' -> f n'
go n' = case isZeroNat n' of
ZeroNat -> f0
NonZeroNat -> let n'' = predNat n' in ih n'' (go n'')
mulCancelR ::
(1 <= c, (n1 * c) ~ (n2 * c)) => f1 n1 -> f2 n2 -> f3 c -> (n1 :~: n2)
mulCancelR _ _ _ = unsafeCoerce Refl