{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

-- |
-- Module : Test.Method.Protocol
-- Description:
-- License: BSD-3
-- Maintainer: autotaker@gmail.com
-- Stability: experimental
module Test.Method.Protocol
  ( protocol,
    ProtocolM,
    ProtocolEnv,
    Call,
    CallArgs,
    CallId,
    IsMethodName,
    lookupMock,
    lookupMockWithShow,
    decl,
    whenArgs,
    thenMethod,
    thenAction,
    thenReturn,
    dependsOn,
    verify,
  )
where

import Control.Method
  ( Method (Args, Base, curryMethod, uncurryMethod),
    TupleLike (AsTuple, toTuple),
  )
import Control.Monad.Trans.State.Strict (StateT, execStateT, state)
import Data.Maybe (fromJust)
import Data.Typeable
  ( Typeable,
    cast,
    typeOf,
  )
import RIO (IORef, MonadIO (liftIO), Set, forM_, newIORef, on, readIORef, unless, writeIORef, (&))
import qualified RIO.List as L
import qualified RIO.Map as M
import qualified RIO.Set as S
import Test.Method.Behavior (Behave (Condition, MethodOf, thenMethod), thenAction, thenReturn)
import Test.Method.Matcher (ArgsMatcher (EachMatcher, args), Matcher)
import Unsafe.Coerce (unsafeCoerce)

newtype CallId = CallId {CallId -> Int
unCallId :: Int}
  deriving (CallId -> CallId -> Bool
(CallId -> CallId -> Bool)
-> (CallId -> CallId -> Bool) -> Eq CallId
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CallId -> CallId -> Bool
$c/= :: CallId -> CallId -> Bool
== :: CallId -> CallId -> Bool
$c== :: CallId -> CallId -> Bool
Eq, Eq CallId
Eq CallId
-> (CallId -> CallId -> Ordering)
-> (CallId -> CallId -> Bool)
-> (CallId -> CallId -> Bool)
-> (CallId -> CallId -> Bool)
-> (CallId -> CallId -> Bool)
-> (CallId -> CallId -> CallId)
-> (CallId -> CallId -> CallId)
-> Ord CallId
CallId -> CallId -> Bool
CallId -> CallId -> Ordering
CallId -> CallId -> CallId
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: CallId -> CallId -> CallId
$cmin :: CallId -> CallId -> CallId
max :: CallId -> CallId -> CallId
$cmax :: CallId -> CallId -> CallId
>= :: CallId -> CallId -> Bool
$c>= :: CallId -> CallId -> Bool
> :: CallId -> CallId -> Bool
$c> :: CallId -> CallId -> Bool
<= :: CallId -> CallId -> Bool
$c<= :: CallId -> CallId -> Bool
< :: CallId -> CallId -> Bool
$c< :: CallId -> CallId -> Bool
compare :: CallId -> CallId -> Ordering
$ccompare :: CallId -> CallId -> Ordering
$cp1Ord :: Eq CallId
Ord, Int -> CallId -> ShowS
[CallId] -> ShowS
CallId -> String
(Int -> CallId -> ShowS)
-> (CallId -> String) -> ([CallId] -> ShowS) -> Show CallId
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CallId] -> ShowS
$cshowList :: [CallId] -> ShowS
show :: CallId -> String
$cshow :: CallId -> String
showsPrec :: Int -> CallId -> ShowS
$cshowsPrec :: Int -> CallId -> ShowS
Show)

data CallArgs f m = CallArgs
  { CallArgs f m -> f m
methodName :: f m,
    CallArgs f m -> Matcher (Args m)
argsMatcher :: Matcher (Args m)
  }

data Call f m = Call
  { Call f m -> CallArgs f m
argsSpec :: CallArgs f m,
    Call f m -> m
retSpec :: m,
    Call f m -> [CallId]
dependCall :: [CallId]
  }

data SomeCall f where
  SomeCall :: IsMethodName f m => Call f m -> SomeCall f

data SomeMethodName f where
  SomeMethodName :: IsMethodName f m => f m -> SomeMethodName f

instance Eq (SomeMethodName f) where
  SomeMethodName f m
