Skip to content

Commit 3a1004f

Browse files
Davies Liurxin
authored andcommitted
Dsl -> functions, toDF()
1 parent fb256af commit 3a1004f

File tree

7 files changed

+234
-155
lines changed

7 files changed

+234
-155
lines changed

python/docs/pyspark.sql.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,11 @@ pyspark.sql.types module
1616
:members:
1717
:undoc-members:
1818
:show-inheritance:
19+
20+
21+
pyspark.sql.functions module
22+
------------------------
23+
.. automodule:: pyspark.sql.functions
24+
:members:
25+
:undoc-members:
26+
:show-inheritance:

python/pyspark/sql/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@
3434

3535
from pyspark.sql.context import SQLContext, HiveContext
3636
from pyspark.sql.types import Row
37-
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD
37+
from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD
3838

3939
__all__ = [
4040
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
41-
'Dsl',
4241
]

python/pyspark/sql/context.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@
3838
__all__ = ["SQLContext", "HiveContext"]
3939

4040

41+
def _monkey_patch_RDD(sqlCtx):
42+
def toDF(self, schema=None, sampleRatio=None):
43+
"""
44+
Convert current :class:`RDD` into a :class:`DataFrame`
45+
46+
This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)`
47+
48+
:param schema: a StructType or list of names of columns
49+
:param samplingRatio: the sample ratio of rows used for inferring
50+
:return: a DataFrame
51+
52+
>>> rdd.toDF().collect()
53+
[Row(name=u'Alice', age=1)]
54+
"""
55+
return sqlCtx.createDataFrame(self, schema, sampleRatio)
56+
57+
RDD.toDF = toDF
58+
59+
4160
class SQLContext(object):
4261

