{-|
Copyright  :  (C) 2013-2016, University of Twente
License    :  BSD2 (see the file LICENSE)
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>
-}

{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MagicHash        #-}
{-# LANGUAGE TypeOperators    #-}
{-# LANGUAGE TypeFamilies     #-}

{-# LANGUAGE Trustworthy #-}

{-# OPTIONS_HADDOCK show-extensions #-}

module Clash.Prelude.BitIndex where

import GHC.TypeLits                   (KnownNat, type (+), type (-))

import Clash.Class.BitPack            (BitPack (..))
import Clash.Promoted.Nat             (SNat)
import Clash.Sized.Internal.BitVector (BitVector, Bit, index#, lsb#, msb#,
                                       replaceBit#, setSlice#, slice#, split#)

{- $setup
>>> :set -XDataKinds
>>> import Clash.Prelude
-}

{-# INLINE (!) #-}
-- | Get the bit at the specified bit index.
--
-- __NB:__ Bit indices are __DESCENDING__.
--
-- >>> pack (7 :: Unsigned 6)
-- 00_0111
-- >>> (7 :: Unsigned 6) ! 1
-- 1
-- >>> (7 :: Unsigned 6) ! 5
-- 0
-- >>> (7 :: Unsigned 6) ! 6
-- *** Exception: (!): 6 is out of range [5..0]
-- ...
(!) :: (BitPack a, KnownNat (BitSize a), Enum i) => a -> i -> Bit
(!) v :: a
v i :: i
i = BitVector (BitSize a) -> Int -> Bit
forall (n :: Nat). KnownNat n => BitVector n -> Int -> Bit
index# (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
v) (i -> Int
forall a. Enum a => a -> Int
fromEnum i
i)

{-# INLINE slice #-}
-- | Get a slice between bit index @m@ and and bit index @n@.
--
-- __NB:__ Bit indices are __DESCENDING__.
--
-- >>> pack (7 :: Unsigned 6)
-- 00_0111
-- >>> slice d4 d2 (7 :: Unsigned 6)
-- 001
-- >>> slice d6 d4 (7 :: Unsigned 6)
-- <BLANKLINE>
-- <interactive>:...
--     • Couldn't match type ‘7 + i0’ with ‘6’
--         arising from a use of ‘slice’
--       The type variable ‘i0’ is ambiguous
--     • In the expression: slice d6 d4 (7 :: Unsigned 6)
--       In an equation for ‘it’: it = slice d6 d4 (7 :: Unsigned 6)
slice :: (BitPack a, BitSize a ~ ((m + 1) + i)) => SNat m -> SNat n -> a
      -> BitVector (m + 1 - n)
slice :: SNat m -> SNat n -> a -> BitVector ((m + 1) - n)
slice m :: SNat m
m n :: SNat n
n v :: a
v = BitVector ((m + 1) + i)
-> SNat m -> SNat n -> BitVector ((m + 1) - n)
forall (m :: Nat) (i :: Nat) (n :: Nat).
BitVector ((m + 1) + i)
-> SNat m -> SNat n -> BitVector ((m + 1) - n)
slice# (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
v) SNat m
m SNat n
n

{-# INLINE split #-}
-- | Split a value of a bit size @m + n@ into a tuple of values with size @m@
-- and size @n@.
--
-- >>> pack (7 :: Unsigned 6)
-- 00_0111
-- >>> split (7 :: Unsigned 6) :: (BitVector 2, BitVector 4)
-- (00,0111)
split :: (BitPack a, BitSize a ~ (m + n), KnownNat n) => a
      -> (BitVector m, BitVector n)
split :: a -> (BitVector m, BitVector n)
split v :: a
v = BitVector (m + n) -> (BitVector m, BitVector n)
forall (n :: Nat) (m :: Nat).
KnownNat n =>
BitVector (m + n) -> (BitVector m, BitVector n)
split# (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
v)

{-# INLINE replaceBit #-}
-- | Set the bit at the specified index
--
-- __NB:__ Bit indices are __DESCENDING__.
--
-- >>> pack (-5 :: Signed 6)
-- 11_1011
-- >>> replaceBit 4 0 (-5 :: Signed 6)
-- -21
-- >>> pack (-21 :: Signed 6)
-- 10_1011
-- >>> replaceBit 5 0 (-5 :: Signed 6)
-- 27
-- >>> pack (27 :: Signed 6)
-- 01_1011
-- >>> replaceBit 6 0 (-5 :: Signed 6)
-- *** Exception: replaceBit: 6 is out of range [5..0]
-- ...
replaceBit :: (BitPack a, KnownNat (BitSize a), Enum i) => i -> Bit -> a
           -> a
replaceBit :: i -> Bit -> a -> a
replaceBit i :: i
i b :: Bit
b v :: a
v = BitVector (BitSize a) -> a
forall a. BitPack a => BitVector (BitSize a) -> a
unpack (BitVector (BitSize a) -> Int -> Bit -> BitVector (BitSize a)
forall (n :: Nat).
KnownNat n =>
BitVector n -> Int -> Bit -> BitVector n
replaceBit# (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
v) (i -> Int
forall a. Enum a => a -> Int
fromEnum i
i) Bit
b)

{-# INLINE setSlice #-}
-- | Set the bits between bit index @m@ and bit index @n@.
--
-- __NB:__ Bit indices are __DESCENDING__.
--
-- >>> pack (-5 :: Signed 6)
-- 11_1011
-- >>> setSlice d4 d3 0 (-5 :: Signed 6)
-- -29
-- >>> pack (-29 :: Signed 6)
-- 10_0011
-- >>> setSlice d6 d5 0 (-5 :: Signed 6)
-- <BLANKLINE>
-- <interactive>:...
--     • Couldn't match type ‘7 + i0’ with ‘6’
--         arising from a use of ‘setSlice’
--       The type variable ‘i0’ is ambiguous
--     • In the expression: setSlice d6 d5 0 (- 5 :: Signed 6)
--       In an equation for ‘it’: it = setSlice d6 d5 0 (- 5 :: Signed 6)
setSlice :: (BitPack a, BitSize a ~ ((m + 1) + i)) => SNat m -> SNat n
         -> BitVector (m + 1 - n) -> a -> a
setSlice :: SNat m -> SNat n -> BitVector ((m + 1) - n) -> a -> a
setSlice m :: SNat m
m n :: SNat n
n w :: BitVector ((m + 1) - n)
w v :: a
v = BitVector (BitSize a) -> a
forall a. BitPack a => BitVector (BitSize a) -> a
unpack (BitVector ((m + 1) + i)
-> SNat m
-> SNat n
-> BitVector ((m + 1) - n)
-> BitVector ((m + 1) + i)
forall (m :: Nat) (i :: Nat) (n :: Nat).
BitVector ((m + 1) + i)
-> SNat m
-> SNat n
-> BitVector ((m + 1) - n)
-> BitVector ((m + 1) + i)
setSlice# (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
v) SNat m
m SNat n
n BitVector ((m + 1) - n)
w)

{-# INLINE msb #-}
-- | Get the most significant bit.
--
-- >>> pack (-4 :: Signed 6)
-- 11_1100
-- >>> msb (-4 :: Signed 6)
-- 1
-- >>> pack (4 :: Signed 6)
-- 00_0100
-- >>> msb (4 :: Signed 6)
-- 0
msb :: (BitPack a, KnownNat (BitSize a)) => a -> Bit
msb :: a -> Bit
msb v :: a
v = BitVector (BitSize a) -> Bit
forall (n :: Nat). KnownNat n => BitVector n -> Bit
msb# (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
v)

{-# INLINE lsb #-}
-- | Get the least significant bit.
--
-- >>> pack (-9 :: Signed 6)
-- 11_0111
-- >>> lsb (-9 :: Signed 6)
-- 1
-- >>> pack (-8 :: Signed 6)
-- 11_1000
-- >>> lsb (-8 :: Signed 6)
-- 0
lsb :: BitPack a => a -> Bit
lsb :: a -> Bit
lsb v :: a
v = BitVector (BitSize a) -> Bit
forall (n :: Nat). BitVector n -> Bit
lsb# (a -> BitVector (BitSize a)
forall a. BitPack a => a -> BitVector (BitSize a)
pack a
v)