x == :: SomeMethodName f -> SomeMethodName f -> Bool
== SomeMethodName f m
y =
    case f m -> Maybe (f m)
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast f m
y of
      Just f m
y' -> f m
x f m -> f m -> Bool
forall a. Eq a => a -> a -> Bool
== f m
y'
      Maybe (f m)
Nothing -> Bool
False

instance Ord (SomeMethodName f) where
  compare :: SomeMethodName f -> SomeMethodName f -> Ordering
compare (SomeMethodName f m
x) (SomeMethodName f m
y) =
    TypeRep -> TypeRep -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (f m -> TypeRep
forall a. Typeable a => a -> TypeRep
typeOf f m
x) (f m -> TypeRep
forall a. Typeable a => a -> TypeRep
typeOf f m
y) Ordering -> Ordering -> Ordering
forall a. Semigroup a => a -> a -> a
<> case f m -> Maybe (f m)
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast f m
y of
      Just f m
y' -> f m -> f m -> Ordering
forall a. Ord a => a -> a -> Ordering
compare f m
x f m
y'
      Maybe (f m)
Nothing -> Ordering
LT

instance Show (SomeMethodName f) where
  show :: SomeMethodName f -> String
show (SomeMethodName f m
x) = f m -> String
forall a. Show a => a -> String
show f m
x

data MethodCallAssoc f where
  MethodCallAssoc ::
    forall f m.
    (Typeable (f m), Show (f m)) =>
    { ()
assocCalls :: [(CallId, Call f m)],
      MethodCallAssoc f -> IORef Int
assocCounter :: IORef Int
    } ->
    MethodCallAssoc f

-- | @'ProtocolEnv' f@ provides mock methods, where @f@ is a GADT functor that
--   represents the set of dependent methods.
data ProtocolEnv f = ProtocolEnv
  { ProtocolEnv f -> [(CallId, SomeCall f)]
callSpecs :: [(CallId, SomeCall f)],
    ProtocolEnv f -> Map (SomeMethodName f) (MethodCallAssoc f)
methodEnv :: M.Map (SomeMethodName f) (MethodCallAssoc f),
    ProtocolEnv f -> IORef (Set CallId)
calledIdSetRef :: IORef (Set CallId)
  }

newtype ProtocolM f a
  = ProtocolM (StateT ([(CallId, SomeCall f)], CallId) IO a)

deriving instance Functor (ProtocolM f)

deriving instance Applicative (ProtocolM f)

deriving instance Monad (ProtocolM f)

getMethodName :: SomeCall f -> SomeMethodName f
getMethodName :: SomeCall f -> SomeMethodName f
getMethodName (SomeCall Call {argsSpec :: forall (f :: * -> *) m. Call f m -> CallArgs f m
argsSpec = CallArgs {methodName :: forall (f :: * -> *) m. CallArgs f m -> f m
methodName = f m
name}}) = f m -> SomeMethodName f
forall (f :: * -> *) m. IsMethodName f m => f m -> SomeMethodName f
SomeMethodName f m
name

-- | Build 'ProtocolEnv' from Protocol DSL.
protocol :: ProtocolM f a -> IO (ProtocolEnv f)
protocol :: ProtocolM f a -> IO (ProtocolEnv f)
protocol (ProtocolM StateT ([(CallId, SomeCall f)], CallId) IO a
dsl) = do
  ([(CallId, SomeCall f)]
specs, CallId
_) <- StateT ([(CallId, SomeCall f)], CallId) IO a
-> ([(CallId, SomeCall f)], CallId)
-> IO ([(CallId, SomeCall f)], CallId)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT StateT ([(CallId, SomeCall f)], CallId) IO a
dsl ([], Int -> CallId
CallId Int
0)
  [(SomeMethodName f, MethodCallAssoc f)]
assocList <-
    [(CallId, SomeCall f)]
specs
      [(CallId, SomeCall f)]
-> ([(CallId, SomeCall f)]
    -> [(SomeMethodName f, CallId, SomeCall f)])
