Skip to content

Commit 84b245f

Browse files
koertkuipersrxin
authored andcommitted
[SPARK-15780][SQL] Support mapValues on KeyValueGroupedDataset
## What changes were proposed in this pull request? Add mapValues to KeyValueGroupedDataset ## How was this patch tested? New test in DatasetSuite for groupBy function, mapValues, flatMap Author: Koert Kuipers <[email protected]> Closes #13526 from koertkuipers/feat-keyvaluegroupeddataset-mapvalues.
1 parent fb0894b commit 84b245f

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,19 @@ object AppendColumns {
230230
encoderFor[U].namedExpressions,
231231
child)
232232
}
233+
234+
def apply[T : Encoder, U : Encoder](
235+
func: T => U,
236+
inputAttributes: Seq[Attribute],
237+
child: LogicalPlan): AppendColumns = {
238+
new AppendColumns(
239+
func.asInstanceOf[Any => Any],
240+
implicitly[Encoder[T]].clsTag.runtimeClass,
241+
implicitly[Encoder[T]].schema,
242+
UnresolvedDeserializer(encoderFor[T].deserializer, inputAttributes),
243+
encoderFor[U].namedExpressions,
244+
child)
245+
}
233246
}
234247

235248
/**

sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,48 @@ class KeyValueGroupedDataset[K, V] private[sql](
6666
dataAttributes,
6767
groupingAttributes)
6868

69+
/**
70+
* Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
71+
* to the data. The grouping key is unchanged by this.
72+
*
73+
* {{{
74+
* // Create values grouped by key from a Dataset[(K, V)]
75+
* ds.groupByKey(_._1).mapValues(_._2) // Scala
76+
* }}}
77+
*
78+
* @since 2.1.0
79+
*/
80+
def mapValues[W : Encoder](func: V => W): KeyValueGroupedDataset[K, W] = {
81+
val withNewData = AppendColumns(func, dataAttributes, logicalPlan)
82+
val projected = Project(withNewData.newColumns ++ groupingAttributes, withNewData)
83+
val executed = sparkSession.sessionState.executePlan(projected)
84+
85+
new KeyValueGroupedDataset(
86+
encoderFor[K],
87+
encoderFor[W],
88+
executed,
89+
withNewData.newColumns,
90+
groupingAttributes)
91+
}
92+
93+
/**
94+
* Returns a new [[KeyValueGroupedDataset]] where the given function `func` has been applied
95+
* to the data. The grouping key is unchanged by this.
96+
*
97+
* {{{
98+
* // Create Integer values grouped by String key from a Dataset<Tuple2<String, Integer>>
99+
* Dataset<Tuple2<String, Integer>> ds = ...;
100+
* KeyValueGroupedDataset<String, Integer> grouped =
101+
* ds.groupByKey(t -> t._1, Encoders.STRING()).mapValues(t -> t._2, Encoders.INT()); // Java 8
102+
* }}}
103+
*
104+
* @since 2.1.0
105+
*/
106+
def mapValues[W](func: MapFunction[V, W], encoder: Encoder[W]): KeyValueGroupedDataset[K, W] = {
107+
implicit val uEnc = encoder
108+
mapValues { (v: V) => func.call(v) }
109+
}
110+
69111
/**
70112
* Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping
71113
* over the Dataset to extract the keys and then running a distinct operation on those.

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
336336
"a", "30", "b", "3", "c", "1")
337337
}
338338

339+
test("groupBy function, mapValues, flatMap") {
340+
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
341+
val keyValue = ds.groupByKey(_._1).mapValues(_._2)
342+
val agged = keyValue.mapGroups { case (g, iter) => (g, iter.sum) }
343+
checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))
344+
345+
val keyValue1 = ds.groupByKey(t => (t._1, "key")).mapValues(t => (t._2, "value"))
346+
val agged1 = keyValue1.mapGroups { case (g, iter) => (g._1, iter.map(_._1).sum) }
347+
checkDataset(agged, ("a", 30), ("b", 3), ("c", 1))
348+
}
349+
339350
test("groupBy function, reduce") {
340351
val ds = Seq("abc", "xyz", "hello").toDS()
341352
val agged = ds.groupByKey(_.length).reduceGroups(_ + _)

0 commit comments

Comments
 (0)