{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}

-- |
-- Module      :   Grisette.Core.Data.Class.BitVector
-- Copyright   :   (c) Sirui Lu 2021-2023
-- License     :   BSD-3-Clause (see the LICENSE file)
--
-- Maintainer  :   siruilu@cs.washington.edu
-- Stability   :   Experimental
-- Portability :   GHC only
module Grisette.Core.Data.Class.BitVector
  ( -- * Bit vector operations
    SomeBV (..),
    someBVZext',
    someBVSext',
    someBVExt',
    someBVSelect',
    someBVExtract,
    someBVExtract',
    SizedBV (..),
    sizedBVExtract,
  )
where

import Data.Proxy
import GHC.TypeNats
import Grisette.Utils.Parameterized

-- $setup
-- >>> import Grisette.Core
-- >>> import Grisette.IR.SymPrim
-- >>> import Grisette.Utils.Parameterized
-- >>> :set -XDataKinds
-- >>> :set -XBinaryLiterals
-- >>> :set -XFlexibleContexts
-- >>> :set -XFlexibleInstances
-- >>> :set -XFunctionalDependencies

-- | Bit vector operations. Including concatenation ('someBVConcat'),
-- extension ('someBVZext', 'someBVSext', 'someBVExt'), and selection
-- ('someBVSelect').
class SomeBV bv where
  -- | Concatenation of two bit vectors.
  --
  -- >>> someBVConcat (SomeSymWordN (0b101 :: SymWordN 3)) (SomeSymWordN (0b010 :: SymWordN 3))
  -- 0b101010
  someBVConcat :: bv -> bv -> bv

  -- | Zero extension of a bit vector.
  --
  -- >>> someBVZext (Proxy @6) (SomeSymWordN (0b101 :: SymWordN 3))
  -- 0b000101
  someBVZext ::
    forall p l.
    KnownNat l =>
    -- | Desired output length
    p l ->
    -- | Bit vector to extend
    bv ->
    bv

  -- | Sign extension of a bit vector.
  --
  -- >>> someBVSext (Proxy @6) (SomeSymWordN (0b101 :: SymWordN 3))
  -- 0b111101
  someBVSext ::
    forall p l.
    KnownNat l =>
    -- | Desired output length
    p l ->
    -- | Bit vector to extend
    bv ->
    bv

  -- | Extension of a bit vector.
  -- Signedness is determined by the input bit vector type.
  --
  -- >>> someBVExt (Proxy @6) (SomeSymIntN (0b101 :: SymIntN 3))
  -- 0b111101
  -- >>> someBVExt (Proxy @6) (SomeSymIntN (0b001 :: SymIntN 3))
  -- 0b000001
  -- >>> someBVExt (Proxy @6) (SomeSymWordN (0b101 :: SymWordN 3))
  -- 0b000101
  -- >>> someBVExt (Proxy @6) (SomeSymWordN (0b001 :: SymWordN 3))
  -- 0b000001
  someBVExt ::
    forall p l.
    KnownNat l =>
    -- | Desired output length
    p l ->
    -- | Bit vector to extend
    bv ->
    bv

  -- | Slicing out a smaller bit vector from a larger one,
  -- selecting a slice with width @w@ starting from index @ix@.
  --
  -- The least significant bit is indexed as 0.
  --
  -- >>> someBVSelect (Proxy @1) (Proxy @3) (SomeSymIntN (0b001010 :: SymIntN 6))
  -- 0b101
  someBVSelect ::
    forall p ix q w.
    (KnownNat ix, KnownNat w) =>
    -- | Index of the least significant bit of the slice
    p ix ->
    -- | Desired output width, @ix + w <= n@ must hold where @n@ is
    -- the size of the input bit vector
    q w ->
    -- | Bit vector to select from
    bv ->
    bv

-- | Zero extension of a bit vector.
--
-- >>> someBVZext' (natRepr @6) (SomeSymWordN (0b101 :: SymWordN 3))
-- 0b000101
someBVZext' ::
  forall l bv.
  SomeBV bv =>
  -- | Desired output length
  NatRepr l ->
  -- | Bit vector to extend
  bv ->
  bv
