Skip to content

Commit 1150a19

Browse files
cloud-fanmarmbrus
authored andcommitted
[SPARK-8670] [SQL] Nested columns can't be referenced in pyspark
This bug is caused by a wrong column-exist-check in `__getitem__` of pyspark dataframe. `DataFrame.apply` accepts not only top level column names, but also nested column name like `a.b`, so we should remove that check from `__getitem__`. Author: Wenchen Fan <[email protected]> Closes #8202 from cloud-fan/nested.
1 parent 2a6590e commit 1150a19

File tree

3 files changed

+5
-3
lines changed

3 files changed

+5
-3
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -722,8 +722,6 @@ def __getitem__(self, item):
722722
[Row(age=5, name=u'Bob')]
723723
"""
724724
if isinstance(item, basestring):
725-
if item not in self.columns:
726-
raise IndexError("no such column: %s" % item)
727725
jc = self._jdf.apply(item)
728726
return Column(jc)
729727
elif isinstance(item, Column):

python/pyspark/sql/tests.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -770,7 +770,7 @@ def test_access_column(self):
770770
self.assertTrue(isinstance(df['key'], Column))
771771
self.assertTrue(isinstance(df[0], Column))
772772
self.assertRaises(IndexError, lambda: df[2])
773-
self.assertRaises(IndexError, lambda: df["bad_key"])
773+
self.assertRaises(AnalysisException, lambda: df["bad_key"])
774774
self.assertRaises(TypeError, lambda: df[{}])
775775

776776
def test_column_name_with_non_ascii(self):
@@ -794,7 +794,9 @@ def test_field_accessor(self):
794794
df = self.sc.parallelize([Row(l=[1], r=Row(a=1, b="b"), d={"k": "v"})]).toDF()
795795
self.assertEqual(1, df.select(df.l[0]).first()[0])
796796
self.assertEqual(1, df.select(df.r["a"]).first()[0])
797+
self.assertEqual(1, df.select(df["r.a"]).first()[0])
797798
self.assertEqual("b", df.select(df.r["b"]).first()[0])
799+
self.assertEqual("b", df.select(df["r.b"]).first()[0])
798800
self.assertEqual("v", df.select(df.d["k"]).first()[0])
799801

800802
def test_infer_long_type(self):

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,13 +634,15 @@ class DataFrame private[sql](
634634

635635
/**
636636
* Selects column based on the column name and return it as a [[Column]].
637+
* Note that the column name can also reference to a nested column like `a.b`.
637638
* @group dfops
638639
* @since 1.3.0
639640
*/
640641
def apply(colName: String): Column = col(colName)
641642

642643
/**
643644
* Selects column based on the column name and return it as a [[Column]].
645+
* Note that the column name can also reference to a nested column like `a.b`.
644646
* @group dfops
645647
* @since 1.3.0
646648
*/

0 commit comments

Comments
 (0)