-- |
-- Module:      Math.NumberTheory.Moduli.Class
-- Copyright:   (c) 2017 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- Safe modular arithmetic with modulo on type level.
--

{-# LANGUAGE BangPatterns               #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures             #-}
{-# LANGUAGE LambdaCase                 #-}
{-# LANGUAGE MagicHash                  #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE StandaloneDeriving         #-}
{-# LANGUAGE UnboxedTuples              #-}

module Math.NumberTheory.Moduli.Class
  ( -- * Known modulo
    Mod
  , getVal
  , getNatVal
  , getMod
  , getNatMod
  , invertMod
  , powMod
  , (^%)
  -- * Multiplicative group
  , MultMod
  , multElement
  , isMultElement
  , invertGroup
  -- * Unknown modulo
  , SomeMod(..)
  , modulo
  , invertSomeMod
  , powSomeMod
  -- * Re-exported from GHC.TypeNats.Compat
  , KnownNat
  ) where

import Data.Proxy
import Data.Ratio
import Data.Semigroup
import Data.Type.Equality
import GHC.Exts
import GHC.Integer.GMP.Internals
import GHC.Natural (Natural(..), powModNatural)
import GHC.TypeNats.Compat

-- | Wrapper for residues modulo @m@.
--
-- @Mod 3 :: Mod 10@ stands for the class of integers, congruent to 3 modulo 10 (…−17, −7, 3, 13, 23…).
-- The modulo is stored on type level, so it is impossible, for example, to add up by mistake
-- residues with different moduli.
--
-- >>> :set -XDataKinds
-- >>> (3 :: Mod 10) + (4 :: Mod 12)
-- error: Couldn't match type ‘12’ with ‘10’...
-- >>> (3 :: Mod 10) + 8
-- (1 `modulo` 10)
--
-- Note that modulo cannot be negative.
newtype Mod (m :: Nat) = Mod Natural
  deriving (Eq, Ord, Enum)

instance KnownNat m => Show (Mod m) where
  show m = "(" ++ show (getVal m) ++ " `modulo` " ++ show (getMod m) ++ ")"

instance KnownNat m => Bounded (Mod m) where
  minBound = Mod 0
  maxBound = let mx = Mod (getNatMod mx - 1) in mx

instance KnownNat m => Num (Mod m) where
  mx@(Mod x) + Mod y =
    Mod $ if xy >= m then xy - m else xy
    where
      xy = x + y
      m = getNatMod mx
  {-# INLINE (+) #-}
  mx@(Mod x) - Mod y =
    Mod $ if x >= y then x - y else m + x - y
    where
      m = getNatMod mx
  {-# INLINE (-) #-}
  negate mx@(Mod x) =
    Mod $ if x == 0 then 0 else getNatMod mx - x
  {-# INLINE negate #-}

  -- If modulo is small and fits into one machine word,
  -- there is no need to use long arithmetic at all
  -- and we can save some allocations.
  mx@(Mod (NatS# x#)) * (Mod (NatS# y#)) = case getNatMod mx of
    NatS# m# -> let !(# z1#, z2# #) = timesWord2# x# y# in
                let !(# _, r# #) = quotRemWord2# z1# z2# m# in
                Mod (NatS# r#)
    NatJ# b# -> let !(# z1#, z2# #) = timesWord2# x# y# in
                let r# = wordToBigNat2 z1# z2# `remBigNat` b# in
                Mod $ if isTrue# (sizeofBigNat# r# ==# 1#)
                  then NatS# (bigNatToWord r#)
                  else NatJ# r#

  mx@(Mod !x) * (Mod !y) =
    Mod $ x * y `rem` getNatMod mx
    -- `rem` is slightly faster than `mod`
  {-# INLINE (*) #-}

  abs = id
  {-# INLINE abs #-}
  signum = const $ Mod 1
  {-# INLINE signum #-}
  fromInteger x = mx
    where
      mx = Mod $ fromInteger $ x `mod` getMod mx
  {-# INLINE fromInteger #-}

-- | Beware that division by residue, which is not coprime with the modulo,
-- will result in runtime error. Consider using 'invertMod' instead.
instance KnownNat m => Fractional (Mod m) where
  fromRational r = case denominator r of
    1   -> num
    den -> num / fromInteger den
    where
      num = fromInteger (numerator r)
  {-# INLINE fromRational #-}
  recip mx = case invertMod mx of
    Nothing -> error $ "recip{Mod}: residue is not coprime with modulo"
    Just y  -> y
  {-# INLINE recip #-}

-- | Linking type and value levels: extract modulo @m@ as a value.
getMod :: KnownNat m => Mod m -> Integer
getMod = toInteger . natVal
{-# INLINE getMod #-}

-- | Linking type and value levels: extract modulo @m@ as a value.
getNatMod :: KnownNat m => Mod m -> Natural
getNatMod = natVal
{-# INLINE getNatMod #-}

-- | The canonical representative of the residue class, always between 0 and @m-1@ inclusively.
getVal :: Mod m -> Integer
getVal (Mod x) = toInteger x
{-# INLINE getVal #-}

-- | The canonical representative of the residue class, always between 0 and @m-1@ inclusively.
getNatVal :: Mod m -> Natural
getNatVal (Mod x) = x
{-# INLINE getNatVal #-}

-- | Computes the modular inverse, if the residue is coprime with the modulo.
--
-- >>> :set -XDataKinds
-- >>> invertMod (3 :: Mod 10)
-- Just (7 `modulo` 10) -- because 3 * 7 = 1 :: Mod 10
-- >>> invertMod (4 :: Mod 10)
-- Nothing
invertMod :: KnownNat m => Mod m -> Maybe (Mod m)
invertMod mx
  = if y <= 0
    then Nothing
    else Just $ Mod $ fromInteger y
  where
    -- first argument of recipModInteger is guaranteed to be positive
    y = recipModInteger (getVal mx) (getMod mx)
{-# INLINABLE invertMod #-}

-- | Drop-in replacement for 'Prelude.^', with much better performance.
--
-- >>> :set -XDataKinds
-- >>> powMod (3 :: Mod 10) 4
-- (1 `modulo` 10)
powMod :: (KnownNat m, Integral a) => Mod m -> a -> Mod m
powMod mx a
  | a < 0     = error $ "^{Mod}: negative exponent"
  | otherwise = Mod $ powModNatural (getNatVal mx) (fromIntegral a) (getNatMod mx)
{-# INLINABLE [1] powMod #-}

{-# SPECIALISE [1] powMod ::
  KnownNat m => Mod m -> Integer -> Mod m,
  KnownNat m => Mod m -> Natural -> Mod m,
  KnownNat m => Mod m -> Int     -> Mod m,
  KnownNat m => Mod m -> Word    -> Mod m #-}

{-# RULES
"powMod/2/Integer"     forall x. powMod x (2 :: Integer) = let u = x in u*u
"powMod/3/Integer"     forall x. powMod x (3 :: Integer) = let u = x in u*u*u
"powMod/2/Int"         forall x. powMod x (2 :: Int)     = let u = x in u*u
"powMod/3/Int"         forall x. powMod x (3 :: Int)     = let u = x in u*u*u
"powMod/2/Word"        forall x. powMod x (2 :: Word)    = let u = x in u*u
"powMod/3/Word"        forall x. powMod x (3 :: Word)    = let u = x in u*u*u
#-}

-- | Infix synonym of 'powMod'.
(^%) :: (KnownNat m, Integral a) => Mod m -> a -> Mod m
(^%) = powMod
{-# INLINE (^%) #-}

infixr 8 ^%

-- Unfortunately, such rule never fires due to technical details
-- of type classes in Core.
-- {-# RULES "^%Mod" forall (x :: KnownNat m => Mod m) p. x ^ p = x ^% p #-}

-- | This type represents elements of the multiplicative group mod m, i.e.
-- those elements which are coprime to m. Use @toMultElement@ to construct.
newtype MultMod m = MultMod {
  multElement :: Mod m -- ^ Unwrap a residue.
  } deriving (Eq, Ord, Show)

instance KnownNat m => Semigroup (MultMod m) where
  MultMod a <> MultMod b = MultMod (a * b)
  stimes k a@(MultMod a')
    | k >= 0 = MultMod (powMod a' k)
    | otherwise = invertGroup $ stimes (-k) a
  -- ^ This Semigroup is in fact a group, so @stimes@ can be called with a negative first argument.

instance KnownNat m => Monoid (MultMod m) where
  mempty = MultMod 1
  mappend = (<>)

instance KnownNat m => Bounded (MultMod m) where
  minBound = MultMod 1
  maxBound = MultMod (-1)

-- | Attempt to construct a multiplicative group element.
isMultElement :: KnownNat m => Mod m -> Maybe (MultMod m)
isMultElement a = if getNatVal a `gcd` getNatMod a == 1
                     then Just $ MultMod a
                     else Nothing

-- | For elements of the multiplicative group, we can safely perform the inverse
-- without needing to worry about failure.
invertGroup :: KnownNat m => MultMod m -> MultMod m
invertGroup (MultMod a) = case invertMod a of
                            Just b -> MultMod b
                            Nothing -> error "Math.NumberTheory.Moduli.invertGroup: failed to invert element"

-- | This type represents residues with unknown modulo and rational numbers.
-- One can freely combine them in arithmetic expressions, but each operation
-- will spend time on modulo's recalculation:
--
-- >>> 2 `modulo` 10 + 4 `modulo` 15
-- (1 `modulo` 5)
-- >>> (2 `modulo` 10) * (4 `modulo` 15)
-- (3 `modulo` 5)
-- >>> 2 `modulo` 10 + fromRational (3 % 7)
-- (1 `modulo` 10)
-- >>> 2 `modulo` 10 * fromRational (3 % 7)
-- (8 `modulo` 10)
--
-- If performance is crucial, it is recommended to extract @Mod m@ for further processing
-- by pattern matching. E. g.,
--
-- > case modulo n m of
-- >   SomeMod k -> process k -- Here k has type Mod m
-- >   InfMod{}  -> error "impossible"
data SomeMod where
  SomeMod :: KnownNat m => Mod m -> SomeMod
  InfMod  :: Rational -> SomeMod

instance Eq SomeMod where
  SomeMod mx == SomeMod my = getMod mx == getMod my && getVal mx == getVal my
  InfMod rx  == InfMod ry  = rx == ry
  _          == _          = False

instance Show SomeMod where
  show = \case
    SomeMod m -> show m
    InfMod  r -> show r

-- | Create modular value by representative of residue class and modulo.
-- One can use the result either directly (via functions from 'Num' and 'Fractional'),
-- or deconstruct it by pattern matching. Note that 'modulo' never returns 'InfMod'.
modulo :: Integer -> Natural -> SomeMod
modulo n m = case someNatVal m of
  SomeNat (_ :: Proxy t) -> SomeMod (fromInteger n :: Mod t)
{-# INLINABLE modulo #-}
infixl 7 `modulo`

liftUnOp
  :: (forall k. KnownNat k => Mod k -> Mod k)
  -> (Rational -> Rational)
  -> SomeMod
  -> SomeMod
liftUnOp fm fr = \case
  SomeMod m -> SomeMod (fm m)
  InfMod  r -> InfMod  (fr r)
{-# INLINEABLE liftUnOp #-}

liftBinOpMod
  :: (KnownNat m, KnownNat n)
  => (forall k. KnownNat k => Mod k -> Mod k -> Mod k)
  -> Mod m
  -> Mod n
  -> SomeMod
liftBinOpMod f mx@(Mod x) my@(Mod y) = case someNatVal m of
  SomeNat (_ :: Proxy t) -> SomeMod (Mod (x `mod` m) `f` Mod (y `mod` m) :: Mod t)
  where
    m = natVal mx `gcd` natVal my

liftBinOp
  :: (forall k. KnownNat k => Mod k -> Mod k -> Mod k)
  -> (Rational -> Rational -> Rational)
  -> SomeMod
  -> SomeMod
  -> SomeMod
liftBinOp _ fr (InfMod rx)  (InfMod ry)  = InfMod  (rx `fr` ry)
liftBinOp fm _ (InfMod rx)  (SomeMod my) = SomeMod (fromRational rx `fm` my)
liftBinOp fm _ (SomeMod mx) (InfMod ry)  = SomeMod (mx `fm` fromRational ry)
liftBinOp fm _ (SomeMod (mx :: Mod m)) (SomeMod (my :: Mod n))
  = case (Proxy :: Proxy m) `sameNat` (Proxy :: Proxy n) of
    Nothing   -> liftBinOpMod fm mx my
    Just Refl -> SomeMod (mx `fm` my)

instance Num SomeMod where
  (+)    = liftBinOp (+) (+)
  (-)    = liftBinOp (-) (+)
  negate = liftUnOp negate negate
  {-# INLINE negate #-}
  (*)    = liftBinOp (*) (*)
  abs    = id
  {-# INLINE abs #-}
  signum = const 1
  {-# INLINE signum #-}
  fromInteger = InfMod . fromInteger
  {-# INLINE fromInteger #-}

-- | Beware that division by residue, which is not coprime with the modulo,
-- will result in runtime error. Consider using 'invertSomeMod' instead.
instance Fractional SomeMod where
  fromRational = InfMod
  {-# INLINE fromRational #-}
  recip x = case invertSomeMod x of
    Nothing -> error $ "recip{SomeMod}: residue is not coprime with modulo"
    Just y  -> y

-- | Computes the inverse value, if it exists.
--
-- >>> invertSomeMod (3 `modulo` 10)
-- Just (7 `modulo` 10) -- because 3 * 7 = 1 :: Mod 10
-- >>> invertSomeMod (4 `modulo` 10)
-- Nothing
-- >>> invertSomeMod (fromRational (2 % 5))
-- Just 5 % 2
invertSomeMod :: SomeMod -> Maybe SomeMod
invertSomeMod = \case
  SomeMod m -> fmap SomeMod (invertMod m)
  InfMod  r -> Just (InfMod (recip r))
{-# INLINABLE [1] invertSomeMod #-}

{-# SPECIALISE [1] powSomeMod ::
  SomeMod -> Integer -> SomeMod,
  SomeMod -> Natural -> SomeMod,
  SomeMod -> Int     -> SomeMod,
  SomeMod -> Word    -> SomeMod #-}

-- | Drop-in replacement for 'Prelude.^', with much better performance.
-- When -O is enabled, there is a rewrite rule, which specialises 'Prelude.^' to 'powSomeMod'.
--
-- >>> powSomeMod (3 `modulo` 10) 4
-- (1 `modulo` 10)
powSomeMod :: Integral a => SomeMod -> a -> SomeMod
powSomeMod (SomeMod m) a = SomeMod (m ^% a)
powSomeMod (InfMod  r) a = InfMod  (r ^  a)
{-# INLINABLE [1] powSomeMod #-}

{-# RULES "^%SomeMod" forall x p. x ^ p = powSomeMod x p #-}