From 3de67748093ea8ee74a42dd2e70c024e39c2a02d Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Wed, 18 Sep 2024 13:29:06 +0800 Subject: [PATCH 1/3] fix --- python/pyspark/sql/connect/expressions.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/expressions.py b/python/pyspark/sql/connect/expressions.py index 0b5512b61925c..85f1b3565c696 100644 --- a/python/pyspark/sql/connect/expressions.py +++ b/python/pyspark/sql/connect/expressions.py @@ -809,7 +809,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"WithField({self._structExpr}, {self._fieldName}, {self._valueExpr})" + return f"update_field({self._structExpr}, {self._fieldName}, {self._valueExpr})" class DropField(Expression): @@ -833,7 +833,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"DropField({self._structExpr}, {self._fieldName})" + return f"drop_field({self._structExpr}, {self._fieldName})" class UnresolvedExtractValue(Expression): @@ -857,7 +857,7 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression: return expr def __repr__(self) -> str: - return f"UnresolvedExtractValue({str(self._child)}, {str(self._extraction)})" + return f"{self._child}['{self._extraction}']" class UnresolvedRegex(Expression): From 75d6f0fc9648f99e9d683b893b541f9d6aa5226e Mon Sep 17 00:00:00 2001 From: Ruifeng Zheng Date: Mon, 7 Oct 2024 21:50:18 +0800 Subject: [PATCH 2/3] fix --- python/pyspark/sql/tests/test_column.py | 70 +++++++++++++++++++++++++ 1 file changed, 70 insertions(+) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 1972dd2804d98..68b0f9733b0d4 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -283,6 +283,76 @@ def test_expr_str_representation(self): when_cond = sf.when(expression, sf.lit(None)) self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") + def test_col_field_ops_representation(self): + c = sf.col("c") + + # getField + self.assertEqual(str(c.x), "Column<'c['x']'>") + self.assertEqual(str(c.x.y), "Column<'c['x']['y']'>") + self.assertEqual(str(c.x.y.z), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c["x"]), "Column<'c['x']'>") + self.assertEqual(str(c["x"]["y"]), "Column<'c['x']['y']'>") + self.assertEqual(str(c["x"]["y"]["z"]), "Column<'c['x']['y']['z']'>") + + self.assertEqual(str(c.getField("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getField("x").getField("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getField("x").getField("y").getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual(str(c.getItem("x")), "Column<'c['x']'>") + self.assertEqual( + str(c.getItem("x").getItem("y")), + "Column<'c['x']['y']'>", + ) + self.assertEqual( + str(c.getItem("x").getItem("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + + self.assertEqual( + str(c.x["y"].getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].getField("y").getItem("z")), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c.getField("x").getItem("y").z), + "Column<'c['x']['y']['z']'>", + ) + self.assertEqual( + str(c["x"].y.getField("z")), + "Column<'c['x']['y']['z']'>", + ) + + # WithField + self.assertEqual( + str(c.withField("x", sf.col("y"))), + "Column<'update_field(c, x, y)'>", + ) + self.assertEqual( + str(c.withField("x", sf.col("y")).withField("x", sf.col("z"))), + "Column<'update_field(update_field(c, x, y), x, z)'>", + ) + + # DropFields + self.assertEqual(str(c.dropFields("x")), "Column<'drop_field(c, x)'>") + self.assertEqual( + str(c.dropFields("x", "y")), + "Column<'drop_field(drop_field(c, x), y)'>", + ) + self.assertEqual( + str(c.dropFields("x", "y", "z")), + "Column<'drop_field(drop_field(drop_field(c, x), y), z)'>", + ) + def test_lit_time_representation(self): dt = datetime.date(2021, 3, 4) self.assertEqual(str(sf.lit(dt)), "Column<'2021-03-04'>") From 38d4c49b7d95c3b21e44f80cc06db9e836116b15 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 8 Oct 2024 09:37:07 +0900 Subject: [PATCH 3/3] Update python/pyspark/sql/tests/test_column.py --- python/pyspark/sql/tests/test_column.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/tests/test_column.py b/python/pyspark/sql/tests/test_column.py index 68b0f9733b0d4..5f1991973d27d 100644 --- a/python/pyspark/sql/tests/test_column.py +++ b/python/pyspark/sql/tests/test_column.py @@ -284,6 +284,7 @@ def test_expr_str_representation(self): self.assertEqual(str(when_cond), "Column<'CASE WHEN foo THEN NULL END'>") def test_col_field_ops_representation(self): + # SPARK-49894: Test string representation of columns c = sf.col("c") # getField