|
24 | 24 | SparkUpgradeException, |
25 | 25 | ) |
26 | 26 | from pyspark.testing.sqlutils import ReusedSQLTestCase |
27 | | -from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime |
28 | | - |
| 27 | +from pyspark.sql.functions import col, to_date, unix_timestamp, from_unixtime |
| 28 | +from pyspark.sql.utils import ( |
| 29 | + isinstance_iterable, |
| 30 | + to_list_column_style, |
| 31 | +) |
29 | 32 |
|
30 | 33 | class UtilsTests(ReusedSQLTestCase): |
31 | 34 | def test_capture_analysis_exception(self): |
@@ -75,6 +78,32 @@ def test_get_error_class_state(self): |
75 | 78 | self.assertEquals(e.getErrorClass(), "UNRESOLVED_COLUMN.WITHOUT_SUGGESTION") |
76 | 79 | self.assertEquals(e.getSqlState(), "42703") |
77 | 80 |
|
| 81 | + def test_isinstance_iterable(self): |
| 82 | + a_col = col("") |
| 83 | + a_str = "" |
| 84 | + |
| 85 | + self.assertFalse(isinstance_iterable(a_col)) |
| 86 | + self.assertTrue(isinstance_iterable([a_col])) |
| 87 | + self.assertTrue(isinstance_iterable((a_col,))) |
| 88 | + |
| 89 | + self.assertFalse(isinstance_iterable(a_str)) |
| 90 | + self.assertTrue(isinstance_iterable([a_str])) |
| 91 | + self.assertTrue(isinstance_iterable((a_str,))) |
| 92 | + |
| 93 | + self.assertIsInstance(to_list_column_style(a_col), list) |
| 94 | + self.assertIsInstance(to_list_column_style([a_col]), list) |
| 95 | + self.assertIsInstance(to_list_column_style((a_col,)), list) |
| 96 | + self.assertFalse(isinstance_iterable(to_list_column_style(a_col)[0])) |
| 97 | + self.assertFalse(isinstance_iterable(to_list_column_style([a_col])[0])) |
| 98 | + self.assertFalse(isinstance_iterable(to_list_column_style((a_col,))[0])) |
| 99 | + |
| 100 | + self.assertIsInstance(to_list_column_style(a_str), list) |
| 101 | + self.assertIsInstance(to_list_column_style([a_str]), list) |
| 102 | + self.assertIsInstance(to_list_column_style((a_str,)), list) |
| 103 | + self.assertFalse(isinstance_iterable(to_list_column_style(a_str)[0])) |
| 104 | + self.assertFalse(isinstance_iterable(to_list_column_style([a_str])[0])) |
| 105 | + self.assertFalse(isinstance_iterable(to_list_column_style((a_str,))[0])) |
| 106 | + |
78 | 107 |
|
79 | 108 | if __name__ == "__main__": |
80 | 109 | import unittest |
|
0 commit comments