-> [(SomeMethodName f, CallId, SomeCall f)]
forall a b. a -> (a -> b) -> b
& ((CallId, SomeCall f) -> (SomeMethodName f, CallId, SomeCall f))
-> [(CallId, SomeCall f)]
-> [(SomeMethodName f, CallId, SomeCall f)]
forall a b. (a -> b) -> [a] -> [b]
map (\(CallId
callId, SomeCall f
call) -> (SomeCall f -> SomeMethodName f
forall (f :: * -> *). SomeCall f -> SomeMethodName f
getMethodName SomeCall f
call, CallId
callId, SomeCall f
call))
      [(SomeMethodName f, CallId, SomeCall f)]
-> ([(SomeMethodName f, CallId, SomeCall f)]
    -> [(SomeMethodName f, CallId, SomeCall f)])
-> [(SomeMethodName f, CallId, SomeCall f)]
forall a b. a -> (a -> b) -> b
& ((SomeMethodName f, CallId, SomeCall f)
 -> (SomeMethodName f, CallId))
-> [(SomeMethodName f, CallId, SomeCall f)]
-> [(SomeMethodName f, CallId, SomeCall f)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
L.sortOn (\(SomeMethodName f
x, CallId
y, SomeCall f
_) -> (SomeMethodName f
x, CallId
y))
      [(SomeMethodName f, CallId, SomeCall f)]
-> ([(SomeMethodName f, CallId, SomeCall f)]
    -> [[(SomeMethodName f, CallId, SomeCall f)]])
-> [[(SomeMethodName f, CallId, SomeCall f)]]
forall a b. a -> (a -> b) -> b
& ((SomeMethodName f, CallId, SomeCall f)
 -> (SomeMethodName f, CallId, SomeCall f) -> Bool)
-> [(SomeMethodName f, CallId, SomeCall f)]
-> [[(SomeMethodName f, CallId, SomeCall f)]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
L.groupBy (SomeMethodName f -> SomeMethodName f -> Bool
forall a. Eq a => a -> a -> Bool
(==) (SomeMethodName f -> SomeMethodName f -> Bool)
-> ((SomeMethodName f, CallId, SomeCall f) -> SomeMethodName f)
-> (SomeMethodName f, CallId, SomeCall f)
-> (SomeMethodName f, CallId, SomeCall f)
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (\(SomeMethodName f
x, CallId
_, SomeCall f
_) -> SomeMethodName f
x))
      [[(SomeMethodName f, CallId, SomeCall f)]]
-> ([[(SomeMethodName f, CallId, SomeCall f)]]
    -> IO [(SomeMethodName f, MethodCallAssoc f)])
-> IO [(SomeMethodName f, MethodCallAssoc f)]
forall a b. a -> (a -> b) -> b
& ([(SomeMethodName f, CallId, SomeCall f)]
 -> IO (SomeMethodName f, MethodCallAssoc f))
-> [[(SomeMethodName f, CallId, SomeCall f)]]
-> IO [(SomeMethodName f, MethodCallAssoc f)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
        ( \[(SomeMethodName f, CallId, SomeCall f)]
l ->
            case [(SomeMethodName f, CallId, SomeCall f)]
-> (SomeMethodName f, CallId, SomeCall f)
forall a. [a] -> a
head [(SomeMethodName f, CallId, SomeCall f)]
l of
              (SomeMethodName (f m
name :: f m), CallId
_, SomeCall f
_) -> do
                IORef Int
ref <- Int -> IO (IORef Int)
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef Int
0
                (SomeMethodName f, MethodCallAssoc f)
-> IO (SomeMethodName f, MethodCallAssoc f)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
                  ( f m -> SomeMethodName f
forall (f :: * -> *) m. IsMethodName f m => f m -> SomeMethodName f
SomeMethodName f m
name,
                    [(CallId, Call f m)] -> IORef Int -> MethodCallAssoc f
forall (f :: * -> *) m.
(Typeable (f m), Show (f m)) =>
[(CallId, Call f m)] -> IORef Int -> MethodCallAssoc f
MethodCallAssoc @f @m
                      [(CallId
callId, Call f m -> Call f m
forall a b. a -> b
unsafeCoerce Call f m
call) | (SomeMethodName f
_, CallId
callId, SomeCall call) <- [(SomeMethodName f, CallId, SomeCall f)]
l]
                      IORef Int
ref
                  )
        )
  IORef (Set CallId)
ref <- Set CallId -> IO (IORef (Set CallId))
forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef Set CallId
forall a. Set a
S.empty
  ProtocolEnv f -> IO (ProtocolEnv f)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ProtocolEnv :: forall (f :: * -> *).
[(CallId, SomeCall f)]
-> Map (SomeMethodName f) (MethodCallAssoc f)
-> IORef (Set CallId)
-> ProtocolEnv f
ProtocolEnv
      { callSpecs :: [(CallId, SomeCall f)]
callSpecs = [(CallId, SomeCall f)]
specs,
        methodEnv :: Map (SomeMethodName f) (MethodCallAssoc f)
methodEnv = [(SomeMethodName f, MethodCallAssoc f)]
-> Map (SomeMethodName f) (MethodCallAssoc f)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(SomeMethodName f, MethodCallAssoc f)]
assocList,
        calledIdSetRef :: IORef (Set CallId)
calledIdSetRef = IORef (Set CallId)
ref
      }

tick :: MonadIO m => IORef Int -> m Int
tick :: IORef Int -> m Int
tick IORef Int
ref = IO Int -> m Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> m Int) -> IO Int -> m Int
forall a b. (a -> b) -> a -> b
$ do
  Int
x <- IORef Int -> IO Int
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef Int
ref
  IORef Int -> Int -> IO ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef Int
ref (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
  Int -> IO Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
x

type IsMethodName f m = (Typeable (f m), Ord (f m), Show (f m))

-- | Get the mock method by method name.
--   Return a unstubed method (which throws exception for every call)
--   if the behavior of the method is unspecified by ProtocolEnv
lookupMock ::
  forall f m.
  (IsMethodName f m, Show (AsTuple (Args m)), TupleLike (Args m), Method m, MonadIO (Base m)) =>
  -- | name of method
  f m ->
  ProtocolEnv f ->
  m
lookupMock :: f m -> ProtocolEnv f -> m
lookupMock = (Args m -> String) -> f m -> ProtocolEnv f -> m
forall (f :: * -> *) m.
(IsMethodName f m, Method m, MonadIO (Base m)) =>
(Args m -> String) -> f m -> ProtocolEnv f -> m
lookupMockWithShow (AsTuple (Args m) -> String
forall a. Show a => a -> String
show (AsTuple (Args m) -> String)
-> (Args m -> AsTuple (Args m)) -> Args m -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Args m -> AsTuple (Args m)
forall a. TupleLike a => a -> AsTuple a
toTuple)

-- | Get the mock method by method name.
--   Return a unstubed method (which throws exception for every call)
--   if the behavior of the method is unspecified by ProtocolEnv.
--   Use this function only if you want to customize
--   show implementation for the argument of the method.
lookupMockWithShow ::
  forall f m.
  (IsMethodName f m, Method m, MonadIO (Base m)) =>
  -- | show function for the argument of method
  (Args m -> String) ->
  -- | name of method
  f m ->
  ProtocolEnv f ->
  m
lookupMockWithShow :: (Args m -> String) -> f m -> ProtocolEnv f -> m
lookupMockWithShow Args m -> String
fshow f m
name ProtocolEnv {[(CallId, SomeCall f)]
IORef (Set CallId)
Map (SomeMethodName f) (MethodCallAssoc f)
calledIdSetRef :: IORef (Set CallId)
methodEnv :: Map (SomeMethodName f) (MethodCallAssoc f)
callSpecs :: [(CallId, SomeCall f)]
calledIdSetRef :: forall (f :: * -> *). ProtocolEnv f -> IORef (Set CallId)
methodEnv :: forall (f :: * -> *).
ProtocolEnv f -> Map (SomeMethodName f) (MethodCallAssoc f)
callSpecs :: forall (f :: * -> *). ProtocolEnv f -> [(CallId, SomeCall f)]
..} =
  case SomeMethodName f
-> Map (SomeMethodName f) (MethodCallAssoc f)
-> Maybe (MethodCallAssoc f)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (f m -> SomeMethodName f
forall (f :: * -> *) m. IsMethodName f m => f m -> SomeMethodName f
SomeMethodName f m
name) Map (SomeMethodName f) (MethodCallAssoc f)
methodEnv of
    Maybe (MethodCallAssoc f)
Nothing -> (Args m -> Base m (Ret m)) -> m
forall method.
Method method =>
(Args method -> Base method (Ret method)) -> method
curryMethod ((Args m -> Base m (Ret m)) -> m)
-> (Args m -> Base m (Ret m)) -> m
forall a b. (a -> b) -> a -> b
$ \Args m
_ ->
      String -> Base m (Ret m)
forall a. HasCallStack => String -> a
error (String -> Base m (Ret m)) -> String -> Base m (Ret m)
forall a b. (a -> b) -> a -> b
$
        String
"0-th call of method " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> f m -> String
forall a. Show a => a -> String
show f m
name String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is unspecified"
    Just MethodCallAssoc {assocCalls :: ()
assocCalls = [(CallId, Call f m)]
assocCalls', IORef Int
assocCounter :: IORef Int
assocCounter :: forall (f :: * -> *). MethodCallAssoc f -> IORef Int
..} ->
      let assocCalls :: [(CallId, Call f m)]
assocCalls = [(CallId, Call f m)] -> [(CallId, Call f m)]
forall a b. a -> b
unsafeCoerce [(CallId, Call f m)]
assocCalls' :: [(CallId, Call f m)]
       in (Args m -> Base m (Ret m)) -> m
forall method.
Method method =>
(Args method -> Base method (Ret method)) -> method
curryMethod ((Args m -> Base m (Ret m)) -> m)
-> (Args m -> Base m (Ret m)) -> m
forall a b. (a -> b) -> a -> b
$ \Args m
xs -> do
            Int
i <- IORef Int -> Base m Int
forall (m :: * -> *). MonadIO m => IORef Int -> m Int
tick IORef Int
assocCounter
            Bool -> Base m () -> Base m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< [(CallId, Call f m)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(CallId, Call f m)]
assocCalls) (Base m () -> Base m ()) -> Base m () -> Base m ()
forall a b. (a -> b) -> a -> b
$
              String -> Base m ()
forall a. HasCallStack => String -> a
error (String -> Base m ()) -> String -> Base m ()
forall a b. (a -> b) -> a -> b
$ Int -> String
forall a. Show a => a -> String
show Int
i String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"-th call of method " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> f m -> String
forall a. Show a => a -> String
show f m
name String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is unspecified"
            let (CallId
callId, Call {m
[CallId]
CallArgs f m
dependCall :: [CallId]
retSpec :: m
argsSpec :: CallArgs f m
dependCall :: forall (f :: * -> *) m. Call f m -> [CallId]
retSpec :: forall (f :: * -> *) m. Call f m -> m
argsSpec :: forall (f :: * -> *) m. Call f m -> CallArgs f m
..}) = [(CallId, Call f m)]
assocCalls [(CallId, Call f m)] -> Int -> (CallId, Call f m)
forall a. [a] -> Int -> a
!! Int
i
                CallArgs {f m
Matcher (Args m)
argsMatcher :: Matcher (Args m)
methodName :: f m
argsMatcher :: forall (f :: * -> *) m. CallArgs f m -> Matcher (Args m)
methodName :: forall (f :: * -> *) m. CallArgs f m -> f m
..} = CallArgs f m
argsSpec
            Bool -> Base m () -> Base m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Matcher (Args m)
argsMatcher Args m
xs) (Base m () -> Base m ()) -> Base m () -> Base m ()
forall a b. (a -> b) -> a -> b
$
              String -> Base m ()
