-- Copyright 2016 Google Inc. All Rights Reserved.
--
-- Use of this source code is governed by a BSD-style
-- license that can be found in the LICENSE file or at
-- https://developers.google.com/open-source/licenses/bsd

{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
-- | An Arbitrary instance for protocol buffer Messages to use with QuickCheck.
module Data.ProtoLens.Arbitrary
    ( ArbitraryMessage(..)
    , arbitraryMessage
    , shrinkMessage
    ) where

import Data.ProtoLens.Message

import Control.Arrow ((&&&))
import Control.Monad (foldM)
import qualified Data.ByteString as BS
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (isJust, fromJust)
import qualified Data.Text as T
import Lens.Family2 (Lens', view, set)
import Lens.Family2.Unchecked (lens)
import Test.QuickCheck (Arbitrary(..), Gen, suchThat, frequency, listOf,
                        shrinkList, scale)


-- | A newtype wrapper that provides an Arbitrary instance for the underlying
-- message.
newtype ArbitraryMessage a =
    ArbitraryMessage { forall a. ArbitraryMessage a -> a
unArbitraryMessage :: a } deriving (ArbitraryMessage a -> ArbitraryMessage a -> Bool
forall a. Eq a => ArbitraryMessage a -> ArbitraryMessage a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ArbitraryMessage a -> ArbitraryMessage a -> Bool
$c/= :: forall a. Eq a => ArbitraryMessage a -> ArbitraryMessage a -> Bool
== :: ArbitraryMessage a -> ArbitraryMessage a -> Bool
$c== :: forall a. Eq a => ArbitraryMessage a -> ArbitraryMessage a -> Bool
Eq, Int -> ArbitraryMessage a -> ShowS
forall a. Show a => Int -> ArbitraryMessage a -> ShowS
forall a. Show a => [ArbitraryMessage a] -> ShowS
forall a. Show a => ArbitraryMessage a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArbitraryMessage a] -> ShowS
$cshowList :: forall a. Show a => [ArbitraryMessage a] -> ShowS
show :: ArbitraryMessage a -> String
$cshow :: forall a. Show a => ArbitraryMessage a -> String
showsPrec :: Int -> ArbitraryMessage a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> ArbitraryMessage a -> ShowS
Show, forall a b. a -> ArbitraryMessage b -> ArbitraryMessage a
forall a b. (a -> b) -> ArbitraryMessage a -> ArbitraryMessage b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ArbitraryMessage b -> ArbitraryMessage a
$c<$ :: forall a b. a -> ArbitraryMessage b -> ArbitraryMessage a
fmap :: forall a b. (a -> b) -> ArbitraryMessage a -> ArbitraryMessage b
$cfmap :: forall a b. (a -> b) -> ArbitraryMessage a -> ArbitraryMessage b
Functor)

instance Message a => Arbitrary (ArbitraryMessage a) where
    arbitrary :: Gen (ArbitraryMessage a)
arbitrary = forall a. a -> ArbitraryMessage a
ArbitraryMessage forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Message a => Gen a
arbitraryMessage
    shrink :: ArbitraryMessage a -> [ArbitraryMessage a]
shrink (ArbitraryMessage a
a) = forall a. a -> ArbitraryMessage a
ArbitraryMessage forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Message a => a -> [a]
shrinkMessage a
a

arbitraryMessage :: Message a => Gen a
arbitraryMessage :: forall a. Message a => Gen a
arbitraryMessage = forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall msg. FieldDescriptor msg -> msg -> Gen msg
arbitraryField) forall msg. Message msg => msg
defMessage forall msg. Message msg => [FieldDescriptor msg]
allFields

-- | Imitation of the (Arbitrary a => Arbitrary (Maybe a)) instance from
-- QuickCheck.
maybeGen :: Gen a -> Gen (Maybe a)
maybeGen :: forall a. Gen a -> Gen (Maybe a)
maybeGen Gen a
gen = forall a. [(Int, Gen a)] -> Gen a
frequency [ (Int
1, forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing), (Int
3, forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen a
gen) ]

mapGen :: (Ord key, Message entry) => Lens' entry key -> Lens' entry value ->
          Gen entry -> Gen (Map key value)
mapGen :: forall key entry value.
(Ord key, Message entry) =>
Lens' entry key
-> Lens' entry value -> Gen entry -> Gen (Map key value)
mapGen Lens' entry key
keyLens Lens' entry value
valueLens Gen entry
entryGen =
    forall key entry value.
