{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Haskus.Number.BitNat
( NatVal (..)
, Widen
, widen
, Narrow
, narrow
, IsBitNat
, BitNat
, pattern BitNat
, unsafeMakeBitNat
, safeMakeBitNat
, bitNat
, bitNatZero
, bitNatOne
, extractW
, compareW
, (.+.)
, (.-.)
, (.*.)
, (./.)
, BitNatShiftLeft
, BitNatShiftRight
, (.<<.)
, (.>>.)
, bitNatTestBit
, bitNatXor
, bitNatAnd
, bitNatOr
, BitNatWord
, MakeBitNat
, bitNatToNatural
)
where
import Haskus.Number.Word
import Haskus.Binary.Bits
import Haskus.Utils.Types
import Numeric.Natural
newtype BitNat (b :: Nat)
= BitNat' (BitNatWord b)
pattern BitNat :: forall (n :: Nat). (Integral (BitNatWord n), MakeBitNat n) => Natural -> BitNat n
{-# COMPLETE BitNat #-}
pattern BitNat x <- (bitNatToNatural -> x)
where
BitNat x = makeW @n x
bitNat :: forall (v :: Nat) (n :: Nat).
( n ~ NatBitCount v
, Integral (BitNatWord n)
, MakeBitNat n
, KnownNat v
) => BitNat n
bitNat = BitNat @n (natValue @v)
mapW :: (BitNatWord a -> BitNatWord a) -> BitNat a -> BitNat a
mapW f (BitNat' x) = BitNat' (f x)
zipWithW :: (BitNatWord a -> BitNatWord a -> BitNatWord b) -> BitNat a -> BitNat a -> BitNat b
zipWithW f (BitNat' x) (BitNat' y) = BitNat' (f x y)
instance (KnownNat b, Integral (BitNatWord b)) => Show (BitNat b) where
showsPrec d x = showParen (d /= 0)
$ showString "BitNat @"
. showsPrec 0 (natValue' @b)
. showString " "
. showsPrec 0 (bitNatToNatural x)
type family BitNatWord b where
BitNatWord 0 = TypeError ('Text "Naturals encoded on 0 bits are not allowed")
BitNatWord b = BitNatWord' (b <=? 8) (b <=? 16) (b <=? 32) (b <=? 64)
type family BitNatWord' b8 b16 b32 b64 where
BitNatWord' 'True _ _ _ = Word8
BitNatWord' _ 'True _ _ = Word16
BitNatWord' _ _ 'True _ = Word32
BitNatWord' _ _ _ 'True = Word64
BitNatWord' _ _ _ _ = Natural
type IsBitNat b =
( Num (BitNatWord b)
, Integral (BitNatWord b)
, Bitwise (BitNatWord b)
, IndexableBits (BitNatWord b)
)
bitNatZero :: Num (BitNatWord a) => BitNat a
bitNatZero = BitNat' 0
bitNatOne :: Num (BitNatWord a) => BitNat a
bitNatOne = BitNat' 1
bitNatToNatural :: Integral (BitNatWord a) => BitNat a -> Natural
bitNatToNatural (BitNat' x) = fromIntegral x
unsafeMakeBitNat :: forall a. (Maskable a (BitNatWord a)) => BitNatWord a -> BitNat a
unsafeMakeBitNat x = BitNat' (mask @a x)
type MakeBitNat a =
( Maskable a (BitNatWord a)
, ShiftableBits (BitNatWord a)
, Show (BitNatWord a)
, Eq (BitNatWord a)
, Num (BitNatWord a)
)
safeMakeBitNat :: forall a. MakeBitNat a => Natural -> Maybe (BitNat a)
safeMakeBitNat x =
let
x' = fromIntegral x :: BitNatWord a
in case x' `uncheckedShiftR` natValue' @a of
0 -> Just (unsafeMakeBitNat x')
_ -> Nothing
makeW :: forall a. MakeBitNat a => Natural -> BitNat a
makeW x = case safeMakeBitNat x of
Just y -> y
Nothing -> error $
"`" ++ show x
++ "` is out of the range of values that can be encoded by a "
++ show (natValue' @a)
++ "-bit natural number: [0.."
++ show (2 ^ (natValue' @a) -1 :: Natural)
++ "]"
extractW :: BitNat a -> BitNatWord a
extractW (BitNat' a) = a
widen :: forall b a. Widen a b => BitNat a -> BitNat b
widen (BitNat' a) = BitNat' (fromIntegral a)
type Widen a b =
( Assert (a <=? b) (() :: Constraint)
('Text "Can't widen a natural of "
':<>: 'ShowType a
':<>: 'Text " bits into a natural of "
':<>: 'ShowType b
':<>: 'Text " bits"
)
, Integral (BitNatWord a)
, Integral (BitNatWord b)
)
narrow :: forall b a. Narrow a b => BitNat a -> BitNat b
narrow (BitNat' a) = unsafeMakeBitNat (fromIntegral a)
type Narrow a b =
( Assert (b <=? a) (() :: Constraint)
('Text "Can't narrow a natural of "
':<>: 'ShowType a
':<>: 'Text " bits into a natural of "
':<>: 'ShowType b
':<>: 'Text " bits"
)
, Integral (BitNatWord a)
, Integral (BitNatWord b)
, Maskable b (BitNatWord b)
)
compareW :: forall a b.
( Ord (BitNatWord (Max a b))
, Widen a (Max a b)
, Widen b (Max a b)
) => BitNat a -> BitNat b -> Ordering
compareW x y = compare x' y'
where
BitNat' x' = widen @(Max a b) x
BitNat' y' = widen @(Max a b) y
instance Eq (BitNatWord a) => Eq (BitNat a) where
(BitNat' x) == (BitNat' y) = x == y
instance Ord (BitNatWord a) => Ord (BitNat a) where
compare (BitNat' x) (BitNat' y) = compare x y
(.+.) :: forall a b m.
( m ~ (Max a b + 1)
, Widen a m
, Widen b m
, Num (BitNatWord m)
) => BitNat a -> BitNat b -> BitNat m
(.+.) x y = zipWithW (+) (widen @m x) (widen @m y)
(.-.) :: forall a b m.
( m ~ Max a b
, Widen a m
, Widen b m
, Num (BitNatWord m)
) => BitNat a -> BitNat b -> Maybe (BitNat m)
(.-.) (widen @m -> x) (widen @m -> y) = case compare x y of
LT -> Nothing
EQ -> Just bitNatZero
GT -> Just (zipWithW (-) x y)
(.*.) :: forall a b m.
( m ~ (a + b)
, Widen a m
, Widen b m
, Num (BitNatWord m)
) => BitNat a -> BitNat b -> BitNat m
(.*.) x y = zipWithW (*) (widen @m x) (widen @m y)
(./.) :: forall a b m.
( m ~ Max a b
, Widen a m
, Widen b m
, Num (BitNatWord (Min a b))
) => BitNat a -> BitNat b -> Maybe (BitNat a,BitNat (Min a b))
(./.) x y
| y == bitNatZero = Nothing
| otherwise = Just (BitNat' (fromIntegral q), BitNat' (fromIntegral r))
where
(q,r) = quotRem x' y'
BitNat' x' = widen @m x
BitNat' y' = widen @m y
type BitNatShiftRight a s =
( ShiftableBits (BitNatWord a)
, KnownNat s
, Narrow a (a-s)
)
type BitNatShiftLeft a s =
( ShiftableBits (BitNatWord (a+s))
, KnownNat s
, Widen a (a+s)
)
(.<<.) :: forall (s :: Nat) a.
( BitNatShiftLeft a s
) => BitNat a -> NatVal s -> BitNat (a + s)
(.<<.) x _ = mapW (`uncheckedShiftL` natValue @s) (widen @(a+s) x)
(.>>.) :: forall (s :: Nat) a.
( BitNatShiftRight a s
) => BitNat a -> NatVal s -> BitNat (a - s)
(.>>.) x _ = narrow @(a-s) (mapW (`uncheckedShiftR` natValue @s) x)
bitNatTestBit ::
( IndexableBits (BitNatWord a)
) => BitNat a -> Word -> Bool
bitNatTestBit (BitNat' b) i = testBit b i
bitNatXor :: forall a.
( IsBitNat a
) => BitNat a -> BitNat a -> BitNat a
bitNatXor (BitNat' a) (BitNat' b) = BitNat' (a `xor` b)
bitNatAnd :: forall a.
( IsBitNat a
) => BitNat a -> BitNat a -> BitNat a
bitNatAnd (BitNat' a) (BitNat' b) = BitNat' (a .&. b)
bitNatOr :: forall a.
( IsBitNat a
) => BitNat a -> BitNat a -> BitNat a
bitNatOr (BitNat' a) (BitNat' b) = BitNat' (a .|. b)