{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Haskus.Format.Number.BitNat
( NatVal (..)
, Widen
, widen
, Narrow
, narrow
, BitNat
, pattern BitNat
, nat
, unsafeMakeW
, safeMakeW
, zeroW
, oneW
, extractW
, compareW
, (.+.)
, (.-.)
, (.*.)
, (./.)
, (.<<.)
, (.>>.)
, BitNatWord
, MakeW
, toNaturalW
)
where
import Haskus.Format.Binary.Word
import Haskus.Format.Binary.Bits
import Haskus.Utils.Types
import Numeric.Natural
newtype BitNat (b :: Nat)
= BitNat' (BitNatWord b)
pattern BitNat :: forall (n :: Nat). (Integral (BitNatWord n), MakeW n) => Natural -> BitNat n
{-# COMPLETE BitNat #-}
pattern BitNat x <- (toNaturalW -> x)
where
BitNat x = makeW @n x
nat :: forall (v :: Nat) (n :: Nat).
( n ~ NatBitCount v
, Integral (BitNatWord n)
, MakeW n
, KnownNat v
) => BitNat n
nat = BitNat @n (natValue @v)
mapW :: (BitNatWord a -> BitNatWord a) -> BitNat a -> BitNat a
mapW f (BitNat' x) = BitNat' (f x)
zipWithW :: (BitNatWord a -> BitNatWord a -> BitNatWord b) -> BitNat a -> BitNat a -> BitNat b
zipWithW f (BitNat' x) (BitNat' y) = BitNat' (f x y)
instance (KnownNat b, Integral (BitNatWord b)) => Show (BitNat b) where
showsPrec d x = showParen (d /= 0)
$ showString "BitNat @"
. showsPrec 0 (natValue' @b)
. showString " "
. showsPrec 0 (toNaturalW x)
type family BitNatWord b where
BitNatWord 0 = TypeError ('Text "Naturals encoded on 0 bits are not allowed")
BitNatWord b = BitNatWord' (b <=? 8) (b <=? 16) (b <=? 32) (b <=? 64)
type family BitNatWord' b8 b16 b32 b64 where
BitNatWord' 'True _ _ _ = Word8
BitNatWord' _ 'True _ _ = Word16
BitNatWord' _ _ 'True _ = Word32
BitNatWord' _ _ _ 'True = Word64
BitNatWord' _ _ _ _ = Natural
zeroW :: Num (BitNatWord a) => BitNat a
zeroW = BitNat' 0
oneW :: Num (BitNatWord a) => BitNat a
oneW = BitNat' 1
toNaturalW :: Integral (BitNatWord a) => BitNat a -> Natural
toNaturalW (BitNat' x) = fromIntegral x
unsafeMakeW :: forall a. (Maskable a (BitNatWord a)) => BitNatWord a -> BitNat a
unsafeMakeW x = BitNat' (mask @a x)
type MakeW a =
( Maskable a (BitNatWord a)
, ShiftableBits (BitNatWord a)
, Show (BitNatWord a)
, Eq (BitNatWord a)
, Num (BitNatWord a)
)
safeMakeW :: forall a. MakeW a => Natural -> Maybe (BitNat a)
safeMakeW x =
let
x' = fromIntegral x :: BitNatWord a
in case x' `uncheckedShiftR` natValue' @a of
0 -> Just (unsafeMakeW x')
_ -> Nothing
makeW :: forall a. MakeW a => Natural -> BitNat a
makeW x = case safeMakeW x of
Just y -> y
Nothing -> error $
"`" ++ show x
++ "` is out of the range of values that can be encoded by a "
++ show (natValue' @a)
++ "-bit natural number: [0.."
++ show (2 ^ (natValue' @a) -1 :: Natural)
++ "]"
extractW :: BitNat a -> BitNatWord a
extractW (BitNat' a) = a
widen :: forall b a. Widen a b => BitNat a -> BitNat b
widen (BitNat' a) = BitNat' (fromIntegral a)
type Widen a b =
( Assert (a <=? b) (() :: Constraint)
('Text "Can't widen a natural of "
':<>: 'ShowType a
':<>: 'Text " bits into a natural of "
':<>: 'ShowType b
':<>: 'Text " bits"
)
, Integral (BitNatWord a)
, Integral (BitNatWord b)
)
narrow :: forall b a. Narrow a b => BitNat a -> BitNat b
narrow (BitNat' a) = unsafeMakeW (fromIntegral a)
type Narrow a b =
( Assert (b <=? a) (() :: Constraint)
('Text "Can't narrow a natural of "
':<>: 'ShowType a
':<>: 'Text " bits into a natural of "
':<>: 'ShowType b
':<>: 'Text " bits"
)
, Integral (BitNatWord a)
, Integral (BitNatWord b)
, Maskable b (BitNatWord b)
)
compareW :: forall a b.
( Ord (BitNatWord (Max a b))
, Widen a (Max a b)
, Widen b (Max a b)
) => BitNat a -> BitNat b -> Ordering
compareW x y = compare x' y'
where
BitNat' x' = widen @(Max a b) x
BitNat' y' = widen @(Max a b) y
instance Eq (BitNatWord a) => Eq (BitNat a) where
(BitNat' x) == (BitNat' y) = x == y
instance Ord (BitNatWord a) => Ord (BitNat a) where
compare (BitNat' x) (BitNat' y) = compare x y
(.+.) :: forall a b m.
( m ~ (Max a b + 1)
, Widen a m
, Widen b m
, Num (BitNatWord m)
) => BitNat a -> BitNat b -> BitNat m
(.+.) x y = zipWithW (+) (widen @m x) (widen @m y)
(.-.) :: forall a b m.
( m ~ Max a b
, Widen a m
, Widen b m
, Num (BitNatWord m)
) => BitNat a -> BitNat b -> Maybe (BitNat m)
(.-.) (widen @m -> x) (widen @m -> y) = case compare x y of
LT -> Nothing
EQ -> Just zeroW
GT -> Just (zipWithW (-) x y)
(.*.) :: forall a b m.
( m ~ (a + b)
, Widen a m
, Widen b m
, Num (BitNatWord m)
) => BitNat a -> BitNat b -> BitNat m
(.*.) x y = zipWithW (*) (widen @m x) (widen @m y)
(./.) :: forall a b m.
( m ~ Max a b
, Widen a m
, Widen b m
, Num (BitNatWord (Min a b))
) => BitNat a -> BitNat b -> Maybe (BitNat a,BitNat (Min a b))
(./.) x y
| y == zeroW = Nothing
| otherwise = Just (BitNat' (fromIntegral q), BitNat' (fromIntegral r))
where
(q,r) = quotRem x' y'
BitNat' x' = widen @m x
BitNat' y' = widen @m y
(.<<.) :: forall (s :: Nat) a.
( ShiftableBits (BitNatWord (a + s))
, KnownNat s
, Widen a (a+s)
) => BitNat a -> NatVal s -> BitNat (a + s)
(.<<.) x _ = mapW (`uncheckedShiftL` natValue @s) (widen @(a+s) x)
(.>>.) :: forall (s :: Nat) a.
( ShiftableBits (BitNatWord a)
, KnownNat s
, Narrow a (a-s)
) => BitNat a -> NatVal s -> BitNat (a - s)
(.>>.) x _ = narrow @(a-s) (mapW (`uncheckedShiftR` natValue @s) x)