(Ord key, Message entry) =>
Lens' entry key
-> Lens' entry value -> Lens' (Map key value) [entry]
mapEntriesLens Lens' entry key
keyLens Lens' entry value
valueLens (forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall a. Gen a -> Gen [a]
listOf Gen entry
entryGen) forall k a. Map k a
M.empty

setGen :: Lens' msg a -> Gen a -> msg -> Gen msg
setGen :: forall msg a. Lens' msg a -> Gen a -> msg -> Gen msg
setGen Lens' msg a
l Gen a
gen = Lens' msg a
l (forall a b. a -> b -> a
const Gen a
gen)

arbitraryField :: FieldDescriptor msg -> msg -> Gen msg
arbitraryField :: forall msg. FieldDescriptor msg -> msg -> Gen msg
arbitraryField (FieldDescriptor String
_ FieldTypeDescriptor value
ftd FieldAccessor msg value
fa) = case FieldAccessor msg value
fa of
    PlainField WireDefault value
_ Lens' msg value
l -> forall msg a. Lens' msg a -> Gen a -> msg -> Gen msg
setGen Lens' msg value
l Gen value
fieldGen
    OptionalField Lens' msg (Maybe value)
l -> forall msg a. Lens' msg a -> Gen a -> msg -> Gen msg
setGen Lens' msg (Maybe value)
l (forall a. Gen a -> Gen (Maybe a)
maybeGen Gen value
fieldGen)
    RepeatedField Packing
_ Lens' msg [value]
l -> forall msg a. Lens' msg a -> Gen a -> msg -> Gen msg
setGen Lens' msg [value]
l (forall a. Gen a -> Gen [a]
listOf Gen value
fieldGen)
    MapField Lens' value key
keyLens Lens' value value1
valueLens Lens' msg (Map key value1)
mapLens ->
        forall msg a. Lens' msg a -> Gen a -> msg -> Gen msg
setGen Lens' msg (Map key value1)
mapLens (forall key entry value.
(Ord key, Message entry) =>
Lens' entry key
-> Lens' entry value -> Gen entry -> Gen (Map key value)
mapGen Lens' value key
keyLens Lens' value value1
valueLens Gen value
fieldGen)
  where
    fieldGen :: Gen value
fieldGen = forall value. FieldTypeDescriptor value -> Gen value
arbitraryFieldValue FieldTypeDescriptor value
ftd

arbitraryFieldValue :: FieldTypeDescriptor value -> Gen value
arbitraryFieldValue :: forall value. FieldTypeDescriptor value -> Gen value
arbitraryFieldValue = \case
    MessageField MessageOrGroup
_ -> forall a. (Int -> Int) -> Gen a -> Gen a
scale (forall a. Integral a => a -> a -> a
`div` Int
2) forall a. Message a => Gen a
arbitraryMessage
    ScalarField ScalarField value
f -> forall value. ScalarField value -> Gen value
arbitraryScalarValue ScalarField value
f

arbitraryScalarValue :: ScalarField value -> Gen value
arbitraryScalarValue :: forall value. ScalarField value -> Gen value
arbitraryScalarValue = \case
    -- For enum fields, all we know is that the value is an instance of
    -- MessageEnum, meaning we can only use fromEnum, toEnum, or maybeToEnum. So
    -- we must rely on the instance of Arbitrary for Int and filter out only the
    -- cases that can actually be converted to one of the enum values.
    --
    -- 'fromJust' is okay here because 'suchThat' will ensure that all generated
    -- values are 'Just _'.
    ScalarField value
EnumField -> forall a. HasCallStack => Maybe a -> a
fromJust forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall a. MessageEnum a => Int -> Maybe a
maybeToEnum forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => Gen a
arbitrary) forall a. Gen a -> (a -> Bool) -> Gen a
`suchThat` forall a. Maybe a -> Bool
isJust
    ScalarField value
Int32Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
Int64Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
UInt32Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
UInt64Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
SInt32Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
SInt64Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
Fixed32Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
Fixed64Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
SFixed32Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
SFixed64Field -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
FloatField -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
DoubleField -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
BoolField -> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
StringField -> String -> Text
T.pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => Gen a
arbitrary
    ScalarField value
BytesField -> [Word8] -> ByteString
BS.pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Arbitrary a => Gen a
arbitrary

-- | Shrink each field individually and append all shrinks together into
-- a single list.
shrinkMessage :: Message a => a -> [a]
shrinkMessage :: forall a. Message a => a -> [a]
shrinkMessage a
msg = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (forall msg. FieldDescriptor msg -> msg -> [msg]
`shrinkField` a
msg) forall msg. Message msg => [FieldDescriptor msg]
allFields

