Skip to content

Commit da77adc

Browse files
committed
improve aggregate
1 parent e0ca98f commit da77adc

File tree

8 files changed

+227
-96
lines changed

8 files changed

+227
-96
lines changed

python/pyspark/rdd.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,11 +2330,9 @@ def _prepare_for_python_RDD(sc, command, obj=None):
23302330
return pickled_command, broadcast_vars, env, includes
23312331

23322332

2333-
def _wrap_function(sc, func, deserializer=None, serializer=None, profiler=None):
2334-
if deserializer is None:
2335-
deserializer = AutoBatchedSerializer(PickleSerializer())
2336-
if serializer is None:
2337-
serializer = AutoBatchedSerializer(PickleSerializer())
2333+
def _wrap_function(sc, func, deserializer, serializer, profiler=None):
2334+
assert deserializer, "deserializer should not be empty"
2335+
assert serializer, "serializer should not be empty"
23382336
command = (func, profiler, deserializer, serializer)
23392337
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
23402338
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,

python/pyspark/sql/dataframe.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
from pyspark import since
3030
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
31-
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
31+
from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, PickleSerializer, \
32+
UTF8Deserializer, PairDeserializer
3233
from pyspark.storagelevel import StorageLevel
3334
from pyspark.traceback_utils import SCCallSiteSync
3435
from pyspark.sql.types import _parse_datatype_json_string
@@ -236,9 +237,14 @@ def collect(self):
236237
>>> df.collect()
237238
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
238239
"""
240+
241+
if self._jdf.isPickled():
242+
deserializer = PickleSerializer()
243+
else:
244+
deserializer = BatchedSerializer(PickleSerializer())
239245
with SCCallSiteSync(self._sc) as css:
240246
port = self._jdf.collectToPython()
241-
return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
247+
return list(_load_from_socket(port, deserializer))
242248

243249
@ignore_unicode_prefix
244250
@since(1.3)
@@ -282,13 +288,16 @@ def map(self, f):
282288
@since(2.0)
283289
def applySchema(self, schema=None):
284290
""" TODO """
285-
if schema is None:
286-
from pyspark.sql.types import _infer_type, _merge_type
287-
# If no schema is specified, infer it from the whole data set.
288-
jrdd = self._prev_jdf.javaToPython()
289-
rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer()))
290-
schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type)
291-
return PipelinedDataFrame(self, output_schema=schema)
291+
if isinstance(self, PipelinedDataFrame):
292+
if schema is None:
293+
from pyspark.sql.types import _infer_type, _merge_type
294+
# If no schema is specified, infer it from the whole data set.
295+
jrdd = self._prev_jdf.javaToPython()
296+
rdd = RDD(jrdd, self._sc, BatchedSerializer(PickleSerializer()))
297+
schema = rdd.mapPartitions(self._func).map(_infer_type).reduce(_merge_type)
298+
return PipelinedDataFrame(self, output_schema=schema)
299+
else:
300+
return self
292301

293302
@ignore_unicode_prefix
294303
@since(2.0)
@@ -926,7 +935,7 @@ def groupByKey(self, key_func, key_type):
926935
wraped_func = _wrap_func(self._sc, self._jdf, f, False)
927936
jgd = self._jdf.pythonGroupBy(wraped_func, key_type.json())
928937
from pyspark.sql.group import GroupedData
929-
return GroupedData(jgd, self.sql_ctx, key_func)
938+
return GroupedData(jgd, self.sql_ctx, not isinstance(key_type, StructType))
930939

931940
@since(1.4)
932941
def rollup(self, *cols):
@@ -1396,6 +1405,7 @@ def __init__(self, prev, func=None, output_schema=None):
13961405
from pyspark.sql.group import GroupedData
13971406

13981407
if output_schema is None:
1408+
# should get it from java side
13991409
self._schema = StructType().add("binary", BinaryType(), False, {"pickled": True})
14001410
else:
14011411
self._schema = output_schema
@@ -1446,7 +1456,7 @@ def _jdf(self):
14461456
return self._jdf_val
14471457

14481458
def _create_jdf(self, func, schema=None):
1449-
wrapped_func = _wrap_func(self._sc, self._prev_jdf, func, schema is None)
1459+
wrapped_func = _wrap_func(self._sc, self._prev_jdf, func, schema is None, self._grouped)
14501460
if schema is None:
14511461
if self._grouped:
14521462
return self._prev_jdf.flatMapGroups(wrapped_func)
@@ -1460,16 +1470,18 @@ def _create_jdf(self, func, schema=None):
14601470
return self._prev_jdf.pythonMapPartitions(wrapped_func, schema_string)
14611471

14621472

1463-
def _wrap_func(sc, jdf, func, output_binary):
1464-
if jdf.isPickled():
1473+
def _wrap_func(sc, jdf, func, output_binary, input_grouped=False):
1474+
if input_grouped:
1475+
deserializer = PairDeserializer(PickleSerializer(), PickleSerializer())
1476+
elif jdf.isPickled():
14651477
deserializer = PickleSerializer()
14661478
else:
1467-
deserializer = None # the framework will provide a default one
1479+
deserializer = AutoBatchedSerializer(PickleSerializer())
14681480

14691481
if output_binary:
14701482
serializer = PickleSerializer()
14711483
else:
1472-
serializer = None # the framework will provide a default one
1484+
serializer = AutoBatchedSerializer(PickleSerializer())
14731485

14741486
from pyspark.rdd import _wrap_function
14751487
return _wrap_function(sc, lambda _, iterator: func(iterator), deserializer, serializer)

python/pyspark/sql/group.py

Lines changed: 99 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@
1515
# limitations under the License.
1616
#
1717

18+
import sys
19+
20+
if sys.version >= '3':
21+
basestring = unicode = str
22+
long = int
23+
from functools import reduce
24+
else:
25+
from itertools import imap as map
26+
1827
from pyspark import since
1928
from pyspark.rdd import ignore_unicode_prefix
2029
from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal
@@ -54,25 +63,25 @@ class GroupedData(object):
5463
.. versionadded:: 1.3
5564
"""
5665

