Skip to content

Commit be88832

Browse files
ueshinHyukjinKwon
authored andcommitted
[SPARK-42612][CONNECT][PYTHON][TESTS] Enable more parity tests related to functions
### What changes were proposed in this pull request? Enables more parity tests related to `functions`. ### Why are the changes needed? There are still some more tests we should enable. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Modified/enabled related tests. Closes #40203 from ueshin/issues/SPARK-42612/tests. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]> (cherry picked from commit a9f20c1) Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 1f3d9a9 commit be88832

File tree

5 files changed

+36
-31
lines changed

5 files changed

+36
-31
lines changed

python/pyspark/sql/connect/functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,10 @@ def lit(col: Any) -> Column:
224224
if isinstance(col, Column):
225225
return col
226226
elif isinstance(col, list):
227+
if any(isinstance(c, Column) for c in col):
228+
raise PySparkValueError(
229+
error_class="COLUMN_IN_LIST", message_parameters={"func_name": "lit"}
230+
)
227231
return array(*[lit(c) for c in col])
228232
elif isinstance(col, np.ndarray) and col.ndim == 1:
229233
if _from_numpy_type(col.dtype) is None:

python/pyspark/sql/tests/connect/test_connect_plan.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -838,9 +838,6 @@ def test_list_to_literal(self):
838838
p = multi_type_lit.to_plan(None)
839839
self.assertIsNotNone(p)
840840

841-
lit_list_plan = lit([lit(10), lit("str")]).to_plan(None)
842-
self.assertIsNotNone(lit_list_plan)
843-
844841
def test_column_alias(self) -> None:
845842
# SPARK-40809: Support for Column Aliases
846843
col0 = col("a").alias("martin")

python/pyspark/sql/tests/connect/test_parity_functions.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,23 +38,13 @@ def test_function_parity(self):
3838
def test_input_file_name_reset_for_rdd(self):
3939
super().test_input_file_name_reset_for_rdd()
4040

41-
# TODO(SPARK-41901): Parity in String representation of Column
42-
@unittest.skip("Fails in Spark Connect, should enable.")
43-
def test_inverse_trig_functions(self):
44-
super().test_inverse_trig_functions()
45-
46-
# TODO(SPARK-41834): Implement SparkSession.conf
47-
@unittest.skip("Fails in Spark Connect, should enable.")
48-
def test_lit_list(self):
49-
super().test_lit_list()
50-
5141
def test_raise_error(self):
5242
self.check_raise_error(SparkConnectException)
5343

54-
# Comparing column type of connect and pyspark
55-
@unittest.skip("Fails in Spark Connect, should enable.")
5644
def test_sorting_functions_with_column(self):
57-
super().test_sorting_functions_with_column()
45+
from pyspark.sql.connect.column import Column
46+
47+
self.check_sorting_functions_with_column(Column)
5848

5949

6050
if __name__ == "__main__":

python/pyspark/sql/tests/test_functions.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -318,19 +318,29 @@ def test_math_functions(self):
318318
)
319319

320320
def test_inverse_trig_functions(self):
321-
from pyspark.sql import functions
321+
df = self.spark.createDataFrame([Row(a=i * 0.2, b=i * -0.2) for i in range(10)])
322322

323-
funs = [
324-
(functions.acosh, "ACOSH"),
325-
(functions.asinh, "ASINH"),
326-
(functions.atanh, "ATANH"),
327-
]
323+
def check(trig, inv, y_axis_symmetrical):
324+
SQLTestUtils.assert_close(
325+
[n * 0.2 for n in range(10)],
326+
df.select(inv(trig(df.a))).collect(),
327+
)
328+
if y_axis_symmetrical:
329+
SQLTestUtils.assert_close(
330+
[n * 0.2 for n in range(10)],
331+
df.select(inv(trig(df.b))).collect(),
332+
)
333+
else:
334+
SQLTestUtils.assert_close(
335+
[n * -0.2 for n in range(10)],
336+
df.select(inv(trig(df.b))).collect(),
337+
)
328338

329-
cols = ["a", functions.col("a")]
339+
from pyspark.sql import functions
330340

331-
for f, alias in funs:
332-
for c in cols:
333-
self.assertIn(f"{alias}(a)", repr(f(c)))
341+
check(functions.cosh, functions.acosh, y_axis_symmetrical=True)
342+
check(functions.sinh, functions.asinh, y_axis_symmetrical=False)
343+
check(functions.tanh, functions.atanh, y_axis_symmetrical=False)
334344

335345
def test_reciprocal_trig_functions(self):
336346
# SPARK-36683: Tests for reciprocal trig functions (SEC, CSC and COT)
@@ -578,9 +588,13 @@ def test_approxQuantile(self):
578588
self.assertRaises(TypeError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
579589

580590
def test_sorting_functions_with_column(self):
581-
from pyspark.sql import functions
582591
from pyspark.sql.column import Column
583592

593+
self.check_sorting_functions_with_column(Column)
594+
595+
def check_sorting_functions_with_column(self, tpe):
596+
from pyspark.sql import functions
597+
584598
funs = [
585599
functions.asc_nulls_first,
586600
functions.asc_nulls_last,
@@ -592,17 +606,17 @@ def test_sorting_functions_with_column(self):
592606
for fun in funs:
593607
for _expr in exprs:
594608
res = fun(_expr)
595-
self.assertIsInstance(res, Column)
609+
self.assertIsInstance(res, tpe)
596610
self.assertIn(f"""'x {fun.__name__.replace("_", " ").upper()}'""", str(res))
597611

598612
for _expr in exprs:
599613
res = functions.asc(_expr)
600-
self.assertIsInstance(res, Column)
614+
self.assertIsInstance(res, tpe)
601615
self.assertIn("""'x ASC NULLS FIRST'""", str(res))
602616

603617
for _expr in exprs:
604618
res = functions.desc(_expr)
605-
self.assertIsInstance(res, Column)
619+
self.assertIsInstance(res, tpe)
606620
self.assertIn("""'x DESC NULLS LAST'""", str(res))
607621

608622
def test_sort_with_nulls_order(self):

python/pyspark/testing/sqlutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def function(self, *functions):
251251
def assert_close(a, b):
252252
c = [j[0] for j in b]
253253
diff = [abs(v - c[k]) < 1e-6 if math.isfinite(v) else v == c[k] for k, v in enumerate(a)]
254-
return sum(diff) == len(a)
254+
assert sum(diff) == len(a), f"sum: {sum(diff)}, len: {len(a)}"
255255

256256

257257
class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils, PySparkErrorTestUtils):

0 commit comments

Comments
 (0)