{-# LANGUAGE DataKinds, PolyKinds, TypeOperators, GADTs #-} {-# LANGUAGE FlexibleInstances, FlexibleContexts, MultiParamTypeClasses #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving, DeriveTraversable #-} {-# LANGUAGE Safe #-} {-| Description : Combinators and utility types for making 'GHC.Generics.Generic'-based 'Binary' instances more expressive. This module defines a bunch of types to be used as fields of records with 'GHC.Generics.Generic'-based deriving of 'Binary' instances. For example: @ data MyFileFormat = MyFileFormat { header :: MatchBytes "my format header" '[ 0xd3, 0x4d, 0xf0, 0x0d ] , slack :: SkipByte 0xff , reserved :: SkipCount Word8 4 , subElements :: Some MyElement } deriving (Generic, Binary) @ -} module Data.Binary.Combinators ( Many(..) , Some(..) , CountedBy(..) , SkipCount(..) , SkipByte(..) , MatchBytes , matchBytes , MatchByte ) where import Control.Applicative import Control.Monad import Data.Binary import Data.Binary.Get(lookAhead) import Data.Functor import Data.Kind import Data.Proxy import GHC.TypeLits import Numeric import Test.QuickCheck -- | Zero or more elements of @a@, parsing as long as the parser for @a@ succeeds. -- -- @Many Word8@ will consume all your input! newtype Many a = Many { getMany :: [a] } deriving (Eq, Ord, Functor, Foldable, Traversable) instance Show a => Show (Many a) where show = show . getMany instance Binary a => Binary (Many a) where get = Many <$> many get put = mapM_ put . getMany instance Arbitrary a => Arbitrary (Many a) where arbitrary = Many <$> arbitrary shrink (Many xs) = Many <$> shrink xs -- | One or more elements of @a@, parsing as long as the parser for @a@ succeeds. -- -- @Some Word8@ will consume all your non-empty input! newtype Some a = Some { getSome :: [a] } deriving (Eq, Ord, Functor, Foldable, Traversable) instance Show a => Show (Some a) where show = show . getSome instance Binary a => Binary (Some a) where get = Some <$> some get put = mapM_ put . getSome instance Arbitrary a => Arbitrary (Some a) where arbitrary = Some . getNonEmpty <$> arbitrary shrink (Some xs) = Some <$> filter (not . null) (shrink xs) -- | First, parse the elements count as type @ty@. Then, parse exactly as many elements of type @a@. newtype CountedBy ty a = CountedBy { getCounted :: [a] } deriving (Eq, Ord, Functor, Foldable, Traversable) instance Show a => Show (CountedBy ty a) where show = show . getCounted instance (Integral ty, Binary ty, Binary a) => Binary (CountedBy ty a) where get = do cnt :: ty <- get CountedBy <$> replicateM (fromIntegral cnt) get put (CountedBy xs) = put (fromIntegral $ length xs :: ty) >> mapM_ put xs instance Arbitrary a => Arbitrary (CountedBy ty a) where arbitrary = CountedBy <$> arbitrary shrink (CountedBy xs) = CountedBy <$> shrink xs -- | Parse out and skip @n@ elements of type @ty@. -- -- Serializing this type produces no bytes. data SkipCount ty (n :: Nat) = SkipCount deriving (Eq, Ord, Show) instance (Num ty, Binary ty, KnownNat n) => Binary (SkipCount ty n) where get = replicateM_ (fromIntegral $ natVal (Proxy :: Proxy n)) (get :: Get ty) $> SkipCount put _ = replicateM_ (fromIntegral $ natVal (Proxy :: Proxy n)) $ put (0 :: ty) instance Arbitrary (SkipCount ty n) where arbitrary = pure SkipCount -- | Skip any number of bytes with value @n@. -- -- Serializing this type produces no bytes. data SkipByte (n :: Nat) = SkipByte deriving (Eq, Ord, Show) instance (KnownNat n) => Binary (SkipByte n) where get = do nextByte <- lookAhead get if nextByte /= expected then pure SkipByte else (get :: Get Word8) >> get where expected :: Word8 expected = fromIntegral $ natVal (Proxy :: Proxy n) put _ = pure () instance Arbitrary (SkipByte n) where arbitrary = pure SkipByte -- | @MatchBytes str bytes@ ensures that the subsequent bytes in the input stream are the same as @bytes@. -- -- @str@ can be used to denote the context in which @MatchBytes@ is used for better parse failure messages. -- For example, @MatchBytes "my format header" '[ 0xd3, 0x4d, 0xf0, 0x0d ]@ consumes four bytes from the input stream -- if they are equal to @[ 0xd3, 0x4d, 0xf0, 0x0d ]@ respectively, or fails otherwise. -- -- Serializing this type produces the @bytes@. data MatchBytes (ctx :: Symbol) (bytes :: [Nat]) :: Type where ConsumeNil :: MatchBytes ctx '[] ConsumeCons :: KnownNat n => Proxy n -> MatchBytes ctx ns -> MatchBytes ctx (n ': ns) deriving instance Eq (MatchBytes s ns) deriving instance Ord (MatchBytes s ns) instance Binary (MatchBytes ctx '[]) where get = pure ConsumeNil put _ = pure () instance (KnownSymbol ctx, KnownNat n, Binary (MatchBytes ctx ns)) => Binary (MatchBytes ctx (n : ns)) where get = do byte <- get when (byte /= expected) $ fail $ "Unexpected byte 0x" <> showHex byte ", expected 0x" <> showHex expected " when parsing " <> symbolVal (Proxy :: Proxy ctx) ConsumeCons Proxy <$> get where expected :: Word8 expected = fromInteger $ natVal (Proxy :: Proxy n) put (ConsumeCons proxy ns) = put theByte >> put ns where theByte :: Word8 theByte = fromInteger $ natVal proxy instance Show (MatchBytes ctx ns) where show bs = "Marker [ " <> go bs <> "]" where go :: MatchBytes ctx' ns' -> String go ConsumeNil = "" go (ConsumeCons proxy ns) = "0x" <> showHex (natVal proxy) " " <> go ns class MatchBytesSing ctx ns where matchBytesSing :: MatchBytes ctx ns instance MatchBytesSing ctx '[] where matchBytesSing = ConsumeNil instance (KnownNat n, MatchBytesSing ctx ns) => MatchBytesSing ctx (n ': ns) where matchBytesSing = ConsumeCons Proxy matchBytesSing -- | Produce the (singleton) value of type 'MatchBytes' @ctx ns@. matchBytes :: MatchBytesSing ctx ns => MatchBytes ctx ns matchBytes = matchBytesSing instance MatchBytesSing ctx ns => Arbitrary (MatchBytes ctx ns) where arbitrary = pure matchBytes -- | An alias for 'MatchBytes' when you only need to match a single byte. type MatchByte ctx byte = MatchBytes ctx '[ byte ]