forall a. HasCallStack => String -> a
error (String -> Base m ()) -> String -> Base m ()
forall a b. (a -> b) -> a -> b
$
                String
"unexpected argument of " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
i String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"-th call of method " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> f m -> String
forall a. Show a => a -> String
show f m
name String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
": "
                  String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Args m -> String
fshow Args m
xs
            Set CallId
calledIdSet <- IO (Set CallId) -> Base m (Set CallId)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Set CallId) -> Base m (Set CallId))
-> IO (Set CallId) -> Base m (Set CallId)
forall a b. (a -> b) -> a -> b
$ IORef (Set CallId) -> IO (Set CallId)
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef (Set CallId)
calledIdSetRef
            [CallId] -> (CallId -> Base m ()) -> Base m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [CallId]
dependCall ((CallId -> Base m ()) -> Base m ())
-> (CallId -> Base m ()) -> Base m ()
forall a b. (a -> b) -> a -> b
$ \CallId
callId' -> do
              Bool -> Base m () -> Base m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (CallId -> Set CallId -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member CallId
callId' Set CallId
calledIdSet) (Base m () -> Base m ()) -> Base m () -> Base m ()
forall a b. (a -> b) -> a -> b
$
                let call :: SomeCall f
call = Maybe (SomeCall f) -> SomeCall f
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (SomeCall f) -> SomeCall f)
-> Maybe (SomeCall f) -> SomeCall f
forall a b. (a -> b) -> a -> b
$ CallId -> [(CallId, SomeCall f)] -> Maybe (SomeCall f)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
L.lookup CallId
callId' [(CallId, SomeCall f)]
callSpecs
                 in String -> Base m ()