someBVZext' :: forall (l :: Nat) bv. SomeBV bv => NatRepr l -> bv -> bv
someBVZext' p :: NatRepr l
p@(NatRepr l
_ :: NatRepr l) = forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr NatRepr l
p) forall a b. (a -> b) -> a -> b
$ forall bv (p :: Nat -> *) (l :: Nat).
(SomeBV bv, KnownNat l) =>
p l -> bv -> bv
someBVZext (forall {k} (t :: k). Proxy t
Proxy @l)
{-# INLINE someBVZext' #-}

-- | Sign extension of a bit vector.
--
-- >>> someBVSext' (natRepr @6) (SomeSymWordN (0b101 :: SymWordN 3))
-- 0b111101
someBVSext' ::
  forall l bv.
  SomeBV bv =>
  NatRepr l ->
  -- | Desired output length
  bv ->
  -- | Bit vector to extend
  bv
someBVSext' :: forall (l :: Nat) bv. SomeBV bv => NatRepr l -> bv -> bv
someBVSext' p :: NatRepr l
p@(NatRepr l
_ :: NatRepr l) = forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr NatRepr l
p) forall a b. (a -> b) -> a -> b
$ forall bv (p :: Nat -> *) (l :: Nat).
(SomeBV bv, KnownNat l) =>
p l -> bv -> bv
someBVSext (forall {k} (t :: k). Proxy t
Proxy @l)
{-# INLINE someBVSext' #-}

-- | Extension of a bit vector.
-- Signedness is determined by the input bit vector type.
--
-- >>> someBVExt' (natRepr @6) (SomeSymIntN (0b101 :: SymIntN 3))
-- 0b111101
-- >>> someBVExt' (natRepr @6) (SomeSymIntN (0b001 :: SymIntN 3))
-- 0b000001
-- >>> someBVExt' (natRepr @6) (SomeSymWordN (0b101 :: SymWordN 3))
-- 0b000101
-- >>> someBVExt' (natRepr @6) (SomeSymWordN (0b001 :: SymWordN 3))
-- 0b000001
someBVExt' ::
  forall l bv.
  SomeBV bv =>
  -- | Desired output length
  NatRepr l ->
  -- | Bit vector to extend
  bv ->
  bv
someBVExt' :: forall (l :: Nat) bv. SomeBV bv => NatRepr l -> bv -> bv
someBVExt' p :: NatRepr l
p@(NatRepr l
_ :: NatRepr l) = forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr NatRepr l
p) forall a b. (a -> b) -> a -> b
$ forall bv (p :: Nat -> *) (l :: Nat).
(SomeBV bv, KnownNat l) =>
p l -> bv -> bv
someBVExt (forall {k} (t :: k). Proxy t
Proxy @l)
{-# INLINE someBVExt' #-}

-- | Slicing out a smaller bit vector from a larger one,
-- selecting a slice with width @w@ starting from index @ix@.
--
-- The least significant bit is indexed as 0.
--
-- >>> someBVSelect' (natRepr @1) (natRepr @3) (SomeSymIntN (0b001010 :: SymIntN 6))
-- 0b101
someBVSelect' ::
  forall ix w bv.
  SomeBV bv =>
  -- | Index of the least significant bit of the slice
  NatRepr ix ->
  -- | Desired output width, @ix + w <= n@ must hold where @n@ is
  -- the size of the input bit vector
  NatRepr w ->
  -- | Bit vector to select from
  bv ->
  bv
someBVSelect' :: forall (ix :: Nat) (w :: Nat) bv.
SomeBV bv =>
NatRepr ix -> NatRepr w -> bv -> bv
someBVSelect' p :: NatRepr ix
p@(NatRepr ix
_ :: NatRepr l) q :: NatRepr w
q@(NatRepr w
_ :: NatRepr r) = forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr NatRepr ix
p) forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr NatRepr w
q) forall a b. (a -> b) -> a -> b
$ forall bv (p :: Nat -> *) (ix :: Nat) (q :: Nat -> *) (w :: Nat).
(SomeBV bv, KnownNat ix, KnownNat w) =>
p ix -> q w -> bv -> bv
someBVSelect NatRepr ix
p NatRepr w
q
{-# INLINE someBVSelect' #-}

-- | Slicing out a smaller bit vector from a larger one, extract a slice from
-- bit @i@ down to @j@.
--
-- The least significant bit is indexed as 0.
--
-- >>> someBVExtract (Proxy @4) (Proxy @2) (SomeSymIntN (0b010100 :: SymIntN 6))
-- 0b101
someBVExtract ::
  forall p (i :: Nat) q (j :: Nat) bv.
  (SomeBV bv, KnownNat i, KnownNat j) =>
  -- | The start position to extract from, @i < n@ must hold where @n@ is
  -- the size of the output bit vector
  p i ->
  -- | The end position to extract from, @j <= i@ must hold
  q j ->
  -- | Bit vector to extract from
  bv ->
  bv
someBVExtract :: forall (p :: Nat -> *) (i :: Nat) (q :: Nat -> *) (j :: Nat) bv.
(SomeBV bv, KnownNat i, KnownNat j) =>
p i -> q j -> bv -> bv
someBVExtract p i
_ q j
_ =
  forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). Nat -> KnownProof n