57-
def __init__(self, jgd, sql_ctx, key_func=None):
66+
def __init__(self, jgd, sql_ctx, flat_key=False):
5867
self._jgd = jgd
5968
self.sql_ctx = sql_ctx
60-
if key_func is None:
61-
self.key_func = lambda key: key
69+
if flat_key:
70+
self._key_converter = lambda key: key[0]
6271
else:
63-
self.key_func = key_func
72+
self._key_converter = lambda key: key
6473

6574
@ignore_unicode_prefix
6675
@since(2.0)
6776
def flatMapGroups(self, func):
6877
""" TODO """
69-
import itertools
70-
key_func = self.key_func
78+
key_converter = self._key_converter
7179

72-
def process(iterator):
73-
first = iterator.next()
74-
key = key_func(first)
75-
return func(key, itertools.chain([first], iterator))
80+
def process(inputs):
81+
record_converter = lambda record: (key_converter(record[0]), record[1])
82+
for key, values in GroupedIterator(map(record_converter, inputs)):
83+
for output in func(key, values):
84+
yield output
7685

7786
return PipelinedDataFrame(self, process)
7887

@@ -217,6 +226,86 @@ def pivot(self, pivot_col, values=None):
217226
return GroupedData(jgd, self.sql_ctx)
218227

219228

229+
class GroupedIterator(object):
230+
""" TODO """
231+
232+
def __init__(self, inputs):
233+
self.inputs = BufferedIterator(inputs)
234+
self.current_input = inputs.next()
235+
self.current_key = self.current_input[0]
236+
self.current_values = GroupValuesIterator(self)
237+
238+
def __iter__(self):
239+
return self
240+
241+
def next(self):
242+
if self.current_values is None:
243+
self._fetch_next_group()
244+
245+
ret = (self.current_key, self.current_values)
246+
self.current_values = None
247+
return ret
248+
249+
def _fetch_next_group(self):
250+
if self.current_input is None:
251+
self.current_input = self.inputs.next()
252+
253+
# Skip to next group, or consume all inputs and throw StopIteration exception.
254+
while self.current_input[0] == self.current_key:
255+
self.current_input = self.inputs.next()
256+
257+
self.current_key = self.current_input[0]
258+
self.current_values = GroupValuesIterator(self)
259+
260+
261+
class GroupValuesIterator(object):
262+
""" TODO """
263+
264+
def __init__(self, outter):
265+
self.outter = outter
266+
267+
def __iter__(self):
268+
return self
269+
270+
def next(self):
271+
if self.outter.current_input is None:
272+
self._fetch_next_value()
273+
274+
value = self.outter.current_input[1]
275+
self.outter.current_input = None
276+
return value
277+
278+
def _fetch_next_value(self):
279+
if self.outter.inputs.head()[0] == self.outter.current_key:
280+
self.outter.current_input = self.outter.inputs.next()
281+
else:
282+
raise StopIteration
283+
284+
285+
class BufferedIterator(object):
286+
""" TODO """
287+
288+
def __init__(self, iterator):
289+
self.iterator = iterator
290+
self.buffered = None
291+
292+
def __iter__(self):
293+
return self
294+
295+
def next(self):
296+
if self.buffered is None:
297+
return self.iterator.next()
298+
else:
299+
item = self.buffered
300+
self.buffered = None
301+
return item
302+
303+
def head(self):
304+
if self.buffered is None:
305+
self.buffered = self.iterator.next()
306+
return self.buffered
307+
308+
220309
def _test():
221310
import doctest
222311
from pyspark.context import SparkContext
@@ -237,13 +326,6 @@ def _test():
237326
Row(course="dotNET", year=2013, earnings=48000),
238327
Row(course="Java", year=2013, earnings=30000)]).toDF()
239328

