-- Util.hs ---

-- Copyright (C) 2020 Nerd Ed

-- Author: Nerd Ed <nerded.nerded@gmail.com>

-- This program is free software; you can redistribute it and/or
-- modify it under the terms of the GNU General Public License
-- as published by the Free Software Foundation; either version 3
-- of the License, or (at your option) any later version.

-- This program is distributed in the hope that it will be useful,
-- but WITHOUT ANY WARRANTY; without even the implied warranty of
-- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-- GNU General Public License for more details.

-- You should have received a copy of the GNU General Public License
-- along with this program. If not, see <http://www.gnu.org/licenses/>.

{-# LANGUAGE DerivingStrategies  #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}

-- |
-- = This handy module extend Storable typeclasse with default instances for C-like enums/fixed arrays (FFI).
--
-- Using 'StorableExt', we are now able to use deriving via clause on sum types.
--
-- @
-- data X
--   = A
--   | B
--   | C
--   deriving stock Enum
--   deriving Storable via StorableExt X
-- @
--
-- This type will be stored as a word32 (C enum FFI).
--
-- Using the 'StorableFixedArray', we are now able to encode fixed sizes in the type (in conjunction with storable-record "Foreign.Storable.FixedArray").
--
-- @
-- data X = X (StorableFixedArray Word32 10)
-- @
--
-- This type will be stored as 10 contiguous word32 (C fixed array).
--
module Zydis.Util
  ( StorableExt(..)
  , StorableFixedArray(..)
  , Storable
  )
where

import           Data.Foldable
import           Data.Proxy
import           Data.Vector
import           Data.Word
import           Foreign.Ptr
import           Foreign.Storable
import qualified Foreign.Storable.FixedArray   as Fixed
import           GHC.TypeLits


-- | Wrapper to extend storable default instances.
newtype StorableExt a =
  StorableExt
    { StorableExt a -> a
unStorableExt :: a
    }
  deriving stock (Int -> StorableExt a -> ShowS
[StorableExt a] -> ShowS
StorableExt a -> String
(Int -> StorableExt a -> ShowS)
-> (StorableExt a -> String)
-> ([StorableExt a] -> ShowS)
-> Show (StorableExt a)
forall a. Show a => Int -> StorableExt a -> ShowS
forall a. Show a => [StorableExt a] -> ShowS
forall a. Show a => StorableExt a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [StorableExt a] -> ShowS
$cshowList :: forall a. Show a => [StorableExt a] -> ShowS
show :: StorableExt a -> String
$cshow :: forall a. Show a => StorableExt a -> String
showsPrec :: Int -> StorableExt a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> StorableExt a -> ShowS
Show, StorableExt a -> StorableExt a -> Bool
(StorableExt a -> StorableExt a -> Bool)
-> (StorableExt a -> StorableExt a -> Bool) -> Eq (StorableExt a)
forall a. Eq a => StorableExt a -> StorableExt a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: StorableExt a -> StorableExt a -> Bool
$c/= :: forall a. Eq a => StorableExt a -> StorableExt a -> Bool
== :: StorableExt a -> StorableExt a -> Bool
$c== :: forall a. Eq a => StorableExt a -> StorableExt a -> Bool
Eq)

instance forall a. Enum a => Storable (StorableExt a) where
  alignment :: StorableExt a -> Int
alignment = Int -> StorableExt a -> Int
forall a b. a -> b -> a
const (Int -> StorableExt a -> Int) -> Int -> StorableExt a -> Int
forall a b. (a -> b) -> a -> b
$ Word32 -> Int
forall a. Storable a => a -> Int
alignment @Word32 Word32
forall a. HasCallStack => a
undefined
  sizeOf :: StorableExt a -> Int
sizeOf = Int -> StorableExt a -> Int
forall a b. a -> b -> a
const (Int -> StorableExt a -> Int) -> Int -> StorableExt a -> Int
forall a b. (a -> b) -> a -> b
$ Word32 -> Int
forall a. Storable a => a -> Int
sizeOf @Word32 Word32
forall a. HasCallStack => a
undefined
  peek :: Ptr (StorableExt a) -> IO (StorableExt a)
peek = (Word32 -> StorableExt a) -> IO Word32 -> IO (StorableExt a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> StorableExt a
forall a. a -> StorableExt a
StorableExt (a -> StorableExt a) -> (Word32 -> a) -> Word32 -> StorableExt a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a
forall a. Enum a => Int -> a
toEnum (Int -> a) -> (Word32 -> Int) -> Word32 -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral) (IO Word32 -> IO (StorableExt a))
-> (Ptr (StorableExt a) -> IO Word32)
-> Ptr (StorableExt a)
-> IO (StorableExt a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word32 -> IO Word32
forall a. Storable a => Ptr a -> IO a
peek (Ptr Word32 -> IO Word32)
-> (Ptr (StorableExt a) -> Ptr Word32)
-> Ptr (StorableExt a)
-> IO Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (StorableExt a) -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr @_ @Word32
  poke :: Ptr (StorableExt a) -> StorableExt a -> IO ()
poke Ptr (StorableExt a)
ptr StorableExt a
v =
    Ptr Word32 -> Word32 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr (StorableExt a) -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr @_ @Word32 Ptr (StorableExt a)
ptr) (Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ a -> Int
forall a. Enum a => a -> Int
fromEnum (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ StorableExt a -> a
forall a. StorableExt a -> a
unStorableExt StorableExt a
v)

-- | Wrapper to extend storable default instances.
newtype StorableFixedArray a b =
  StorableFixedArray
    { StorableFixedArray a b -> Vector a
unStorableFixedArray :: Vector a
    }
  deriving stock (Int -> StorableFixedArray a b -> ShowS
[StorableFixedArray a b] -> ShowS
StorableFixedArray a b -> String
(Int -> StorableFixedArray a b -> ShowS)
-> (StorableFixedArray a b -> String)
-> ([StorableFixedArray a b] -> ShowS)
-> Show (StorableFixedArray a b)
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
forall a k (b :: k).
Show a =>
Int -> StorableFixedArray a b -> ShowS
forall a k (b :: k). Show a => [StorableFixedArray a b] -> ShowS
forall a k (b :: k). Show a => StorableFixedArray a b -> String
showList :: [StorableFixedArray a b] -> ShowS
$cshowList :: forall a k (b :: k). Show a => [StorableFixedArray a b] -> ShowS
show :: StorableFixedArray a b -> String
$cshow :: forall a k (b :: k). Show a => StorableFixedArray a b -> String
showsPrec :: Int -> StorableFixedArray a b -> ShowS
$cshowsPrec :: forall a k (b :: k).
Show a =>
Int -> StorableFixedArray a b -> ShowS
Show, StorableFixedArray a b -> StorableFixedArray a b -> Bool
(StorableFixedArray a b -> StorableFixedArray a b -> Bool)
-> (StorableFixedArray a b -> StorableFixedArray a b -> Bool)
-> Eq (StorableFixedArray a b)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall a k (b :: k).
Eq a =>
StorableFixedArray a b -> StorableFixedArray a b -> Bool
/= :: StorableFixedArray a b -> StorableFixedArray a b -> Bool
$c/= :: forall a k (b :: k).
Eq a =>
StorableFixedArray a b -> StorableFixedArray a b -> Bool
== :: StorableFixedArray a b -> StorableFixedArray a b -> Bool
$c== :: forall a k (b :: k).
Eq a =>
StorableFixedArray a b -> StorableFixedArray a b -> Bool
Eq)

instance forall a b. (Storable a, KnownNat b) => Storable (StorableFixedArray a b) where
  alignment :: StorableFixedArray a b -> Int
alignment = Int -> StorableFixedArray a b -> Int
forall a b. a -> b -> a
const (Int -> StorableFixedArray a b -> Int)
-> Int -> StorableFixedArray a b -> Int
forall a b. (a -> b) -> a -> b
$ a -> Int
forall a. Storable a => a -> Int
alignment @a a
forall a. HasCallStack => a
undefined
  sizeOf :: StorableFixedArray a b -> Int
sizeOf =
    Int -> StorableFixedArray a b -> Int
forall a b. a -> b -> a
const (Int -> StorableFixedArray a b -> Int)
-> Int -> StorableFixedArray a b -> Int
forall a b. (a -> b) -> a -> b
$ Int -> a -> Int
forall a. Storable a => Int -> a -> Int
Fixed.sizeOfArray @a (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy b -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy b
forall k (t :: k). Proxy t
Proxy @b)) a
forall a. HasCallStack => a
undefined
  peek :: Ptr (StorableFixedArray a b) -> IO (StorableFixedArray a b)
peek Ptr (StorableFixedArray a b)
ptr = Vector a -> StorableFixedArray a b
forall k a (b :: k). Vector a -> StorableFixedArray a b
StorableFixedArray (Vector a -> StorableFixedArray a b)
-> IO (Vector a) -> IO (StorableFixedArray a b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr (Any a) -> StateT (Ptr a) IO (Vector a) -> IO (Vector a)
forall (t :: * -> *) a c. Ptr (t a) -> StateT (Ptr a) IO c -> IO c
Fixed.run Ptr (Any a)
forall b. Ptr b
ptr' StateT (Ptr a) IO (Vector a)
loop
   where
    ptr' :: Ptr b
ptr' = Ptr (StorableFixedArray a b) -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr Ptr (StorableFixedArray a b)
ptr
    loop :: StateT (Ptr a) IO (Vector a)
loop = Int -> StateT (Ptr a) IO a -> StateT (Ptr a) IO (Vector a)
forall (m :: * -> *) a. Monad m => Int -> m a -> m (Vector a)
replicateM (Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy b -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal (Proxy b
forall k (t :: k). Proxy t
Proxy @b)) StateT (Ptr a) IO a
forall a. Storable a => StateT (Ptr a) IO a
Fixed.peekNext
  poke :: Ptr (StorableFixedArray a b) -> StorableFixedArray a b -> IO ()
poke Ptr (StorableFixedArray a b)
ptr StorableFixedArray a b
x = Ptr (Any a) -> StateT (Ptr a) IO () -> IO ()
forall (t :: * -> *) a c. Ptr (t a) -> StateT (Ptr a) IO c -> IO c
Fixed.run Ptr (Any a)
forall b. Ptr b
ptr' (StateT (Ptr a) IO () -> IO ()) -> StateT (Ptr a) IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Vector a -> StateT (Ptr a) IO ()
loop (Vector a -> StateT (Ptr a) IO ())
-> Vector a -> StateT (Ptr a) IO ()
forall a b. (a -> b) -> a -> b
$ StorableFixedArray a b -> Vector a
forall a k (b :: k). StorableFixedArray a b -> Vector a
unStorableFixedArray StorableFixedArray a b
x
   where
    ptr' :: Ptr b
ptr' = Ptr (StorableFixedArray a b) -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr Ptr (StorableFixedArray a b)
ptr
    loop :: Vector a -> StateT (Ptr a) IO ()
loop = (a -> StateT (Ptr a) IO ()) -> Vector a -> StateT (Ptr a) IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ a -> StateT (Ptr a) IO ()
forall a. Storable a => a -> StateT (Ptr a) IO ()
Fixed.pokeNext