diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 161a0d9d265f..accfff9f2b07 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -524,27 +524,11 @@ class Dataset[T] private[sql] ( result(0) } - /** - * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func) } - /** - * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given - * key `func`. - * - * @group typedrel - * @since 3.5.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ @scala.annotation.varargs def rollup(cols: Column*): RelationalGroupedDataset = { @@ -1480,4 +1464,10 @@ class Dataset[T] private[sql] ( /** @inheritdoc */ @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala index 284a69fe6ee3..6eef034aa515 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala @@ -1422,6 +1422,28 @@ abstract class Dataset[T] extends Serializable { */ def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func)) + /** + * (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] + + /** + * (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given + * key `func`. + * + * @group typedrel + * @since 2.0.0 + */ + def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = { + groupByKey(ToScalaUDF(func))(encoder) + } + /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns * set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 61f9e6ff7c04..ef628ca612b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable} import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions -import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF} +import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter @@ -865,24 +865,7 @@ class Dataset[T] private[sql]( Filter(condition.expr, logicalPlan) } - /** - * Groups the Dataset using the specified columns, so we can run aggregation on them. See - * [[RelationalGroupedDataset]] for all the available aggregate functions. - * - * {{{ - * // Compute the average for all numeric columns grouped by department. - * ds.groupBy($"department").avg() - * - * // Compute the max age and average salary, grouped by department and gender. - * ds.groupBy($"department", $"gender").agg(Map( - * "salary" -> "avg", - * "age" -> "max" - * )) - * }}} - * - * @group untypedrel - * @since 2.0.0 - */ + /** @inheritdoc */ @scala.annotation.varargs def groupBy(cols: Column*): RelationalGroupedDataset = { RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType) @@ -914,13 +897,7 @@ class Dataset[T] private[sql]( rdd.reduce(func) } - /** - * (Scala-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ + /** @inheritdoc */ def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = { val withGroupingKey = AppendColumns(func, logicalPlan) val executed = sparkSession.sessionState.executePlan(withGroupingKey) @@ -933,16 +910,6 @@ class Dataset[T] private[sql]( withGroupingKey.newColumns) } - /** - * (Java-specific) - * Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`. - * - * @group typedrel - * @since 2.0.0 - */ - def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = - groupByKey(ToScalaUDF(func))(encoder) - /** @inheritdoc */ def unpivot( ids: Array[Column], @@ -1640,28 +1607,7 @@ class Dataset[T] private[sql]( new DataFrameWriterV2Impl[T](table, this) } - /** - * Merges a set of updates, insertions, and deletions based on a source table into - * a target table. - * - * Scala Examples: - * {{{ - * spark.table("source") - * .mergeInto("target", $"source.id" === $"target.id") - * .whenMatched($"salary" === 100) - * .delete() - * .whenNotMatched() - * .insertAll() - * .whenNotMatchedBySource($"salary" === 100) - * .update(Map( - * "salary" -> lit(200) - * )) - * .merge() - * }}} - * - * @group basic - * @since 4.0.0 - */ + /** @inheritdoc */ def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( @@ -2024,6 +1970,12 @@ class Dataset[T] private[sql]( @scala.annotation.varargs override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*) + /** @inheritdoc */ + override def groupByKey[K]( + func: MapFunction[T, K], + encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = + super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]] + //////////////////////////////////////////////////////////////////////////// // For Python API ////////////////////////////////////////////////////////////////////////////