diff --git a/Lib/test/test_ast.py b/Lib/test/test_ast.py index 3fd982c79ea21e..9e94cd5ca0fe1e 100644 --- a/Lib/test/test_ast.py +++ b/Lib/test/test_ast.py @@ -1,7 +1,9 @@ import ast import dis import os +import random import sys +import tokenize import unittest import warnings import weakref @@ -25,6 +27,9 @@ def to_tuple(t): result.append(to_tuple(getattr(t, f))) return tuple(result) +STDLIB = os.path.dirname(ast.__file__) +STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")] +STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"]) # These tests are compiled through "exec" # There should be at least one test per statement @@ -654,6 +659,70 @@ def test_ast_asdl_signature(self): expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}" self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions) + def test_compare_basis(self): + self.assertEqual(ast.parse("x = 10"), ast.parse("x = 10")) + self.assertNotEqual(ast.parse("x = 10"), ast.parse("")) + self.assertNotEqual(ast.parse("x = 10"), ast.parse("x")) + self.assertNotEqual(ast.parse("x = 10;y = 20"), ast.parse("class C:pass")) + + def test_compare_literals(self): + constants = (-20, 20, 20.0, 1, 1.0, True, 0, False, frozenset(), tuple(), "ABCD", "abcd", "中文字", 1e1000, -1e1000) + for next_index, constant in enumerate(constants[:-1], 1): + next_constant = constants[next_index] + with self.subTest(literal=constant, next_literal=next_constant): + self.assertEqual(ast.Constant(constant), ast.Constant(constant)) + self.assertNotEqual(ast.Constant(constant), ast.Constant(next_constant)) + + same_looking_literal_cases = [{1, 1.0, True, 1+0j}, {0, 0.0, False, 0+0j}] + for same_looking_literals in same_looking_literal_cases: + for literal in same_looking_literals: + for same_looking_literal in same_looking_literals - {literal}: + self.assertNotEqual(ast.Constant(literal), ast.Constant(same_looking_literal)) + + def test_compare_operators(self): + self.assertEqual(ast.Add(), ast.Add()) + self.assertEqual(ast.Sub(), ast.Sub()) + + self.assertNotEqual(ast.Add(), ast.Sub()) + self.assertNotEqual(ast.Add(), ast.Constant()) + + def test_compare_stdlib(self): + if support.is_resource_enabled("cpu"): + files = STDLIB_FILES + else: + files = random.sample(STDLIB_FILES, 10) + + for module in files: + with self.subTest(module): + fn = os.path.join(STDLIB, module) + with tokenize.open(fn) as fp: + source = fp.read() + a = ast.parse(source, fn) + b = ast.parse(source, fn) + self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}") + self.assertFalse(a != b) + + def test_exec_compare(self): + for source in exec_tests: + a = ast.parse(source, mode="exec") + b = ast.parse(source, mode="exec") + self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}") + self.assertFalse(a != b) + + def test_single_compare(self): + for source in single_tests: + a = ast.parse(source, mode="single") + b = ast.parse(source, mode="single") + self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}") + self.assertFalse(a != b) + + def test_eval_compare(self): + for source in eval_tests: + a = ast.parse(source, mode="eval") + b = ast.parse(source, mode="eval") + self.assertEqual(a, b, f"{ast.dump(a)} != {ast.dump(b)}") + self.assertFalse(a != b) + class ASTHelpers_Test(unittest.TestCase): maxDiff = None @@ -1369,12 +1438,9 @@ def test_nameconstant(self): self.expr(ast.NameConstant(4)) def test_stdlib_validates(self): - stdlib = os.path.dirname(ast.__file__) - tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")] - tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"]) - for module in tests: + for module in STDLIB_FILES: with self.subTest(module): - fn = os.path.join(stdlib, module) + fn = os.path.join(STDLIB, module) with open(fn, "r", encoding="utf-8") as fp: source = fp.read() mod = ast.parse(source, fn) diff --git a/Misc/ACKS b/Misc/ACKS index ce100b972aa053..ede229ea0c7b50 100644 --- a/Misc/ACKS +++ b/Misc/ACKS @@ -1043,6 +1043,7 @@ Jason Lowe Tony Lownds Ray Loyzaga Kang-Hao (Kenny) Lu +Louie Lu Lukas Lueg Loren Luke Fredrik Lundh diff --git a/Misc/NEWS.d/next/Library/2019-07-20-16-02-36.bpo-15987.sQtFns.rst b/Misc/NEWS.d/next/Library/2019-07-20-16-02-36.bpo-15987.sQtFns.rst new file mode 100644 index 00000000000000..3c6990fa01f7cb --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-07-20-16-02-36.bpo-15987.sQtFns.rst @@ -0,0 +1,2 @@ +Provide a way to compare AST nodes for equality recursively. Patch by Louie +Lu, Flavian Hautbois and Batuhan Taskaya. diff --git a/Parser/asdl_c.py b/Parser/asdl_c.py index bd22fb6bf73fe4..12255018985b5c 100755 --- a/Parser/asdl_c.py +++ b/Parser/asdl_c.py @@ -742,6 +742,88 @@ def visitModule(self, mod): return Py_BuildValue("O()", Py_TYPE(self)); } +static PyObject * +ast_richcompare(PyObject *self, PyObject *other, int op) +{ + Py_ssize_t i, numfields = 0; + PyObject *fields, *key = NULL; + + /* Check operator */ + if ((op != Py_EQ && op != Py_NE) || + !PyAST_Check(self) || !PyAST_Check(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + + /* Compare types */ + if (Py_TYPE(self) != Py_TYPE(other)) { + Py_RETURN_RICHCOMPARE(Py_TYPE(self), Py_TYPE(other), op); + } + + if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), astmodulestate_global->_fields, &fields) < 0) { + return NULL; + } + if (fields) { + numfields = PySequence_Size(fields); + if (numfields == -1) { + goto fail; + } + } + + PyObject *a, *b; + /* Compare fields */ + for (i = 0; i < numfields; i++) { + key = PySequence_GetItem(fields, i); + if (!key) { + goto fail; + } + if (!PyObject_HasAttr(self, key) || !PyObject_HasAttr(other, key)) { + Py_DECREF(key); + goto unsuccessful; + } + Py_DECREF(key); + + a = PyObject_GetAttr(self, key); + b = PyObject_GetAttr(other, key); + if (!a || !b) { + goto unsuccessful; + } + + /* Ensure they belong to the same type */ + if (Py_TYPE(a) != Py_TYPE(b)) { + goto unsuccessful; + } + + if (!PyObject_RichCompareBool(a, b, Py_EQ)) { + goto unsuccessful; + } + Py_DECREF(a); + Py_DECREF(b); + } + Py_DECREF(fields); + + if (op == Py_EQ) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } + + unsuccessful: + Py_XDECREF(a); + Py_XDECREF(b); + Py_DECREF(fields); + if (op == Py_EQ) { + Py_RETURN_FALSE; + } + else { + Py_RETURN_TRUE; + } + + fail: + Py_DECREF(fields); + return NULL; +} + static PyMemberDef ast_type_members[] = { {"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY}, {NULL} /* Sentinel */ @@ -770,6 +852,8 @@ def visitModule(self, mod): {Py_tp_alloc, PyType_GenericAlloc}, {Py_tp_new, PyType_GenericNew}, {Py_tp_free, PyObject_GC_Del}, + {Py_tp_richcompare, ast_richcompare}, + {Py_tp_hash, (hashfunc)_Py_HashPointer}, {0, 0}, }; diff --git a/Python/Python-ast.c b/Python/Python-ast.c index c7c7fda45d8519..44852038b78e7c 100644 --- a/Python/Python-ast.c +++ b/Python/Python-ast.c @@ -1177,6 +1177,88 @@ ast_type_reduce(PyObject *self, PyObject *unused) return Py_BuildValue("O()", Py_TYPE(self)); } +static PyObject * +ast_richcompare(PyObject *self, PyObject *other, int op) +{ + Py_ssize_t i, numfields = 0; + PyObject *fields, *key = NULL; + + /* Check operator */ + if ((op != Py_EQ && op != Py_NE) || + !PyAST_Check(self) || !PyAST_Check(other)) { + Py_RETURN_NOTIMPLEMENTED; + } + + /* Compare types */ + if (Py_TYPE(self) != Py_TYPE(other)) { + Py_RETURN_RICHCOMPARE(Py_TYPE(self), Py_TYPE(other), op); + } + + if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), astmodulestate_global->_fields, &fields) < 0) { + return NULL; + } + if (fields) { + numfields = PySequence_Size(fields); + if (numfields == -1) { + goto fail; + } + } + + PyObject *a, *b; + /* Compare fields */ + for (i = 0; i < numfields; i++) { + key = PySequence_GetItem(fields, i); + if (!key) { + goto fail; + } + if (!PyObject_HasAttr(self, key) || !PyObject_HasAttr(other, key)) { + Py_DECREF(key); + goto unsuccessful; + } + Py_DECREF(key); + + a = PyObject_GetAttr(self, key); + b = PyObject_GetAttr(other, key); + if (!a || !b) { + goto unsuccessful; + } + + /* Ensure they belong to the same type */ + if (Py_TYPE(a) != Py_TYPE(b)) { + goto unsuccessful; + } + + if (!PyObject_RichCompareBool(a, b, Py_EQ)) { + goto unsuccessful; + } + Py_DECREF(a); + Py_DECREF(b); + } + Py_DECREF(fields); + + if (op == Py_EQ) { + Py_RETURN_TRUE; + } + else { + Py_RETURN_FALSE; + } + + unsuccessful: + Py_XDECREF(a); + Py_XDECREF(b); + Py_DECREF(fields); + if (op == Py_EQ) { + Py_RETURN_FALSE; + } + else { + Py_RETURN_TRUE; + } + + fail: + Py_DECREF(fields); + return NULL; +} + static PyMemberDef ast_type_members[] = { {"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY}, {NULL} /* Sentinel */ @@ -1205,6 +1287,8 @@ static PyType_Slot AST_type_slots[] = { {Py_tp_alloc, PyType_GenericAlloc}, {Py_tp_new, PyType_GenericNew}, {Py_tp_free, PyObject_GC_Del}, + {Py_tp_richcompare, ast_richcompare}, + {Py_tp_hash, (hashfunc)_Py_HashPointer}, {0, 0}, };