{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralisedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}

module Clash.Class.Counter.Internal where

import Clash.CPP (maxTupleSize)

import Clash.Class.Counter.TH (genTupleInstances)
import Clash.Sized.BitVector (BitVector, Bit)
import Clash.Sized.Index (Index)
import Clash.Sized.Signed (Signed)
import Clash.Sized.Unsigned (Unsigned)
import Clash.Sized.Vector as Vec (Vec, repeat, mapAccumR)

import Data.Bifunctor (bimap)
import Data.Functor.Identity (Identity(..))
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Word (Word8, Word16, Word32, Word64)
import GHC.TypeLits (KnownNat, type (<=))

-- $setup
-- >>> :m -Prelude
-- >>> import Clash.Prelude
-- >>> import Clash.Class.Counter
-- >>> import Clash.Sized.BitVector (BitVector)
-- >>> import Clash.Sized.Index (Index)
-- >>> import Clash.Sized.Signed (Signed)
-- >>> import Clash.Sized.Unsigned (Unsigned)
-- >>> import Clash.Sized.Vector (Vec(..), iterate)

-- | t'Clash.Class.Counter.Counter' is a class that composes multiple counters
-- into a single one. It is similar to odometers found in olds cars,
-- once all counters reach their maximum they reset to zero - i.e. odometer
-- rollover. See 'Clash.Class.Counter.countSucc' and 'Clash.Class.Counter.countPred'
-- for API usage examples.
--
-- Example use case: when driving a monitor through VGA you would like to keep
-- track at least two counters: one counting a horizontal position, and one
-- vertical. Perhaps a fancy VGA driver would also like to keep track of the
-- number of drawn frames. To do so, the three counters are setup with different
-- types. On each /round/ of the horizontal counter the vertical counter should
-- be increased. On each /round/ of the vertical counter the frame counter should
-- be increased. With this class you could simply use the type:
--
-- @
-- (FrameCount, VerticalCount, HorizontalCount)
-- @
--
-- and have 'Clash.Class.Counter.countSucc' work as described.
--
class Counter a where
  -- | Value counter wraps around to on a 'countSuccOverflow' overflow
  countMin :: a
  default countMin :: Bounded a => a
  countMin = a
forall a. Bounded a => a
minBound

  -- | Value counter wraps around to on a 'countPredOverflow' overflow
  countMax :: a
  default countMax :: Bounded a => a
  countMax = a
forall a. Bounded a => a
maxBound

  -- | Gets the successor of @a@. If it overflows, the first part of the tuple
  -- will be set to True and the second part wraps around to `countMin`.
  countSuccOverflow :: a -> (Bool, a)
  default countSuccOverflow :: (Eq a, Enum a, Bounded a) => a -> (Bool, a)
  countSuccOverflow a
a
    | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
maxBound = (Bool
True, a
forall a. Counter a => a
countMin)
    | Bool
otherwise = (Bool
False, a -> a
forall a. Enum a => a -> a
succ a
a)

  -- | Gets the predecessor of @a@. If it underflows, the first part of the tuple
  -- will be set to True and the second part wraps around to `countMax`.
  countPredOverflow :: a -> (Bool, a)
  default countPredOverflow :: (Eq a, Enum a, Bounded a) => a -> (Bool, a)
  countPredOverflow a
a
    | a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Bounded a => a
minBound = (Bool
True, a
forall a. Counter a => a
countMax)
    | Bool
otherwise = (Bool
False, a -> a
forall a. Enum a => a -> a
pred a
a)

instance (1 <= n, KnownNat n) => Counter (Index n)
instance KnownNat n => Counter (Unsigned n)
instance KnownNat n => Counter (Signed n)
instance KnownNat n => Counter (BitVector n)

-- | @since 1.8.2
instance Counter Bool
-- | @since 1.8.2
instance Counter Bit
-- | @since 1.8.2
instance Counter Int
-- | @since 1.8.2
instance Counter Int8
-- | @since 1.8.2
instance Counter Int16
-- | @since 1.8.2
instance Counter Int32
-- | @since 1.8.2
instance Counter Int64
-- | @since 1.8.2
instance Counter Word
-- | @since 1.8.2
instance Counter Word8
-- | @since 1.8.2
instance Counter Word16
-- | @since 1.8.2
instance Counter Word32
-- | @since 1.8.2
instance Counter Word64

-- | @since 1.8.2
deriving newtype instance Counter a => Counter (Identity a)

-- | 'Nothing' is considered the minimum value, while @'Just' 'countMax'@ is
-- considered the maximum value.
--
-- @since 1.8.2
instance Counter a => Counter (Maybe a) where
  countMin :: Maybe a
countMin = Maybe a
forall a. Maybe a
Nothing
  countMax :: Maybe a
