Skip to content

Commit af45902

Browse files
committed
[SPARK-49422][CONNECT][SQL] Add groupByKey to sql/api
### What changes were proposed in this pull request? This PR adds `Dataset.groupByKey(..)` to the shared interface. I forgot to add in the previous PR. ### Why are the changes needed? The shared interface needs to support all functionality. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#48147 from hvanhovell/SPARK-49422-follow-up. Authored-by: Herman van Hovell <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent 3b34891 commit af45902

File tree

3 files changed

+39
-75
lines changed

3 files changed

+39
-75
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -524,27 +524,11 @@ class Dataset[T] private[sql] (
524524
result(0)
525525
}
526526

527-
/**
528-
* (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
529-
* key `func`.
530-
*
531-
* @group typedrel
532-
* @since 3.5.0
533-
*/
527+
/** @inheritdoc */
534528
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
535529
KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func)
536530
}
537531

538-
/**
539-
* (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
540-
* key `func`.
541-
*
542-
* @group typedrel
543-
* @since 3.5.0
544-
*/
545-
def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
546-
groupByKey(ToScalaUDF(func))(encoder)
547-
548532
/** @inheritdoc */
549533
@scala.annotation.varargs
550534
def rollup(cols: Column*): RelationalGroupedDataset = {
@@ -1480,4 +1464,10 @@ class Dataset[T] private[sql] (
14801464
/** @inheritdoc */
14811465
@scala.annotation.varargs
14821466
override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*)
1467+
1468+
/** @inheritdoc */
1469+
override def groupByKey[K](
1470+
func: MapFunction[T, K],
1471+
encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
1472+
super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]
14831473
}

sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,28 @@ abstract class Dataset[T] extends Serializable {
14221422
*/
14231423
def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func))
14241424

1425+
/**
1426+
* (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
1427+
* key `func`.
1428+
*
1429+
* @group typedrel
1430+
* @since 2.0.0
1431+
*/
1432+
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T, DS]
1433+
1434+
/**
1435+
* (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
1436+
* key `func`.
1437+
*
1438+
* @group typedrel
1439+
* @since 2.0.0
1440+
*/
1441+
def groupByKey[K](
1442+
func: MapFunction[T, K],
1443+
encoder: Encoder[K]): KeyValueGroupedDataset[K, T, DS] = {
1444+
groupByKey(ToScalaUDF(func))(encoder)
1445+
}
1446+
14251447
/**
14261448
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
14271449
* set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 10 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
6262
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable}
6363
import org.apache.spark.sql.execution.python.EvaluatePython
6464
import org.apache.spark.sql.execution.stat.StatFunctions
65-
import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF}
65+
import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf}
6666
import org.apache.spark.sql.internal.ExpressionUtils.column
6767
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
6868
import org.apache.spark.sql.streaming.DataStreamWriter
@@ -865,24 +865,7 @@ class Dataset[T] private[sql](
865865
Filter(condition.expr, logicalPlan)
866866
}
867867

868-
/**
869-
* Groups the Dataset using the specified columns, so we can run aggregation on them. See
870-
* [[RelationalGroupedDataset]] for all the available aggregate functions.
871-
*
872-
* {{{
873-
* // Compute the average for all numeric columns grouped by department.
874-
* ds.groupBy($"department").avg()
875-
*
876-
* // Compute the max age and average salary, grouped by department and gender.
877-
* ds.groupBy($"department", $"gender").agg(Map(
878-
* "salary" -> "avg",
879-
* "age" -> "max"
880-
* ))
881-
* }}}
882-
*
883-
* @group untypedrel
884-
* @since 2.0.0
885-
*/
868+
/** @inheritdoc */
886869
@scala.annotation.varargs
887870
def groupBy(cols: Column*): RelationalGroupedDataset = {
888871
RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
@@ -914,13 +897,7 @@ class Dataset[T] private[sql](
914897
rdd.reduce(func)
915898
}
916899

917-
/**
918-
* (Scala-specific)
919-
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
920-
*
921-
* @group typedrel
922-
* @since 2.0.0
923-
*/
900+
/** @inheritdoc */
924901
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
925902
val withGroupingKey = AppendColumns(func, logicalPlan)
926903
val executed = sparkSession.sessionState.executePlan(withGroupingKey)
@@ -933,16 +910,6 @@ class Dataset[T] private[sql](
933910
withGroupingKey.newColumns)
934911
}
935912

936-
/**
937-
* (Java-specific)
938-
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
939-
*
940-
* @group typedrel
941-
* @since 2.0.0
942-
*/
943-
def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
944-
groupByKey(ToScalaUDF(func))(encoder)
945-
946913
/** @inheritdoc */
947914
def unpivot(
948915
ids: Array[Column],
@@ -1640,28 +1607,7 @@ class Dataset[T] private[sql](
16401607
new DataFrameWriterV2Impl[T](table, this)
16411608
}
16421609

1643-
/**
1644-
* Merges a set of updates, insertions, and deletions based on a source table into
1645-
* a target table.
1646-
*
1647-
* Scala Examples:
1648-
* {{{
1649-
* spark.table("source")
1650-
* .mergeInto("target", $"source.id" === $"target.id")
1651-
* .whenMatched($"salary" === 100)
1652-
* .delete()
1653-
* .whenNotMatched()
1654-
* .insertAll()
1655-
* .whenNotMatchedBySource($"salary" === 100)
1656-
* .update(Map(
1657-
* "salary" -> lit(200)
1658-
* ))
1659-
* .merge()
1660-
* }}}
1661-
*
1662-
* @group basic
1663-
* @since 4.0.0
1664-
*/
1610+
/** @inheritdoc */
16651611
def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = {
16661612
if (isStreaming) {
16671613
logicalPlan.failAnalysis(
@@ -2024,6 +1970,12 @@ class Dataset[T] private[sql](
20241970
@scala.annotation.varargs
20251971
override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*)
20261972

1973+
/** @inheritdoc */
1974+
override def groupByKey[K](
1975+
func: MapFunction[T, K],
1976+
encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
1977+
super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]
1978+
20271979
////////////////////////////////////////////////////////////////////////////
20281980
// For Python API
20291981
////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)