diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 046139f5952d3..127c58656d366 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -60,7 +60,7 @@ Row, _parse_datatype_json_string, ) -from pyspark.sql.utils import get_active_spark_context +from pyspark.sql.utils import get_active_spark_context, to_list_column_style from pyspark.sql.pandas.conversion import PandasConversionMixin from pyspark.sql.pandas.map_ops import PandasMapOpsMixin @@ -2374,9 +2374,7 @@ def join( +-----+---+ """ - if on is not None and not isinstance(on, list): - on = [on] # type: ignore[assignment] - + on = to_list_column_style(on) if on is not None: if isinstance(on[0], str): on = self._jseq(cast(List[str], on)) @@ -2484,9 +2482,7 @@ def _joinAsOf( rightAsOfColumn = other[rightAsOfColumn] right_as_of_jcol = rightAsOfColumn._jc - if on is not None and not isinstance(on, list): - on = [on] # type: ignore[assignment] - + on = to_list_column_style(on) if on is not None: if isinstance(on[0], str): on = self._jseq(cast(List[str], on)) diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 527a51cc239ed..38084e005193f 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -1077,6 +1077,17 @@ def test_join_without_on(self): expected = [Row(a=0, b=0)] self.assertEqual(actual, expected) + def test_join_on_types(self): + df1 = self.spark.range(1).toDF("a").withColumn('key', lit(1)) + df2 = self.spark.range(1).toDF("b").withColumn('key', lit(1)) + expected_result = df1.join(df2, on='key') + + self.assertEqual(expected_result, df1.join(df2, on=['key'])) + self.assertEqual(expected_result, df1.join(df2, on=('key',))) + self.assertEqual(expected_result, df1.join(df2, on=col('key'))) + self.assertEqual(expected_result, df1.join(df2, on=[col('key')])) + self.assertEqual(expected_result, df1.join(df2, on=(col('key'),))) + # Regression test for invalid join methods when on is None, Spark-14761 def test_invalid_join_method(self): df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"]) diff --git a/python/pyspark/sql/tests/test_utils.py b/python/pyspark/sql/tests/test_utils.py index f735442d2b7c9..ee082b9779667 100644 --- a/python/pyspark/sql/tests/test_utils.py +++ b/python/pyspark/sql/tests/test_utils.py @@ -24,8 +24,11 @@ SparkUpgradeException, ) from pyspark.testing.sqlutils import ReusedSQLTestCase -from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime - +from pyspark.sql.functions import col, to_date, unix_timestamp, from_unixtime +from pyspark.sql.utils import ( + isinstance_iterable, + to_list_column_style, +) class UtilsTests(ReusedSQLTestCase): def test_capture_analysis_exception(self): @@ -75,6 +78,32 @@ def test_get_error_class_state(self): self.assertEquals(e.getErrorClass(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") self.assertEquals(e.getSqlState(), "42703") + def test_isinstance_iterable(self): + a_col = col("") + a_str = "" + + self.assertFalse(isinstance_iterable(a_col)) + self.assertTrue(isinstance_iterable([a_col])) + self.assertTrue(isinstance_iterable((a_col,))) + + self.assertFalse(isinstance_iterable(a_str)) + self.assertTrue(isinstance_iterable([a_str])) + self.assertTrue(isinstance_iterable((a_str,))) + + self.assertIsInstance(to_list_column_style(a_col), list) + self.assertIsInstance(to_list_column_style([a_col]), list) + self.assertIsInstance(to_list_column_style((a_col,)), list) + self.assertFalse(isinstance_iterable(to_list_column_style(a_col)[0])) + self.assertFalse(isinstance_iterable(to_list_column_style([a_col])[0])) + self.assertFalse(isinstance_iterable(to_list_column_style((a_col,))[0])) + + self.assertIsInstance(to_list_column_style(a_str), list) + self.assertIsInstance(to_list_column_style([a_str]), list) + self.assertIsInstance(to_list_column_style((a_str,)), list) + self.assertFalse(isinstance_iterable(to_list_column_style(a_str)[0])) + self.assertFalse(isinstance_iterable(to_list_column_style([a_str])[0])) + self.assertFalse(isinstance_iterable(to_list_column_style((a_str,))[0])) + if __name__ == "__main__": import unittest diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index 7ecfa65dcd13f..9a96aa4411d88 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -17,6 +17,7 @@ import functools import os from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING, cast, TypeVar, Union, Type +from collections.abc import Iterable from py4j.java_collections import JavaArray from py4j.java_gateway import ( @@ -38,6 +39,7 @@ UnknownException, SparkUpgradeException, PySparkNotImplementedError, + PySparkTypeError, ) from pyspark.errors.exceptions.captured import CapturedException # noqa: F401 from pyspark.find_spark_home import _find_spark_home @@ -282,3 +284,41 @@ def get_dataframe_class() -> Type["DataFrame"]: return ConnectDataFrame # type: ignore[return-value] else: return PySparkDataFrame + +def isinstance_iterable(obj: Any) -> bool: + """ + Check for iterability (with string-safe and Column-safe modifications). + """ + if isinstance(obj, str): + return False + try: + iter(obj) + return True + except (TypeError, PySparkTypeError): + return False + +def to_list_column_style( + cols: Optional[Union[ + str, + Any, + Iterable[Union[str, Any]] + ]] +) -> list[Type["Column"]]: + """ + Convert a range of objects to a list in a way that is consistent with + what would be expected for a column-like input field. + + None -> None + str -> [str] + container[Any] -> list[Any] + Any -> [Any] + + No type-checking is performed on contained objects. + """ + if cols is None: + return None + if isinstance_iterable(cols): + return list(cols) + if: + return [cols] +