{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE MultiWayIf #-} {-# LANGUAGE AllowAmbiguousTypes #-} -- | Posit (type III unum) module Haskus.Number.Posit ( Posit (..) , PositKind (..) , PositK (..) , positKind , isZero , isInfinity , isPositive , isNegative , positAbs , PositEncoding (..) , PositFields (..) , positEncoding , positFields , positToRational , positFromRational , positApproxFactor , positDecimalError , positDecimalAccuracy , positBinaryError , positBinaryAccuracy , floatBinaryAccuracy ) where import Haskus.Number.Int import Haskus.Binary.Bits import Haskus.Utils.Types import Haskus.Utils.Tuple import Haskus.Utils.Flow import Data.Ratio import qualified GHC.Real as Ratio newtype Posit (nbits :: Nat) (es :: Nat) = Posit (IntN nbits) -- | Show posit instance ( Bits (IntN n) , FiniteBits (IntN n) , Ord (IntN n) , Num (IntN n) , KnownNat n , KnownNat es , Integral (IntN n) ) => Show (Posit n es) where show p = case positKind p of SomePosit Zero -> "0" SomePosit Infinity -> "Infinity" SomePosit (Value v) -> show (positToRational v) data PositKind = ZeroK | InfinityK | NormalK deriving (Show,Eq) -- | Kinded Posit -- -- GADT that can be used to ensure at the type level that we deal with -- non-infinite/non-zero Posit values data PositK k nbits es where Zero :: PositK 'ZeroK nbits es Infinity :: PositK 'InfinityK nbits es Value :: Posit nbits es -> PositK 'NormalK nbits es data SomePosit n es where SomePosit :: PositK k n es -> SomePosit n es type PositValue n es = PositK 'NormalK n es -- | Get the kind of the posit at the type level positKind :: forall n es. ( Bits (IntN n) , KnownNat n , Eq (IntN n) ) => Posit n es -> SomePosit n es positKind p | isZero p = SomePosit Zero | isInfinity p = SomePosit Infinity | otherwise = SomePosit (Value p) -- | Check if a posit is zero isZero :: forall n es. ( Bits (IntN n) , Eq (IntN n) , KnownNat n ) => Posit n es -> Bool {-# INLINABLE isZero #-} isZero (Posit i) = i == zeroBits -- | Check if a posit is infinity isInfinity :: forall n es. ( Bits (IntN n) , Eq (IntN n) , KnownNat n ) => Posit n es -> Bool {-# INLINABLE isInfinity #-} isInfinity (Posit i) = i == bit (natValue @n - 1) -- | Check if a posit is positive isPositive :: forall n es. ( Bits (IntN n) , Ord (IntN n) , KnownNat n ) => PositValue n es -> Bool {-# INLINABLE isPositive #-} isPositive (Value (Posit i)) = i > zeroBits -- | Check if a posit is negative isNegative :: forall n es. ( Bits (IntN n) , Ord (IntN n) , KnownNat n ) => PositValue n es -> Bool {-# INLINABLE isNegative #-} isNegative (Value (Posit i)) = i < zeroBits -- | Posit absolute value positAbs :: forall n es. ( Num (IntN n) , KnownNat n ) => PositValue n es -> PositValue n es positAbs (Value (Posit i)) = Value (Posit (abs i)) data PositFields = PositFields { positNegative :: Bool , positRegimeBitCount :: Word , positExponentBitCount :: Word , positFractionBitCount :: Word , positRegime :: Int , positExponent :: Word , positFraction :: Word } deriving (Show) data PositEncoding = PositInfinity | PositZero | PositEncoding PositFields deriving (Show) positEncoding :: forall n es. ( Bits (IntN n) , Ord (IntN n) , Num (IntN n) , KnownNat n , KnownNat es , Integral (IntN n) ) => Posit n es -> PositEncoding positEncoding p = case positKind p of SomePosit Zero -> PositZero SomePosit Infinity -> PositInfinity SomePosit v@(Value _) -> PositEncoding (positFields v) -- | Decode posit fields positFields :: forall n es. ( Bits (IntN n) , Ord (IntN n) , Num (IntN n) , KnownNat n , KnownNat es , Integral (IntN n) ) => PositValue n es -> PositFields positFields p = PositFields { positNegative = isNegative p , positRegimeBitCount = rs , positExponentBitCount = es , positFractionBitCount = fs , positRegime = regime , positExponent = expo , positFraction = frac } where -- get absolute value Value (Posit v) = positAbs p (negativeRegime,regimeLen) = if v `testBit` (natValue @n - 2) -- regime has shape 111...[0|end of word], subtract 1 for sign bit then (False, countLeadingZeros (complement v `clearBit` (natValue @n - 1)) - 1) -- regime has shape 00000...[1|end of word], subtract 1 for sign bit else (True, countLeadingZeros v - 1) regime = if negativeRegime then negate (fromIntegral regimeLen) else fromIntegral regimeLen - 1 -- we encode the 0 regime -- length of regime bits (with stop bit) rs = min (natValue @n - 1) (regimeLen + 1) -- real exponent size (regime bits can reduce the size of the exponent) es = min (natValue @n - rs - 1) (natValue @es) -- fraction size fs = natValue @n - es - rs - 1 expo = fromIntegral (maskDyn es (v `shiftR` fs)) frac = fromIntegral (maskDyn fs v) -- | Convert a Posit into a Rational positToRational :: forall n es. ( KnownNat n , KnownNat es , Eq (IntN n) , Bits (IntN n) , Integral (IntN n) ) => Posit n es -> Rational positToRational p | isZero p = 0 Ratio.:% 1 | isInfinity p = Ratio.infinity | otherwise = (fromIntegral useed ^^ r) * (2 ^^ e) * (1 + (f % fd)) where fields = positFields (Value p) r = positRegime fields e = positExponent fields f = fromIntegral (positFraction fields) fd = 1 `shiftL` positFractionBitCount fields useed = 1 `shiftL` (1 `shiftL` natValue @es) :: Integer -- 2^(2^es) -- | Convert a rational into the approximate Posit positFromRational :: forall p n es. ( Posit n es ~ p , Num (IntN n) , Bits (IntN n) , KnownNat es , KnownNat n ) => Rational -> Posit n es positFromRational x = if | x == 0 -> Posit 0 | x == Ratio.infinity -> Posit (bit (natValue @n - 1)) | otherwise -> computeRegime |> uncurry3 computeExponent |> uncurry3 computeFraction |> uncurry computeRounding |> computeSign |> Posit where useed = fromIntegral (1 `shiftL` (1 `shiftL` es) :: Integer) -- 2^(2^es) nbits = natValue @n es = natValue @es -- compute regime bits of the posit, return (y,p,i) -- y: remaining value to convert, in [1,useed) if there are enough available bits -- p: current posit bits -- i: number of set bits in p computeRegime | absx >= 1 = regime111 absx 1 2 | otherwise = regime000 absx 1 where absx = abs x -- push regime bits 111..1110 regime111 y p i | y >= useed && i < nbits = regime111 (y / useed) ((p `uncheckedShiftL` 1) .|. 1) (i+1) | otherwise = (y, p `uncheckedShiftL` 1, i+1) -- push regime bits 000..0001 (or 000...00010 if the full word -- (including the sign bit) is set) regime000 y i | y < 1 && i <= nbits = regime000 (y*useed) (i+1) | i >= nbits = (y,2,nbits+1) | otherwise = (y,1,i+1) -- compute exponent bits; return (y,p,i) -- y: remaining value to convert, in [1,2) if there are enough available bits -- p: current posit bits -- i: number of set bits in p computeExponent | es == 0 = (,,) | otherwise = go (1 `shiftL` (es - 1)) where go e y p i | i > nbits || e == 0 = (y,p,i) | y >= pow2e = go (e `uncheckedShiftR` 1) (y / pow2e) ((p `uncheckedShiftL` 1) .|. 1) (i+1) | otherwise = go (e `uncheckedShiftR` 1) y (p `uncheckedShiftL` 1) (i+1) where pow2e = fromIntegral (1 `shiftL` e :: Integer) -- compute fraction bits; return (y,p) -- y: remaining value to convert -- p: current posit bits computeFraction y' = go (y'-1) -- subtract hidden bit. Now y is in [0,1) if there are enough available bits where go y p i | i > nbits = (y,p) | y <= 0 = (y, p `shiftL` (nbits+1-i)) -- add remaining 0s fraction bits | y2 > 1 = go (y2-1) (p `shiftL` 1 + 1) (i+1) | otherwise = go y2 (p `shiftL` 1) (i+1) where y2 = 2*y -- at this stage, p contains an additional fraction bit. -- We remove it and we round accordingly. computeRounding y p = let p' = p `uncheckedShiftR` 1 in if | not (p `testBit` 0) -> p' -- closer to lower value | y == 1 || y == 0 -> p' + (if p' `testBit` 0 then 1 else 0) -- tie goes to nearest even | otherwise -> p' + 1 -- closer to upper value -- fixup the sign bit (and use 2's complement for the other bits) computeSign p | x < 0 = negate p | otherwise = p -- | Factor of approximation for a given Rational when encoded as a Posit. -- The closer to 1, the better. -- -- Usage: -- -- positApproxFactor @(Posit 8 2) (52 % 137) -- positApproxFactor :: forall p n es. ( Posit n es ~ p , Num (IntN n) , Bits (IntN n) , Integral (IntN n) , KnownNat es , KnownNat n ) => Rational -> Double positApproxFactor r = fromRational ((positToRational (positFromRational r :: p)) / r) -- | Compute the decimal error if the given Rational is encoded as a Posit. -- -- Usage: -- -- positDecimalError @(Posit 8 2) (52 % 137) -- positDecimalError :: forall p n es. ( Posit n es ~ p , Num (IntN n) , Bits (IntN n) , Integral (IntN n) , KnownNat es , KnownNat n ) => Rational -> Double positDecimalError r = abs (logBase 10 (positApproxFactor @p r)) -- | Compute the number of decimals of accuracy if the given Rational is encoded -- as a Posit. -- -- Usage: -- -- positDecimalAccuracy @(Posit 8 2) (52 % 137) -- positDecimalAccuracy :: forall p n es. ( Posit n es ~ p , Num (IntN n) , Bits (IntN n) , Integral (IntN n) , KnownNat es , KnownNat n ) => Rational -> Double positDecimalAccuracy r = -1 * logBase 10 (positDecimalError @p r) -- | Compute the binary error if the given Rational is encoded as a Posit. -- -- Usage: -- -- positBinaryError @(Posit 8 2) (52 % 137) -- positBinaryError :: forall p n es. ( Posit n es ~ p , Num (IntN n) , Bits (IntN n) , Integral (IntN n) , KnownNat es , KnownNat n ) => Rational -> Double positBinaryError r = abs (logBase 2 (positApproxFactor @p r)) -- | Compute the number of bits of accuracy if the given Rational is encoded -- as a Posit. -- -- Usage: -- -- positBinaryAccuracy @(Posit 8 2) (52 % 137) -- positBinaryAccuracy :: forall p n es. ( Posit n es ~ p , Num (IntN n) , Bits (IntN n) , Integral (IntN n) , KnownNat es , KnownNat n ) => Rational -> Double positBinaryAccuracy r = -1 * logBase 2 (positBinaryError @p r) -- | Compute the number of bits of accuracy if the given Rational is encoded -- as a Float/Double. -- -- Usage: -- -- floatBinaryAccuracy @Double (52 % 137) -- floatBinaryAccuracy :: forall f. ( Fractional f , Real f ) => Rational -> Double floatBinaryAccuracy r = -1 * logBase 2 floatError where floatApprox = fromRational (toRational (fromRational r :: f) / r) floatError = abs (logBase 2 floatApprox)