Skip to content

Commit 934d14d

Browse files
vicennialhvanhovell
authored andcommitted
[SPARK-42133] Add basic Dataset API methods to Spark Connect Scala Client
### What changes were proposed in this pull request? Adds the following methods: - Dataset API methods - project - filter - limit - SparkSession - range (and its variations) This PR also introduces `Column` and `functions` to support the above changes. ### Why are the changes needed? Incremental development of Spark Connect Scala Client. ### Does this PR introduce _any_ user-facing change? Yes, users may now use the proposed API methods. Example: `val df = sparkSession.range(5).limit(3)` ### How was this patch tested? Unit tests + simple E2E test. Closes apache#39672 from vicennial/SPARK-42133. Authored-by: vicennial <[email protected]> Signed-off-by: Herman van Hovell <[email protected]>
1 parent cc1674d commit 934d14d

File tree

10 files changed

+595
-0
lines changed

10 files changed

+595
-0
lines changed
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql
18+
19+
import scala.collection.JavaConverters._
20+
21+
import org.apache.spark.connect.proto
22+
import org.apache.spark.sql.Column.fn
23+
import org.apache.spark.sql.connect.client.unsupported
24+
import org.apache.spark.sql.functions.lit
25+
26+
/**
27+
* A column that will be computed based on the data in a `DataFrame`.
28+
*
29+
* A new column can be constructed based on the input columns present in a DataFrame:
30+
*
31+
* {{{
32+
* df("columnName") // On a specific `df` DataFrame.
33+
* col("columnName") // A generic column not yet associated with a DataFrame.
34+
* col("columnName.field") // Extracting a struct field
35+
* col("`a.column.with.dots`") // Escape `.` in column names.
36+
* $"columnName" // Scala short hand for a named column.
37+
* }}}
38+
*
39+
* [[Column]] objects can be composed to form complex expressions:
40+
*
41+
* {{{
42+
* $"a" + 1
43+
* }}}
44+
*
45+
* @since 3.4.0
46+
*/
47+
class Column private[sql] (private[sql] val expr: proto.Expression) {
48+
49+
/**
50+
* Sum of this expression and another expression.
51+
* {{{
52+
* // Scala: The following selects the sum of a person's height and weight.
53+
* people.select( people("height") + people("weight") )
54+
*
55+
* // Java:
56+
* people.select( people.col("height").plus(people.col("weight")) );
57+
* }}}
58+
*
59+
* @group expr_ops
60+
* @since 3.4.0
61+
*/
62+
def +(other: Any): Column = fn("+", this, lit(other))
63+
64+
/**
65+
* Gives the column a name (alias).
66+
* {{{
67+
* // Renames colA to colB in select output.
68+
* df.select($"colA".name("colB"))
69+
* }}}
70+
*
71+
* If the current column has metadata associated with it, this metadata will be propagated to
72+
* the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
73+
* explicit metadata.
74+
*
75+
* @group expr_ops
76+
* @since 3.4.0
77+
*/
78+
def name(alias: String): Column = Column { builder =>
79+
builder.getAliasBuilder.addName(alias).setExpr(expr)
80+
}
81+
}
82+
83+
object Column {
84+
85+
def apply(name: String): Column = Column { builder =>
86+
name match {
87+
case "*" =>
88+
builder.getUnresolvedStarBuilder
89+
case _ if name.endsWith(".*") =>
90+
unsupported("* with prefix is not supported yet.")
91+
case _ =>
92+
builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
93+
}
94+
}
95+
96+
private[sql] def apply(f: proto.Expression.Builder => Unit): Column = {
97+
val builder = proto.Expression.newBuilder()
98+
f(builder)
99+
new Column(builder.build())
100+
}
101+
102+
private[sql] def fn(name: String, inputs: Column*): Column = Column { builder =>
103+
builder.getUnresolvedFunctionBuilder
104+
.setFunctionName(name)
105+
.addAllArguments(inputs.map(_.expr).asJava)
106+
}
107+
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,59 @@
1616
*/
1717
package org.apache.spark.sql
1818

19+
import scala.collection.JavaConverters._
20+
1921
import org.apache.spark.connect.proto
2022
import org.apache.spark.sql.connect.client.SparkResult
2123

2224
class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
25+
26+
/**
27+
* Selects a set of column based expressions.
28+
* {{{
29+
* ds.select($"colA", $"colB" + 1)
30+
* }}}
31+
*
32+
* @group untypedrel
33+
* @since 3.4.0
34+
*/
35+
@scala.annotation.varargs
36+
def select(cols: Column*): Dataset = session.newDataset { builder =>
37+
builder.getProjectBuilder
38+
.setInput(plan.getRoot)
39+
.addAllExpressions(cols.map(_.expr).asJava)
40+
}
41+
42+
/**
43+
* Filters rows using the given condition.
44+
* {{{
45+
* // The following are equivalent:
46+
* peopleDs.filter($"age" > 15)
47+
* peopleDs.where($"age" > 15)
48+
* }}}
49+
*
50+
* @group typedrel
51+
* @since 3.4.0
52+
*/
53+
def filter(condition: Column): Dataset = session.newDataset { builder =>
54+
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
55+
}
56+
57+
/**
58+
* Returns a new Dataset by taking the first `n` rows. The difference between this function and
59+
* `head` is that `head` is an action and returns an array (by triggering query execution) while
60+
* `limit` returns a new Dataset.
61+
*
62+
* @group typedrel
63+
* @since 3.4.0
64+
*/
65+
def limit(n: Int): Dataset = session.newDataset { builder =>
66+
builder.getLimitBuilder
67+
.setInput(plan.getRoot)
68+
.setLimit(n)
69+
}
70+
71+
private[sql] def analyze: proto.AnalyzePlanResponse = session.analyze(plan)
72+
2373
def collectResult(): SparkResult = session.execute(plan)
2474
}

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,64 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
5757
builder.setSql(proto.SQL.newBuilder().setQuery(query))
5858
}
5959

