{-# LANGUAGE UndecidableInstances #-}
module Data.ByteString.IsoBaseFileFormat.Util.BitRecords where

import Data.Kind
import Data.Word
import Data.Type.Bool
import GHC.TypeLits
import Data.Bits
import Data.Proxy
import Test.TypeSpecCrazy

data Field :: Nat -> Type
data (:=>) :: k -> Type -> Type
data (:*:) :: Type -> Type -> Type
type FieldPosition = (Nat, Nat)

type Flag = Field 1

infixr 6 :=>
infixl 5 :*:

-- nested fields
data (:/) :: Symbol -> k -> Type
infixr 7 :/


type family
  GetFieldSize (f :: l) :: Nat where
  GetFieldSize (label :=> f) = GetFieldSize f
  GetFieldSize (Field n ) = n
  GetFieldSize (l :*: r) = GetFieldSize l + GetFieldSize r

type family
  HasField (f :: fk) (l :: lk) :: Bool where
  HasField (l :=> f) l = 'True
  HasField (l :=> f) (l :/ p) = HasField f p
  HasField (f1 :*: f2) l = HasField f1 l || HasField f2 l
  HasField f l = 'False

type family
  HasFieldConstraint (label :: lk) (field :: fk) :: Constraint where
  HasFieldConstraint l f =
      If (HasField f l)
         (HasField f l ~ 'True)
         (TypeError ('Text "Label not found: '"
                     ':<>: 'ShowType l
                     ':<>: 'Text "' in:"
                     ':$$: 'ShowType f ))

type family
  FocusOn (l :: lk) (f :: fk) :: Result fk where
    FocusOn l f =
      If (HasField f l)
         ('Right (FocusOnUnsafe l f))
         ('Left ('Text "Label not found. Cannot focus '"
                     ':<>: 'ShowType l
                     ':<>: 'Text "' in:"
                     ':$$: 'ShowType f ))

type family
  FocusOnUnsafe (l :: lk) (f :: fk) :: fk where
  FocusOnUnsafe l        (l :=> f) = f
  FocusOnUnsafe (l :/ p) (l :=> f) = FocusOnUnsafe p f
  FocusOnUnsafe l        (f :*: f') = FocusOnUnsafe l (If (HasField f l) f f')

-- field location and access

type family
  GetFieldPosition (f :: field) (l :: label) :: Result FieldPosition where
  GetFieldPosition f l =
     If (HasField f l)
       ('Right (GetFieldPositionUnsafe f l))
       ('Left ('Text "Label not found. Cannot get bit range for '"
          ':<>: 'ShowType l
          ':<>: 'Text "' in:"
          ':$$: 'ShowType f ))

type family
  GetFieldPositionUnsafe (f :: field) (l :: label) :: FieldPosition where
  GetFieldPositionUnsafe (l :=> f)  l        = '(0, GetFieldSize f - 1)
  GetFieldPositionUnsafe (l :=> f)  (l :/ p) = GetFieldPositionUnsafe f p
  GetFieldPositionUnsafe (f :*: f') l        =
     If (HasField f l)
      (GetFieldPositionUnsafe f l)
      (AddToFieldPosition (GetFieldSize f) (GetFieldPositionUnsafe f' l))

type family
  AddToFieldPosition (v :: Nat) (e :: (Nat, Nat)) :: (Nat, Nat) where
  AddToFieldPosition v '(a,b) = '(a + v, b + v)

type family
  IsFieldPostition (pos :: FieldPosition) :: Constraint where
  IsFieldPostition '(a, b) =
    If (a <=? b)
       (a <= b, KnownNat a, KnownNat b)
       (TypeError
         ('Text "Bad field position: " ':<>: 'ShowType '(a,b)
          ':$$: 'Text "First index greater than last: "
          ':<>: 'ShowType a
          ':<>: 'Text " > "
          ':<>: 'ShowType b ))

type family
  FieldPostitionToList (pos :: FieldPosition) :: [Nat] where
    FieldPostitionToList '(a, a) = '[a]
    FieldPostitionToList '(a, b) = (a ': (FieldPostitionToList '(a+1, b)))

type family
  AlignField (a :: Nat) (f :: field) :: Result field where
  AlignField 0 f = 'Left ('Text "Invalid alignment of 0")
  AlignField a f = 'Right (AddPadding ((a - (GetFieldSize f `Rem` a)) `Rem` a) f)

type family
  AddPadding (n :: Nat) (f :: field) :: field where
  AddPadding 0 f = f
  AddPadding n f = f :*: Field n

-- | Get the remainder of the integer division of x and y, such that @forall x
-- y. exists k. (Rem x y) == x - y * k@ The algorithm is: count down x
-- until zero, incrementing the accumulator at each step. Whenever the
-- accumulator is equal to y set it to zero.
--
-- If the accumulator has reached y reset it. It is important to do this
-- BEFORE checking if x == y and then returning the accumulator, for the case
-- where x = k * y with k > 0. For example:
--
-- @
--  6 `Rem` 3     = RemImpl 6 3 0
--  RemImpl 6 3 0 = RemImpl (6-1) 3 (0+1)   -- RemImpl Clause 4
--  RemImpl 5 3 1 = RemImpl (5-1) 3 (1+1)   -- RemImpl Clause 4
--  RemImpl 4 3 2 = RemImpl (4-1) 3 (2+1)   -- RemImpl Clause 4
--  RemImpl 3 3 3 = RemImpl 3 3 0           -- RemImpl Clause 2 !!!
--  RemImpl 3 3 0 = 0                       -- RemImpl Clause 3 !!!
-- @
--
type Rem (x :: Nat) (y :: Nat) = RemImpl x y 0
type family
  RemImpl (x :: Nat) (y :: nat) (acc :: Nat) :: Nat where
  -- finished if x was < y:
  RemImpl 0 y acc = acc
  RemImpl x y y   = RemImpl x y 0
  -- finished if x was >= y:
  RemImpl y y acc = acc
  -- the base case
  RemImpl x y acc = RemImpl (x - 1) y (acc + 1)

getFlag
  :: forall a (path :: k) (first :: Nat) field p1 p2
  . ( IsFieldC path field first first
    , Bits a )
   => p1 path -> p2 field -> a -> Bool
getFlag _ _ a = testBit a pos
    where pos = fromIntegral $ natVal (Proxy :: Proxy first)

setFlag
  :: forall a (path :: k) (first :: Nat) field p1 p2
  . ( IsFieldC path field first first
    , Bits a )
   => p1 path -> p2 field -> Bool -> a -> a
setFlag _ _ v a = modifyBit a pos
    where pos = fromIntegral $ natVal (Proxy :: Proxy first)
          modifyBit = if v then setBit else clearBit

getField
  :: forall a b (path :: k) (first :: Nat) (last :: Nat) field pxy1 pxy2
  . ( IsFieldC path field first last
    , Integral a
    , Bits a
    , Num b)
   => pxy1 path -> pxy2 field -> a -> b
getField _ _ a = fromIntegral ((a `shiftR` posFirst) .&. bitMask)
    where
      bitMask =
        let bitCount = 1 + posLast - posFirst
            in (2 ^ bitCount) - 1
      posFirst = fromIntegral $ natVal (Proxy :: Proxy first)
      posLast = fromIntegral $ natVal (Proxy :: Proxy last)

setField
  :: forall a b (path :: k) (first :: Nat) (last :: Nat) field pxy1 pxy2
  . ( IsFieldC path field first last
    , Num a
    , Bits a
    , Integral b)
   => pxy1 path -> pxy2 field -> b -> a -> a
setField _ _ v x = (x .&. bitMaskField) .|. (v' `shiftL` posFirst)
    where
      v' = bitMaskValue .&. fromIntegral v
      bitMaskField = complement (bitMaskValue `shiftL` posFirst)
      bitMaskValue =
        let bitCount = 1 + posLast - posFirst
            in (2 ^ bitCount) - 1
      posFirst = fromIntegral $ natVal (Proxy :: Proxy first)
      posLast = fromIntegral $ natVal (Proxy :: Proxy last)



type Foo =
       "foo" :=> Flag
   :*:           Field 4
   :*: "bar" :=> Field 2
   :*:           Field 4
   :*: "baz" :=> Field 17

type IsFieldC name field first last =
    ( name `HasFieldConstraint` field
     , KnownNat first
     , KnownNat last
     , 'Right '(first, last) ~ (GetFieldPosition field name)
     )

getFooField :: IsFieldC name Foo first last
   => proxy name -> Word64 -> Word64
getFooField px = getField px (Proxy :: Proxy Foo)