forall a. HasCallStack => String -> a
error (String -> Base m ()) -> String -> Base m ()
forall a b. (a -> b) -> a -> b
$ String
"dependent method " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> SomeMethodName f -> String
forall a. Show a => a -> String
show (SomeCall f -> SomeMethodName f
forall (f :: * -> *). SomeCall f -> SomeMethodName f
getMethodName SomeCall f
call) String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" is not called: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> CallId -> String
forall a. Show a => a -> String
show CallId
callId'
            IO () -> Base m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Base m ()) -> IO () -> Base m ()
forall a b. (a -> b) -> a -> b
$ IORef (Set CallId) -> Set CallId -> IO ()
forall (m :: * -> *) a. MonadIO m => IORef a -> a -> m ()
writeIORef IORef (Set CallId)
calledIdSetRef (Set CallId -> IO ()) -> Set CallId -> IO ()
forall a b. (a -> b) -> a -> b
$! CallId -> Set CallId -> Set CallId
forall a. Ord a => a -> Set a -> Set a
S.insert CallId
callId Set CallId
calledIdSet
            m -> Args m -> Base m (Ret m)
forall method.
Method method =>
method -> Args method -> Base method (Ret method)
uncurryMethod m
retSpec Args m
xs

-- | Declare a method call specification. It returns the call id of the method call.
decl :: (IsMethodName f m) => Call f m -> ProtocolM f CallId
decl :: Call f m -> ProtocolM f CallId
decl Call f m
call = StateT ([(CallId, SomeCall f)], CallId) IO CallId
-> ProtocolM f CallId
forall (f :: * -> *) a.
StateT ([(CallId, SomeCall f)], CallId) IO a -> ProtocolM f a
ProtocolM (StateT ([(CallId, SomeCall f)], CallId) IO CallId
 -> ProtocolM f CallId)
