Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.spark.connect.proto
import org.apache.spark.sql.Column.fn
import org.apache.spark.sql.connect.client.unsupported
import org.apache.spark.sql.functions.lit

/**
* A column that will be computed based on the data in a `DataFrame`.
*
* A new column can be constructed based on the input columns present in a DataFrame:
*
* {{{
* df("columnName") // On a specific `df` DataFrame.
* col("columnName") // A generic column not yet associated with a DataFrame.
* col("columnName.field") // Extracting a struct field
* col("`a.column.with.dots`") // Escape `.` in column names.
* $"columnName" // Scala short hand for a named column.
* }}}
*
* [[Column]] objects can be composed to form complex expressions:
*
* {{{
* $"a" + 1
* }}}
*
* @since 3.4.0
*/
class Column private[sql] (private[sql] val expr: proto.Expression) {

/**
* Sum of this expression and another expression.
* {{{
* // Scala: The following selects the sum of a person's height and weight.
* people.select( people("height") + people("weight") )
*
* // Java:
* people.select( people.col("height").plus(people.col("weight")) );
* }}}
*
* @group expr_ops
* @since 3.4.0
*/
def +(other: Any): Column = fn("+", this, lit(other))

/**
* Gives the column a name (alias).
* {{{
* // Renames colA to colB in select output.
* df.select($"colA".name("colB"))
* }}}
*
* If the current column has metadata associated with it, this metadata will be propagated to
* the new column. If this not desired, use the API `as(alias: String, metadata: Metadata)` with
* explicit metadata.
*
* @group expr_ops
* @since 3.4.0
*/
def name(alias: String): Column = Column { builder =>
builder.getAliasBuilder.addName(alias).setExpr(expr)
}
}

