{-# LANGUAGE CPP                   #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE ConstraintKinds      #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}

-- | This module gives ways to force the alignment of types.
module Raaz.Core.Types.Aligned
  ( -- * Types to force alignment.
    Aligned, unAligned, aligned16Bytes, aligned32Bytes, aligned64Bytes
  ) where


#if MIN_VERSION_base(4,7,0)
import           Data.Proxy
#endif

import           GHC.TypeLits
import           Foreign.Ptr                 ( castPtr      )
import           Foreign.Storable            ( Storable(..) )
import           Prelude hiding              ( length       )


-- | A type @w@ forced to be aligned to the alignment boundary @alg@
newtype Aligned (align :: Nat) w
  = Aligned { Aligned align w -> w
unAligned :: w -- ^ The underlying unAligned value.
            }

-- | Align the value to 16-byte boundary
aligned16Bytes :: w -> Aligned 16 w
{-# INLINE aligned16Bytes #-}

-- | Align the value to 32-byte boundary
aligned32Bytes :: w -> Aligned 32 w
{-# INLINE aligned32Bytes #-}

-- | Align the value to 64-byte boundary
aligned64Bytes :: w -> Aligned 64 w
{-# INLINE aligned64Bytes #-}

aligned16Bytes :: w -> Aligned 16 w
aligned16Bytes = w -> Aligned 16 w
forall (align :: Nat) w. w -> Aligned align w
Aligned
aligned32Bytes :: w -> Aligned 32 w
aligned32Bytes = w -> Aligned 32 w
forall (align :: Nat) w. w -> Aligned align w
Aligned
aligned64Bytes :: w -> Aligned 64 w
aligned64Bytes = w -> Aligned 64 w
forall (align :: Nat) w. w -> Aligned align w
Aligned

#if MIN_VERSION_base(4,7,0)

-- | The constraint on the alignment o(since base 4.7.0).
type AlignBoundary (alg :: Nat) = KnownNat alg

alignmentBoundary :: AlignBoundary alg => Aligned alg a -> Int
alignmentBoundary :: Aligned alg a -> Int
alignmentBoundary = Proxy alg -> Aligned alg a -> Int
forall (algn :: Nat) a.
AlignBoundary algn =>
Proxy algn -> Aligned algn a -> Int
aB Proxy alg
forall k (t :: k). Proxy t
Proxy
  where aB :: AlignBoundary algn => Proxy algn -> Aligned algn a -> Int
        aB :: Proxy algn -> Aligned algn a -> Int
aB Proxy algn
algn Aligned algn a
_ = Integer -> Int
forall a. Enum a => a -> Int
fromEnum (Integer -> Int) -> Integer -> Int
forall a b. (a -> b) -> a -> b
$ Proxy algn -> Integer
forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
natVal Proxy algn
algn

#else

-- | The constraint on the alignment (pre base 4.7.0).
type AlignBoundary (alg :: Nat) = SingI alg

alignmentBoundary :: AlignBoundary algn => Aligned algn a -> Int
alignmentBoundary = withSing aB
  where aB        ::  AlignBoundary algn => Sing algn      -> Aligned algn a -> Int
        aB algn _ = fromEnum $ fromSing algn


#endif


instance (Storable a, AlignBoundary alg) => Storable (Aligned alg a) where

  sizeOf :: Aligned alg a -> Int
sizeOf = a -> Int
forall a. Storable a => a -> Int
sizeOf (a -> Int) -> (Aligned alg a -> a) -> Aligned alg a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Aligned alg a -> a
forall (align :: Nat) w. Aligned align w -> w
unAligned

  alignment :: Aligned alg a -> Int
alignment Aligned alg a
alg = Int -> Int -> Int
forall a. Integral a => a -> a -> a
lcm Int
valueAlignment Int
forceAlignment
    where valueAlignment :: Int
valueAlignment = a -> Int
forall a. Storable a => a -> Int
alignment (a -> Int) -> a -> Int
forall a b. (a -> b) -> a -> b
$ Aligned alg a -> a
forall (align :: Nat) w. Aligned align w -> w
unAligned Aligned alg a
alg
          forceAlignment :: Int
forceAlignment = Aligned alg a -> Int
forall (alg :: Nat) a. AlignBoundary alg => Aligned alg a -> Int
alignmentBoundary Aligned alg a
alg

  peek :: Ptr (Aligned alg a) -> IO (Aligned alg a)
peek = (a -> Aligned alg a) -> IO a -> IO (Aligned alg a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Aligned alg a
forall (align :: Nat) w. w -> Aligned align w
Aligned (IO a -> IO (Aligned alg a))
-> (Ptr (Aligned alg a) -> IO a)
-> Ptr (Aligned alg a)
-> IO (Aligned alg a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.  Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peek (Ptr a -> IO a)
-> (Ptr (Aligned alg a) -> Ptr a) -> Ptr (Aligned alg a) -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr (Aligned alg a) -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr

  poke :: Ptr (Aligned alg a) -> Aligned alg a -> IO ()
poke Ptr (Aligned alg a)
ptr = Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr (Aligned alg a) -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr Ptr (Aligned alg a)
ptr) (a -> IO ()) -> (Aligned alg a -> a) -> Aligned alg a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Aligned alg a -> a
forall (align :: Nat) w. Aligned align w -> w
unAligned