Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -524,27 +524,11 @@ class Dataset[T] private[sql] (
result(0)
}

/**
* (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
* key `func`.
*
* @group typedrel
* @since 3.5.0
*/
/** @inheritdoc */
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
KeyValueGroupedDatasetImpl[K, T](this, encoderFor[K], func)
}

/**
* (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
* key `func`.
*
* @group typedrel
* @since 3.5.0
*/
def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
groupByKey(ToScalaUDF(func))(encoder)

/** @inheritdoc */
@scala.annotation.varargs
def rollup(cols: Column*): RelationalGroupedDataset = {
Expand Down Expand Up @@ -1480,4 +1464,10 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
@scala.annotation.varargs
override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*)

/** @inheritdoc */
override def groupByKey[K](
func: MapFunction[T, K],
encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]
}
22 changes: 22 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,28 @@ abstract class Dataset[T] extends Serializable {
*/
def reduce(func: ReduceFunction[T]): T = reduce(ToScalaUDF(func))

/**
* (Scala-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
* key `func`.
*
* @group typedrel
* @since 2.0.0
*/
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T]

/**
* (Java-specific) Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given
* key `func`.
*
* @group typedrel
* @since 2.0.0
*/
def groupByKey[K](
func: MapFunction[T, K],
encoder: Encoder[K]): KeyValueGroupedDataset[K, T] = {
groupByKey(ToScalaUDF(func))(encoder)
}

/**
* Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns
* set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation,
Expand Down
68 changes: 10 additions & 58 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanRelation, FileTable}
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf, ToScalaUDF}
import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
import org.apache.spark.sql.streaming.DataStreamWriter
Expand Down Expand Up @@ -865,24 +865,7 @@ class Dataset[T] private[sql](
Filter(condition.expr, logicalPlan)
}

/**
* Groups the Dataset using the specified columns, so we can run aggregation on them. See
* [[RelationalGroupedDataset]] for all the available aggregate functions.
*
* {{{
* // Compute the average for all numeric columns grouped by department.
* ds.groupBy($"department").avg()
*
* // Compute the max age and average salary, grouped by department and gender.
* ds.groupBy($"department", $"gender").agg(Map(
* "salary" -> "avg",
* "age" -> "max"
* ))
* }}}
*
* @group untypedrel
* @since 2.0.0
*/
/** @inheritdoc */
@scala.annotation.varargs
def groupBy(cols: Column*): RelationalGroupedDataset = {
RelationalGroupedDataset(toDF(), cols.map(_.expr), RelationalGroupedDataset.GroupByType)
Expand Down Expand Up @@ -914,13 +897,7 @@ class Dataset[T] private[sql](
rdd.reduce(func)
}

/**
* (Scala-specific)
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
*
* @group typedrel
* @since 2.0.0
*/
/** @inheritdoc */
def groupByKey[K: Encoder](func: T => K): KeyValueGroupedDataset[K, T] = {
val withGroupingKey = AppendColumns(func, logicalPlan)
val executed = sparkSession.sessionState.executePlan(withGroupingKey)
Expand All @@ -933,16 +910,6 @@ class Dataset[T] private[sql](
withGroupingKey.newColumns)
}

/**
* (Java-specific)
* Returns a [[KeyValueGroupedDataset]] where the data is grouped by the given key `func`.
*
* @group typedrel
* @since 2.0.0
*/
def groupByKey[K](func: MapFunction[T, K], encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
groupByKey(ToScalaUDF(func))(encoder)

/** @inheritdoc */
def unpivot(
ids: Array[Column],
Expand Down Expand Up @@ -1640,28 +1607,7 @@ class Dataset[T] private[sql](
new DataFrameWriterV2Impl[T](table, this)
}

/**
* Merges a set of updates, insertions, and deletions based on a source table into
* a target table.
*
* Scala Examples:
* {{{
* spark.table("source")
* .mergeInto("target", $"source.id" === $"target.id")
* .whenMatched($"salary" === 100)
* .delete()
* .whenNotMatched()
* .insertAll()
* .whenNotMatchedBySource($"salary" === 100)
* .update(Map(
* "salary" -> lit(200)
* ))
* .merge()
* }}}
*
* @group basic
* @since 4.0.0
*/
/** @inheritdoc */
def mergeInto(table: String, condition: Column): MergeIntoWriter[T] = {
if (isStreaming) {
logicalPlan.failAnalysis(
Expand Down Expand Up @@ -2024,6 +1970,12 @@ class Dataset[T] private[sql](
@scala.annotation.varargs
override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*)

/** @inheritdoc */
override def groupByKey[K](
func: MapFunction[T, K],
encoder: Encoder[K]): KeyValueGroupedDataset[K, T] =
super.groupByKey(func, encoder).asInstanceOf[KeyValueGroupedDataset[K, T]]

////////////////////////////////////////////////////////////////////////////
// For Python API
////////////////////////////////////////////////////////////////////////////
Expand Down