module Data.Digest.WebMoney.Algebra where import Data.Bits (Bits, bitSize, shiftL, shiftR, testBit, (.&.), (.|.)) import Data.Int (Int32, Int64) import Data.Word (Word32, Word64) import Control.Lens (ix, (&), (.~)) import Data.Vector (Vector, singleton, (!)) import qualified Data.Vector as V (init, last, length, null, replicate, take, (++)) longMask :: Int64 longMask = 0xFFFFFFFF intSize :: Int intSize = 32 logicalShiftR :: Integral a => a -> Int -> a logicalShiftR x i = fromIntegral ((fromIntegral x :: Word64) `shiftR` i) logicalShiftRight :: Int32 -> Int -> Int32 logicalShiftRight x i = fromIntegral ((fromIntegral x :: Word32) `shiftR` i) getBitsNumber :: Bits a => a -> Int getBitsNumber x = intSize - numberOfLeadingZeros x numberOfLeadingZeros :: Bits a => a -> Int numberOfLeadingZeros x = length $ takeWhile (not . testBit x) [size - 1, size - 2 .. 0] where size = bitSize x getBitsCount :: (Bits a, Num a) => Vector a -> Int getBitsCount xs = ( vLenght - 1 ) * intSize + getBitsNumber ( xs ! (vLenght - 1) ) where vLenght = significance xs compareLists :: Vector Int32 -> Vector Int32 -> Ordering compareLists lhs rhs | lhsLenght > rhsLenght = GT | lhsLenght < rhsLenght = LT | otherwise = comp (V.take lhsLenght lhs) (V.take lhsLenght rhs) where lhsLenght = significance lhs rhsLenght = significance rhs comp :: Vector Int32 -> Vector Int32 -> Ordering comp ls rs | V.null ls || V.null rs = EQ | lb > rb = GT | lb < rb = LT | otherwise = comp (V.init ls) (V.init rs) where lb = fromIntegral (V.last ls) .&. longMask rb = fromIntegral (V.last rs) .&. longMask significance :: (Eq a, Bits a, Num a) => Vector a -> Int significance xs | V.null xs = 0 | V.last xs == 0 = significance ( V.init xs ) | otherwise = V.length xs shift :: Vector Int32 -> Int -> Vector Int32 shift lhs rhs | outWordsCount <= 0 = singleton 0 | shiftBits == 0 && rhs > 0 = V.take shiftWords r0 V.++ V.take (outWordsCount - shiftWords) lhs | rhs > 0 = let (res, carry) = foldl shRight (r0, 0) [0 .. inWordsCount - 1] in if inWordsCount - 1 + shiftWords < outWordsCount then res & ix ( inWordsCount + shiftWords ) .~ (res ! (inWordsCount + shiftWords) .|. carry) else res | shiftBits == 0 = error "3" | otherwise = let carry = if outWordsCount + shiftWords < inWordsCount then (lhs ! (outWordsCount + shiftWords)) `shiftL` ( intSize - shiftBits) else 0 in fst $ foldl shLeft (r0, carry) [inWordsCount - 1, inWordsCount - 2 .. 0] where shiftBits, shiftWords, inBitsCount, inWordsCount, outBitsCount, outWordsCount :: Int shiftBits = abs rhs `mod` intSize shiftWords = abs rhs `div` intSize inBitsCount = getBitsCount lhs inWordsCount = inBitsCount `div` intSize + (if inBitsCount `mod` intSize > 0 then 1 else 0) outBitsCount = inBitsCount + rhs outWordsCount = outBitsCount `div` intSize + (if outBitsCount `mod` intSize > 0 then 1 else 0) r0 = V.replicate (max inWordsCount outWordsCount) 0 shRight, shLeft :: (Vector Int32, Int32) -> Int -> (Vector Int32, Int32) shRight (res, carry) pos = ( res & ix ( pos + shiftWords ) .~ val, nextCarry ) where temp = lhs ! pos val = ( temp `shiftL` shiftBits ) .|. carry nextCarry = temp `logicalShiftRight` ( intSize - shiftBits ) shLeft (res, carry) pos = ( res & ix ( pos + shiftWords ) .~ val, nextCarry ) where temp = lhs ! (pos + shiftWords) val = (temp `logicalShiftRight` shiftBits) .|. carry nextCarry = temp `shiftL` ( intSize - shiftBits ) shiftRight :: Vector Int32 -> Vector Int32 shiftRight value = fst $ foldl right (value, 0) [len-1, len-2..0] where len = significance value right :: (Vector Int32, Int64) -> Int -> (Vector Int32, Int64) right (v, carry) pos = ( v & ix pos .~ fromIntegral val, nextCarry ) where temp, nextCarry, val :: Int64 temp = fromIntegral ( v ! pos) .&. longMask nextCarry = (temp .&. 1) `shiftL` ( intSize - 1) .&. longMask val = ((temp `logicalShiftR` 1) .|. carry ) .&. longMask sub :: Vector Int32 -> Vector Int32 -> Vector Int32 sub lhs rhs | lhsLength < rhsLength = error "Difference should not be negative." | otherwise = modulo $ rest subscribed where lhsLength = significance lhs rhsLength = significance rhs modulo :: (Vector Int32, Int32) -> Vector Int32 modulo (_, 1) = error "Difference should not be negative." modulo (l, _) = l subscribed :: (Vector Int32, Int32) subscribed = foldl substr (lhs, 0) [0..rhsLength - 1] where substr :: (Vector Int32, Int32) -> Int -> (Vector Int32, Int32) substr (l, borrow) pos = ( l & ix pos .~ fromIntegral temp, nBorrow ) where temp = (fromIntegral ( l ! pos ) .&. longMask ) - (fromIntegral ( rhs ! pos ) .&. longMask ) - fromIntegral borrow nBorrow = if temp .&. ( 1 `shiftL` intSize ) /= 0 then 1 else 0 rest :: (Vector Int32, Int32) -> (Vector Int32, Int32) rest (ls, b) = foldl substr (ls, b) [rhsLength..lhsLength - 1] where substr :: (Vector Int32, Int32) -> Int -> (Vector Int32, Int32) substr (l, borrow) pos = ( l & ix pos .~ fromIntegral temp, nBorrow ) where temp = (fromIntegral ( l ! pos ) .&. longMask ) - fromIntegral borrow nBorrow = if temp .&. ( 1 `shiftL` intSize ) /= 0 then 1 else 0 remainder :: Vector Int32 -> Vector Int32 -> Vector Int32 remainder lhs rhs = divide lhs rhs where rhsBitsCount = getBitsCount rhs -- check attemption to divide by zero divide :: Vector Int32 -> Vector Int32 -> Vector Int32 divide l r | LT == compareLists l r = l | lhsBitsCount == 0 = l | otherwise = let temp' = if compareLists l temp == LT then shiftRight temp else temp in divide ( subs l temp' ) r where lhsBitsCount = getBitsCount l temp = shift r (lhsBitsCount - rhsBitsCount) subs :: Vector Int32 -> Vector Int32 -> Vector Int32 subs l t = if compareLists l t /= LT then subs (sub l t) t else l resize :: Vector Int32 -> Int -> Vector Int32 resize v l | l < 0 = error "Invalid value for length" | vLength < l = v V.++ V.replicate (l - vLength) 0 | otherwise = V.take l v where vLength = V.length v normalize :: Vector Int32 -> Vector Int32 normalize x = resize x ( significance x )