{-# LANGUAGE
    TypeApplications
  , ScopedTypeVariables
  , LambdaCase
  , NumericUnderscores
#-}

module Atrophy.LongDivision
  ( module X
  , module Atrophy.LongDivision
  )
  where

import Data.Word
import Atrophy.Internal.LongDivision as X
import Atrophy.Internal
import qualified Data.Primitive.Contiguous as Contiguous
import Data.Primitive.Contiguous (PrimArray, Mutable, Sliced)
import Control.Monad.ST.Strict (ST)
import Data.STRef.Strict (newSTRef, readSTRef, writeSTRef)
import Data.Bits

{-# NOINLINE longDivision #-}
longDivision :: forall s. Sliced PrimArray Word64 -> StrengthReducedW64 -> Mutable PrimArray s Word64 -> ST s ()
longDivision :: Sliced PrimArray Word64
-> StrengthReducedW64 -> Mutable PrimArray s Word64 -> ST s ()
longDivision Sliced PrimArray Word64
numeratorSlice StrengthReducedW64
reducedDivisor Mutable PrimArray s Word64
quotient = do
  STRef s Word64
remainder <- Word64 -> ST s (STRef s Word64)
forall a s. a -> ST s (STRef s a)
newSTRef Word64
0
  (((Int -> Word64 -> ST s ()) -> Slice PrimArray Word64 -> ST s ())
-> Slice PrimArray Word64 -> (Int -> Word64 -> ST s ()) -> ST s ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Int -> Word64 -> ST s ()) -> Slice PrimArray Word64 -> ST s ()
forall (arr :: * -> *) a (f :: * -> *) b.
(Contiguous arr, Element arr a, Applicative f) =>
(Int -> a -> f b) -> arr a -> f ()
Contiguous.itraverse_) Slice PrimArray Word64
Sliced PrimArray Word64
numeratorSlice ((Int -> Word64 -> ST s ()) -> ST s ())
-> (Int -> Word64 -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i Word64
numerator -> do
    STRef s Word64 -> ST s Word64
forall s a. STRef s a -> ST s a
readSTRef STRef s Word64
remainder ST s Word64 -> (Word64 -> ST s ()) -> ST s ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Word64
0 -> do
        -- The remainder is zero, which means we can take a shortcut and only do a single division!
        let (Word64
digitQuotient, Word64
digitRemainder) = Word64 -> StrengthReducedW64 -> (Word64, Word64)
forall strRed a b.
(HasField "divisor" strRed a, HasField "multiplier" strRed b,
 Integral a, FiniteBits a, Integral b, FiniteBits (Half b),
 Bits b) =>
a -> strRed -> (a, a)
divRem Word64
numerator StrengthReducedW64
reducedDivisor

        Mutable PrimArray (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> b -> m ()
Contiguous.write Mutable PrimArray s Word64
Mutable PrimArray (PrimState (ST s)) Word64
quotient Int
i Word64
digitQuotient
        STRef s Word64 -> Word64 -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Word64
remainder Word64
digitRemainder

      Word64
remainder' -> do
        -- Do one division that includes the running remainder and the upper half of this numerator element,
        -- then a second division for the first division's remainder combinedwith the lower half
        let upperNumerator :: Word64
upperNumerator = (Word64
remainder' Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
numerator Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
32)
        let (Word64
upperQuotient, Word64
upperRemainder) = Word64 -> StrengthReducedW64 -> (Word64, Word64)
forall strRed a b.
(HasField "divisor" strRed a, HasField "multiplier" strRed b,
 Integral a, FiniteBits a, Integral b, FiniteBits (Half b),
 Bits b) =>
a -> strRed -> (a, a)
divRem Word64
upperNumerator StrengthReducedW64
reducedDivisor

        let lowerNumerator :: Word64
lowerNumerator = (Word64
upperRemainder Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word64
0x00000000_ffffffff Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.&. Word64
numerator)
        let (Word64
lowerQuotient, Word64
lowerRemainder) = Word64 -> StrengthReducedW64 -> (Word64, Word64)
forall strRed a b.
(HasField "divisor" strRed a, HasField "multiplier" strRed b,
 Integral a, FiniteBits a, Integral b, FiniteBits (Half b),
 Bits b) =>
a -> strRed -> (a, a)
divRem Word64
lowerNumerator StrengthReducedW64
reducedDivisor

        Mutable PrimArray (PrimState (ST s)) Word64
-> Int -> Word64 -> ST s ()
forall (arr :: * -> *) (m :: * -> *) b.
(Contiguous arr, PrimMonad m, Element arr b) =>
Mutable arr (PrimState m) b -> Int -> b -> m ()
Contiguous.write Mutable PrimArray s Word64
Mutable PrimArray (PrimState (ST s)) Word64
quotient Int
i (Word64 -> ST s ()) -> Word64 -> ST s ()
forall a b. (a -> b) -> a -> b
$ (Word64
upperQuotient Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. Word64
lowerQuotient
        STRef s Word64 -> Word64 -> ST s ()
forall s a. STRef s a -> a -> ST s ()
writeSTRef STRef s Word64
remainder Word64
lowerRemainder