33import unittest
44
55from unittest .mock import (ANY , call , AsyncMock , patch , MagicMock ,
6- create_autospec , _AwaitEvent )
6+ create_autospec , _AwaitEvent , sentinel , _CallList )
77
88
99def tearDownModule ():
@@ -595,11 +595,173 @@ class AsyncMockAssert(unittest.TestCase):
595595 def setUp (self ):
596596 self .mock = AsyncMock ()
597597
598- async def _runnable_test (self , * args ):
599- if not args :
600- await self .mock ()
601- else :
602- await self .mock (* args )
598+ async def _runnable_test (self , * args , ** kwargs ):
599+ await self .mock (* args , ** kwargs )
600+
601+ async def _await_coroutine (self , coroutine ):
602+ return await coroutine
603+
604+ def test_assert_called_but_not_awaited (self ):
605+ mock = AsyncMock (AsyncClass )
606+ with self .assertWarns (RuntimeWarning ):
607+ # Will raise a warning because never awaited
608+ mock .async_method ()
609+ self .assertTrue (asyncio .iscoroutinefunction (mock .async_method ))
610+ mock .async_method .assert_called ()
611+ mock .async_method .assert_called_once ()
612+ mock .async_method .assert_called_once_with ()
613+ with self .assertRaises (AssertionError ):
614+ mock .assert_awaited ()
615+ with self .assertRaises (AssertionError ):
616+ mock .async_method .assert_awaited ()
617+
618+ def test_assert_called_then_awaited (self ):
619+ mock = AsyncMock (AsyncClass )
620+ mock_coroutine = mock .async_method ()
621+ mock .async_method .assert_called ()
622+ mock .async_method .assert_called_once ()
623+ mock .async_method .assert_called_once_with ()
624+ with self .assertRaises (AssertionError ):
625+ mock .async_method .assert_awaited ()
626+
627+ asyncio .run (self ._await_coroutine (mock_coroutine ))
628+ # Assert we haven't re-called the function
629+ mock .async_method .assert_called_once ()
630+ mock .async_method .assert_awaited ()
631+ mock .async_method .assert_awaited_once ()
632+ mock .async_method .assert_awaited_once_with ()
633+
634+ def test_assert_called_and_awaited_at_same_time (self ):
635+ with self .assertRaises (AssertionError ):
636+ self .mock .assert_awaited ()
637+
638+ with self .assertRaises (AssertionError ):
639+ self .mock .assert_called ()
640+
641+ asyncio .run (self ._runnable_test ())
642+ self .mock .assert_called_once ()
643+ self .mock .assert_awaited_once ()
644+
645+ def test_assert_called_twice_and_awaited_once (self ):
646+ mock = AsyncMock (AsyncClass )
647+ coroutine = mock .async_method ()
648+ with self .assertWarns (RuntimeWarning ):
649+ # The first call will be awaited so no warning there
650+ # But this call will never get awaited, so it will warn here
651+ mock .async_method ()
652+ with self .assertRaises (AssertionError ):
653+ mock .async_method .assert_awaited ()
654+ mock .async_method .assert_called ()
655+ asyncio .run (self ._await_coroutine (coroutine ))
656+ mock .async_method .assert_awaited ()
657+ mock .async_method .assert_awaited_once ()
658+
659+ def test_assert_called_once_and_awaited_twice (self ):
660+ mock = AsyncMock (AsyncClass )
661+ coroutine = mock .async_method ()
662+ mock .async_method .assert_called_once ()
663+ asyncio .run (self ._await_coroutine (coroutine ))
664+ with self .assertRaises (RuntimeError ):
665+ # Cannot reuse already awaited coroutine
666+ asyncio .run (self ._await_coroutine (coroutine ))
667+ mock .async_method .assert_awaited ()
668+
669+ def test_assert_awaited_but_not_called (self ):
670+ with self .assertRaises (AssertionError ):
671+ self .mock .assert_awaited ()
672+ with self .assertRaises (AssertionError ):
673+ self .mock .assert_called ()
674+ with self .assertRaises (TypeError ):
675+ # You cannot await an AsyncMock, it must be a coroutine
676+ asyncio .run (self ._await_coroutine (self .mock ))
677+
678+ with self .assertRaises (AssertionError ):
679+ self .mock .assert_awaited ()
680+ with self .assertRaises (AssertionError ):
681+ self .mock .assert_called ()
682+
683+ def test_assert_has_calls_not_awaits (self ):
684+ kalls = [call ('foo' )]
685+ with self .assertWarns (RuntimeWarning ):
686+ # Will raise a warning because never awaited
687+ self .mock ('foo' )
688+ self .mock .assert_has_calls (kalls )
689+ with self .assertRaises (AssertionError ):
690+ self .mock .assert_has_awaits (kalls )
691+
692+ def test_assert_has_mock_calls_on_async_mock_no_spec (self ):
693+ with self .assertWarns (RuntimeWarning ):
694+ # Will raise a warning because never awaited
695+ self .mock ()
696+ kalls_empty = [('' , (), {})]
697+ self .assertEqual (self .mock .mock_calls , kalls_empty )
698+
699+ with self .assertWarns (RuntimeWarning ):
700+ # Will raise a warning because never awaited
701+ self .mock ('foo' )
702+ self .mock ('baz' )
703+ mock_kalls = ([call (), call ('foo' ), call ('baz' )])
704+ self .assertEqual (self .mock .mock_calls , mock_kalls )
705+
706+ def test_assert_has_mock_calls_on_async_mock_with_spec (self ):
707+ a_class_mock = AsyncMock (AsyncClass )
708+ with self .assertWarns (RuntimeWarning ):
709+ # Will raise a warning because never awaited
710+ a_class_mock .async_method ()
711+ kalls_empty = [('' , (), {})]
712+ self .assertEqual (a_class_mock .async_method .mock_calls , kalls_empty )
713+ self .assertEqual (a_class_mock .mock_calls , [call .async_method ()])
714+
715+ with self .assertWarns (RuntimeWarning ):
716+ # Will raise a warning because never awaited
717+ a_class_mock .async_method (1 , 2 , 3 , a = 4 , b = 5 )
718+ method_kalls = [call (), call (1 , 2 , 3 , a = 4 , b = 5 )]
719+ mock_kalls = [call .async_method (), call .async_method (1 , 2 , 3 , a = 4 , b = 5 )]
720+ self .assertEqual (a_class_mock .async_method .mock_calls , method_kalls )
721+ self .assertEqual (a_class_mock .mock_calls , mock_kalls )
722+
723+ def test_async_method_calls_recorded (self ):
724+ with self .assertWarns (RuntimeWarning ):
725+ # Will raise warnings because never awaited
726+ self .mock .something (3 , fish = None )
727+ self .mock .something_else .something (6 , cake = sentinel .Cake )
728+
729+ self .assertEqual (self .mock .method_calls , [
730+ ("something" , (3 ,), {'fish' : None }),
731+ ("something_else.something" , (6 ,), {'cake' : sentinel .Cake })
732+ ],
733+ "method calls not recorded correctly" )
734+ self .assertEqual (self .mock .something_else .method_calls ,
735+ [("something" , (6 ,), {'cake' : sentinel .Cake })],
736+ "method calls not recorded correctly" )
737+
738+ def test_async_arg_lists (self ):
739+ def assert_attrs (mock ):
740+ names = ('call_args_list' , 'method_calls' , 'mock_calls' )
741+ for name in names :
742+ attr = getattr (mock , name )
743+ self .assertIsInstance (attr , _CallList )
744+ self .assertIsInstance (attr , list )
745+ self .assertEqual (attr , [])
746+
747+ assert_attrs (self .mock )
748+ with self .assertWarns (RuntimeWarning ):
749+ # Will raise warnings because never awaited
750+ self .mock ()
751+ self .mock (1 , 2 )
752+ self .mock (a = 3 )
753+
754+ self .mock .reset_mock ()
755+ assert_attrs (self .mock )
756+
757+ a_mock = AsyncMock (AsyncClass )
758+ with self .assertWarns (RuntimeWarning ):
759+ # Will raise warnings because never awaited
760+ a_mock .async_method ()
761+ a_mock .async_method (1 , a = 3 )
762+
763+ a_mock .reset_mock ()
764+ assert_attrs (a_mock )
603765
604766 def test_assert_awaited (self ):
605767 with self .assertRaises (AssertionError ):
@@ -645,20 +807,20 @@ def test_assert_awaited_once_with(self):
645807
646808 def test_assert_any_wait (self ):
647809 with self .assertRaises (AssertionError ):
648- self .mock .assert_any_await ('NormalFoo ' )
810+ self .mock .assert_any_await ('foo ' )
649811
650- asyncio .run (self ._runnable_test ('foo ' ))
812+ asyncio .run (self ._runnable_test ('baz ' ))
651813 with self .assertRaises (AssertionError ):
652- self .mock .assert_any_await ('NormalFoo ' )
814+ self .mock .assert_any_await ('foo ' )
653815
654- asyncio .run (self ._runnable_test ('NormalFoo ' ))
655- self .mock .assert_any_await ('NormalFoo ' )
816+ asyncio .run (self ._runnable_test ('foo ' ))
817+ self .mock .assert_any_await ('foo ' )
656818
657819 asyncio .run (self ._runnable_test ('SomethingElse' ))
658- self .mock .assert_any_await ('NormalFoo ' )
820+ self .mock .assert_any_await ('foo ' )
659821
660822 def test_assert_has_awaits_no_order (self ):
661- calls = [call ('NormalFoo ' ), call ('baz' )]
823+ calls = [call ('foo ' ), call ('baz' )]
662824
663825 with self .assertRaises (AssertionError ) as cm :
664826 self .mock .assert_has_awaits (calls )
@@ -668,7 +830,7 @@ def test_assert_has_awaits_no_order(self):
668830 with self .assertRaises (AssertionError ):
669831 self .mock .assert_has_awaits (calls )
670832
671- asyncio .run (self ._runnable_test ('NormalFoo ' ))
833+ asyncio .run (self ._runnable_test ('foo ' ))
672834 with self .assertRaises (AssertionError ):
673835 self .mock .assert_has_awaits (calls )
674836
@@ -703,19 +865,19 @@ async def _custom_mock_runnable_test(*args):
703865 mock_with_spec .assert_any_await (ANY , 1 )
704866
705867 def test_assert_has_awaits_ordered (self ):
706- calls = [call ('NormalFoo ' ), call ('baz' )]
868+ calls = [call ('foo ' ), call ('baz' )]
707869 with self .assertRaises (AssertionError ):
708870 self .mock .assert_has_awaits (calls , any_order = True )
709871
710872 asyncio .run (self ._runnable_test ('baz' ))
711873 with self .assertRaises (AssertionError ):
712874 self .mock .assert_has_awaits (calls , any_order = True )
713875
714- asyncio .run (self ._runnable_test ('foo ' ))
876+ asyncio .run (self ._runnable_test ('bamf ' ))
715877 with self .assertRaises (AssertionError ):
716878 self .mock .assert_has_awaits (calls , any_order = True )
717879
718- asyncio .run (self ._runnable_test ('NormalFoo ' ))
880+ asyncio .run (self ._runnable_test ('foo ' ))
719881 self .mock .assert_has_awaits (calls , any_order = True )
720882
721883 asyncio .run (self ._runnable_test ('qux' ))
0 commit comments