{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE UnboxedTuples #-}
{-# OPTIONS_GHC -Wno-missing-signatures #-}
module Numeric.QuoteQuot
(
quoteQuot
, quoteRem
, quoteQuotRem
, astQuot
, AST(..)
, interpretAST
, quoteAST
, assumeNonNegArg
, MulHi(..)
) where
#include "MachDeps.h"
import Prelude
import Data.Bits
import Data.Int
import Data.Word
import GHC.Exts
import Language.Haskell.TH.Syntax
quoteQuot ::
#if MIN_VERSION_template_haskell(2,17,0)
(MulHi a, Lift a, Quote m) => a -> Code m (a -> a)
#else
(MulHi a, Lift a) => a -> Q (TExp (a -> a))
#endif
quoteQuot :: a -> Q (TExp (a -> a))
quoteQuot a
d = AST a -> Q (TExp (a -> a))
forall a. (MulHi a, Lift a) => AST a -> Q (TExp (a -> a))
quoteAST (a -> AST a
forall a. (Integral a, FiniteBits a) => a -> AST a
astQuot a
d)
quoteRem ::
#if MIN_VERSION_template_haskell(2,17,0)
(MulHi a, Lift a, Quote m) => a -> Code m (a -> a)
#else
(MulHi a, Lift a) => a -> Q (TExp (a -> a))
#endif
quoteRem :: a -> Q (TExp (a -> a))
quoteRem a
d = [|| snd . $$(quoteQuotRem d) ||]
quoteQuotRem ::
#if MIN_VERSION_template_haskell(2,17,0)
(MulHi a, Lift a, Quote m) => a -> Code m (a -> (a, a))
#else
(MulHi a, Lift a) => a -> Q (TExp (a -> (a, a)))
#endif
quoteQuotRem :: a -> Q (TExp (a -> (a, a)))
quoteQuotRem a
d = [|| \w -> let q = $$(quoteQuot d) w in (q, w - d * q) ||]
class (Integral a, FiniteBits a) => MulHi a where
mulHi :: a -> a -> a
instance MulHi Word8 where
mulHi :: Word8 -> Word8 -> Word8
mulHi Word8
x Word8
y = Word16 -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
x Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
* Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
y :: Word16) Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)
instance MulHi Word16 where
mulHi :: Word16 -> Word16 -> Word16
mulHi Word16
x Word16
y = Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
x Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
y :: Word32) Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
instance MulHi Word32 where
mulHi :: Word32 -> Word32 -> Word32
mulHi Word32
x Word32
y = Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
y :: Word64) Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`shiftR` Int
32)
#if WORD_SIZE_IN_BITS == 64
instance MulHi Word64 where
mulHi :: Word64 -> Word64 -> Word64
mulHi Word64
x Word64
y = Word -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
x Word -> Word -> Word
forall a. MulHi a => a -> a -> a
`mulHi` Word64 -> Word
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
y :: Word)
#endif
instance MulHi Word where
mulHi :: Word -> Word -> Word
mulHi (W# Word#
x) (W# Word#
y) = let !(# Word#
hi, Word#
_ #) = Word# -> Word# -> (# Word#, Word# #)
timesWord2# Word#
x Word#
y in Word# -> Word
W# Word#
hi
instance MulHi Int8 where
mulHi :: Int8 -> Int8 -> Int8
mulHi Int8
x Int8
y = Int16 -> Int8
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int8 -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
x Int16 -> Int16 -> Int16
forall a. Num a => a -> a -> a
* Int8 -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int8
y :: Int16) Int16 -> Int -> Int16
forall a. Bits a => a -> Int -> a
`shiftR` Int
8)
instance MulHi Int16 where
mulHi :: Int16 -> Int16 -> Int16
mulHi Int16
x Int16
y = Int32 -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int16 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
x Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int16 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int16
y :: Int32) Int32 -> Int -> Int32
forall a. Bits a => a -> Int -> a
`shiftR` Int
16)
instance MulHi Int32 where
mulHi :: Int32 -> Int32 -> Int32
mulHi Int32
x Int32
y = Int64 -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int32 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
x Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Int32 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int32
y :: Int64) Int64 -> Int -> Int64
forall a. Bits a => a -> Int -> a
`shiftR` Int
32)
#if MIN_VERSION_base(4,15,0)
#if WORD_SIZE_IN_BITS == 64
instance MulHi Int64 where
mulHi x y = fromIntegral (fromIntegral x `mulHi` fromIntegral y :: Int)
#endif
instance MulHi Int where
mulHi (I# x) (I# y) = let !(# _, hi, _ #) = timesInt2# x y in I# hi
#endif
data AST a
= Arg
| MulHi (AST a) a
| MulLo (AST a) a
| Add (AST a) (AST a)
| Sub (AST a) (AST a)
| Shl (AST a) Int
| Shr (AST a) Int
| CmpGE (AST a) a
| CmpLT (AST a) a
deriving (Int -> AST a -> ShowS
[AST a] -> ShowS
AST a -> String
(Int -> AST a -> ShowS)
-> (AST a -> String) -> ([AST a] -> ShowS) -> Show (AST a)
forall a. Show a => Int -> AST a -> ShowS
forall a. Show a => [AST a] -> ShowS
forall a. Show a => AST a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AST a] -> ShowS
$cshowList :: forall a. Show a => [AST a] -> ShowS
show :: AST a -> String
$cshow :: forall a. Show a => AST a -> String
showsPrec :: Int -> AST a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> AST a -> ShowS
Show)
assumeNonNegArg :: (Ord a, Num a) => AST a -> AST a
assumeNonNegArg :: AST a -> AST a
assumeNonNegArg = \case
Add AST a
x (CmpLT AST a
Arg a
n)
| a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 -> AST a
x
Sub AST a
x (CmpLT AST a
Arg a
n)
| a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 -> AST a
x
Add AST a
x (MulLo (CmpLT AST a
Arg a
n) a
_)
| a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0 -> AST a
x
AST a
e -> AST a
e
interpretAST :: (Integral a, FiniteBits a) => AST a -> (a -> a)
interpretAST :: AST a -> a -> a
interpretAST AST a
ast a
n = AST a -> a
go AST a
ast
where
go :: AST a -> a
go = \case
AST a
Arg -> a
n
MulHi AST a
x a
k -> Integer -> a
forall a. Num a => Integer -> a
fromInteger (Integer -> a) -> Integer -> a
forall a b. (a -> b) -> a -> b
$ (a -> Integer
forall a. Integral a => a -> Integer
toInteger (AST a -> a
go AST a
x) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k) Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
k
MulLo AST a
x a
k -> AST a -> a
go AST a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
k
Add AST a
x AST a
y -> AST a -> a
go AST a
x a -> a -> a
forall a. Num a => a -> a -> a
+ AST a -> a
go AST a
y
Sub AST a
x AST a
y -> AST a -> a
go AST a
x a -> a -> a
forall a. Num a => a -> a -> a
- AST a -> a
go AST a
y
Shl AST a
x Int
k -> AST a -> a
go AST a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
k
Shr AST a
x Int
k -> AST a -> a
go AST a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
k
CmpGE AST a
x a
k -> if AST a -> a
go AST a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
k then a
1 else a
0
CmpLT AST a
x a
k -> if AST a -> a
go AST a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
k then a
1 else a
0
quoteAST ::
#if MIN_VERSION_template_haskell(2,17,0)
(MulHi a, Lift a, Quote m) => AST a -> Code m (a -> a)
#else
(MulHi a, Lift a) => AST a -> Q (TExp (a -> a))
#endif
quoteAST :: AST a -> Q (TExp (a -> a))
quoteAST = \case
AST a
Arg -> [|| id ||]
Shr AST a
x Int
k -> [|| (`shiftR` k) . $$(quoteAST x) ||]
Shl AST a
x Int
k -> [|| (`shiftL` k) . $$(quoteAST x) ||]
MulHi AST a
x a
k -> [|| (`mulHi` k) . $$(quoteAST x) ||]
MulLo AST a
x a
k -> [|| (* k) . $$(quoteAST x) ||]
Add AST a
x AST a
y -> [|| \w -> $$(quoteAST x) w + $$(quoteAST y) w ||]
Sub AST a
x AST a
y -> [|| \w -> $$(quoteAST x) w - $$(quoteAST y) w ||]
CmpGE AST a
x a
k -> [|| (\w -> fromIntegral (I# (dataToTag# (w >= k)))) . $$(quoteAST x) ||]
CmpLT AST a
x a
k -> [|| (\w -> fromIntegral (I# (dataToTag# (w < k)))) . $$(quoteAST x) ||]
astQuot :: (Integral a, FiniteBits a) => a -> AST a
astQuot :: a -> AST a
astQuot a
k
| a -> Bool
forall a. Bits a => a -> Bool
isSigned a
k = a -> AST a
forall a. (Integral a, FiniteBits a) => a -> AST a
signedQuot a
k
| Bool
otherwise = a -> AST a
forall a. (Integral a, FiniteBits a) => a -> AST a
unsignedQuot a
k
unsignedQuot :: (Integral a, FiniteBits a) => a -> AST a
unsignedQuot :: a -> AST a
unsignedQuot a
k'
| a -> Bool
forall a. Bits a => a -> Bool
isSigned a
k
= String -> AST a
forall a. HasCallStack => String -> a
error String
"unsignedQuot works for unsigned types only"
| a
k' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0
= String -> AST a
forall a. HasCallStack => String -> a
error String
"divisor must be positive"
| a
k' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a
forall a. AST a
Arg
| a
k a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr AST a
forall a. AST a
Arg Int
kZeros
| a
k' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
= AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpGE AST a
forall a. AST a
Arg a
k'
| a
k a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
shft
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)
| Bool
otherwise
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add (AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Sub AST a
forall a. AST a
Arg (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic)) Int
1) (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic)) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)
where
fbs :: Int
fbs = a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
k'
kZeros :: Int
kZeros = a -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros a
k'
k :: a
k = a
k' a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
kZeros
r0 :: a
r0 = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
fbs) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k)
shft :: Int
shft = a -> Int -> Int
go a
r0 Int
0
magic :: a
magic = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
shft)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)
go :: a -> Int -> Int
go a
r Int
s
| (a
k a -> a -> a
forall a. Num a => a -> a -> a
- a
r) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
s = Int
s
| Bool
otherwise = a -> Int -> Int
go (a
r a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
1 a -> a -> a
forall a. Integral a => a -> a -> a
`rem` a
k) (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
signedQuot :: (Integral a, FiniteBits a) => a -> AST a
signedQuot :: a -> AST a
signedQuot a
k'
| Bool -> Bool
not (a -> Bool
forall a. Bits a => a -> Bool
isSigned a
k)
= String -> AST a
forall a. HasCallStack => String -> a
error String
"signedQuot works for signed types only"
| a
k' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0
= String -> AST a
forall a. HasCallStack => String -> a
error String
"divisor must be positive"
| a
k' a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a
forall a. AST a
Arg
| a
k a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1
= AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add AST a
forall a. AST a
Arg (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulLo (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg a
0) (a
k' a -> a -> a
forall a. Num a => a -> a -> a
- a
1))) Int
kZeros
| a
k' a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2)
= AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Sub (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpGE AST a
forall a. AST a
Arg a
k') (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg (a
1 a -> a -> a
forall a. Num a => a -> a -> a
- a
k'))
| a
magic a -> a -> Bool
forall a. Ord a => a -> a -> Bool
>= a
0
= AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add (AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)) (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg a
0)
| Bool
otherwise
= AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add (AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
shr (AST a -> AST a -> AST a
forall a. AST a -> AST a -> AST a
Add AST a
forall a. AST a
Arg (AST a -> a -> AST a
forall a. AST a -> a -> AST a
MulHi AST a
forall a. AST a
Arg a
magic)) (Int
shft Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
kZeros)) (AST a -> a -> AST a
forall a. AST a -> a -> AST a
CmpLT AST a
forall a. AST a
Arg a
0)
where
fbs :: Int
fbs = a -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize a
k'
kZeros :: Int
kZeros = a -> Int
forall b. FiniteBits b => b -> Int
countTrailingZeros a
k'
k :: a
k = a
k' a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftR` Int
kZeros
r0 :: a
r0 = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` Int
fbs) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k)
shft :: Int
shft = a -> Int -> Int
go a
r0 Int
0
magic :: a
magic = Integer -> a
forall a. Num a => Integer -> a
fromInteger ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftL` (Int
fbs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
shft)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` a -> Integer
forall a. Integral a => a -> Integer
toInteger a
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)
go :: a -> Int -> Int
go a
r Int
s
| (a
k a -> a -> a
forall a. Num a => a -> a -> a
- a
r) a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
1 a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) = Int
s
| Bool
otherwise = a -> Int -> Int
go (a
r a -> Int -> a
forall a. Bits a => a -> Int -> a
`shiftL` Int
1 a -> a -> a
forall a. Integral a => a -> a -> a
`rem` a
k) (Int
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
shr :: AST a -> Int -> AST a
shr :: AST a -> Int -> AST a
shr AST a
x Int
0 = AST a
x
shr AST a
x Int
k = AST a -> Int -> AST a
forall a. AST a -> Int -> AST a
Shr AST a
x Int
k