object Column {

def apply(name: String): Column = Column { builder =>
name match {
case "*" =>
builder.getUnresolvedStarBuilder
case _ if name.endsWith(".*") =>
unsupported("* with prefix is not supported yet.")
case _ =>
builder.getUnresolvedAttributeBuilder.setUnparsedIdentifier(name)
}
}

private[sql] def apply(f: proto.Expression.Builder => Unit): Column = {
val builder = proto.Expression.newBuilder()
f(builder)
new Column(builder.build())
}

private[sql] def fn(name: String, inputs: Column*): Column = Column { builder =>
builder.getUnresolvedFunctionBuilder
.setFunctionName(name)
.addAllArguments(inputs.map(_.expr).asJava)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,59 @@
*/
package org.apache.spark.sql

import scala.collection.JavaConverters._

import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.client.SparkResult

class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {

/**
* Selects a set of column based expressions.
* {{{
* ds.select($"colA", $"colB" + 1)
* }}}
*
* @group untypedrel
* @since 3.4.0
*/
@scala.annotation.varargs
def select(cols: Column*): Dataset = session.newDataset { builder =>
builder.getProjectBuilder
.setInput(plan.getRoot)
.addAllExpressions(cols.map(_.expr).asJava)
}

/**
* Filters rows using the given condition.
* {{{
* // The following are equivalent:
* peopleDs.filter($"age" > 15)
* peopleDs.where($"age" > 15)
* }}}
*
* @group typedrel
* @since 3.4.0
*/
def filter(condition: Column): Dataset = session.newDataset { builder =>
builder.getFilterBuilder.setInput(plan.getRoot).setCondition(condition.expr)
}

/**
* Returns a new Dataset by taking the first `n` rows. The difference between this function and
* `head` is that `head` is an action and returns an array (by triggering query execution) while
* `limit` returns a new Dataset.
*
* @group typedrel
* @since 3.4.0
*/
def limit(n: Int): Dataset = session.newDataset { builder =>
builder.getLimitBuilder
.setInput(plan.getRoot)
.setLimit(n)
}

private[sql] def analyze: proto.AnalyzePlanResponse = session.analyze(plan)

def collectResult(): SparkResult = session.execute(plan)
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,64 @@ class SparkSession(private val client: SparkConnectClient, private val cleaner:
builder.setSql(proto.SQL.newBuilder().setQuery(query))
}

/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from 0 to `end` (exclusive) with step value 1.
*
* @since 3.4.0
*/
def range(end: Long): Dataset = range(0, end)

/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with step value 1.
*
* @since 3.4.0
*/
def range(start: Long, end: Long): Dataset = {
range(start, end, step = 1)
}

/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with a step value.
*
* @since 3.4.0
*/
def range(start: Long, end: Long, step: Long): Dataset = {
range(start, end, step, None)
}

/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from `start` to `end` (exclusive) with a step value, with partition number specified.
*
* @since 3.4.0
*/
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset = {
range(start, end, step, Option(numPartitions))
}

private def range(start: Long, end: Long, step: Long, numPartitions: Option[Int]): Dataset = {
newDataset { builder =>
val rangeBuilder = builder.getRangeBuilder
.setStart(start)
.setEnd(end)
.setStep(step)
numPartitions.foreach(rangeBuilder.setNumPartitions)
}
}

private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = {
val builder = proto.Relation.newBuilder()
f(builder)
val plan = proto.Plan.newBuilder().setRoot(builder).build()
new Dataset(this, plan)
}

private[sql] def analyze(plan: proto.Plan): proto.AnalyzePlanResponse =
client.analyze(plan)

private[sql] def execute(plan: proto.Plan): SparkResult = {
val value = client.execute(plan)
val result = new SparkResult(value, allocator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.language.existentials

import io.grpc.{ManagedChannel, ManagedChannelBuilder}
import java.net.URI
import java.util.UUID

import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.common.config.ConnectCommon
Expand All @@ -41,6 +42,11 @@ class SparkConnectClient(
*/
def userId: String = userContext.getUserId()

// Generate a unique session ID for this client. This UUID must be unique to allow
// concurrent Spark sessions of the same user. If the channel is closed, creating
// a new client will create a new session ID.
private[client] val sessionId: String = UUID.randomUUID.toString

/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
* @return
Expand All @@ -58,6 +64,22 @@ class SparkConnectClient(
stub.executePlan(request)
}

/**
* Builds a [[proto.AnalyzePlanRequest]] from `plan` and dispatched it to the Spark Connect
* server.
* @return
* A [[proto.AnalyzePlanResponse]] from the Spark Connect server.
*/
def analyze(plan: proto.Plan): proto.AnalyzePlanResponse = {
val request = proto.AnalyzePlanRequest
.newBuilder()
.setPlan(plan)
.setUserContext(userContext)
.setClientId(sessionId)
.build()
analyze(request)
}

/**
* Shutdown the client's connection to the server.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.connect

package object client {

private[sql] def unsupported(): Nothing = {
throw new UnsupportedOperationException
}

private[sql] def unsupported(message: String): Nothing = {
throw new UnsupportedOperationException(message)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql

import java.math.{BigDecimal => JBigDecimal}
import java.time.LocalDate

import com.google.protobuf.ByteString

import org.apache.spark.connect.proto
import org.apache.spark.sql.connect.client.unsupported

/**
* Commonly used functions available for DataFrame operations.
*
* @since 3.4.0
*/
// scalastyle:off
object functions {
// scalastyle:on

private def createLiteral(f: proto.Expression.Literal.Builder => Unit): Column = Column {
builder =>
val literalBuilder = proto.Expression.Literal.newBuilder()
f(literalBuilder)
builder.setLiteral(literalBuilder)
}

private def createDecimalLiteral(precision: Int, scale: Int, value: String): Column =
createLiteral { builder =>
builder.getDecimalBuilder
.setPrecision(precision)
.setScale(scale)
.setValue(value)
}

/**
* Creates a [[Column]] of literal value.
*
* The passed in object is returned directly if it is already a [[Column]]. If the object is a
* Scala Symbol, it is converted into a [[Column]] also. Otherwise, a new [[Column]] is created
* to represent the literal value.
*
* @since 3.4.0
*/
def lit(literal: Any): Column = {
literal match {
case c: Column => c
case s: Symbol => Column(s.name)
case v: Boolean => createLiteral(_.setBoolean(v))
case v: Byte => createLiteral(_.setByte(v))
case v: Short => createLiteral(_.setShort(v))
case v: Int => createLiteral(_.setInteger(v))
case v: Long => createLiteral(_.setLong(v))
case v: Float => createLiteral(_.setFloat(v))
case v: Double => createLiteral(_.setDouble(v))
case v: BigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString)
case v: JBigDecimal => createDecimalLiteral(v.precision, v.scale, v.toString)
case v: String => createLiteral(_.setString(v))
case v: Char => createLiteral(_.setString(v.toString))
case v: Array[Char] => createLiteral(_.setString(String.valueOf(v)))
case v: Array[Byte] => createLiteral(_.setBinary(ByteString.copyFrom(v)))
case v: collection.mutable.WrappedArray[_] => lit(v.array)
case v: LocalDate => createLiteral(_.setDate(v.toEpochDay.toInt))
case null => unsupported("Null literals not supported yet.")
case _ => unsupported(s"literal $literal not supported (yet).")
}
}
}
Loading