-> StateT ([(CallId, SomeCall f)], CallId) IO CallId
-> ProtocolM f CallId
forall a b. (a -> b) -> a -> b
$
  (([(CallId, SomeCall f)], CallId)
 -> (CallId, ([(CallId, SomeCall f)], CallId)))
-> StateT ([(CallId, SomeCall f)], CallId) IO CallId
forall (m :: * -> *) s a. Monad m => (s -> (a, s)) -> StateT s m a
state ((([(CallId, SomeCall f)], CallId)
  -> (CallId, ([(CallId, SomeCall f)], CallId)))
 -> StateT ([(CallId, SomeCall f)], CallId) IO CallId)
-> (([(CallId, SomeCall f)], CallId)
    -> (CallId, ([(CallId, SomeCall f)], CallId)))
-> StateT ([(CallId, SomeCall f)], CallId) IO CallId
forall a b. (a -> b) -> a -> b
$ \([(CallId, SomeCall f)]
l, callId :: CallId
callId@(CallId Int
i)) ->
    (CallId
callId, ((CallId
callId, Call f m -> SomeCall f
forall (f :: * -> *) m. IsMethodName f m => Call f m -> SomeCall f
SomeCall Call f m
call) (CallId, SomeCall f)
-> [(CallId, SomeCall f)] -> [(CallId, SomeCall f)]
forall a. a -> [a] -> [a]
: [(CallId, SomeCall f)]
l, Int -> CallId
CallId (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)))