countMax = a -> Maybe a
forall a. a -> Maybe a
Just a
forall a. Counter a => a
countMax

  countSuccOverflow :: Maybe a -> (Bool, Maybe a)
countSuccOverflow = \case
    Maybe a
Nothing -> (Bool
False, a -> Maybe a
forall a. a -> Maybe a
Just a
forall a. Counter a => a
countMin)
    Just a
a0 ->
      case a -> (Bool, a)
forall a. Counter a => a -> (Bool, a)
countSuccOverflow a
a0 of
        (Bool
True, a
_) -> (Bool
True, Maybe a
forall a. Maybe a
Nothing)
        (Bool
False, a
a1) -> (Bool
False, a -> Maybe a
forall a. a -> Maybe a
Just a
a1)

  countPredOverflow :: Maybe a -> (Bool, Maybe a)
countPredOverflow = \case
    Maybe a
Nothing -> (Bool
True, a -> Maybe a
forall a. a -> Maybe a
Just a
forall a. Counter a => a
countMax)
    Just a
a0 ->
      case a -> (Bool, a)
forall a. Counter a => a -> (Bool, a)
countPredOverflow a
a0 of
        (Bool
True, a
_) -> (Bool
False, Maybe a
forall a. Maybe a
Nothing)
        (Bool
False, a
a1) -> (Bool
False, a -> Maybe a
forall a. a -> Maybe a
Just a
a1)

-- | Counter instance that flip-flops between 'Left' and 'Right'. Examples:
--
-- >>> type T = Either (Index 2) (Unsigned 2)
-- >>> countSucc @T (Left 0)
-- Left 1
-- >>> countSucc @T (Left 1)
-- Right 0
-- >>> countSucc @T (Right 0)
-- Right 1
instance (Counter a, Counter b) => Counter (Either a b) where
  countMin :: Either a b
countMin = a -> Either a b
forall a b. a -> Either a b
Left a
forall a. Counter a => a
countMin
  countMax :: Either a b
countMax = b -> Either a b
forall a b. b -> Either a b
Right b
forall a. Counter a => a
countMax

  countSuccOverflow :: Either a b -> (Bool, Either a b)
countSuccOverflow Either a b
e =
    case (a -> (Bool, a))
-> (b -> (Bool, b)) -> Either a b -> Either (Bool, a) (Bool, b)
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap a -> (Bool, a)
forall a. Counter a => a -> (Bool, a)
countSuccOverflow b -> (Bool, b)
forall a. Counter a => a -> (Bool, a)
countSuccOverflow Either a b
e of
      Left (Bool
overflow, a
a)  -> (Bool
False, if Bool
overflow then b -> Either a b
forall a b. b -> Either a b
Right b
forall a. Counter a => a
countMin else a -> Either a b
forall a b. a -> Either a b
Left a
a)
      Right (Bool
overflow, b
b) -> (Bool
overflow, if Bool
overflow then a -> Either a b
forall a b. a -> Either a b
Left a
forall a. Counter a => a
countMin else b -> Either a b
forall a b. b -> Either a b
Right b
b)

  countPredOverflow :: Either a b -> (Bool, Either a b)
countPredOverflow Either a b
e =
    case (a -> (Bool, a))
-> (b -> (Bool, b)) -> Either a b -> Either (Bool, a) (Bool, b)
forall (p :: Type -> Type -> Type) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap a -> (Bool, a)
forall a. Counter a => a -> (Bool, a)
countPredOverflow b -> (Bool, b)
forall a. Counter a => a -> (Bool, a)
countPredOverflow Either a b
e of
      Left (Bool
overflow, a
a)  -> (Bool
overflow, if Bool
overflow then b -> Either a b
forall a b. b -> Either a b
Right b
forall a. Counter a => a
countMax else a -> Either a b
forall a b. a -> Either a b
Left a
a)
      Right (Bool
overflow, b
b) -> (Bool
False, if Bool
overflow then a -> Either a b
forall a b. a -> Either a b
Left a
forall a. Counter a => a
countMax else b -> Either a b
forall a b. b -> Either a b
Right b
b)

-- | Counters on tuples increment from right-to-left. This makes sense from the
-- perspective of LSB/MSB; MSB is on the left-hand-side and LSB is on the
-- right-hand-side in other Clash types.
--
-- >>> type T = (Unsigned 2, Index 2, Index 2)
-- >>> countSucc @T (0, 0, 0)
-- (0,0,1)
-- >>> countSucc @T (0, 0, 1)
-- (0,1,0)
-- >>> countSucc @T (0, 1, 0)
-- (0,1,1)
-- >>> countSucc @T (0, 1, 1)
-- (1,0,0)
--
-- __NB__: The documentation only shows the instances up to /3/-tuples. By
-- default, instances up to and including /12/-tuples will exist. If the flag
-- @large-tuples@ is set instances up to the GHC imposed limit will exist. The
-- GHC imposed limit is either 62 or 64 depending on the GHC version.
instance (Counter a0, Counter a1) => Counter (a0, a1) where
  -- a0/a1 instead of a/b to be consistent with TH generated instances
  countMin :: (a0, a1)
