-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.BitPrecise.PEXT_PDEP
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
--
-- Models the x86 [PEXT](https://www.felixcloutier.com/x86/pext) and [PDEP](https://www.felixcloutier.com/x86/pdep) instructions.
--
-- The pseudo-code implementation given by Intel for PEXT (parallel extract) is:
--
-- @
--    TEMP := SRC1;
--    MASK := SRC2;
--    DEST := 0 ;
--    m := 0, k := 0;
--    DO WHILE m < OperandSize
--        IF MASK[m] = 1 THEN
--            DEST[k] := TEMP[m];
--            k := k+ 1;
--        FI
--        m := m+ 1;
--    OD
-- @
--
-- PDEP (parallel deposit) is similar, except the assigment is:
--
-- @
--    DEST[m] := TEMP[k]
-- @
--
-- In PEXT, we grab the values of the source corresponding to the mask, and pile them into the destination from the bottom. In PDEP, we
-- do the reverse: We distribute the bits from the bottom of the source to the destination according to the mask.
-----------------------------------------------------------------------------

{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.BitPrecise.PEXT_PDEP where

import Data.SBV
import GHC.TypeLits (KnownNat)

-- $setup
-- >>> -- For doctest purposes only:
-- >>> import Data.SBV
-- >>> :set -XDataKinds

--------------------------------------------------------------------------------------------------
-- * Parallel extraction
--------------------------------------------------------------------------------------------------

-- | Parallel extraction: Given a source value and a mask, extract the bits in the source that are
-- pointed to by the mask, and put it in the destination starting from the bottom.
--
-- >>> satWith z3{printBase = 16} $ \r -> r .== pext (0xAA :: SWord 8) 0xAA
-- Satisfiable. Model:
--   s0 = 0x0f :: Word8
-- >>> prove $ \x -> pext @8 x 0 .== 0
-- Q.E.D.
-- >>> prove $ \x -> pext @8 x (complement 0) .== x
-- Q.E.D.
pext :: forall n. (KnownNat n, BVIsNonZero n) => SWord n -> SWord n -> SWord n
pext :: forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
SWord n -> SWord n -> SWord n
pext SBV (WordN n)
src SBV (WordN n)
mask = SBV (WordN n)
-> SBV (WordN n) -> SBV (WordN n) -> [SBool] -> SBV (WordN n)
forall {a} {a}.
(Integral a, SFiniteBits a, SFiniteBits a) =>
SBV a -> SBV a -> SBV a -> [SBool] -> SBV a
walk SBV (WordN n)
0 SBV (WordN n)
src SBV (WordN n)
0 (SBV (WordN n) -> [SBool]
forall a. SFiniteBits a => SBV a -> [SBool]
blastLE SBV (WordN n)
mask)
  where walk :: SBV a -> SBV a -> SBV a -> [SBool] -> SBV a
walk SBV a
dest SBV a
_ SBV a
_   []     = SBV a
dest
        walk SBV a
dest SBV a
x SBV a
idx (SBool
m:[SBool]
ms) = SBV a -> SBV a -> SBV a -> [SBool] -> SBV a
walk (SBool -> SBV a -> SBV a -> SBV a
forall a. Mergeable a => SBool -> a -> a -> a
ite SBool
m (SBV a -> SBV a -> SBool -> SBV a
forall a.
(SFiniteBits a, Integral a) =>
SBV a -> SBV a -> SBool -> SBV a
sSetBitTo SBV a
dest SBV a
idx (SBV a -> SBool
forall a. SFiniteBits a => SBV a -> SBool
lsb SBV a
x)) SBV a
dest)
                                      (SBV a
x SBV a -> Int -> SBV a
forall a. Bits a => a -> Int -> a
`shiftR` Int
1)
                                      (SBool -> SBV a -> SBV a -> SBV a
forall a. Mergeable a => SBool -> a -> a -> a
ite SBool
m (SBV a
idx SBV a -> SBV a -> SBV a
forall a. Num a => a -> a -> a
+ SBV a
1) SBV a
idx)
                                      [SBool]
ms

--------------------------------------------------------------------------------------------------
-- * Parallel deposit
--------------------------------------------------------------------------------------------------

-- | Parallel deposit: Given a source value and a mask, write into the destination that are
-- allowed by the mask, grabbing the bits from the source starting from the bottom.
--
-- >>> satWith z3{printBase = 16} $ \r -> r .== pdep (0xFF :: SWord 8) 0xAA
-- Satisfiable. Model:
--   s0 = 0xaa :: Word8
-- >>> prove $ \x -> pdep @8 x 0 .== 0
-- Q.E.D.
-- >>> prove $ \x -> pdep @8 x (complement 0) .== x
-- Q.E.D.
pdep :: forall n. (KnownNat n, BVIsNonZero n) => SWord n -> SWord n -> SWord n
pdep :: forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
SWord n -> SWord n -> SWord n
pdep SBV (WordN n)
src SBV (WordN n)
mask = SBV (WordN n)
-> SBV (WordN n) -> SBV (WordN n) -> [SBool] -> SBV (WordN n)
forall {a} {a}.
(Integral a, SFiniteBits a, SFiniteBits a) =>
SBV a -> SBV a -> SBV a -> [SBool] -> SBV a
walk SBV (WordN n)
0 SBV (WordN n)
src SBV (WordN n)
0 (SBV (WordN n) -> [SBool]
forall a. SFiniteBits a => SBV a -> [SBool]
blastLE SBV (WordN n)
mask)
  where walk :: SBV a -> SBV a -> SBV a -> [SBool] -> SBV a
walk SBV a
dest SBV a
_ SBV a
_   []     = SBV a
dest
        walk SBV a
dest SBV a
x SBV a
idx (SBool
m:[SBool]
ms) = SBV a -> SBV a -> SBV a -> [SBool] -> SBV a
walk (SBool -> SBV a -> SBV a -> SBV a
forall a. Mergeable a => SBool -> a -> a -> a
ite SBool
m (SBV a -> SBV a -> SBool -> SBV a
forall a.
(SFiniteBits a, Integral a) =>
SBV a -> SBV a -> SBool -> SBV a
sSetBitTo SBV a
dest SBV a
idx (SBV a -> SBool
forall a. SFiniteBits a => SBV a -> SBool
lsb SBV a
x)) SBV a
dest)
                                      (SBool -> SBV a -> SBV a -> SBV a
forall a. Mergeable a => SBool -> a -> a -> a
ite SBool
m (SBV a
x SBV a -> Int -> SBV a
forall a. Bits a => a -> Int -> a
`shiftR` Int
1) SBV a
x)
                                      (SBV a
idx SBV a -> SBV a -> SBV a
forall a. Num a => a -> a -> a
+ SBV a
1)
                                      [SBool]
ms
--------------------------------------------------------------------------------------------------
-- * Round-trip property
--------------------------------------------------------------------------------------------------

-- | Prove that extraction and depositing with the same mask restore the source in all masked positions:
--
-- >>> extractThenDeposit
-- Q.E.D.
extractThenDeposit :: IO ThmResult
extractThenDeposit :: IO ThmResult
extractThenDeposit = SymbolicT IO SBool -> IO ThmResult
forall a. Provable a => a -> IO ThmResult
prove (SymbolicT IO SBool -> IO ThmResult)
-> SymbolicT IO SBool -> IO ThmResult
forall a b. (a -> b) -> a -> b
$ do SWord 8
x :: SWord 8 <- String -> Symbolic (SWord 8)
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
String -> Symbolic (SWord n)
sWord String
"x"
                                SWord 8
m :: SWord 8 <- String -> Symbolic (SWord 8)
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
String -> Symbolic (SWord n)
sWord String
"m"
                                SBool -> SymbolicT IO SBool
forall a. a -> SymbolicT IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SBool -> SymbolicT IO SBool) -> SBool -> SymbolicT IO SBool
forall a b. (a -> b) -> a -> b
$ (SWord 8
x SWord 8 -> SWord 8 -> SWord 8
forall a. Bits a => a -> a -> a
.&. SWord 8
m) SWord 8 -> SWord 8 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 8 -> SWord 8 -> SWord 8
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
SWord n -> SWord n -> SWord n
pdep (SWord 8 -> SWord 8 -> SWord 8
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
SWord n -> SWord n -> SWord n
pext SWord 8
x SWord 8
m) SWord 8
m

-- | Prove that depositing and extracting with the same mask will push preserve the bottom
-- n-bits of the source, where n is the number of bits set in the mask.
--
-- >>> depositThenExtract
-- Q.E.D.
depositThenExtract :: IO ThmResult
depositThenExtract :: IO ThmResult
depositThenExtract = SymbolicT IO SBool -> IO ThmResult
forall a. Provable a => a -> IO ThmResult
prove (SymbolicT IO SBool -> IO ThmResult)
-> SymbolicT IO SBool -> IO ThmResult
forall a b. (a -> b) -> a -> b
$ do SWord 8
x :: SWord 8 <- String -> Symbolic (SWord 8)
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
String -> Symbolic (SWord n)
sWord String
"x"
                                SWord 8
m :: SWord 8 <- String -> Symbolic (SWord 8)
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
String -> Symbolic (SWord n)
sWord String
"m"
                                let preserved :: SWord 8
preserved = SWord 8
2 SWord 8 -> SBV Word8 -> SWord 8
forall b e. (Mergeable b, Num b, SIntegral e) => b -> SBV e -> b
.^ SWord 8 -> SBV Word8
forall a. SFiniteBits a => SBV a -> SBV Word8
sPopCount SWord 8
m SWord 8 -> SWord 8 -> SWord 8
forall a. Num a => a -> a -> a
- SWord 8
1
                                SBool -> SymbolicT IO SBool
forall a. a -> SymbolicT IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SBool -> SymbolicT IO SBool) -> SBool -> SymbolicT IO SBool
forall a b. (a -> b) -> a -> b
$ (SWord 8
x SWord 8 -> SWord 8 -> SWord 8
forall a. Bits a => a -> a -> a
.&. SWord 8
preserved) SWord 8 -> SWord 8 -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SWord 8 -> SWord 8 -> SWord 8
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
SWord n -> SWord n -> SWord n
pext (SWord 8 -> SWord 8 -> SWord 8
forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
SWord n -> SWord n -> SWord n
pdep SWord 8
x SWord 8
m) SWord 8
m