60+
/**
61+
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
62+
* range from 0 to `end` (exclusive) with step value 1.
63+
*
64+
* @since 3.4.0
65+
*/
66+
def range(end: Long): Dataset = range(0, end)
67+
68+
/**
69+
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
70+
* range from `start` to `end` (exclusive) with step value 1.
71+
*
72+
* @since 3.4.0
73+
*/
74+
def range(start: Long, end: Long): Dataset = {
75+
range(start, end, step = 1)
76+
}
77+
78+
/**
79+
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
80+
* range from `start` to `end` (exclusive) with a step value.
81+
*
82+
* @since 3.4.0
83+
*/
84+
def range(start: Long, end: Long, step: Long): Dataset = {
85+
range(start, end, step, None)
86+
}
87+
88+
/**
89+
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
90+
* range from `start` to `end` (exclusive) with a step value, with partition number specified.
91+
*
92+
* @since 3.4.0
93+
*/
94+
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset = {
95+
range(start, end, step, Option(numPartitions))
96+
}
97+
98+
private def range(start: Long, end: Long, step: Long, numPartitions: Option[Int]): Dataset = {
99+
newDataset { builder =>
100+
val rangeBuilder = builder.getRangeBuilder
101+
.setStart(start)
102+
.setEnd(end)
103+
.setStep(step)
104+
numPartitions.foreach(rangeBuilder.setNumPartitions)
105+
}
106+
}
107+
60108
private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = {
61109
val builder = proto.Relation.newBuilder()
62110
f(builder)
63111
val plan = proto.Plan.newBuilder().setRoot(builder).build()
64112
new Dataset(this, plan)
65113
}
66114