-- | Specify the argument condition of a method call
whenArgs :: ArgsMatcher (Args m) => f m -> EachMatcher (Args m) -> CallArgs f m
whenArgs :: f m -> EachMatcher (Args m) -> CallArgs f m
whenArgs f m
name EachMatcher (Args m)
matcher = CallArgs :: forall (f :: * -> *) m. f m -> Matcher (Args m) -> CallArgs f m
CallArgs {methodName :: f m
methodName = f m
name, argsMatcher :: Matcher (Args m)
argsMatcher = EachMatcher (Args m) -> Matcher (Args m)
forall a. ArgsMatcher a => EachMatcher a -> Matcher a
args EachMatcher (Args m)
matcher}

instance Behave (Call f m) where
  type Condition (Call f m) = CallArgs f m
  type MethodOf (Call f m) = m
  thenMethod :: Condition (Call f m) -> MethodOf (Call f m) -> Call f m
thenMethod Condition (Call f m)
lhs MethodOf (Call f m)
m =
    Call :: forall (f :: * -> *) m. CallArgs f m -> m -> [CallId] -> Call f m
Call
      { argsSpec :: CallArgs f m
argsSpec = Condition (Call f m)
CallArgs f m
lhs,
        retSpec :: m
retSpec = m
MethodOf (Call f m)
m,
        dependCall :: [CallId]
dependCall = []
      }

-- | Specify on which method calls the call depends.
dependsOn :: Call f m -> [CallId] -> Call f m
dependsOn :: Call f m -> [CallId] -> Call f m
dependsOn Call f m
call [CallId]
depends = Call f m
call {dependCall :: [CallId]
dependCall = [CallId]
depends [CallId] -> [CallId] -> [CallId]
forall a. Semigroup a => a -> a -> a
<> Call f m -> [CallId]
forall (f :: * -> *) m. Call f m -> [CallId]
dependCall Call f m
call}

-- | Verify that all method calls specified by Protocol DSL are fired.
verify :: ProtocolEnv f -> IO ()
verify :: ProtocolEnv f -> IO ()
verify ProtocolEnv {[(CallId, SomeCall f)]
IORef (Set CallId)
Map (SomeMethodName f) (MethodCallAssoc f)
calledIdSetRef :: IORef (Set CallId)
methodEnv :: Map (SomeMethodName f) (MethodCallAssoc f)
callSpecs :: [(CallId, SomeCall f)]
calledIdSetRef :: forall (f :: * -> *). ProtocolEnv f -> IORef (Set CallId)
methodEnv :: forall (f :: * -> *).
ProtocolEnv f -> Map (SomeMethodName f) (MethodCallAssoc f)
callSpecs :: forall (f :: * -> *). ProtocolEnv f -> [(CallId, SomeCall f)]
..} = do
  [(SomeMethodName f, MethodCallAssoc f)]
-> ((SomeMethodName f, MethodCallAssoc f) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map (SomeMethodName f) (MethodCallAssoc f)
-> [(SomeMethodName f, MethodCallAssoc f)]
forall k a. Map k a -> [(k, a)]
M.assocs Map (SomeMethodName f) (MethodCallAssoc f)
methodEnv) (((SomeMethodName f, MethodCallAssoc f) -> IO ()) -> IO ())
-> ((SomeMethodName f, MethodCallAssoc f) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(SomeMethodName f
name, MethodCallAssoc {[(CallId, Call f m)]
IORef Int
assocCounter :: IORef Int
assocCalls :: [(CallId, Call f m)]
assocCounter :: forall (f :: * -> *). MethodCallAssoc f -> IORef Int
assocCalls :: ()
..}) -> do
    Int
n <- IORef Int -> IO Int
forall (m :: * -> *) a. MonadIO m => IORef a -> m a
readIORef IORef Int
assocCounter
    let expected :: Int
expected = [(CallId, Call f m)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(CallId, Call f m)]
assocCalls
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Int
n Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
expected) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$
        String
"method " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> SomeMethodName f -> String
forall a. Show a => a -> String
show SomeMethodName f
name String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" should be called " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
expected
          String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" times, but actually is called "
          String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n
          String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
" times"