Skip to content

Commit 8ca3418

Browse files
rxinmarmbrus
authored andcommitted
[SPARK-5904][SQL] DataFrame API fixes.
1. Column is no longer a DataFrame to simplify class hierarchy. 2. Don't use varargs on abstract methods (see Scala compiler bug SI-9013). Author: Reynold Xin <[email protected]> Closes apache#4686 from rxin/SPARK-5904 and squashes the following commits: fd9b199 [Reynold Xin] Fixed Python tests. df25cef [Reynold Xin] Non final. 5221530 [Reynold Xin] [SPARK-5904][SQL] DataFrame API fixes.
1 parent 94cdb05 commit 8ca3418

File tree

9 files changed

+427
-1024
lines changed

9 files changed

+427
-1024
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def first(self):
546546
def __getitem__(self, item):
547547
""" Return the column by given name
548548
549-
>>> df['age'].collect()
549+
>>> df.select(df['age']).collect()
550550
[Row(age=2), Row(age=5)]
551551
>>> df[ ["name", "age"]].collect()
552552
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
@@ -555,7 +555,7 @@ def __getitem__(self, item):
555555
"""
556556
if isinstance(item, basestring):
557557
jc = self._jdf.apply(item)
558-
return Column(jc, self.sql_ctx)
558+
return Column(jc)
559559
elif isinstance(item, Column):
560560
return self.filter(item)
561561
elif isinstance(item, list):
@@ -566,13 +566,13 @@ def __getitem__(self, item):
566566
def __getattr__(self, name):
567567
""" Return the column by given name
568568
569-
>>> df.age.collect()
569+
>>> df.select(df.age).collect()
570570
[Row(age=2), Row(age=5)]
571571
"""
572572
if name.startswith("__"):
573573
raise AttributeError(name)
574574
jc = self._jdf.apply(name)
575-
return Column(jc, self.sql_ctx)
575+
return Column(jc)
576576

577577
def select(self, *cols):
578578
""" Selecting a set of expressions.
@@ -698,7 +698,7 @@ def withColumnRenamed(self, existing, new):
698698
>>> df.withColumnRenamed('age', 'age2').collect()
699699
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
700700
"""
701-
cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
701+
cols = [Column(_to_java_column(c)).alias(new)
702702
if c == existing else c
703703
for c in self.columns]
704704
return self.select(*cols)
@@ -873,15 +873,16 @@ def _unary_op(name, doc="unary operator"):
873873
""" Create a method for given unary operator """
874874
def _(self):
875875
jc = getattr(self._jc, name)()
876-
return Column(jc, self.sql_ctx)
876+
return Column(jc)
877877
_.__doc__ = doc
878878
return _
879879

880880

881881
def _func_op(name, doc=''):
882882
def _(self):
883-
jc = getattr(self._sc._jvm.functions, name)(self._jc)
884-
return Column(jc, self.sql_ctx)
883+
sc = SparkContext._active_spark_context
884+
jc = getattr(sc._jvm.functions, name)(self._jc)
885+
return Column(jc)
885886
_.__doc__ = doc
886887
return _
887888

@@ -892,7 +893,7 @@ def _bin_op(name, doc="binary operator"):
892893
def _(self, other):
893894
jc = other._jc if isinstance(other, Column) else other
894895
njc = getattr(self._jc, name)(jc)
895-
return Column(njc, self.sql_ctx)
896+
return Column(njc)
896897
_.__doc__ = doc
897898
return _
898899

@@ -903,12 +904,12 @@ def _reverse_op(name, doc="binary operator"):
903904
def _(self, other):
904905
jother = _create_column_from_literal(other)
905906
jc = getattr(jother, name)(self._jc)
906-
return Column(jc, self.sql_ctx)
907+
return Column(jc)
907908
_.__doc__ = doc
908909
return _
909910

910911

911-
class Column(DataFrame):
912+
class Column(object):
912913

913914
"""
914915
A column in a DataFrame.
@@ -924,9 +925,8 @@ class Column(DataFrame):
924925
1 / df.colName
925926
"""
926927

927-
def __init__(self, jc, sql_ctx=None):
928+
def __init__(self, jc):
928929
self._jc = jc
929-
super(Column, self).__init__(jc, sql_ctx)
930930

931931
# arithmetic operators
932932
__neg__ = _func_op("negate")
@@ -975,7 +975,7 @@ def substr(self, startPos, length):
975975
:param startPos: start position (int or Column)
976976
:param length: length of the substring (int or Column)
977977
978-
>>> df.name.substr(1, 3).collect()
978+
>>> df.select(df.name.substr(1, 3).alias("col")).collect()
979979
[Row(col=u'Ali'), Row(col=u'Bob')]
980980
"""
981981
if type(startPos) != type(length):
@@ -986,7 +986,7 @@ def substr(self, startPos, length):
986986
jc = self._jc.substr(startPos._jc, length._jc)
987987
else:
988988
raise TypeError("Unexpected type: %s" % type(startPos))
989-
return Column(jc, self.sql_ctx)
989+
return Column(jc)
990990

991991
__getslice__ = substr
992992

@@ -1000,10 +1000,10 @@ def substr(self, startPos, length):
10001000
def alias(self, alias):
10011001
"""Return a alias for this column
10021002
1003-
>>> df.age.alias("age2").collect()
1003+
>>> df.select(df.age.alias("age2")).collect()
10041004
[Row(age2=2), Row(age2=5)]
10051005
"""
1006-
return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
1006+
return Column(getattr(self._jc, "as")(alias))
10071007

10081008
def cast(self, dataType):
10091009
""" Convert the column into type `dataType`
@@ -1013,34 +1013,18 @@ def cast(self, dataType):
10131013
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
10141014
[Row(ages=u'2'), Row(ages=u'5')]
10151015
"""
1016-
if self.sql_ctx is None:
1017-
sc = SparkContext._active_spark_context
1018-
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
1019-
else:
1020-
ssql_ctx = self.sql_ctx._ssql_ctx
10211016
if isinstance(dataType, basestring):
10221017
jc = self._jc.cast(dataType)
10231018
elif isinstance(dataType, DataType):
1019+
sc = SparkContext._active_spark_context
1020+
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
10241021
jdt = ssql_ctx.parseDataType(dataType.json())
10251022
jc = self._jc.cast(jdt)
1026-
return Column(jc, self.sql_ctx)
1023+
return Column(jc)
10271024

10281025
def __repr__(self):
10291026
return 'Column<%s>' % self._jdf.toString().encode('utf8')
10301027

1031-
def toPandas(self):
1032-
"""
1033-
Return a pandas.Series from the column
1034-
1035-
>>> df.age.toPandas() # doctest: +SKIP
1036-
0 2
1037-
1 5
1038-
dtype: int64
1039-
"""
1040-
import pandas as pd
1041-
data = [c for c, in self.collect()]
1042-
return pd.Series(data)
1043-
10441028

10451029
def _test():
10461030
import doctest

0 commit comments

Comments
 (0)