{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE UnboxedTuples #-}
module Math.NumberTheory.Moduli.Class
(
Mod
, getVal
, getNatVal
, getMod
, getNatMod
, invertMod
, powMod
, (^%)
, MultMod
, multElement
, isMultElement
, invertGroup
, SomeMod(..)
, modulo
, invertSomeMod
, powSomeMod
, KnownNat
) where
import Data.Proxy
import Data.Ratio
import Data.Semigroup
import Data.Type.Equality
import GHC.Exts
import GHC.Integer.GMP.Internals
import GHC.Natural (Natural(..), powModNatural)
import GHC.TypeNats.Compat
newtype Mod (m :: Nat) = Mod Natural
deriving (Eq, Ord, Enum)
instance KnownNat m => Show (Mod m) where
show m = "(" ++ show (getVal m) ++ " `modulo` " ++ show (getMod m) ++ ")"
instance KnownNat m => Bounded (Mod m) where
minBound = Mod 0
maxBound = let mx = Mod (getNatMod mx - 1) in mx
instance KnownNat m => Num (Mod m) where
mx@(Mod x) + Mod y =
Mod $ if xy >= m then xy - m else xy
where
xy = x + y
m = getNatMod mx
{-# INLINE (+) #-}
mx@(Mod x) - Mod y =
Mod $ if x >= y then x - y else m + x - y
where
m = getNatMod mx
{-# INLINE (-) #-}
negate mx@(Mod x) =
Mod $ if x == 0 then 0 else getNatMod mx - x
{-# INLINE negate #-}
mx@(Mod (NatS# x#)) * (Mod (NatS# y#)) = case getNatMod mx of
NatS# m# -> let !(# z1#, z2# #) = timesWord2# x# y# in
let !(# _, r# #) = quotRemWord2# z1# z2# m# in
Mod (NatS# r#)
NatJ# b# -> let !(# z1#, z2# #) = timesWord2# x# y# in
let r# = wordToBigNat2 z1# z2# `remBigNat` b# in
Mod $ if isTrue# (sizeofBigNat# r# ==# 1#)
then NatS# (bigNatToWord r#)
else NatJ# r#
mx@(Mod !x) * (Mod !y) =
Mod $ x * y `rem` getNatMod mx
{-# INLINE (*) #-}
abs = id
{-# INLINE abs #-}
signum = const $ Mod 1
{-# INLINE signum #-}
fromInteger x = mx
where
mx = Mod $ fromInteger $ x `mod` getMod mx
{-# INLINE fromInteger #-}
instance KnownNat m => Fractional (Mod m) where
fromRational r = case denominator r of
1 -> num
den -> num / fromInteger den
where
num = fromInteger (numerator r)
{-# INLINE fromRational #-}
recip mx = case invertMod mx of
Nothing -> error $ "recip{Mod}: residue is not coprime with modulo"
Just y -> y
{-# INLINE recip #-}
getMod :: KnownNat m => Mod m -> Integer
getMod = toInteger . natVal
{-# INLINE getMod #-}
getNatMod :: KnownNat m => Mod m -> Natural
getNatMod = natVal
{-# INLINE getNatMod #-}
getVal :: Mod m -> Integer
getVal (Mod x) = toInteger x
{-# INLINE getVal #-}
getNatVal :: Mod m -> Natural
getNatVal (Mod x) = x
{-# INLINE getNatVal #-}
invertMod :: KnownNat m => Mod m -> Maybe (Mod m)
invertMod mx
= if y <= 0
then Nothing
else Just $ Mod $ fromInteger y
where
y = recipModInteger (getVal mx) (getMod mx)
{-# INLINABLE invertMod #-}
powMod :: (KnownNat m, Integral a) => Mod m -> a -> Mod m
powMod mx a
| a < 0 = error $ "^{Mod}: negative exponent"
| otherwise = Mod $ powModNatural (getNatVal mx) (fromIntegral a) (getNatMod mx)
{-# INLINABLE [1] powMod #-}
{-# SPECIALISE [1] powMod ::
KnownNat m => Mod m -> Integer -> Mod m,
KnownNat m => Mod m -> Natural -> Mod m,
KnownNat m => Mod m -> Int -> Mod m,
KnownNat m => Mod m -> Word -> Mod m #-}
{-# RULES
"powMod/2/Integer" forall x. powMod x (2 :: Integer) = let u = x in u*u
"powMod/3/Integer" forall x. powMod x (3 :: Integer) = let u = x in u*u*u
"powMod/2/Int" forall x. powMod x (2 :: Int) = let u = x in u*u
"powMod/3/Int" forall x. powMod x (3 :: Int) = let u = x in u*u*u
"powMod/2/Word" forall x. powMod x (2 :: Word) = let u = x in u*u
"powMod/3/Word" forall x. powMod x (3 :: Word) = let u = x in u*u*u
#-}
(^%) :: (KnownNat m, Integral a) => Mod m -> a -> Mod m
(^%) = powMod
{-# INLINE (^%) #-}
infixr 8 ^%
newtype MultMod m = MultMod {
multElement :: Mod m
} deriving (Eq, Ord, Show)
instance KnownNat m => Semigroup (MultMod m) where
MultMod a <> MultMod b = MultMod (a * b)
stimes k a@(MultMod a')
| k >= 0 = MultMod (powMod a' k)
| otherwise = invertGroup $ stimes (-k) a
instance KnownNat m => Monoid (MultMod m) where
mempty = MultMod 1
mappend = (<>)
instance KnownNat m => Bounded (MultMod m) where
minBound = MultMod 1
maxBound = MultMod (-1)
isMultElement :: KnownNat m => Mod m -> Maybe (MultMod m)
isMultElement a = if getNatVal a `gcd` getNatMod a == 1
then Just $ MultMod a
else Nothing
invertGroup :: KnownNat m => MultMod m -> MultMod m
invertGroup (MultMod a) = case invertMod a of
Just b -> MultMod b
Nothing -> error "Math.NumberTheory.Moduli.invertGroup: failed to invert element"
data SomeMod where
SomeMod :: KnownNat m => Mod m -> SomeMod
InfMod :: Rational -> SomeMod
instance Eq SomeMod where
SomeMod mx == SomeMod my = getMod mx == getMod my && getVal mx == getVal my
InfMod rx == InfMod ry = rx == ry
_ == _ = False
instance Show SomeMod where
show = \case
SomeMod m -> show m
InfMod r -> show r
modulo :: Integer -> Natural -> SomeMod
modulo n m = case someNatVal m of
SomeNat (_ :: Proxy t) -> SomeMod (fromInteger n :: Mod t)
{-# INLINABLE modulo #-}
infixl 7 `modulo`
liftUnOp
:: (forall k. KnownNat k => Mod k -> Mod k)
-> (Rational -> Rational)
-> SomeMod
-> SomeMod
liftUnOp fm fr = \case
SomeMod m -> SomeMod (fm m)
InfMod r -> InfMod (fr r)
{-# INLINEABLE liftUnOp #-}
liftBinOpMod
:: (KnownNat m, KnownNat n)
=> (forall k. KnownNat k => Mod k -> Mod k -> Mod k)
-> Mod m
-> Mod n
-> SomeMod
liftBinOpMod f mx@(Mod x) my@(Mod y) = case someNatVal m of
SomeNat (_ :: Proxy t) -> SomeMod (Mod (x `mod` m) `f` Mod (y `mod` m) :: Mod t)
where
m = natVal mx `gcd` natVal my
liftBinOp
:: (forall k. KnownNat k => Mod k -> Mod k -> Mod k)
-> (Rational -> Rational -> Rational)
-> SomeMod
-> SomeMod
-> SomeMod
liftBinOp _ fr (InfMod rx) (InfMod ry) = InfMod (rx `fr` ry)
liftBinOp fm _ (InfMod rx) (SomeMod my) = SomeMod (fromRational rx `fm` my)
liftBinOp fm _ (SomeMod mx) (InfMod ry) = SomeMod (mx `fm` fromRational ry)
liftBinOp fm _ (SomeMod (mx :: Mod m)) (SomeMod (my :: Mod n))
= case (Proxy :: Proxy m) `sameNat` (Proxy :: Proxy n) of
Nothing -> liftBinOpMod fm mx my
Just Refl -> SomeMod (mx `fm` my)
instance Num SomeMod where
(+) = liftBinOp (+) (+)
(-) = liftBinOp (-) (+)
negate = liftUnOp negate negate
{-# INLINE negate #-}
(*) = liftBinOp (*) (*)
abs = id
{-# INLINE abs #-}
signum = const 1
{-# INLINE signum #-}
fromInteger = InfMod . fromInteger
{-# INLINE fromInteger #-}
instance Fractional SomeMod where
fromRational = InfMod
{-# INLINE fromRational #-}
recip x = case invertSomeMod x of
Nothing -> error $ "recip{SomeMod}: residue is not coprime with modulo"
Just y -> y
invertSomeMod :: SomeMod -> Maybe SomeMod
invertSomeMod = \case
SomeMod m -> fmap SomeMod (invertMod m)
InfMod r -> Just (InfMod (recip r))
{-# INLINABLE [1] invertSomeMod #-}
{-# SPECIALISE [1] powSomeMod ::
SomeMod -> Integer -> SomeMod,
SomeMod -> Natural -> SomeMod,
SomeMod -> Int -> SomeMod,
SomeMod -> Word -> SomeMod #-}
powSomeMod :: Integral a => SomeMod -> a -> SomeMod
powSomeMod (SomeMod m) a = SomeMod (m ^% a)
powSomeMod (InfMod r) a = InfMod (r ^ a)
{-# INLINABLE [1] powSomeMod #-}
{-# RULES "^%SomeMod" forall x p. x ^ p = powSomeMod x p #-}