Skip to content

Commit 331adbc

Browse files
Add testing and include PySparkType
1 parent 5d6dafe commit 331adbc

File tree

3 files changed

+43
-3
lines changed

3 files changed

+43
-3
lines changed

python/pyspark/sql/tests/test_dataframe.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,6 +1077,17 @@ def test_join_without_on(self):
10771077
expected = [Row(a=0, b=0)]
10781078
self.assertEqual(actual, expected)
10791079

1080+
def test_join_on_types(self):
1081+
df1 = self.spark.range(1).toDF("a").withColumn('key', lit(1))
1082+
df2 = self.spark.range(1).toDF("b").withColumn('key', lit(1))
1083+
expected_result = df1.join(df2, on='key')
1084+
1085+
self.assertEqual(expected_result, df1.join(df2, on=['key']))
1086+
self.assertEqual(expected_result, df1.join(df2, on=('key',)))
1087+
self.assertEqual(expected_result, df1.join(df2, on=col('key')))
1088+
self.assertEqual(expected_result, df1.join(df2, on=[col('key')]))
1089+
self.assertEqual(expected_result, df1.join(df2, on=(col('key'),)))
1090+
10801091
# Regression test for invalid join methods when on is None, Spark-14761
10811092
def test_invalid_join_method(self):
10821093
df1 = self.spark.createDataFrame([("Alice", 5), ("Bob", 8)], ["name", "age"])

python/pyspark/sql/tests/test_utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
SparkUpgradeException,
2525
)
2626
from pyspark.testing.sqlutils import ReusedSQLTestCase
27-
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
28-
27+
from pyspark.sql.functions import col, to_date, unix_timestamp, from_unixtime
28+
from pyspark.sql.utils import (
29+
isinstance_iterable,
30+
to_list_column_style,
31+
)
2932

3033
class UtilsTests(ReusedSQLTestCase):
3134
def test_capture_analysis_exception(self):
@@ -75,6 +78,32 @@ def test_get_error_class_state(self):
7578
self.assertEquals(e.getErrorClass(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION")
7679
self.assertEquals(e.getSqlState(), "42703")
7780

81+
def test_isinstance_iterable(self):
82+
a_col = col("")
83+
a_str = ""
84+
85+
self.assertFalse(isinstance_iterable(a_col))
86+
self.assertTrue(isinstance_iterable([a_col]))
87+
self.assertTrue(isinstance_iterable((a_col,)))
88+
89+
self.assertFalse(isinstance_iterable(a_str))
90+
self.assertTrue(isinstance_iterable([a_str]))
91+
self.assertTrue(isinstance_iterable((a_str,)))
92+
93+
self.assertIsInstance(to_list_column_style(a_col), list)
94+
self.assertIsInstance(to_list_column_style([a_col]), list)
95+
self.assertIsInstance(to_list_column_style((a_col,)), list)
96+
self.assertFalse(isinstance_iterable(to_list_column_style(a_col)[0]))
97+
self.assertFalse(isinstance_iterable(to_list_column_style([a_col])[0]))
98+
self.assertFalse(isinstance_iterable(to_list_column_style((a_col,))[0]))
99+
100+
self.assertIsInstance(to_list_column_style(a_str), list)
101+
self.assertIsInstance(to_list_column_style([a_str]), list)
102+
self.assertIsInstance(to_list_column_style((a_str,)), list)
103+
self.assertFalse(isinstance_iterable(to_list_column_style(a_str)[0]))
104+
self.assertFalse(isinstance_iterable(to_list_column_style([a_str])[0]))
105+
self.assertFalse(isinstance_iterable(to_list_column_style((a_str,))[0]))
106+
78107

79108
if __name__ == "__main__":
80109
import unittest

python/pyspark/sql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ def isinstance_iterable(obj: Any) -> bool:
294294
try:
295295
iter(obj)
296296
return True
297-
except TypeError, PySparkTypeError:
297+
except (TypeError, PySparkTypeError):
298298
return False
299299

300300
def to_list_column_style(

0 commit comments

Comments
 (0)