4362
"""Main entry point for Spark SQL functionality.
@@ -70,6 +89,7 @@ def __init__(self, sparkContext, sqlContext=None):
7089
self._jsc = self._sc._jsc
7190
self._jvm = self._sc._jvm
7291
self._scala_SQLContext = sqlContext
92+
_monkey_patch_RDD(self)
7393

7494
@property
7595
def _ssql_ctx(self):
@@ -800,7 +820,8 @@ def _test():
800820
Row(field1=2, field2="row2"),
801821
Row(field1=3, field2="row3")]
802822
)
803-
globs['df'] = sqlCtx.createDataFrame(rdd)
823+
_monkey_patch_RDD(sqlCtx)
824+
globs['df'] = rdd.toDF()
804825
jsonStrings = [
805826
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
806827
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'

python/pyspark/sql/dataframe.py

Lines changed: 21 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,19 @@
2121
import random
2222
import os
2323
from tempfile import NamedTemporaryFile
24-
from itertools import imap
2524

2625
from py4j.java_collections import ListConverter, MapConverter
2726

2827
from pyspark.context import SparkContext
29-
from pyspark.rdd import RDD, _prepare_for_python_RDD
30-
from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
31-
UTF8Deserializer
28+
from pyspark.rdd import RDD
29+
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
3230
from pyspark.storagelevel import StorageLevel
3331
from pyspark.traceback_utils import SCCallSiteSync
3432
from pyspark.sql.types import *
3533
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
3634

3735

38-
__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
36+
__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"]
3937

4038

4139
class DataFrame(object):
@@ -310,8 +308,9 @@ def take(self, num):
310308
return self.limit(num).collect()
311309

312310
def map(self, f):
313-
""" Return a new RDD by applying a function to each Row, it's a
314-
shorthand for df.rdd.map()
311+
""" Return a new RDD by applying a function to each Row
312+
313+
It's a shorthand for df.rdd.map()
315314
316315
>>> df.map(lambda p: p.name).collect()
317316
[u'Alice', u'Bob']
@@ -586,8 +585,8 @@ def agg(self, *exprs):
586585
587586
>>> df.agg({"age": "max"}).collect()
588587
[Row(MAX(age#0)=5)]
589-
>>> from pyspark.sql import Dsl
590-
>>> df.agg(Dsl.min(df.age)).collect()
588+
>>> from pyspark.sql import functions as F
589+
>>> df.agg(F.min(df.age)).collect()
591590
[Row(MIN(age#0)=2)]
592591
"""
593592
return self.groupBy().agg(*exprs)
@@ -616,18 +615,18 @@ def subtract(self, other):
616615
"""
617616
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
618617

619-
def addColumn(self, colName, col):
618+
def withColumn(self, colName, col):
620619
""" Return a new :class:`DataFrame` by adding a column.
621620
622-
>>> df.addColumn('age2', df.age + 2).collect()
621+
>>> df.withColumn('age2', df.age + 2).collect()
623622
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
624623
"""
625624
return self.select('*', col.alias(colName))
626625

627-
def renameColumn(self, existing, new):
626+
def withColumnRenamed(self, existing, new):
628627
""" Rename an existing column to a new name
629628
630-
>>> df.renameColumn('age', 'age2').collect()
629+
>>> df.withColumnRenamed('age', 'age2').collect()
631630
[Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
632631
"""
633632
cols = [Column(_to_java_column(c), self.sql_ctx).alias(new)
@@ -689,8 +688,9 @@ def agg(self, *exprs):
689688
>>> gdf = df.groupBy(df.name)
690689
>>> gdf.agg({"age": "max"}).collect()
691690
[Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
692-
>>> from pyspark.sql import Dsl
693-
>>> gdf.agg(Dsl.min(df.age)).collect()
691+
692+
>>> from pyspark.sql import functions as F
693+
>>> gdf.agg(F.min(df.age)).collect()
694694
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
695695
"""
696696
assert exprs, "exprs should not be empty"
@@ -742,12 +742,12 @@ def sum(self):
742742

743743
def _create_column_from_literal(literal):
744744
sc = SparkContext._active_spark_context
745-
return sc._jvm.Dsl.lit(literal)
745+
return sc._jvm.functions.lit(literal)
746746

747747

748748
def _create_column_from_name(name):
749749
sc = SparkContext._active_spark_context
750-
return sc._jvm.Dsl.col(name)
750+
return sc._jvm.functions.col(name)
751751

752752

753753
def _to_java_column(col):
@@ -767,9 +767,9 @@ def _(self):
767767
return _
768768

769769

770-
def _dsl_op(name, doc=''):
770+
def _func_op(name, doc=''):
771771
def _(self):
772-
jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
772+
jc = getattr(self._sc._jvm.functions, name)(self._jc)
773773
return Column(jc, self.sql_ctx)
774774
_.__doc__ = doc
775775
return _
@@ -818,7 +818,7 @@ def __init__(self, jc, sql_ctx=None):
818818
super(Column, self).__init__(jc, sql_ctx)
819819

820820
# arithmetic operators
821-
__neg__ = _dsl_op("negate")
821+
__neg__ = _func_op("negate")
822822
__add__ = _bin_op("plus")
823823
__sub__ = _bin_op("minus")
824824
__mul__ = _bin_op("multiply")
@@ -842,7 +842,7 @@ def __init__(self, jc, sql_ctx=None):
842842
# so use bitwise operators as boolean operators
843843
__and__ = _bin_op('and')
844844
__or__ = _bin_op('or')
845-
__invert__ = _dsl_op('not')
845+
__invert__ = _func_op('not')
846846
__rand__ = _bin_op("and")
847847
__ror__ = _bin_op("or")
848848

@@ -934,123 +934,6 @@ def to_pandas(self):
934934
return pd.Series(data)
935935

936936

937-
def _aggregate_func(name, doc=""):
938-
""" Create a function for aggregator by name"""
939-
def _(col):
940-
sc = SparkContext._active_spark_context
941-
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
942-
return Column(jc)
943-
_.__name__ = name
944-
_.__doc__ = doc
945-
return staticmethod(_)
946-
947-
948-
class UserDefinedFunction(object):
949-
def __init__(self, func, returnType):
950-
self.func = func
951-
self.returnType = returnType
952-
self._broadcast = None
953-
self._judf = self._create_judf()
954-
955-
def _create_judf(self):
956-
f = self.func # put it in closure `func`
957-
func = lambda _, it: imap(lambda x: f(*x), it)
958-
ser = AutoBatchedSerializer(PickleSerializer())
959-
command = (func, None, ser, ser)
960-
sc = SparkContext._active_spark_context
961-
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
962-
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
963-
jdt = ssql_ctx.parseDataType(self.returnType.json())
964-
judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
965-
includes, sc.pythonExec, broadcast_vars,
966-
sc._javaAccumulator, jdt)
967-
return judf
968-
969-
def __del__(self):
970-
if self._broadcast is not None:
971-
self._broadcast.unpersist()
972-
self._broadcast = None
973-
974-
def __call__(self, *cols):
975-
sc = SparkContext._active_spark_context
976-
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
977-
sc._gateway._gateway_client)
978-
jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
979-
return Column(jc)
980-
981-
982-
class Dsl(object):
983-
"""
984-
A collections of builtin aggregators
985-
"""
986-
DSLS = {
987-
'lit': 'Creates a :class:`Column` of literal value.',
988-
'col': 'Returns a :class:`Column` based on the given column name.',
989-
'column': 'Returns a :class:`Column` based on the given column name.',
990-
'upper': 'Converts a string expression to upper case.',
991-
'lower': 'Converts a string expression to upper case.',
992-
'sqrt': 'Computes the square root of the specified float value.',
993-
'abs': 'Computes the absolutle value.',
994-
995-
'max': 'Aggregate function: returns the maximum value of the expression in a group.',
996-
'min': 'Aggregate function: returns the minimum value of the expression in a group.',
997-
'first': 'Aggregate function: returns the first value in a group.',
998-
'last': 'Aggregate function: returns the last value in a group.',
999-
'count': 'Aggregate function: returns the number of items in a group.',
1000-
'sum': 'Aggregate function: returns the sum of all values in the expression.',
1001-
'avg': 'Aggregate function: returns the average of the values in a group.',
1002-
'mean': 'Aggregate function: returns the average of the values in a group.',
1003-
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
1004-
}
1005-
1006-
for _name, _doc in DSLS.items():
1007-
locals()[_name] = _aggregate_func(_name, _doc)
1008-
del _name, _doc
1009-
1010-
@staticmethod
1011-
def countDistinct(col, *cols):
1012-
""" Return a new Column for distinct count of (col, *cols)
1013-
1014-
>>> from pyspark.sql import Dsl
1015-
>>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
1016-
[Row(c=2)]
1017-
1018-
>>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
1019-
[Row(c=2)]
1020-
"""
1021-
sc = SparkContext._active_spark_context
1022-
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
1023-
sc._gateway._gateway_client)
1024-
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
1025-
sc._jvm.PythonUtils.toSeq(jcols))
1026-
return Column(jc)
1027-
1028-
@staticmethod
1029-
def approxCountDistinct(col, rsd=None):
1030-
""" Return a new Column for approxiate distinct count of (col, *cols)
1031-
1032-
>>> from pyspark.sql import Dsl
1033-
>>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
1034-
[Row(c=2)]
1035-
"""
1036-
sc = SparkContext._active_spark_context
1037-
if rsd is None:
1038-
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
1039-
else:
1040-
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
1041-
return Column(jc)
1042-
1043-
@staticmethod
1044-
def udf(f, returnType=StringType()):
1045-
"""Create a user defined function (UDF)
1046-
1047-
>>> slen = Dsl.udf(lambda s: len(s), IntegerType())
1048-
>>> df.select(slen(df.name).alias('slen')).collect()
1049-
[Row(slen=5), Row(slen=3)]
1050-
"""
1051-
return UserDefinedFunction(f, returnType)
1052-
1053-
1054937
def _test():
1055938
import doctest
1056939
from pyspark.context import SparkContext

0 commit comments

Comments
 (0)