Skip to content

Commit 49abdc4

Browse files
viiryaHyukjinKwon
authored andcommitted
[SPARK-31186][PYSPARK][SQL][2.4] toPandas should not fail on duplicate column names
### What changes were proposed in this pull request? When `toPandas` API works on duplicate column names produced from operators like join, we see the error like: ``` ValueError: The truth value of a Series is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all(). ``` This patch fixes the error in `toPandas` API. This is the backport of original patch to branch-2.4. ### Why are the changes needed? To make `toPandas` work on dataframe with duplicate column names. ### Does this PR introduce any user-facing change? Yes. Previously calling `toPandas` API on a dataframe with duplicate column names will fail. After this patch, it will produce correct result. ### How was this patch tested? Unit test. Closes #28219 from viirya/SPARK-31186-2.4. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 775e958 commit 49abdc4

File tree

2 files changed

+52
-6
lines changed

2 files changed

+52
-6
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from itertools import imap as map
2828
from cgi import escape as html_escape
2929

30+
from collections import Counter
3031
import warnings
3132

3233
from pyspark import copy_func, since, _NoValue
@@ -2148,21 +2149,48 @@ def toPandas(self):
21482149

21492150
# Below is toPandas without Arrow optimization.
21502151
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
2152+
column_counter = Counter(self.columns)
2153+
2154+
dtype = [None] * len(self.schema)
2155+
for fieldIdx, field in enumerate(self.schema):
2156+
# For duplicate column name, we use `iloc` to access it.
2157+
if column_counter[field.name] > 1:
2158+
pandas_col = pdf.iloc[:, fieldIdx]
2159+
else:
2160+
pandas_col = pdf[field.name]
21512161

2152-
dtype = {}
2153-
for field in self.schema:
21542162
pandas_type = _to_corrected_pandas_type(field.dataType)
21552163
# SPARK-21766: if an integer field is nullable and has null values, it can be
21562164
# inferred by pandas as float column. Once we convert the column with NaN back
21572165
# to integer type e.g., np.int16, we will hit exception. So we use the inferred
21582166
# float type, not the corrected type from the schema in this case.
21592167
if pandas_type is not None and \
21602168
not(isinstance(field.dataType, IntegralType) and field.nullable and
2161-
pdf[field.name].isnull().any()):
2162-
dtype[field.name] = pandas_type
2169+
pandas_col.isnull().any()):
2170+
dtype[fieldIdx] = pandas_type
2171+
2172+
df = pd.DataFrame()
2173+
for index, t in enumerate(dtype):
2174+
column_name = self.schema[index].name
2175+
2176+
# For duplicate column name, we use `iloc` to access it.
2177+
if column_counter[column_name] > 1:
2178+
series = pdf.iloc[:, index]
2179+
else:
2180+
series = pdf[column_name]
2181+
2182+
if t is not None:
2183+
series = series.astype(t, copy=False)
2184+
2185+
# `insert` API makes copy of data, we only do it for Series of duplicate column names.
2186+
# `pdf.iloc[:, index] = pdf.iloc[:, index]...` doesn't always work because `iloc` could
2187+
# return a view or a copy depending by context.
2188+
if column_counter[column_name] > 1:
2189+
df.insert(index, column_name, series, allow_duplicates=True)
2190+
else:
2191+
df[column_name] = series
21632192

2164-
for f, t in dtype.items():
2165-
pdf[f] = pdf[f].astype(t, copy=False)
2193+
pdf = df
21662194

21672195
if timezone is None:
21682196
return pdf

python/pyspark/sql/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3296,6 +3296,24 @@ def test_to_pandas(self):
32963296
self.assertEquals(types[4], np.object) # datetime.date
32973297
self.assertEquals(types[5], 'datetime64[ns]')
32983298

3299+
@unittest.skipIf(not _have_pandas, _pandas_requirement_message)
3300+
def test_to_pandas_on_cross_join(self):
3301+
import numpy as np
3302+
3303+
sql = """
3304+
select t1.*, t2.* from (
3305+
select explode(sequence(1, 3)) v
3306+
) t1 left join (
3307+
select explode(sequence(1, 3)) v
3308+
) t2
3309+
"""
3310+
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
3311+
df = self.spark.sql(sql)
3312+
pdf = df.toPandas()
3313+
types = pdf.dtypes
3314+
self.assertEquals(types.iloc[0], np.int32)
3315+
self.assertEquals(types.iloc[1], np.int32)
3316+
32993317
@unittest.skipIf(_have_pandas, "Required Pandas was found.")
33003318
def test_to_pandas_required_pandas_not_found(self):
33013319
with QuietTest(self.sc):

0 commit comments

Comments
 (0)