240-
ds = globs['sqlContext'].createDataFrame([(i, i) for i in range(100)], ("key", "value"))
241-
grouped = ds.groupByKey(lambda row: row.key % 5, IntegerType())
242-
value_sum = lambda rows: sum(map(lambda row: row.value, rows))
243-
agged = grouped.mapGroups(lambda key, values: str(key) + ":" + str(value_sum(values)))
244-
result = agged.applySchema(StringType()).collect()
245-
raise ValueError(result[0][0])
246-
247329
(failure_count, test_count) = doctest.testmod(
248330
pyspark.sql.group, globs=globs,
249331
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)

python/pyspark/sql/tests.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,11 @@ def test_dataset(self):
11881188
self.assertTrue(result[0][0] > 0)
11891189
self.assertTrue(result[1][0] > 0)
11901190

1191+
# If no schema is given, collect will return custom objects instead of rows.
1192+
result = ds2.collect()
1193+
self.assertEqual(result[0], 0)
1194+
self.assertEqual(result[1], 3)
1195+
11911196
# row count should be corrected even no schema is specified.
11921197
self.assertEqual(ds2.count(), 100)
11931198

@@ -1198,6 +1203,40 @@ def test_dataset(self):
11981203
self.assertEqual(result[0][0], 0)
11991204
self.assertEqual(result[1][0], 3)
12001205

1206+
def test_typed_aggregate(self):
1207+
data = [(i, i * 2) for i in range(100)]
1208+
ds = self.sqlCtx.createDataFrame(data, ("key", "value"))
1209+
sum_tuple = lambda values: sum(map(lambda value: value[0] * value[1], values))
1210+
1211+
def get_python_result(data, key_func, agg_func):
1212+
data.sort(key=key_func)
1213+
expected_result = []
1214+
import itertools
1215+
for key, values in itertools.groupby(data, key_func):
1216+
expected_result.append(agg_func(key, values))
1217+
return expected_result
1218+
1219+
grouped = ds.groupByKey(lambda row: row.key % 5, IntegerType())
1220+
agg_func = lambda key, values: str(key) + ":" + str(sum_tuple(values))
1221+
result = sorted(grouped.mapGroups(agg_func).collect())
1222+
expected_result = get_python_result(data, lambda i: i[0] % 5, agg_func)
1223+
self.assertEqual(result, expected_result)
1224+
1225+
# We can also call groupByKey on a Dataset of custom objects.
1226+
ds2 = ds.map2(lambda row: row.key)
1227+
grouped = ds2.groupByKey(lambda i: i % 5, IntegerType())
1228+
agg_func = lambda key, values: str(key) + ":" + str(sum(values))
1229+
result = sorted(grouped.mapGroups(agg_func).collect())
1230+
expected_result = get_python_result(range(100), lambda i: i % 5, agg_func)
1231+
self.assertEqual(result, expected_result)
1232+
1233+
# We can also apply typed aggregate after structured groupBy, the key is row object.
1234+
grouped = ds.groupBy(ds.key % 2, ds.key % 3)
1235+
agg_func = lambda key, values: str(key[0]) + str(key[1]) + ":" + str(sum_tuple(values))
1236+
result = sorted(grouped.mapGroups(agg_func).collect())
1237+
expected_result = get_python_result(data, lambda i: (i[0] % 2, i[0] % 3), agg_func)
1238+
self.assertEqual(result, expected_result)
1239+
12011240

12021241
class HiveContextSQLTests(ReusedPySparkTestCase):
12031242

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,6 @@ package object expressions {
8888
*/
8989
implicit class AttributeSeq(attrs: Seq[Attribute]) {
9090
/** Creates a StructType with a schema matching this `Seq[Attribute]`. */
91-
def toStructType: StructType = {
92-
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
93-
}
91+
def toStructType: StructType = StructType.fromAttributes(attrs)
9492
}
9593
}

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,9 +1750,13 @@ class DataFrame private[sql](
17501750
* Converts a JavaRDD to a PythonRDD.
17511751
*/
17521752
protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
1753-
val structType = schema // capture it for closure
1754-
val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType))
1755-
EvaluatePython.javaToPython(rdd)
1753+
if (EvaluatePython.isPickled(schema)) {
1754+
queryExecution.toRdd.map(_.getBinary(0))
1755+
} else {
1756+
val structType = schema // capture it for closure
1757+
val rdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType))
1758+
EvaluatePython.javaToPython(rdd)
1759+
}
17561760
}
17571761

17581762
protected[sql] def collectToPython(): Int = {

sql/core/src/main/scala/org/apache/spark/sql/GroupedPythonDataset.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ class GroupedPythonDataset private[sql](
3333

3434
private def sqlContext = queryExecution.sqlContext
3535

36-
protected[sql] def isPickled(): Boolean =
37-
EvaluatePython.isPickled(queryExecution.analyzed.output.toStructType)
36+
protected[sql] def isPickled(): Boolean = EvaluatePython.isPickled(dataAttributes.toStructType)
3837

3938
private def groupedData =
4039
new GroupedData(

0 commit comments

Comments
 (0)