{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
module Haskus.Number.Posit
( Posit (..)
, PositKind (..)
, PositK (..)
, positKind
, isZero
, isInfinity
, isPositive
, isNegative
, positAbs
, PositEncoding (..)
, PositFields (..)
, positEncoding
, positFields
, positToRational
, positFromRational
, positApproxFactor
, positDecimalError
, positDecimalAccuracy
, positBinaryError
, positBinaryAccuracy
, floatBinaryAccuracy
)
where
import Haskus.Number.Int
import Haskus.Binary.Bits
import Haskus.Utils.Types
import Haskus.Utils.Tuple
import Haskus.Utils.Flow
import Data.Ratio
import qualified GHC.Real as Ratio
newtype Posit (nbits :: Nat) (es :: Nat) = Posit (IntN nbits)
instance
( Bits (IntN n)
, FiniteBits (IntN n)
, Ord (IntN n)
, Num (IntN n)
, KnownNat n
, KnownNat es
, Integral (IntN n)
) => Show (Posit n es)
where
show p = case positKind p of
SomePosit Zero -> "0"
SomePosit Infinity -> "Infinity"
SomePosit (Value v) -> show (positToRational v)
data PositKind
= ZeroK
| InfinityK
| NormalK
deriving (Show,Eq)
data PositK k nbits es where
Zero :: PositK 'ZeroK nbits es
Infinity :: PositK 'InfinityK nbits es
Value :: Posit nbits es -> PositK 'NormalK nbits es
data SomePosit n es where
SomePosit :: PositK k n es -> SomePosit n es
type PositValue n es = PositK 'NormalK n es
positKind :: forall n es.
( Bits (IntN n)
, KnownNat n
, Eq (IntN n)
) => Posit n es -> SomePosit n es
positKind p
| isZero p = SomePosit Zero
| isInfinity p = SomePosit Infinity
| otherwise = SomePosit (Value p)
isZero :: forall n es.
( Bits (IntN n)
, Eq (IntN n)
, KnownNat n
) => Posit n es -> Bool
{-# INLINABLE isZero #-}
isZero (Posit i) = i == zeroBits
isInfinity :: forall n es.
( Bits (IntN n)
, Eq (IntN n)
, KnownNat n
) => Posit n es -> Bool
{-# INLINABLE isInfinity #-}
isInfinity (Posit i) = i == bit (natValue @n - 1)
isPositive :: forall n es.
( Bits (IntN n)
, Ord (IntN n)
, KnownNat n
) => PositValue n es -> Bool
{-# INLINABLE isPositive #-}
isPositive (Value (Posit i)) = i > zeroBits
isNegative :: forall n es.
( Bits (IntN n)
, Ord (IntN n)
, KnownNat n
) => PositValue n es -> Bool
{-# INLINABLE isNegative #-}
isNegative (Value (Posit i)) = i < zeroBits
positAbs :: forall n es.
( Num (IntN n)
, KnownNat n
) => PositValue n es -> PositValue n es
positAbs (Value (Posit i)) = Value (Posit (abs i))
data PositFields = PositFields
{ positNegative :: Bool
, positRegimeBitCount :: Word
, positExponentBitCount :: Word
, positFractionBitCount :: Word
, positRegime :: Int
, positExponent :: Word
, positFraction :: Word
}
deriving (Show)
data PositEncoding
= PositInfinity
| PositZero
| PositEncoding PositFields
deriving (Show)
positEncoding :: forall n es.
( Bits (IntN n)
, Ord (IntN n)
, Num (IntN n)
, KnownNat n
, KnownNat es
, Integral (IntN n)
) => Posit n es -> PositEncoding
positEncoding p = case positKind p of
SomePosit Zero -> PositZero
SomePosit Infinity -> PositInfinity
SomePosit v@(Value _) -> PositEncoding (positFields v)
positFields :: forall n es.
( Bits (IntN n)
, Ord (IntN n)
, Num (IntN n)
, KnownNat n
, KnownNat es
, Integral (IntN n)
) => PositValue n es -> PositFields
positFields p = PositFields
{ positNegative = isNegative p
, positRegimeBitCount = rs
, positExponentBitCount = es
, positFractionBitCount = fs
, positRegime = regime
, positExponent = expo
, positFraction = frac
}
where
Value (Posit v) = positAbs p
(negativeRegime,regimeLen) =
if v `testBit` (natValue @n - 2)
then (False, countLeadingZeros (complement v `clearBit` (natValue @n - 1)) - 1)
else (True, countLeadingZeros v - 1)
regime = if negativeRegime
then negate (fromIntegral regimeLen)
else fromIntegral regimeLen - 1
rs = min (natValue @n - 1) (regimeLen + 1)
es = min (natValue @n - rs - 1) (natValue @es)
fs = natValue @n - es - rs - 1
expo = fromIntegral (maskDyn es (v `shiftR` fs))
frac = fromIntegral (maskDyn fs v)
positToRational :: forall n es.
( KnownNat n
, KnownNat es
, Eq (IntN n)
, Bits (IntN n)
, Integral (IntN n)
) => Posit n es -> Rational
positToRational p
| isZero p = 0 Ratio.:% 1
| isInfinity p = Ratio.infinity
| otherwise = (fromIntegral useed ^^ r) * (2 ^^ e) * (1 + (f % fd))
where
fields = positFields (Value p)
r = positRegime fields
e = positExponent fields
f = fromIntegral (positFraction fields)
fd = 1 `shiftL` positFractionBitCount fields
useed = 1 `shiftL` (1 `shiftL` natValue @es) :: Integer
positFromRational :: forall p n es.
( Posit n es ~ p
, Num (IntN n)
, Bits (IntN n)
, KnownNat es
, KnownNat n
) => Rational -> Posit n es
positFromRational x = if
| x == 0 -> Posit 0
| x == Ratio.infinity -> Posit (bit (natValue @n - 1))
| otherwise -> computeRegime
|> uncurry3 computeExponent
|> uncurry3 computeFraction
|> uncurry computeRounding
|> computeSign
|> Posit
where
useed = fromIntegral (1 `shiftL` (1 `shiftL` es) :: Integer)
nbits = natValue @n
es = natValue @es
computeRegime
| absx >= 1 = regime111 absx 1 2
| otherwise = regime000 absx 1
where
absx = abs x
regime111 y p i
| y >= useed && i < nbits = regime111 (y / useed) ((p `uncheckedShiftL` 1) .|. 1) (i+1)
| otherwise = (y, p `uncheckedShiftL` 1, i+1)
regime000 y i
| y < 1 && i <= nbits = regime000 (y*useed) (i+1)
| i >= nbits = (y,2,nbits+1)
| otherwise = (y,1,i+1)
computeExponent
| es == 0 = (,,)
| otherwise = go (1 `shiftL` (es - 1))
where
go e y p i
| i > nbits || e == 0 = (y,p,i)
| y >= pow2e = go (e `uncheckedShiftR` 1) (y / pow2e) ((p `uncheckedShiftL` 1) .|. 1) (i+1)
| otherwise = go (e `uncheckedShiftR` 1) y (p `uncheckedShiftL` 1) (i+1)
where
pow2e = fromIntegral (1 `shiftL` e :: Integer)
computeFraction y' = go (y'-1)
where
go y p i
| i > nbits = (y,p)
| y <= 0 = (y, p `shiftL` (nbits+1-i))
| y2 > 1 = go (y2-1) (p `shiftL` 1 + 1) (i+1)
| otherwise = go y2 (p `shiftL` 1) (i+1)
where
y2 = 2*y
computeRounding y p =
let p' = p `uncheckedShiftR` 1
in if | not (p `testBit` 0) -> p'
| y == 1 || y == 0 -> p' + (if p' `testBit` 0 then 1 else 0)
| otherwise -> p' + 1
computeSign p
| x < 0 = negate p
| otherwise = p
positApproxFactor :: forall p n es.
( Posit n es ~ p
, Num (IntN n)
, Bits (IntN n)
, Integral (IntN n)
, KnownNat es
, KnownNat n
) => Rational -> Double
positApproxFactor r = fromRational ((positToRational (positFromRational r :: p)) / r)
positDecimalError :: forall p n es.
( Posit n es ~ p
, Num (IntN n)
, Bits (IntN n)
, Integral (IntN n)
, KnownNat es
, KnownNat n
) => Rational -> Double
positDecimalError r = abs (logBase 10 (positApproxFactor @p r))
positDecimalAccuracy :: forall p n es.
( Posit n es ~ p
, Num (IntN n)
, Bits (IntN n)
, Integral (IntN n)
, KnownNat es
, KnownNat n
) => Rational -> Double
positDecimalAccuracy r = -1 * logBase 10 (positDecimalError @p r)
positBinaryError :: forall p n es.
( Posit n es ~ p
, Num (IntN n)
, Bits (IntN n)
, Integral (IntN n)
, KnownNat es
, KnownNat n
) => Rational -> Double
positBinaryError r = abs (logBase 2 (positApproxFactor @p r))
positBinaryAccuracy :: forall p n es.
( Posit n es ~ p
, Num (IntN n)
, Bits (IntN n)
, Integral (IntN n)
, KnownNat es
, KnownNat n
) => Rational -> Double
positBinaryAccuracy r = -1 * logBase 2 (positBinaryError @p r)
floatBinaryAccuracy :: forall f.
( Fractional f
, Real f
) => Rational -> Double
floatBinaryAccuracy r = -1 * logBase 2 floatError
where
floatApprox = fromRational (toRational (fromRational r :: f) / r)
floatError = abs (logBase 2 floatApprox)