--------------------------------------------------------------------------------------------------
-- * Code generation
--------------------------------------------------------------------------------------------------

-- | We can generate the code for these functions if they need to be used in SMTLib. Below
-- is an example at 2-bits, which can be adjusted to produce any bit-size.
--
-- >>> putStrLn =<< sbv2smt pext_2
-- ; Automatically generated by SBV. Do not modify!
-- ; pext_2 :: SWord 2 -> SWord 2 -> SWord 2
-- (define-fun pext_2 ((l1_s0 (_ BitVec 2)) (l1_s1 (_ BitVec 2))) (_ BitVec 2)
--   (let ((l1_s3 #b0))
--   (let ((l1_s7 #b01))
--   (let ((l1_s8 #b00))
--   (let ((l1_s20 #b10))
--   (let ((l1_s2 ((_ extract 1 1) l1_s1)))
--   (let ((l1_s4 (distinct l1_s2 l1_s3)))
--   (let ((l1_s5 ((_ extract 0 0) l1_s1)))
--   (let ((l1_s6 (distinct l1_s3 l1_s5)))
--   (let ((l1_s9 (ite l1_s6 l1_s7 l1_s8)))
--   (let ((l1_s10 (= l1_s7 l1_s9)))
--   (let ((l1_s11 (bvlshr l1_s0 l1_s7)))
--   (let ((l1_s12 ((_ extract 0 0) l1_s11)))
--   (let ((l1_s13 (distinct l1_s3 l1_s12)))
--   (let ((l1_s14 (= l1_s8 l1_s9)))
--   (let ((l1_s15 ((_ extract 0 0) l1_s0)))
--   (let ((l1_s16 (distinct l1_s3 l1_s15)))
--   (let ((l1_s17 (ite l1_s16 l1_s7 l1_s8)))
--   (let ((l1_s18 (ite l1_s6 l1_s17 l1_s8)))
--   (let ((l1_s19 (bvor l1_s7 l1_s18)))
--   (let ((l1_s21 (bvand l1_s18 l1_s20)))
--   (let ((l1_s22 (ite l1_s13 l1_s19 l1_s21)))
--   (let ((l1_s23 (ite l1_s14 l1_s22 l1_s18)))
--   (let ((l1_s24 (bvor l1_s20 l1_s23)))
--   (let ((l1_s25 (bvand l1_s7 l1_s23)))
--   (let ((l1_s26 (ite l1_s13 l1_s24 l1_s25)))
--   (let ((l1_s27 (ite l1_s10 l1_s26 l1_s23)))
--   (let ((l1_s28 (ite l1_s4 l1_s27 l1_s18)))
--   l1_s28))))))))))))))))))))))))))))
pext_2 :: SWord 2 -> SWord 2 -> SWord 2
pext_2 :: SWord 2 -> SWord 2 -> SWord 2
pext_2 = String
-> (SWord 2 -> SWord 2 -> SWord 2) -> SWord 2 -> SWord 2 -> SWord 2
forall a.
(SMTDefinable a, Lambda (SymbolicT IO) a) =>
String -> a -> a
smtFunction String
"pext_2" (forall (n :: Nat).
(KnownNat n, BVIsNonZero n) =>
SWord n -> SWord n -> SWord n
pext @2)