unsafeKnownProof @(i - j + 1) (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @i)) forall a. Num a => a -> a -> a
- forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (forall {k} (t :: k). Proxy t
Proxy @j)) forall a. Num a => a -> a -> a
+ Nat
1)) forall a b. (a -> b) -> a -> b
$
    forall bv (p :: Nat -> *) (ix :: Nat) (q :: Nat -> *) (w :: Nat).
(SomeBV bv, KnownNat ix, KnownNat w) =>
p ix -> q w -> bv -> bv
someBVSelect (forall {k} (t :: k). Proxy t
Proxy @j) (forall {k} (t :: k). Proxy t
Proxy @(i - j + 1))
{-# INLINE someBVExtract #-}

-- | Slicing out a smaller bit vector from a larger one, extract a slice from
-- bit @i@ down to @j@.
--
-- The least significant bit is indexed as 0.
--
-- >>> someBVExtract' (natRepr @4) (natRepr @2) (SomeSymIntN (0b010100 :: SymIntN 6))
-- 0b101
someBVExtract' ::
  forall (i :: Nat) (j :: Nat) bv.
  SomeBV bv =>
  -- | The start position to extract from, @i < n@ must hold where @n@ is
  -- the size of the output bit vector
  NatRepr i ->
  -- | The end position to extract from, @j <= i@ must hold
  NatRepr j ->
  -- | Bit vector to extract from
  bv ->
  bv
someBVExtract' :: forall (ix :: Nat) (w :: Nat) bv.
SomeBV bv =>
NatRepr ix -> NatRepr w -> bv -> bv
someBVExtract' p :: NatRepr i
p@(NatRepr i
_ :: NatRepr l) q :: NatRepr j
q@(NatRepr j
_ :: NatRepr r) = forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr NatRepr i
p) forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) r. KnownProof n -> (KnownNat n => r) -> r
withKnownProof (forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr NatRepr j
q) forall a b. (a -> b) -> a -> b
$ forall (p :: Nat -> *) (i :: Nat) (q :: Nat -> *) (j :: Nat) bv.
(SomeBV bv, KnownNat i, KnownNat j) =>
p i -> q j -> bv -> bv
someBVExtract NatRepr i
p NatRepr j
q
{-# INLINE someBVExtract' #-}

-- | Sized bit vector operations. Including concatenation ('sizedBVConcat'),
-- extension ('sizedBVZext', 'sizedBVSext', 'sizedBVExt'), and selection
-- ('sizedBVSelect').
class SizedBV bv where
  -- | Concatenation of two bit vectors.
  --
  -- >>> sizedBVConcat (0b101 :: SymIntN 3) (0b010 :: SymIntN 3)
  -- 0b101010
  sizedBVConcat :: (KnownNat l, KnownNat r, 1 <= l, 1 <= r) => bv l -> bv r -> bv (l + r)

  -- | Zero extension of a bit vector.
  --
  -- >>> sizedBVZext (Proxy @6) (0b101 :: SymIntN 3)
  -- 0b000101
  sizedBVZext ::
    (KnownNat l, KnownNat r, 1 <= l, KnownNat r, l <= r) =>
    -- | Desired output width
    proxy r ->
    -- | Bit vector to extend
    bv l ->
    bv r

  -- | Signed extension of a bit vector.
  --
  -- >>> sizedBVSext (Proxy @6) (0b101 :: SymIntN 3)
  -- 0b111101
  sizedBVSext ::
    (KnownNat l, KnownNat r, 1 <= l, KnownNat r, l <= r) =>
    -- | Desired output width
    proxy r ->
    -- | Bit vector to extend
    bv l ->
    bv r

  -- | Extension of a bit vector.
  -- Signedness is determined by the input bit vector type.
  --
  -- >>> sizedBVExt (Proxy @6) (0b101 :: SymIntN 3)
  -- 0b111101
  -- >>> sizedBVExt (Proxy @6) (0b001 :: SymIntN 3)
  -- 0b000001
  -- >>> sizedBVExt (Proxy @6) (0b101 :: SymWordN 3)
  -- 0b000101
  -- >>> sizedBVExt (Proxy @6) (0b001 :: SymWordN 3)
  -- 0b000001
  sizedBVExt ::
    (KnownNat l, KnownNat r, 1 <= l, KnownNat r, l <= r) =>
    -- | Desired output width
    proxy r ->
    -- | Bit vector to extend
    bv l ->
    bv r

  -- | Slicing out a smaller bit vector from a larger one, selecting a slice with
  -- width @w@ starting from index @ix@.
  --
  -- The least significant bit is indexed as 0.
  --
  -- >>> sizedBVSelect (Proxy @2) (Proxy @3) (con 0b010100 :: SymIntN 6)
  -- 0b101
  sizedBVSelect ::
    (KnownNat n, KnownNat ix, KnownNat w, 1 <= n, 1 <= w, ix + w <= n) =>
    -- | Index of the least significant bit of the slice
    proxy ix ->
    -- | Desired output width, @ix + w <= n@ must hold where @n@ is
    -- the size of the input bit vector
    proxy w ->
    -- | Bit vector to select from
    bv n ->
    bv w

-- | Slicing out a smaller bit vector from a larger one, extract a slice from
-- bit @i@ down to @j@.
--
-- The least significant bit is indexed as 0.
--
-- >>> sizedBVExtract (Proxy @4) (Proxy @2) (con 0b010100 :: SymIntN 6)
-- 0b101
sizedBVExtract ::
  forall proxy i j n bv.
  (SizedBV bv, KnownNat n, KnownNat i, KnownNat j, 1 <= n, i + 1 <= n, j <= i) =>
  -- | The start position to extract from, @i < n@ must hold where @n@ is
  -- the size of the output bit vector
  proxy i ->
  -- | The end position to extract from, @j <= i@ must hold
  proxy j ->
  -- | Bit vector to extract from
  bv n ->
  bv (i - j + 1)
sizedBVExtract :: forall (proxy :: Nat -> *) (i :: Nat) (j :: Nat) (n :: Nat)
       (bv :: Nat -> *).
(SizedBV bv, KnownNat n, KnownNat i, KnownNat j, 1 <= n,
 (i + 1) <= n, j <= i) =>
proxy i -> proxy j -> bv n -> bv ((i - j) + 1)
sizedBVExtract proxy i
_ proxy j
_ =
  case ( forall (n :: Nat). NatRepr n -> KnownProof n
hasRepr (forall (m :: Nat) (n :: Nat).
NatRepr m -> NatRepr n -> NatRepr (m + n)
addNat (forall (n :: Nat) (m :: Nat).
(n <= m) =>
NatRepr m -> NatRepr n -> NatRepr (m - n)
subNat (forall (n :: Nat). KnownNat n => NatRepr n
natRepr @i) (forall (n :: Nat). KnownNat n => NatRepr n
natRepr @j)) (forall (n :: Nat). KnownNat n => NatRepr n
natRepr @1)),
         forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof @(j + (i - j + 1)) @n,
         forall (m :: Nat) (n :: Nat). LeqProof m n
unsafeLeqProof @1 @(i - j + 1)
       ) of
    (KnownProof ((i - j) + 1)
KnownProof, LeqProof (j + ((i - j) + 1)) n
LeqProof, LeqProof 1 ((i - j) + 1)
LeqProof) ->
      forall (bv :: Nat -> *) (n :: Nat) (ix :: Nat) (w :: Nat)
       (proxy :: Nat -> *).
(SizedBV bv, KnownNat n, KnownNat ix, KnownNat w, 1 <= n, 1 <= w,
 (ix + w) <= n) =>
proxy ix -> proxy w -> bv n -> bv w
sizedBVSelect (forall {k} (t :: k). Proxy t
Proxy @j) (forall {k} (t :: k). Proxy t
Proxy @(i - j + 1))
{-# INLINE sizedBVExtract #-}