115+
private[sql] def analyze(plan: proto.Plan): proto.AnalyzePlanResponse =
116+
client.analyze(plan)
117+
67118
private[sql] def execute(plan: proto.Plan): SparkResult = {
68119
val value = client.execute(plan)
69120
val result = new SparkResult(value, allocator)

connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.language.existentials
2121

2222
import io.grpc.{ManagedChannel, ManagedChannelBuilder}
2323
import java.net.URI
24+
import java.util.UUID
2425

2526
import org.apache.spark.connect.proto
2627
import org.apache.spark.sql.connect.common.config.ConnectCommon
@@ -41,6 +42,11 @@ class SparkConnectClient(
4142
*/
4243
def userId: String = userContext.getUserId()
4344

45+
// Generate a unique session ID for this client. This UUID must be unique to allow
46+
// concurrent Spark sessions of the same user. If the channel is closed, creating
47+
// a new client will create a new session ID.
48+
private[client] val sessionId: String = UUID.randomUUID.toString
49+
4450
/**
4551
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
4652
* @return
@@ -58,6 +64,22 @@ class SparkConnectClient(
5864
stub.executePlan(request)
5965
}
6066

67+
/**
68+
* Builds a [[proto.AnalyzePlanRequest]] from `plan` and dispatched it to the Spark Connect
69+
* server.
70+
* @return
71+
* A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
72+
*/
73+
def analyze(plan: proto.Plan): proto.AnalyzePlanResponse = {
74+
val request = proto.AnalyzePlanRequest
75+
.newBuilder()
76+
.setPlan(plan)
77+
.setUserContext(userContext)
78+
.setClientId(sessionId)
79+
.build()
80+
analyze(request)
81+
}
82+
6183
/**
6284
* Shutdown the client's connection to the server.
6385
*/
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connect
18+
19+
package object client {
20+
21+
private[sql] def unsupported(): Nothing = {
22+
throw new UnsupportedOperationException
23+
}
24+
25+
private[sql] def unsupported(message: String): Nothing = {
26+
throw new UnsupportedOperationException(message)
27+
}
28+
}
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql
18+
19+
import java.math.{BigDecimal => JBigDecimal}
20+
import java.time.LocalDate
21+
22+
import com.google.protobuf.ByteString
23+
24+
import org.apache.spark.connect.proto
25+
import org.apache.spark.sql.connect.client.unsupported
26+
27+
/**
28+
* Commonly used functions available for DataFrame operations.
29+
*
30+
* @since 3.4.0
31+
*/
32+
// scalastyle:off
33+
object functions {
34+
// scalastyle:on
35+
36+
private def createLiteral(f: proto.Expression.Literal.Builder => Unit): Column = Column {
37+
builder =>
38+
val literalBuilder = proto.Expression.Literal.newBuilder()
39+
f(literalBuilder)
40+
builder.setLiteral(literalBuilder)
41+
}
42+
43+
private def createDecimalLiteral(precision: Int, scale: Int, value: String): Column =
44+
createLiteral { builder =>
45+
builder.getDecimalBuilder
46+
.setPrecision(precision)
47+
.setScale(scale)
48+
.setValue(value)
49+
}
50+
51+
/**
52+
* Creates a [[Column]] of literal value.
53+
*
54+
* The passed in object is returned directly if it is already a [[Column]]. If the object is a
55+
* Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created
56+
* to represent the literal value.
57+
*
58+
* @since 3.4.0
59+
*/
60+
def lit(literal: Any): Column = {
61+
literal match {
62+
case c: Column => c
63+
case s: Symbol => Column(s.name)
64+
case v: Boolean => createLiteral(_.setBoolean(v))
65+
case v: Byte => createLiteral(_.setByte(v))
66+
case v: Short => createLiteral(_.setShort(v))
67+
case v: Int => createLiteral(_.setInteger(v))
68+
case v: Long => createLiteral(_.setLong(v))
69+
case v: Float => createLiteral(_.setFloat(v))
70+
case v: Double => createLiteral(_.setDouble(v))
71+
case v: BigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString)
72+
case v: JBigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString)
73+
case v: String => createLiteral(_.setString(v))
74+
case v: Char => createLiteral(_.setString(v.toString))
75+
case v: Array[Char] => createLiteral(_.setString(String.valueOf(v)))
76+
case v: Array[Byte] => createLiteral(_.setBinary(ByteString.copyFrom(v)))
77+
case v: collection.mutable.WrappedArray[_] => lit(v.array)
78+
case v: LocalDate => createLiteral(_.setDate(v.toEpochDay.toInt))
79+
case null => unsupported("Null literals not supported yet.")
80+
case _ => unsupported(s"literal $literal not supported (yet).")
81+
}
82+
}
83+
}

0 commit comments

Comments
 (0)