@@ -51,6 +51,19 @@ def ast_Call(a, b, c):
5151 return ast .Call (a , b , c , None , None )
5252
5353
54+ def ast_Call_helper (func_name , * args , ** kwargs ):
55+ """
56+ func_name: str
57+ args: Iterable[ast.expr]
58+ kwargs: Dict[str,ast.expr]
59+ """
60+ return ast .Call (
61+ ast .Name (func_name , ast .Load ()),
62+ list (args ),
63+ [ast .keyword (key , val ) for key , val in kwargs .items ()],
64+ )
65+
66+
5467class AssertionRewritingHook (object ):
5568 """PEP302 Import hook which rewrites asserts."""
5669
@@ -828,6 +841,12 @@ def visit_Assert(self, assert_):
828841 self .push_format_context ()
829842 # Rewrite assert into a bunch of statements.
830843 top_condition , explanation = self .visit (assert_ .test )
844+ # Check if directly asserting None, in order to warn [Issue #3191]
845+ self .statements .append (
846+ self .warn_about_none_ast (
847+ top_condition , module_path = self .module_path , lineno = assert_ .lineno
848+ )
849+ )
831850 # Create failure message.
832851 body = self .on_failure
833852 negation = ast .UnaryOp (ast .Not (), top_condition )
@@ -858,6 +877,45 @@ def visit_Assert(self, assert_):
858877 set_location (stmt , assert_ .lineno , assert_ .col_offset )
859878 return self .statements
860879
880+ def warn_about_none_ast (self , node , module_path , lineno ):
881+ """Returns an ast warning if node is None with the following statement:
882+ if node is None:
883+ from _pytest.warning_types import PytestWarning
884+ import warnings
885+ warnings.warn_explicit(
886+ PytestWarning('assertion the value None, Please use "assert is None"'),
887+ category=None,
888+ # filename=str(self.module_path),
889+ filename=__file__
890+ lineno=node.lineno,
891+ )
892+ """
893+
894+ warning_msg = ast .Str (
895+ 'Asserting the value None directly, Please use "assert is None" to eliminate ambiguity'
896+ )
897+ AST_NONE = ast .NameConstant (None )
898+ val_is_none = ast .Compare (node , [ast .Is ()], [AST_NONE ])
899+ import_warnings = ast .ImportFrom (
900+ module = "warnings" , names = [ast .alias ("warn_explicit" , None )], level = 0
901+ )
902+ import_pytest_warning = ast .ImportFrom (
903+ module = "pytest" , names = [ast .alias ("PytestWarning" , None )], level = 0
904+ )
905+ pytest_warning = ast_Call_helper ("PytestWarning" , warning_msg )
906+ # This won't work because this isn't the same "self" as an AssertionRewriter!
907+ # ast_filename = improved_ast_Call('str',ast.Attribute('self','module_path',ast.Load).module_path)
908+ warn = ast_Call_helper (
909+ "warn_explicit" ,
910+ pytest_warning ,
911+ category = AST_NONE ,
912+ filename = ast .Str (str (module_path )),
913+ lineno = ast .Num (lineno ),
914+ )
915+ return ast .If (
916+ val_is_none , [import_warnings , import_pytest_warning , ast .Expr (warn )], []
917+ )
918+
861919 def visit_Name (self , name ):
862920 # Display the repr of the name if it's a local variable or
863921 # _should_repr_global_name() thinks it's acceptable.
0 commit comments