Skip to content

Commit 8c718a5

Browse files
committed
[SPARK-11899][SQL] API audit for GroupedDataset.
1. Renamed map to mapGroup, flatMap to flatMapGroup. 2. Renamed asKey -> keyAs. 3. Added more documentation. 4. Changed type parameter T to V on GroupedDataset. 5. Added since versions for all functions. Author: Reynold Xin <[email protected]> Closes #9880 from rxin/SPARK-11899. (cherry picked from commit ff442bb) Signed-off-by: Reynold Xin <[email protected]>
1 parent 0554718 commit 8c718a5

File tree

9 files changed

+131
-45
lines changed

9 files changed

+131
-45
lines changed

core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import java.util.Iterator;
2222

2323
/**
24-
* Base interface for a map function used in GroupedDataset's map function.
24+
* Base interface for a map function used in GroupedDataset's mapGroup function.
2525
*/
2626
public interface MapGroupFunction<K, V, R> extends Serializable {
2727
R call(K key, Iterator<V> values) throws Exception;

sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import org.apache.spark.sql.types._
3030
*
3131
* Encoders are not intended to be thread-safe and thus they are allow to avoid internal locking
3232
* and reuse internal buffers to improve performance.
33+
*
34+
* @since 1.6.0
3335
*/
3436
trait Encoder[T] extends Serializable {
3537

@@ -42,6 +44,8 @@ trait Encoder[T] extends Serializable {
4244

4345
/**
4446
* Methods for creating encoders.
47+
*
48+
* @since 1.6.0
4549
*/
4650
object Encoders {
4751

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.types._
2929
/**
3030
* Type-inference utilities for POJOs and Java collections.
3131
*/
32-
private [sql] object JavaTypeInference {
32+
object JavaTypeInference {
3333

3434
private val iterableType = TypeToken.of(classOf[JIterable[_]])
3535
private val mapType = TypeToken.of(classOf[JMap[_, _]])
@@ -53,7 +53,6 @@ private [sql] object JavaTypeInference {
5353
* @return (SQL data type, nullable)
5454
*/
5555
private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
56-
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
5756
typeToken.getRawType match {
5857
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
5958
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ private[sql] object Column {
4646
* @tparam T The input type expected for this expression. Can be `Any` if the expression is type
4747
* checked by the analyzer instead of the compiler (i.e. `expr("sum(...)")`).
4848
* @tparam U The output type of this column.
49+
*
50+
* @since 1.6.0
4951
*/
5052
class TypedColumn[-T, U](
5153
expr: Expression,

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ private[sql] object DataFrame {
110110
* @groupname action Actions
111111
* @since 1.3.0
112112
*/
113-
// TODO: Improve documentation.
114113
@Experimental
115114
class DataFrame private[sql](
116115
@transient val sqlContext: SQLContext,

sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala

Lines changed: 106 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Ou
2525
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute}
2626
import org.apache.spark.sql.catalyst.plans.logical._
2727
import org.apache.spark.sql.execution.QueryExecution
28+
import org.apache.spark.sql.expressions.Aggregator
2829

2930
/**
3031
* :: Experimental ::
@@ -36,11 +37,13 @@ import org.apache.spark.sql.execution.QueryExecution
3637
* making this change to the class hierarchy would break some function signatures. As such, this
3738
* class should be considered a preview of the final API. Changes will be made to the interface
3839
* after Spark 1.6.
40+
*
41+
* @since 1.6.0
3942
*/
4043
@Experimental
41-
class GroupedDataset[K, T] private[sql](
44+
class GroupedDataset[K, V] private[sql](
4245
kEncoder: Encoder[K],
43-
tEncoder: Encoder[T],
46+
tEncoder: Encoder[V],
4447
val queryExecution: QueryExecution,
4548
private val dataAttributes: Seq[Attribute],
4649
private val groupingAttributes: Seq[Attribute]) extends Serializable {
@@ -67,8 +70,10 @@ class GroupedDataset[K, T] private[sql](
6770
/**
6871
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
6972
* type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
73+
*
74+
* @since 1.6.0
7075
*/
71-
def asKey[L : Encoder]: GroupedDataset[L, T] =
76+
def keyAs[L : Encoder]: GroupedDataset[L, V] =
7277
new GroupedDataset(
7378
encoderFor[L],
7479
unresolvedTEncoder,
@@ -78,6 +83,8 @@ class GroupedDataset[K, T] private[sql](
7883

7984
/**
8085
* Returns a [[Dataset]] that contains each unique key.
86+
*
87+
* @since 1.6.0
8188
*/
8289
def keys: Dataset[K] = {
8390
new Dataset[K](
@@ -92,12 +99,18 @@ class GroupedDataset[K, T] private[sql](
9299
* function can return an iterator containing elements of an arbitrary type which will be returned
93100
* as a new [[Dataset]].
94101
*
102+
* This function does not support partial aggregation, and as a result requires shuffling all
103+
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
104+
* key, it is best to use the reduce function or an [[Aggregator]].
105+
*
95106
* Internally, the implementation will spill to disk if any given group is too large to fit into
96107
* memory. However, users must take care to avoid materializing the whole iterator for a group
97108
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
98109
* constraints of their cluster.
110+
*
111+
* @since 1.6.0
99112
*/
100-
def flatMap[U : Encoder](f: (K, Iterator[T]) => TraversableOnce[U]): Dataset[U] = {
113+
def flatMapGroup[U : Encoder](f: (K, Iterator[V]) => TraversableOnce[U]): Dataset[U] = {
101114
new Dataset[U](
102115
sqlContext,
103116
MapGroups(
@@ -108,41 +121,88 @@ class GroupedDataset[K, T] private[sql](
108121
logicalPlan))
109122
}
110123

111-
def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
112-
flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
124+
/**
125+
* Applies the given function to each group of data. For each unique group, the function will
126+
* be passed the group key and an iterator that contains all of the elements in the group. The
127+
* function can return an iterator containing elements of an arbitrary type which will be returned
128+
* as a new [[Dataset]].
129+
*
130+
* This function does not support partial aggregation, and as a result requires shuffling all
131+
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
132+
* key, it is best to use the reduce function or an [[Aggregator]].
133+
*
134+
* Internally, the implementation will spill to disk if any given group is too large to fit into
135+
* memory. However, users must take care to avoid materializing the whole iterator for a group
136+
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
137+
* constraints of their cluster.
138+
*
139+
* @since 1.6.0
140+
*/
141+
def flatMapGroup[U](f: FlatMapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
142+
flatMapGroup((key, data) => f.call(key, data.asJava).asScala)(encoder)
113143
}
114144

115145
/**
116146
* Applies the given function to each group of data. For each unique group, the function will
117147
* be passed the group key and an iterator that contains all of the elements in the group. The
118148
* function can return an element of arbitrary type which will be returned as a new [[Dataset]].
119149
*
150+
* This function does not support partial aggregation, and as a result requires shuffling all
151+
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
152+
* key, it is best to use the reduce function or an [[Aggregator]].
153+
*
120154
* Internally, the implementation will spill to disk if any given group is too large to fit into
121155
* memory. However, users must take care to avoid materializing the whole iterator for a group
122156
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
123157
* constraints of their cluster.
158+
*
159+
* @since 1.6.0
124160
*/
125-
def map[U : Encoder](f: (K, Iterator[T]) => U): Dataset[U] = {
126-
val func = (key: K, it: Iterator[T]) => Iterator(f(key, it))
127-
flatMap(func)
161+
def mapGroup[U : Encoder](f: (K, Iterator[V]) => U): Dataset[U] = {
162+
val func = (key: K, it: Iterator[V]) => Iterator(f(key, it))
163+
flatMapGroup(func)
128164
}
129165

130-
def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
131-
map((key, data) => f.call(key, data.asJava))(encoder)
166+
/**
167+
* Applies the given function to each group of data. For each unique group, the function will
168+
* be passed the group key and an iterator that contains all of the elements in the group. The
169+
* function can return an element of arbitrary type which will be returned as a new [[Dataset]].
170+
*
171+
* This function does not support partial aggregation, and as a result requires shuffling all
172+
* the data in the [[Dataset]]. If an application intends to perform an aggregation over each
173+
* key, it is best to use the reduce function or an [[Aggregator]].
174+
*
175+
* Internally, the implementation will spill to disk if any given group is too large to fit into
176+
* memory. However, users must take care to avoid materializing the whole iterator for a group
177+
* (for example, by calling `toList`) unless they are sure that this is possible given the memory
178+
* constraints of their cluster.
179+
*
180+
* @since 1.6.0
181+
*/
182+
def mapGroup[U](f: MapGroupFunction[K, V, U], encoder: Encoder[U]): Dataset[U] = {
183+
mapGroup((key, data) => f.call(key, data.asJava))(encoder)
132184
}
133185

134186
/**
135187
* Reduces the elements of each group of data using the specified binary function.
136188
* The given function must be commutative and associative or the result may be non-deterministic.
189+
*
190+
* @since 1.6.0
137191
*/
138-
def reduce(f: (T, T) => T): Dataset[(K, T)] = {
139-
val func = (key: K, it: Iterator[T]) => Iterator(key -> it.reduce(f))
192+
def reduce(f: (V, V) => V): Dataset[(K, V)] = {
193+
val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
140194

141195
implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedTEncoder)
142-
flatMap(func)
196+
flatMapGroup(func)
143197
}
144198

145-
def reduce(f: ReduceFunction[T]): Dataset[(K, T)] = {
199+
/**
200+
* Reduces the elements of each group of data using the specified binary function.
201+
* The given function must be commutative and associative or the result may be non-deterministic.
202+
*
203+
* @since 1.6.0
204+
*/
205+
def reduce(f: ReduceFunction[V]): Dataset[(K, V)] = {
146206
reduce(f.call _)
147207
}
148208

@@ -185,41 +245,51 @@ class GroupedDataset[K, T] private[sql](
185245
/**
186246
* Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
187247
* and the result of computing this aggregation over all elements in the group.
248+
*
249+
* @since 1.6.0
188250
*/
189-
def agg[U1](col1: TypedColumn[T, U1]): Dataset[(K, U1)] =
251+
def agg[U1](col1: TypedColumn[V, U1]): Dataset[(K, U1)] =
190252
aggUntyped(col1).asInstanceOf[Dataset[(K, U1)]]
191253

192254
/**
193255
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
194256
* and the result of computing these aggregations over all elements in the group.
257+
*
258+
* @since 1.6.0
195259
*/
196-
def agg[U1, U2](col1: TypedColumn[T, U1], col2: TypedColumn[T, U2]): Dataset[(K, U1, U2)] =
260+
def agg[U1, U2](col1: TypedColumn[V, U1], col2: TypedColumn[V, U2]): Dataset[(K, U1, U2)] =
197261
aggUntyped(col1, col2).asInstanceOf[Dataset[(K, U1, U2)]]
198262

199263
/**
200264
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
201265
* and the result of computing these aggregations over all elements in the group.
266+
*
267+
* @since 1.6.0
202268
*/
203269
def agg[U1, U2, U3](
204-
col1: TypedColumn[T, U1],
205-
col2: TypedColumn[T, U2],
206-
col3: TypedColumn[T, U3]): Dataset[(K, U1, U2, U3)] =
270+
col1: TypedColumn[V, U1],
271+
col2: TypedColumn[V, U2],
272+
col3: TypedColumn[V, U3]): Dataset[(K, U1, U2, U3)] =
207273
aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, U1, U2, U3)]]
208274

209275
/**
210276
* Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
211277
* and the result of computing these aggregations over all elements in the group.
278+
*
279+
* @since 1.6.0
212280
*/
213281
def agg[U1, U2, U3, U4](
214-
col1: TypedColumn[T, U1],
215-
col2: TypedColumn[T, U2],
216-
col3: TypedColumn[T, U3],
217-
col4: TypedColumn[T, U4]): Dataset[(K, U1, U2, U3, U4)] =
282+
col1: TypedColumn[V, U1],
283+
col2: TypedColumn[V, U2],
284+
col3: TypedColumn[V, U3],
285+
col4: TypedColumn[V, U4]): Dataset[(K, U1, U2, U3, U4)] =
218286
aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, U1, U2, U3, U4)]]
219287

220288
/**
221289
* Returns a [[Dataset]] that contains a tuple with each key and the number of items present
222290
* for that key.
291+
*
292+
* @since 1.6.0
223293
*/
224294
def count(): Dataset[(K, Long)] = agg(functions.count("*").as(ExpressionEncoder[Long]))
225295

@@ -228,10 +298,12 @@ class GroupedDataset[K, T] private[sql](
228298
* be passed the grouping key and 2 iterators containing all elements in the group from
229299
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
230300
* arbitrary type which will be returned as a new [[Dataset]].
301+
*
302+
* @since 1.6.0
231303
*/
232304
def cogroup[U, R : Encoder](
233305
other: GroupedDataset[K, U])(
234-
f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
306+
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
235307
implicit def uEnc: Encoder[U] = other.unresolvedTEncoder
236308
new Dataset[R](
237309
sqlContext,
@@ -243,9 +315,17 @@ class GroupedDataset[K, T] private[sql](
243315
other.logicalPlan))
244316
}
245317

318+
/**
319+
* Applies the given function to each cogrouped data. For each unique group, the function will
320+
* be passed the grouping key and 2 iterators containing all elements in the group from
321+
* [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
322+
* arbitrary type which will be returned as a new [[Dataset]].
323+
*
324+
* @since 1.6.0
325+
*/
246326
def cogroup[U, R](
247327
other: GroupedDataset[K, U],
248-
f: CoGroupFunction[K, T, U, R],
328+
f: CoGroupFunction[K, V, U, R],
249329
encoder: Encoder[R]): Dataset[R] = {
250330
cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder)
251331
}

sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ public Integer call(String v) throws Exception {
170170
}
171171
}, Encoders.INT());
172172

173-
Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() {
173+
Dataset<String> mapped = grouped.mapGroup(new MapGroupFunction<Integer, String, String>() {
174174
@Override
175175
public String call(Integer key, Iterator<String> values) throws Exception {
176176
StringBuilder sb = new StringBuilder(key.toString());
@@ -183,7 +183,7 @@ public String call(Integer key, Iterator<String> values) throws Exception {
183183

184184
Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
185185

186-
Dataset<String> flatMapped = grouped.flatMap(
186+
Dataset<String> flatMapped = grouped.flatMapGroup(
187187
new FlatMapGroupFunction<Integer, String, String>() {
188188
@Override
189189
public Iterable<String> call(Integer key, Iterator<String> values) throws Exception {
@@ -247,9 +247,9 @@ public void testGroupByColumn() {
247247
List<String> data = Arrays.asList("a", "foo", "bar");
248248
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
249249
GroupedDataset<Integer, String> grouped =
250-
ds.groupBy(length(col("value"))).asKey(Encoders.INT());
250+
ds.groupBy(length(col("value"))).keyAs(Encoders.INT());
251251

252-
Dataset<String> mapped = grouped.map(
252+
Dataset<String> mapped = grouped.mapGroup(
253253
new MapGroupFunction<Integer, String, String>() {
254254
@Override
255255
public String call(Integer key, Iterator<String> data) throws Exception {

sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
8686
test("groupBy function, map") {
8787
val ds = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11).toDS()
8888
val grouped = ds.groupBy(_ % 2)
89-
val agged = grouped.map { case (g, iter) =>
89+
val agged = grouped.mapGroup { case (g, iter) =>
9090
val name = if (g == 0) "even" else "odd"
9191
(name, iter.size)
9292
}
@@ -99,7 +99,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
9999
test("groupBy function, flatMap") {
100100
val ds = Seq("a", "b", "c", "xyz", "hello").toDS()
101101
val grouped = ds.groupBy(_.length)
102-
val agged = grouped.flatMap { case (g, iter) => Iterator(g.toString, iter.mkString) }
102+
val agged = grouped.flatMapGroup { case (g, iter) => Iterator(g.toString, iter.mkString) }
103103

104104
checkAnswer(
105105
agged,

0 commit comments

Comments
 (0)