module Data.ProtoLens.Arbitrary
( ArbitraryMessage(..),
) where
import Data.ProtoLens.Message
import Control.Applicative ((<$>), pure)
import Control.Arrow ((&&&))
import Control.Monad (foldM)
import qualified Data.ByteString as BS
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (isJust, fromJust)
import qualified Data.Text as T
import Lens.Family2 (Lens', view, set)
import Lens.Family2.Unchecked (lens)
import Test.QuickCheck (Arbitrary(..), Gen, suchThat, frequency, listOf,
shrinkList)
newtype ArbitraryMessage a =
ArbitraryMessage { unArbitraryMessage :: a } deriving (Eq, Show, Functor)
instance Message a => Arbitrary (ArbitraryMessage a) where
arbitrary = ArbitraryMessage <$> arbitraryMessage
shrink (ArbitraryMessage a) = ArbitraryMessage <$> shrinkMessage a
arbitraryMessage :: Message a => Gen a
arbitraryMessage = foldM (flip arbitraryField) def fields
where
fields = M.elems (fieldsByTag descriptor)
maybeGen :: Gen a -> Gen (Maybe a)
maybeGen gen = frequency [ (1, pure Nothing), (3, Just <$> gen) ]
mapGen :: (Ord key, Message entry) => Lens' entry key -> Lens' entry value ->
Gen entry -> Gen (Map key value)
mapGen keyLens valueLens entryGen =
mapEntriesLens keyLens valueLens (const $ listOf entryGen) M.empty
setGen :: Lens' msg a -> Gen a -> msg -> Gen msg
setGen l gen = l (const gen)
arbitraryField :: FieldDescriptor msg -> msg -> Gen msg
arbitraryField (FieldDescriptor _ ftd fa) = case fa of
PlainField _ l -> setGen l fieldGen
OptionalField l -> setGen l (maybeGen fieldGen)
RepeatedField _ l -> setGen l (listOf fieldGen)
MapField keyLens valueLens mapLens ->
setGen mapLens (mapGen keyLens valueLens fieldGen)
where
fieldGen = arbitraryFieldValue ftd
arbitraryFieldValue :: FieldTypeDescriptor value -> Gen value
arbitraryFieldValue ftd = case ftd of
MessageField -> unArbitraryMessage <$> arbitrary
GroupField -> unArbitraryMessage <$> arbitrary
EnumField -> fromJust <$> (maybeToEnum <$> arbitrary) `suchThat` isJust
Int32Field -> arbitrary
Int64Field -> arbitrary
UInt32Field -> arbitrary
UInt64Field -> arbitrary
SInt32Field -> arbitrary
SInt64Field -> arbitrary
Fixed32Field -> arbitrary
Fixed64Field -> arbitrary
SFixed32Field -> arbitrary
SFixed64Field -> arbitrary
FloatField -> arbitrary
DoubleField -> arbitrary
BoolField -> arbitrary
StringField -> T.pack <$> arbitrary
BytesField -> BS.pack <$> arbitrary
shrinkMessage :: Message a => a -> [a]
shrinkMessage msg = concatMap (`shrinkField` msg) fields
where
fields = M.elems (fieldsByTag descriptor)
shrinkMaybe :: (a -> [a]) -> Maybe a -> [Maybe a]
shrinkMaybe f (Just v) = Nothing : (Just <$> f v)
shrinkMaybe _ Nothing = []
shrinkMap :: (Ord key, Message entry) => Lens' entry key -> Lens' entry value
-> (entry -> [entry]) -> Map key value -> [Map key value]
shrinkMap keyLens valueLens f = mapEntriesLens keyLens valueLens (shrinkList f)
shrinkField :: FieldDescriptor msg -> msg -> [msg]
shrinkField (FieldDescriptor _ ftd fa) = case fa of
PlainField _ l -> l fieldShrinker
OptionalField l -> l (shrinkMaybe fieldShrinker)
RepeatedField _ l -> l (shrinkList fieldShrinker)
MapField keyLens valueLens mapLens ->
mapLens (shrinkMap keyLens valueLens fieldShrinker)
where
fieldShrinker = shrinkFieldValue ftd
shrinkFieldValue :: FieldTypeDescriptor value -> value -> [value]
shrinkFieldValue ftd = case ftd of
MessageField -> map unArbitraryMessage . shrink . ArbitraryMessage
GroupField -> map unArbitraryMessage . shrink . ArbitraryMessage
EnumField -> case maybeToEnum 0 of
Nothing -> const []
Just zeroVal -> \val -> case fromEnum val of
0 -> []
_ -> [zeroVal]
Int32Field -> shrink
Int64Field -> shrink
UInt32Field -> shrink
UInt64Field -> shrink
SInt32Field -> shrink
SInt64Field -> shrink
Fixed32Field -> shrink
Fixed64Field -> shrink
SFixed32Field -> shrink
SFixed64Field -> shrink
FloatField -> shrink
DoubleField -> shrink
BoolField -> shrink
StringField -> map T.pack . shrink . T.unpack
BytesField -> map BS.pack . shrink . BS.unpack
mapToEntries :: Message entry =>
Lens' entry key -> Lens' entry value -> Map key value -> [entry]
mapToEntries keyLens valueLens m = makeEntry <$> M.toList m
where
makeEntry (k, v) = (set keyLens k . set valueLens v) def
entriesToMap :: Ord key =>
Lens' entry key -> Lens' entry value -> [entry] -> Map key value
entriesToMap keyLens valueLens entries = M.fromList kvs
where
kvs = (view keyLens &&& view valueLens) <$> entries
mapEntriesLens :: (Ord key, Message entry) =>
Lens' entry key -> Lens' entry value -> Lens' (Map key value) [entry]
mapEntriesLens kl vl = lens (mapToEntries kl vl) (const (entriesToMap kl vl))