shrinkMaybe :: (a -> [a]) -> Maybe a -> [Maybe a]
shrinkMaybe :: forall a. (a -> [a]) -> Maybe a -> [Maybe a]
shrinkMaybe a -> [a]
f (Just a
v) = forall a. Maybe a
Nothing forall a. a -> [a] -> [a]
: (forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> [a]
f a
v)
shrinkMaybe a -> [a]
_ Maybe a
Nothing  = []

shrinkMap :: (Ord key, Message entry) => Lens' entry key -> Lens' entry value
          -> (entry -> [entry]) -> Map key value -> [Map key value]
shrinkMap :: forall key entry value.
(Ord key, Message entry) =>
Lens' entry key
-> Lens' entry value
-> (entry -> [entry])
-> Map key value
-> [Map key value]
shrinkMap Lens' entry key
keyLens Lens' entry value
valueLens entry -> [entry]
f = forall key entry value.
(Ord key, Message entry) =>
Lens' entry key
-> Lens' entry value -> Lens' (Map key value) [entry]
mapEntriesLens Lens' entry key
keyLens Lens' entry value
valueLens (forall a. (a -> [a]) -> [a] -> [[a]]
shrinkList entry -> [entry]
f')
  where
    f' :: entry -> [entry]
f' = forall a. (a -> Bool) -> [a] -> [a]
filter forall {msg}. Message msg => msg -> Bool
allFieldsAreSet forall b c a. (b -> c) -> (a -> b) -> a -> c
. entry -> [entry]
f
    -- Strip out all entries whose key or value is not set (and which distinguish
    -- between being unset and being the default value (proto2, or in proto3 for a
    -- message value type).
    -- The representation in the Map (as, effectively, a pair of key and value)
    -- does not distinguish between unset/default values.  This can lead to
    -- shrinkMap behaving incorrectly; for example,
    -- `Map.singleton 0 "abc"` gets represented as
    -- `[defMessage & #maybe'key .~ Just 0 & #value .~ "abc"]`, which might be
    -- shrunk to `[defMessage & #maybe'key .~ Nothing & #value .~ "abc"]`,
    -- which maps back to the same Map representation.
    -- Work around this for now by just filtering out entries with unset
    -- optional fields.
    allFieldsAreSet :: msg -> Bool
allFieldsAreSet msg
msg = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {msg}. msg -> FieldDescriptor msg -> Bool
fieldIsSet msg
msg) forall msg. Message msg => [FieldDescriptor msg]
allFields
    fieldIsSet :: msg -> FieldDescriptor msg -> Bool
fieldIsSet msg
msg (FieldDescriptor String
_ FieldTypeDescriptor value
_ (OptionalField Lens' msg (Maybe value)
l)) = forall a. Maybe a -> Bool
isJust (forall a s t b. FoldLike a s t a b -> s -> a
view Lens' msg (Maybe value)
l msg
msg)
    fieldIsSet msg
_ FieldDescriptor msg
_ = Bool
True

shrinkField :: FieldDescriptor msg -> msg -> [msg]
shrinkField :: forall msg. FieldDescriptor msg -> msg -> [msg]
shrinkField (FieldDescriptor String
_ FieldTypeDescriptor value
ftd FieldAccessor msg value
fa) = case FieldAccessor msg value
fa of
    PlainField WireDefault value
_ Lens' msg value
l -> Lens' msg value
l value -> [value]
fieldShrinker
    OptionalField Lens' msg (Maybe value)
l -> Lens' msg (Maybe value)
l (forall a. (a -> [a]) -> Maybe a -> [Maybe a]
shrinkMaybe value -> [value]
fieldShrinker)
    RepeatedField Packing
_ Lens' msg [value]
l -> Lens' msg [value]
l (forall a. (a -> [a]) -> [a] -> [[a]]
shrinkList value -> [value]
fieldShrinker)
    MapField Lens' value key
keyLens Lens' value value1
valueLens Lens' msg (Map key value1)
mapLens ->
        Lens' msg (Map key value1)
mapLens (forall key entry value.
(Ord key, Message entry) =>
Lens' entry key
-> Lens' entry value
-> (entry -> [entry])
-> Map key value
-> [Map key value]
shrinkMap Lens' value key
keyLens Lens' value value1
valueLens value -> [value]
fieldShrinker)
  where
    fieldShrinker :: value -> [value]
fieldShrinker = forall value. FieldTypeDescriptor value -> value -> [value]
shrinkFieldValue FieldTypeDescriptor value
ftd

shrinkFieldValue :: FieldTypeDescriptor value -> value -> [value]
shrinkFieldValue :: forall value. FieldTypeDescriptor value -> value -> [value]
shrinkFieldValue = \case
    MessageField MessageOrGroup
_ -> forall a. Message a => a -> [a]
shrinkMessage
    ScalarField ScalarField value
f -> forall value. ScalarField value -> value -> [value]
shrinkScalarValue ScalarField value
f

shrinkScalarValue :: ScalarField value -> value -> [value]
shrinkScalarValue :: forall value. ScalarField value -> value -> [value]
shrinkScalarValue = \case
    -- Shrink to the 0-equivalent Enum value if it's both a valid Enum value
    -- and the value isn't already 0.
    ScalarField value
EnumField -> case forall a. MessageEnum a => Int -> Maybe a
maybeToEnum Int
0 of
        Maybe value
Nothing -> forall a b. a -> b -> a
const []
        Just value
zeroVal -> \value
val -> case forall a. Enum a => a -> Int
fromEnum value
val of
          Int
0 -> []
          Int
_ -> [value
zeroVal]
    ScalarField value
Int32Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
Int64Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
UInt32Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
UInt64Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
SInt32Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
SInt64Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
Fixed32Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
Fixed64Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
SFixed32Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
SFixed64Field -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
FloatField -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
DoubleField -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
BoolField -> forall a. Arbitrary a => a -> [a]
shrink
    ScalarField value
StringField -> forall a b. (a -> b) -> [a] -> [b]
map String -> Text
T.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Arbitrary a => a -> [a]
shrink forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
T.unpack
    ScalarField value
BytesField -> forall a b. (a -> b) -> [a] -> [b]
map [Word8] -> ByteString
BS.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Arbitrary a => a -> [a]
shrink forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BS.unpack

mapToEntries :: Message entry =>
                Lens' entry key -> Lens' entry value -> Map key value -> [entry]
mapToEntries :: forall entry key value.
Message entry =>
Lens' entry key -> Lens' entry value -> Map key value -> [entry]
mapToEntries Lens' entry key
keyLens Lens' entry value
valueLens Map key value
m = (key, value) -> entry
makeEntry forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall k a. Map k a -> [(k, a)]
M.toList Map key value
m
  where
    makeEntry :: (key, value) -> entry
makeEntry (key
k, value
v) = (forall s t a b. Setter s t a b -> b -> s -> t
set Lens' entry key
keyLens key
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall s t a b. Setter s t a b -> b -> s -> t
set Lens' entry value
valueLens value
v) forall msg. Message msg => msg
defMessage

entriesToMap :: Ord key =>
                Lens' entry key -> Lens' entry value -> [entry] -> Map key value
entriesToMap :: forall key entry value.
Ord key =>
Lens' entry key -> Lens' entry value -> [entry] -> Map key value
entriesToMap Lens' entry key
keyLens Lens' entry value
valueLens [entry]
entries = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(key, value)]
kvs
  where
    kvs :: [(key, value)]