countMin = (a0
forall a. Counter a => a
countMin, a1
forall a. Counter a => a
countMin)
  countMax :: (a0, a1)
countMax = (a0
forall a. Counter a => a
countMax, a1
forall a. Counter a => a
countMax)

  countSuccOverflow :: (a0, a1) -> (Bool, (a0, a1))
countSuccOverflow (a0
a0, a1
b0) =
    if Bool
overflowB
    then (Bool
overflowA, (a0
a1, a1
b1))
    else (Bool
overflowB, (a0
a0, a1
b1))
   where
    (Bool
overflowB, a1
b1) = a1 -> (Bool, a1)
forall a. Counter a => a -> (Bool, a)
countSuccOverflow a1
b0
    (Bool
overflowA, a0
a1) = a0 -> (Bool, a0)
forall a. Counter a => a -> (Bool, a)
countSuccOverflow a0
a0

  countPredOverflow :: (a0, a1) -> (Bool, (a0, a1))
countPredOverflow (a0
a0, a1
b0) =
    if Bool
overflowB
    then (Bool
overflowA, (a0
a1, a1
b1))
    else (Bool
overflowB, (a0
a0, a1
b1))
   where
    (Bool
overflowB, a1
b1) = a1 -> (Bool, a1)
forall a. Counter a => a -> (Bool, a)
countPredOverflow a1
b0
    (Bool
overflowA, a0
a1) = a0 -> (Bool, a0)
forall a. Counter a => a -> (Bool, a)
countPredOverflow a0
a0

genTupleInstances maxTupleSize

rippleR :: (a -> (Bool, a)) -> Vec n a -> (Bool, Vec n a)
rippleR :: (a -> (Bool, a)) -> Vec n a -> (Bool, Vec n a)
rippleR a -> (Bool, a)
f = (Bool -> a -> (Bool, a)) -> Bool -> Vec n a -> (Bool, Vec n a)
forall acc x y (n :: Nat).
(acc -> x -> (acc, y)) -> acc -> Vec n x -> (acc, Vec n y)
mapAccumR Bool -> a -> (Bool, a)
step Bool
True
  where
    step :: Bool -> a -> (Bool, a)
step Bool
carry a
x = if Bool
carry then a -> (Bool, a)
f a
x else (Bool
False, a
x)

-- | Counters on vectors increment from right to left.
--
-- >>> type T = Vec 2 (Index 10)
-- >>> countSucc @T (0 :> 0 :> Nil)
-- 0 :> 1 :> Nil
-- >>> countSucc @T (0 :> 1 :> Nil)
-- 0 :> 2 :> Nil
-- >>> countSucc @T (0 :> 9 :> Nil)
-- 1 :> 0 :> Nil
-- >>> iterate (SNat @5) (countSucc @T) (9 :> 8 :> Nil)
-- (9 :> 8 :> Nil) :> (9 :> 9 :> Nil) :> (0 :> 0 :> Nil) :> (0 :> 1 :> Nil) :> (0 :> 2 :> Nil) :> Nil
instance (Counter a, KnownNat n, 1 <= n) => Counter (Vec n a) where
    countMin :: Vec n a
countMin = a -> Vec n a
forall (n :: Nat) a. KnownNat n => a -> Vec n a
Vec.repeat a
forall a. Counter a => a
countMin
    countMax :: Vec n a
countMax = a -> Vec n a
forall (n :: Nat) a. KnownNat n => a -> Vec n a
Vec.repeat a
forall a. Counter a => a
countMax

    countSuccOverflow :: Vec n a -> (Bool, Vec n a)
countSuccOverflow = (a -> (Bool, a)) -> Vec n a -> (Bool, Vec n a)
forall a (n :: Nat). (a -> (Bool, a)) -> Vec n a -> (Bool, Vec n a)
rippleR a -> (Bool, a)
forall a. Counter a => a -> (Bool, a)
countSuccOverflow
    countPredOverflow :: Vec n a -> (Bool, Vec n a)
countPredOverflow = (a -> (Bool, a)) -> Vec n a -> (Bool, Vec n a)
forall a (n :: Nat). (a -> (Bool, a)) -> Vec n a -> (Bool, Vec n a)
rippleR a -> (Bool, a)
forall a. Counter a => a -> (Bool, a)
countPredOverflow