{-|
Copyright   : (C) 2021-2024, QBayLogic B.V.
License     : BSD2 (see the file LICENSE)
Maintainer  : QBayLogic B.V. <devops@qbaylogic.com>

Random generation of BitVector.
-}

{-# OPTIONS_GHC -fplugin=GHC.TypeLits.KnownNat.Solver #-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}

module Clash.Hedgehog.Sized.BitVector
  ( genDefinedBit
  , genBit
  , genDefinedBitVector
  , genBitVector
  , SomeBitVector(..)
  , genSomeBitVector
  ) where

#if !MIN_VERSION_base(4,16,0)
import GHC.Natural (Natural)
#endif
import GHC.TypeNats
#if MIN_VERSION_base(4,18,0)
  hiding (SNat)
#endif
import Hedgehog (MonadGen, Range)
import Hedgehog.Internal.Range (constantBounded, constant)
import qualified Hedgehog.Gen as Gen

import Clash.Class.BitPack (pack)
import Clash.Promoted.Nat
import Clash.Sized.Internal.BitVector
import Clash.XException (errorX)

import Clash.Hedgehog.Sized.Unsigned

-- | Generate a bit which is guaranteed to be defined.
-- This will either have the value 'low' or 'high'.
--
genDefinedBit :: (MonadGen m) => m Bit
genDefinedBit :: m Bit
genDefinedBit = [Bit] -> m Bit
forall (f :: Type -> Type) (m :: Type -> Type) a.
(Foldable f, MonadGen m) =>
f a -> m a
Gen.element [Bit
low, Bit
high]

-- | Generate a bit which is not guaranteed to be defined.
-- This will either have the value 'low' or 'high', or throw an @XException@.
--
genBit :: (MonadGen m) => m Bit
genBit :: m Bit
genBit = [Bit] -> m Bit
forall (f :: Type -> Type) (m :: Type -> Type) a.
(Foldable f, MonadGen m) =>
f a -> m a
Gen.element [Bit
low, Bit
high, String -> Bit
forall a. HasCallStack => String -> a
errorX String
"X"]

-- | Generate a bit vector where all bits are defined.
--
genDefinedBitVector :: (MonadGen m, KnownNat n) => m (BitVector n)
genDefinedBitVector :: m (BitVector n)
genDefinedBitVector = Unsigned n -> BitVector n
forall a. BitPack a => a -> BitVector (BitSize a)
pack (Unsigned n -> BitVector n) -> m (Unsigned n) -> m (BitVector n)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Range (Unsigned n) -> m (Unsigned n)
forall (m :: Type -> Type) (n :: Nat).
(MonadGen m, KnownNat n) =>
Range (Unsigned n) -> m (Unsigned n)
genUnsigned Range (Unsigned n)
forall a. (Bounded a, Num a) => Range a
constantBounded

-- | Generate a bit vector where some bits may be undefined.
--
genBitVector :: forall m n . (MonadGen m, KnownNat n) => m (BitVector n)
genBitVector :: m (BitVector n)
genBitVector =
  [(Int, m (BitVector n))] -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => [(Int, m a)] -> m a
Gen.frequency
    [ (Int
70, Natural -> Natural -> BitVector n
forall (n :: Nat). Natural -> Natural -> BitVector n
BV (Natural -> Natural -> BitVector n)
-> m Natural -> m (Natural -> BitVector n)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> m Natural
genNatural m (Natural -> BitVector n) -> m Natural -> m (BitVector n)
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> m Natural
genNatural)
    , (Int
10, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall a. Bounded a => a
minBound)
    , (Int
10, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall a. Bounded a => a
maxBound)
    , (Int
10, BitVector n -> m (BitVector n)
forall (m :: Type -> Type) a. MonadGen m => a -> m a
Gen.constant BitVector n
forall (n :: Nat). KnownNat n => BitVector n
undefined#)
    ]
 where
  genNatural :: m Natural
genNatural = Range Natural -> m Natural
forall (m :: Type -> Type) a.
(MonadGen m, Integral a) =>
Range a -> m a
Gen.integral (Range Natural -> m Natural) -> Range Natural -> m Natural
forall a b. (a -> b) -> a -> b
$ Natural -> Natural -> Range Natural
forall a. a -> a -> Range a
constant Natural
0 (Natural
2Natural -> Natural -> Natural
forall a b. (Num a, Integral b) => a -> b -> a
^KnownNat n => Natural
forall (n :: Nat). KnownNat n => Natural
natToNatural @n Natural -> Natural -> Natural
forall a. Num a => a -> a -> a
- Natural
1)

data SomeBitVector atLeast where
  SomeBitVector :: SNat n -> BitVector (atLeast + n) -> SomeBitVector atLeast

instance KnownNat atLeast => Show (SomeBitVector atLeast) where
  show :: SomeBitVector atLeast -> String
show (SomeBitVector SNat n
SNat BitVector (atLeast + n)
bv) = BitVector (atLeast + n) -> String
forall a. Show a => a -> String
show BitVector (atLeast + n)
bv

genSomeBitVector
  :: forall m atLeast
   . (MonadGen m, KnownNat atLeast)
  => Range Natural
  -> (forall n. KnownNat n => m (BitVector n))
  -> m (SomeBitVector atLeast)
genSomeBitVector :: Range Natural
-> (forall (n :: Nat). KnownNat n => m (BitVector n))
-> m (SomeBitVector atLeast)
genSomeBitVector Range Natural
rangeBv forall (n :: Nat). KnownNat n => m (BitVector n)
genBv = do
  Natural
numExtra <- Range Natural -> m Natural
forall (m :: Type -> Type) a.
(MonadGen m, Integral a) =>
Range a -> m a
Gen.integral Range Natural
rangeBv

  case Natural -> SomeNat
someNatVal Natural
numExtra of
    SomeNat Proxy n
proxy -> SNat n -> BitVector (atLeast + n) -> SomeBitVector atLeast
forall (n :: Nat) (atLeast :: Nat).
SNat n -> BitVector (atLeast + n) -> SomeBitVector atLeast
SomeBitVector (Proxy n -> SNat n
forall (n :: Nat) (proxy :: Nat -> Type).
KnownNat n =>
proxy n -> SNat n
snatProxy Proxy n
proxy) (BitVector (atLeast + n) -> SomeBitVector atLeast)
-> m (BitVector (atLeast + n)) -> m (SomeBitVector atLeast)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> m (BitVector (atLeast + n))
forall (n :: Nat). KnownNat n => m (BitVector n)
genBv