Skip to content

Commit db8010b

Browse files
committed
[SPARK-49568][CONNECT][SQL] Remove self type from Dataset
### What changes were proposed in this pull request? This PR removes the self type parameter from Dataset. This turned out to be a bit noisy. The self type is replaced by a combination of covariant return types and abstract types. Abstract types are used when a method takes a Dataset (or a KeyValueGroupedDataset) as an argument. ### Why are the changes needed? The self type made using the classes in sql/api a bit noisy. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48146 from hvanhovell/SPARK-49568. Authored-by: Herman van Hovell <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent 5c48806 commit db8010b

33 files changed

+500
-364
lines changed

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,15 @@ import scala.jdk.CollectionConverters._
2222
import org.apache.spark.connect.proto.{NAReplace, Relation}
2323
import org.apache.spark.connect.proto.Expression.{Literal => GLiteral}
2424
import org.apache.spark.connect.proto.NAReplace.Replacement
25+
import org.apache.spark.sql.connect.ConnectConversions._
2526

2627
/**
2728
* Functionality for working with missing data in `DataFrame`s.
2829
*
2930
* @since 3.4.0
3031
*/
3132
final class DataFrameNaFunctions private[sql] (sparkSession: SparkSession, root: Relation)
32-
extends api.DataFrameNaFunctions[Dataset] {
33+
extends api.DataFrameNaFunctions {
3334
import sparkSession.RichColumn
3435

3536
override protected def drop(minNonNulls: Option[Int]): Dataset[Row] =

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import scala.jdk.CollectionConverters._
2323

2424
import org.apache.spark.annotation.Stable
2525
import org.apache.spark.connect.proto.Parse.ParseFormat
26+
import org.apache.spark.sql.connect.ConnectConversions._
2627
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
2728
import org.apache.spark.sql.types.StructType
2829

@@ -33,8 +34,8 @@ import org.apache.spark.sql.types.StructType
3334
* @since 3.4.0
3435
*/
3536
@Stable
36-
class DataFrameReader private[sql] (sparkSession: SparkSession)
37-
extends api.DataFrameReader[Dataset] {
37+
class DataFrameReader private[sql] (sparkSession: SparkSession) extends api.DataFrameReader {
38+
type DS[U] = Dataset[U]
3839

3940
/** @inheritdoc */
4041
override def format(source: String): this.type = super.format(source)

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.{lang => jl, util => ju}
2222
import org.apache.spark.connect.proto.{Relation, StatSampleBy}
2323
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
2424
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, PrimitiveDoubleEncoder}
25+
import org.apache.spark.sql.connect.ConnectConversions._
2526
import org.apache.spark.sql.functions.lit
2627

2728
/**
@@ -30,7 +31,7 @@ import org.apache.spark.sql.functions.lit
3031
* @since 3.4.0
3132
*/
3233
final class DataFrameStatFunctions private[sql] (protected val df: DataFrame)
33-
extends api.DataFrameStatFunctions[Dataset] {
34+
extends api.DataFrameStatFunctions {
3435
private def root: Relation = df.plan.getRoot
3536
private val sparkSession: SparkSession = df.sparkSession
3637

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
3232
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
3333
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
3434
import org.apache.spark.sql.catalyst.expressions.OrderUtils
35+
import org.apache.spark.sql.connect.ConnectConversions._
3536
import org.apache.spark.sql.connect.client.SparkResult
3637
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, StorageLevelProtoConverter}
3738
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
@@ -134,8 +135,8 @@ class Dataset[T] private[sql] (
134135
val sparkSession: SparkSession,
135136
@DeveloperApi val plan: proto.Plan,
136137
val encoder: Encoder[T])
137-
extends api.Dataset[T, Dataset] {
138-
type RGD = RelationalGroupedDataset
138+
extends api.Dataset[T] {
139+
type DS[U] = Dataset[U]
139140

140141
import sparkSession.RichColumn
141142

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.apache.spark.api.java.function._
2626
import org.apache.spark.connect.proto
2727
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
2828
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
29+
import org.apache.spark.sql.connect.ConnectConversions._
2930
import org.apache.spark.sql.connect.common.UdfUtils
3031
import org.apache.spark.sql.expressions.SparkUserDefinedFunction
3132
import org.apache.spark.sql.functions.col
@@ -40,8 +41,7 @@ import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode
4041
*
4142
* @since 3.5.0
4243
*/
43-
class KeyValueGroupedDataset[K, V] private[sql] ()
44-
extends api.KeyValueGroupedDataset[K, V, Dataset] {
44+
class KeyValueGroupedDataset[K, V] private[sql] () extends api.KeyValueGroupedDataset[K, V] {
4545
type KVDS[KY, VL] = KeyValueGroupedDataset[KY, VL]
4646

4747
private def unsupported(): Nothing = throw new UnsupportedOperationException()

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql
2020
import scala.jdk.CollectionConverters._
2121

2222
import org.apache.spark.connect.proto
23+
import org.apache.spark.sql.connect.ConnectConversions._
2324

2425
/**
2526
* A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]],
@@ -39,8 +40,7 @@ class RelationalGroupedDataset private[sql] (
3940
groupType: proto.Aggregate.GroupType,
4041
pivot: Option[proto.Aggregate.Pivot] = None,
4142
groupingSets: Option[Seq[proto.Aggregate.GroupingSets]] = None)
42-
extends api.RelationalGroupedDataset[Dataset] {
43-
type RGD = RelationalGroupedDataset
43+
extends api.RelationalGroupedDataset {
4444
import df.sparkSession.RichColumn
4545

4646
protected def toDF(aggExprs: Seq[Column]): DataFrame = {

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ import org.apache.spark.util.ArrayImplicits._
6969
class SparkSession private[sql] (
7070
private[sql] val client: SparkConnectClient,
7171
private val planIdGenerator: AtomicLong)
72-
extends api.SparkSession[Dataset]
72+
extends api.SparkSession
7373
with Logging {
7474

7575
private[this] val allocator = new RootAllocator()

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalog
2020
import java.util
2121

2222
import org.apache.spark.sql.{api, DataFrame, Dataset}
23+
import org.apache.spark.sql.connect.ConnectConversions._
2324
import org.apache.spark.sql.types.StructType
2425

2526
/** @inheritdoc */
26-
abstract class Catalog extends api.Catalog[Dataset] {
27+
abstract class Catalog extends api.Catalog {
2728

2829
/** @inheritdoc */
2930
override def listDatabases(): Dataset[Database]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect
18+
19+
import scala.language.implicitConversions
20+
21+
import org.apache.spark.annotation.DeveloperApi
22+
import org.apache.spark.sql._
23+
24+
/**
25+
* Conversions from sql interfaces to the Connect specific implementation.
26+
*
27+
* This class is mainly used by the implementation. In the case of connect it should be extremely
28+
* rare that a developer needs these classes.
29+
*
30+
* We provide both a trait and an object. The trait is useful in situations where an extension
31+
* developer needs to use these conversions in a project covering multiple Spark versions. They
32+
* can create a shim for these conversions, the Spark 4+ version of the shim implements this
33+
* trait, and shims for older versions do not.
34+
*/
35+
@DeveloperApi
36+
trait ConnectConversions {
37+
implicit def castToImpl(session: api.SparkSession): SparkSession =
38+
session.asInstanceOf[SparkSession]
39+
40+
implicit def castToImpl[T](ds: api.Dataset[T]): Dataset[T] =
41+
ds.asInstanceOf[Dataset[T]]
42+
43+
implicit def castToImpl(rgds: api.RelationalGroupedDataset): RelationalGroupedDataset =
44+
rgds.asInstanceOf[RelationalGroupedDataset]
45+
46+
implicit def castToImpl[K, V](
47+
kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] =
48+
kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]
49+
}
50+
51+
object ConnectConversions extends ConnectConversions

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ import org.apache.spark.connect.proto.ExecutePlanResponse
2626
import org.apache.spark.connect.proto.StreamingQueryCommand
2727
import org.apache.spark.connect.proto.StreamingQueryCommandResult
2828
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
29-
import org.apache.spark.sql.{api, Dataset, SparkSession}
29+
import org.apache.spark.sql.{api, SparkSession}
3030

3131
/** @inheritdoc */
32-
trait StreamingQuery extends api.StreamingQuery[Dataset] {
32+
trait StreamingQuery extends api.StreamingQuery {
3333

3434
/** @inheritdoc */
3535
override def sparkSession: SparkSession

0 commit comments

Comments
 (0)