Skip to content

Commit 734c435

Browse files
authored
Merge pull request #2870 from Perlence/rewrite-python-37-docstring
Adapt the Python 3.7 AST changes
2 parents 0b540f9 + 27bb2ec commit 734c435

File tree

2 files changed

+46
-18
lines changed

2 files changed

+46
-18
lines changed

_pytest/assertion/rewrite.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -595,23 +595,26 @@ def run(self, mod):
595595
# docstrings and __future__ imports.
596596
aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
597597
ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
598-
expect_docstring = True
598+
doc = getattr(mod, "docstring", None)
599+
expect_docstring = doc is None
600+
if doc is not None and self.is_rewrite_disabled(doc):
601+
return
599602
pos = 0
600-
lineno = 0
603+
lineno = 1
601604
for item in mod.body:
602605
if (expect_docstring and isinstance(item, ast.Expr) and
603606
isinstance(item.value, ast.Str)):
604607
doc = item.value.s
605-
if "PYTEST_DONT_REWRITE" in doc:
606-
# The module has disabled assertion rewriting.
608+
if self.is_rewrite_disabled(doc):
607609
return
608-
lineno += len(doc) - 1
609610
expect_docstring = False
610611
elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
611612
item.module != "__future__"):
612613
lineno = item.lineno
613614
break
614615
pos += 1
616+
else:
617+
lineno = item.lineno
615618
imports = [ast.Import([alias], lineno=lineno, col_offset=0)
616619
for alias in aliases]
617620
mod.body[pos:pos] = imports
@@ -637,6 +640,9 @@ def run(self, mod):
637640
not isinstance(field, ast.expr)):
638641
nodes.append(field)
639642

643+
def is_rewrite_disabled(self, docstring):
644+
return "PYTEST_DONT_REWRITE" in docstring
645+
640646
def variable(self):
641647
"""Get a new variable."""
642648
# Use a character invalid in python identifiers to avoid clashing.

testing/test_assertrewrite.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,18 @@ class TestAssertionRewrite(object):
6565
def test_place_initial_imports(self):
6666
s = """'Doc string'\nother = stuff"""
6767
m = rewrite(s)
68-
assert isinstance(m.body[0], ast.Expr)
69-
assert isinstance(m.body[0].value, ast.Str)
70-
for imp in m.body[1:3]:
68+
# Module docstrings in 3.7 are part of Module node, it's not in the body
69+
# so we remove it so the following body items have the same indexes on
70+
# all Python versions
71+
if sys.version_info < (3, 7):
72+
assert isinstance(m.body[0], ast.Expr)
73+
assert isinstance(m.body[0].value, ast.Str)
74+
del m.body[0]
75+
for imp in m.body[0:2]:
7176
assert isinstance(imp, ast.Import)
7277
assert imp.lineno == 2
7378
assert imp.col_offset == 0
74-
assert isinstance(m.body[3], ast.Assign)
79+
assert isinstance(m.body[2], ast.Assign)
7580
s = """from __future__ import with_statement\nother_stuff"""
7681
m = rewrite(s)
7782
assert isinstance(m.body[0], ast.ImportFrom)
@@ -80,16 +85,29 @@ def test_place_initial_imports(self):
8085
assert imp.lineno == 2
8186
assert imp.col_offset == 0
8287
assert isinstance(m.body[3], ast.Expr)
88+
s = """'doc string'\nfrom __future__ import with_statement"""
89+
m = rewrite(s)
90+
if sys.version_info < (3, 7):
91+
assert isinstance(m.body[0], ast.Expr)
92+
assert isinstance(m.body[0].value, ast.Str)
93+
del m.body[0]
94+
assert isinstance(m.body[0], ast.ImportFrom)
95+
for imp in m.body[1:3]:
96+
assert isinstance(imp, ast.Import)
97+
assert imp.lineno == 2
98+
assert imp.col_offset == 0
8399
s = """'doc string'\nfrom __future__ import with_statement\nother"""
84100
m = rewrite(s)
85-
assert isinstance(m.body[0], ast.Expr)
86-
assert isinstance(m.body[0].value, ast.Str)
87-
assert isinstance(m.body[1], ast.ImportFrom)
88-
for imp in m.body[2:4]:
101+
if sys.version_info < (3, 7):
102+
assert isinstance(m.body[0], ast.Expr)
103+
assert isinstance(m.body[0].value, ast.Str)
104+
del m.body[0]
105+
assert isinstance(m.body[0], ast.ImportFrom)
106+
for imp in m.body[1:3]:
89107
assert isinstance(imp, ast.Import)
90108
assert imp.lineno == 3
91109
assert imp.col_offset == 0
92-
assert isinstance(m.body[4], ast.Expr)
110+
assert isinstance(m.body[3], ast.Expr)
93111
s = """from . import relative\nother_stuff"""
94112
m = rewrite(s)
95113
for imp in m.body[0:2]:
@@ -101,10 +119,14 @@ def test_place_initial_imports(self):
101119
def test_dont_rewrite(self):
102120
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
103121
m = rewrite(s)
104-
assert len(m.body) == 2
105-
assert isinstance(m.body[0].value, ast.Str)
106-
assert isinstance(m.body[1], ast.Assert)
107-
assert m.body[1].msg is None
122+
if sys.version_info < (3, 7):
123+
assert len(m.body) == 2
124+
assert isinstance(m.body[0], ast.Expr)
125+
assert isinstance(m.body[0].value, ast.Str)
126+
del m.body[0]
127+
else:
128+
assert len(m.body) == 1
129+
assert m.body[0].msg is None
108130

109131
def test_name(self):
110132
def f():

0 commit comments

Comments
 (0)