kvs = (forall a s t b. FoldLike a s t a b -> s -> a
view Lens' entry key
keyLens forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& forall a s t b. FoldLike a s t a b -> s -> a
view Lens' entry value
valueLens) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [entry]
entries

-- This isn't a true lens because it doesn't obey the lens laws. Specifically,
-- view l (set l entries) /= entries because the input list of entries may
-- contain duplicate keys that would become de-duped inside the Map. It's only
-- included here to make it easy to convert from a list of entry Messages to
-- a Map.
-- See the comment in shrinkMap for why this is a problem.
-- TODO: consider a different Message representation for maps.
mapEntriesLens :: (Ord key, Message entry) =>
        Lens' entry key -> Lens' entry value -> Lens' (Map key value) [entry]
mapEntriesLens :: forall key entry value.
(Ord key, Message entry) =>
Lens' entry key
-> Lens' entry value -> Lens' (Map key value) [entry]
mapEntriesLens Lens' entry key
kl Lens' entry value
vl = forall s a b t. (s -> a) -> (s -> b -> t) -> Lens s t a b
lens (forall entry key value.
Message entry =>
Lens' entry key -> Lens' entry value -> Map key value -> [entry]
mapToEntries Lens' entry key
kl Lens' entry value
vl) (forall a b. a -> b -> a
const (forall key entry value.
Ord key =>
Lens' entry key -> Lens' entry value -> [entry] -> Map key value
entriesToMap Lens' entry key
kl Lens' entry value
vl))