{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE UndecidableInstances #-} {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} -- | Union (as in C) -- -- Unions are storable and can contain any storable data. -- -- Use 'fromUnion' to read an alternative: -- -- @ -- {-# LANGUAGE DataKinds #-} -- -- getUnion :: IO (Union '[Word16, Word32, Word64]) -- getUnion = ... -- -- test = do -- u <- getUnion -- -- -- to get one of the member -- let v = fromUnion u :: Word16 -- let v = fromUnion u :: Word32 -- let v = fromUnion u :: Word64 -- -- -- This won't compile (Word8 is not a member of the union) -- let v = fromUnion u :: Word8 -- @ -- -- Use 'toUnion' to create a new union: -- -- @ -- let -- u2 :: Union '[Word32, Vector 4 Word8] -- u2 = toUnion (0x12345678 :: Word32) -- @ -- module Haskus.Binary.Union ( Union , fromUnion , toUnion , toUnionZero ) where import Haskus.Utils.Types hiding (Union) import Haskus.Utils.HList import Haskus.Utils.Flow (when) import Haskus.Binary.Storable import Haskus.Memory.Utils (memCopy, memSet) import System.IO.Unsafe (unsafePerformIO) import Foreign.ForeignPtr import Foreign.Ptr import qualified Foreign.Storable as FS -- TODO: rewrite rules -- poke p (toUnion x) = poke (castPtr p) x -- -- (fromUnion <$> peek p) :: IO a = peek (castPtr p) :: IO a -- | An union -- -- We use a list of types as a parameter. -- -- The union is just a pointer to a buffer containing the value(s). The size of -- the buffer is implicitly known from the types in the list. newtype Union (x :: [*]) = Union (ForeignPtr ()) deriving (Show) -- | Retrieve a union member from its type fromUnion :: (Storable a, Member a l) => Union l -> a fromUnion (Union fp) = unsafePerformIO $ withForeignPtr fp (peek . castPtr) -- | Create a new union from one of the union types toUnion :: forall a l . (Storable (Union l), Storable a, Member a l) => a -> Union l toUnion = toUnion' False -- | Like 'toUnion' but set the remaining bytes to 0 toUnionZero :: forall a l . (Storable (Union l), Storable a, Member a l) => a -> Union l toUnionZero = toUnion' True -- | Create a new union from one of the union types toUnion' :: forall a l . (Storable (Union l), Storable a, Member a l) => Bool -> a -> Union l toUnion' zero v = unsafePerformIO $ do let sz = sizeOfT @(Union l) fp <- mallocForeignPtrBytes (fromIntegral sz) withForeignPtr fp $ \p -> do -- set bytes after the object to 0 when zero $ do let psz = sizeOfT @a memSet (p `plusPtr` fromIntegral psz) (fromIntegral (sz - psz)) 0 poke (castPtr p) v return $ Union fp type family MapSizeOf fs where MapSizeOf '[] = '[] MapSizeOf (x ': xs) = SizeOf x ': MapSizeOf xs type family MapAlignment fs where MapAlignment '[] = '[] MapAlignment (x ': xs) = Alignment x ': MapAlignment xs instance forall fs. ( KnownNat (ListMax (MapSizeOf fs)) , KnownNat (ListMax (MapAlignment fs)) ) => StaticStorable (Union fs) where type SizeOf (Union fs) = ListMax (MapSizeOf fs) type Alignment (Union fs) = ListMax (MapAlignment fs) staticPeekIO ptr = do let sz = natValue @(SizeOf (Union fs)) fp <- mallocForeignPtrBytes sz withForeignPtr fp $ \p -> memCopy p (castPtr ptr) (fromIntegral sz) return (Union fp) staticPokeIO ptr (Union fp) = do withForeignPtr fp $ \p -> memCopy (castPtr ptr) p (natValue @(SizeOf (Union fs))) ------------------------------------------------------------------------------------- -- We use HFoldr' to get the maximum size and alignment of the types in the union ------------------------------------------------------------------------------------- data FoldSizeOf = FoldSizeOf data FoldAlignment = FoldAlignment instance (r ~ Word, Storable a) => Apply FoldSizeOf (a, Word) r where apply _ (_,r) = max r (sizeOfT @a) instance (r ~ Word, Storable a) => Apply FoldAlignment (a, Word) r where apply _ (_,r) = max r (alignmentT @a) -- | Get the union size (i.e. the maximum of the types in the union) unionSize :: forall l . HFoldr' FoldSizeOf Word l Word => Union l -> Word unionSize _ = hFoldr' FoldSizeOf (0 :: Word) (undefined :: HList l) -- | Get the union alignment (i.e. the maximum of the types in the union) unionAlignment :: forall l . HFoldr' FoldAlignment Word l Word => Union l -> Word unionAlignment _ = hFoldr' FoldAlignment (0 :: Word) (undefined :: HList l) ------------------------------------------------------------------------------------- -- Finally we can write the Storable instance ------------------------------------------------------------------------------------- instance ( HFoldr' FoldSizeOf Word l Word , HFoldr' FoldAlignment Word l Word ) => Storable (Union l) where sizeOf = unionSize alignment = unionAlignment peekIO ptr = do let sz = sizeOfT' @(Union l) fp <- mallocForeignPtrBytes sz withForeignPtr fp $ \p -> memCopy p (castPtr ptr) (fromIntegral sz) return (Union fp) pokeIO ptr (Union fp) = withForeignPtr fp $ \p -> memCopy (castPtr ptr) p (sizeOfT' @(Union l)) -- compatibility instance with Foreign.Storable instance ( HFoldr' FoldSizeOf Word l Word , HFoldr' FoldAlignment Word l Word ) => FS.Storable (Union l) where sizeOf = fromIntegral . unionSize alignment = fromIntegral . unionAlignment peek = peekIO poke = pokeIO