{-# LANGUAGE
    TypeApplications
  , ScopedTypeVariables
  , LambdaCase
  , NumericUnderscores
#-}

module Atrophy.Internal.LongDivision where

import Data.WideWord.Word128
import Data.Word
import Data.Bits

-- divides a 128-bit number by a 64-bit divisor, returning the quotient as a 64-bit number
-- assumes that the divisor and numerator have both already been bit-shifted so that countLeadingZeros divisor == 0
{-# INLINE divide128By64Preshifted #-}
divide128By64Preshifted :: Word64 -> Word64 -> Word64 -> Word64
divide128By64Preshifted :: Word64 -> Word64 -> Word64 -> Word64
divide128By64Preshifted Word64
numeratorHi Word64
numeratorLo' Word64
divisor =
  let
    numeratorMid :: Word128
numeratorMid = Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word64 @Word128 (Word64
numeratorLo' Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
32)
    numeratorLo :: Word128
numeratorLo = Word32 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word32 @Word128 (Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word64 @Word32 Word64
numeratorLo')
    divisorFull128 :: Word128
divisorFull128 = Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral @Word64 @Word128 Word64
divisor
    divisorHi :: Word64
divisorHi = Word64
divisor Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
32

    -- To get the upper 32 bits of the quotient, we want to divide 'fullUpperNumerator' by 'divisor'
    -- but the problem is, fullUpperNumerator is a 96-bit number, meaning we would need to use u128 to do the division all at once, and the whole point of this is that we don't want to do 128 bit divison because it's slow
    -- so instead, we'll shift both the numerator and divisor right by 32, giving us a 64 bit / 32 bit division. This won't give us the exact quotient -- but it will be close.
    fullUpperNumerator :: Word128
fullUpperNumerator = (Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
numeratorHi Word128 -> Int -> Word128
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word128 -> Word128 -> Word128
forall a. Bits a => a -> a -> a
.|. Word128
numeratorMid
    quotientHi :: Word64
    quotientHi :: Word64
quotientHi = Word64 -> Word64 -> Word64
forall a. Ord a => a -> a -> a
min (Word64
numeratorHi Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
divisorHi) (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Word64) -> Word32 -> Word64
forall a b. (a -> b) -> a -> b
$ Bounded Word32 => Word32
forall a. Bounded a => a
maxBound @Word32)
    productHi :: Word128
productHi = (Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
quotientHi) Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word128
divisorFull128

  -- quotientHi contains our guess at what the quotient is! the problem is that we got this by ignoring the lower 32 bits of the divisor. when we account for that, the quotient might be slightly lower
  -- we will know our quotient is too high if quotient * divisor > numerator. if it is, decrement until it's in range
    (Word128
productHi', Word64
quotientHi') = Word128 -> Word64 -> Word128 -> Word128 -> (Word128, Word64)
clampToFull Word128
productHi Word64
quotientHi Word128
divisorFull128 Word128
fullUpperNumerator

    remainderHi :: Word128
remainderHi = Word128
fullUpperNumerator Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
- Word128
productHi'

  -- repeat the process using the lower half of the numerator
    fullLowerNumerator :: Word128
fullLowerNumerator = (Word128
remainderHi Word128 -> Int -> Word128
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word128 -> Word128 -> Word128
forall a. Bits a => a -> a -> a
.|. Word128
numeratorLo

    quotientLo :: Word64
quotientLo = Word64 -> Word64 -> Word64
forall a. Ord a => a -> a -> a
min ((Word128 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral @_ @Word64 Word128
remainderHi) Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
divisorHi) (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> Word64) -> Word32 -> Word64
forall a b. (a -> b) -> a -> b
$ Bounded Word32 => Word32
forall a. Bounded a => a
maxBound @Word32)
    productLo :: Word128
productLo = (Word64 -> Word64 -> Word128
Word128 Word64
0 Word64
quotientLo) Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
* Word128
divisorFull128

  -- again, quotientLo is just a guess at this point, it might be slightly too large
    (Word128
_, Word64
quotientLo') = Word128 -> Word64 -> Word128 -> Word128 -> (Word128, Word64)
clampToFull Word128
productLo Word64
quotientLo Word128
divisorFull128 Word128
fullLowerNumerator

  -- We now have our separate quotients, now we just have to add them together
  in (Word64
quotientHi' Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. Word64
quotientLo'

divide128MaxBy64 :: Word64 -> Word128
divide128MaxBy64 :: Word64 -> Word128
divide128MaxBy64 Word64
divisor =
  let
    quotientHi :: Word64
quotientHi = Bounded Word64 => Word64
forall a. Bounded a => a
maxBound @Word64 Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
divisor;
    remainderHi :: Word64
remainderHi = Bounded Word64 => Word64
forall a. Bounded a => a
maxBound @Word64 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
quotientHi Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
divisor;

    leadingZeros :: Int
leadingZeros = Word64 -> Int
forall b. FiniteBits b => b -> Int
countLeadingZeros Word64
divisor
    quotientLo :: Word64
quotientLo = if Int
leadingZeros Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
32
      then
        let
          numeratorMid :: Word64
numeratorMid = (Word64
remainderHi Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Bounded Word32 => Word32
forall a. Bounded a => a
maxBound @Word32))
          quotientMid :: Word64
quotientMid = Word64
numeratorMid Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
divisor;
          remainderMid :: Word64
remainderMid = Word64
numeratorMid Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
- Word64
quotientMid Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
divisor;

          numeratorLo :: Word64
numeratorLo = (Word64
remainderMid Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Bounded Word32 => Word32
forall a. Bounded a => a
maxBound @Word32))
          quotientLo' :: Word64
quotientLo' = Word64
numeratorLo Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
divisor

        in (Word64
quotientMid Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
32) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. Word64
quotientLo'
      else
        let
          numeratorHi :: Word64
numeratorHi = if Int
leadingZeros Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
            then (Word64
remainderHi Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
leadingZeros) Word64 -> Word64 -> Word64
forall a. Bits a => a -> a -> a
.|. (Bounded Word64 => Word64
forall a. Bounded a => a
maxBound @Word64 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (Int
64 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
leadingZeros))
            else Word64
remainderHi
          numeratorLo :: Word64
numeratorLo = Bounded Word64 => Word64
forall a. Bounded a => a
maxBound @Word64 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
leadingZeros;
        in Word64 -> Word64 -> Word64 -> Word64
divide128By64Preshifted Word64
numeratorHi Word64
numeratorLo (Word64
divisor Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
leadingZeros)
  in ((Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
quotientHi) Word128 -> Int -> Word128
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
64) Word128 -> Word128 -> Word128
forall a. Bits a => a -> a -> a
.|. (Word64 -> Word128
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
quotientLo)

clampToFull :: Word128 -> Word64 -> Word128 -> Word128 -> (Word128, Word64)
clampToFull :: Word128 -> Word64 -> Word128 -> Word128 -> (Word128, Word64)
clampToFull Word128
product' Word64
quotient' Word128
divisorFull128 Word128
fullUpperNumerator = Word128 -> Word64 -> (Word128, Word64)
forall b. Num b => Word128 -> b -> (Word128, b)
go Word128
product' Word64
quotient'
  where
  go :: Word128 -> b -> (Word128, b)
go Word128
prod b
quotient =
    if Word128
prod Word128 -> Word128 -> Bool
forall a. Ord a => a -> a -> Bool
> Word128
fullUpperNumerator
    then Word128 -> b -> (Word128, b)
go (Word128
prod Word128 -> Word128 -> Word128
forall a. Num a => a -> a -> a
- Word128
divisorFull128) (b
quotient b -> b -> b
forall a. Num a => a -> a -> a
- b
1)
    else (Word128
prod, b
quotient)