Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions connector/connect/src/main/protobuf/spark/connect/commands.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

syntax = 'proto3';

import "spark/connect/expressions.proto";
import "spark/connect/relations.proto";
import "spark/connect/types.proto";

package spark.connect;
Expand All @@ -29,6 +31,7 @@ option java_package = "org.apache.spark.connect.proto";
message Command {
oneof command_type {
CreateScalarFunction create_function = 1;
WriteOperation write_operation = 2;
}
}

Expand Down Expand Up @@ -62,3 +65,39 @@ message CreateScalarFunction {
FUNCTION_LANGUAGE_SCALA = 3;
}
}

// As writes are not directly handled during analysis and planning, they are modeled as commands.
message WriteOperation {
// The output of the `input` relation will be persisted according to the options.
Relation input = 1;
// Format value according to the Spark documentation. Examples are: text, parquet, delta.
string source = 2;
// The destination of the write operation must be either a path or a table.
oneof save_type {
string path = 3;
string table_name = 4;
}
SaveMode mode = 5;
// List of columns to sort the output by.
repeated string sort_column_names = 6;
// List of columns for partitioning.
repeated string partitioning_columns = 7;
// Optional bucketing specification. Bucketing must set the number of buckets and the columns
// to bucket by.
BucketBy bucket_by = 8;
// Optional list of configuration options.
map<string, string> options = 9;

message BucketBy {
repeated string bucket_column_names = 1;
int32 num_buckets = 2;
}

enum SaveMode {
SAVE_MODE_UNSPECIFIED = 0;
SAVE_MODE_APPEND = 1;
SAVE_MODE_OVERWRITE = 2;
SAVE_MODE_ERROR_IF_EXISTS = 3;
SAVE_MODE_IGNORE = 4;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,16 @@ import com.google.common.collect.{Lists, Maps}
import org.apache.spark.annotation.{Since, Unstable}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.sql.SparkSession
import org.apache.spark.connect.proto.WriteOperation
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner}
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.types.StringType

final case class InvalidCommandInput(
private val message: String = "",
private val cause: Throwable = null)
extends Exception(message, cause)

@Unstable
@Since("3.4.0")
Expand All @@ -40,6 +46,8 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command)
command.getCommandTypeCase match {
case proto.Command.CommandTypeCase.CREATE_FUNCTION =>
handleCreateScalarFunction(command.getCreateFunction)
case proto.Command.CommandTypeCase.WRITE_OPERATION =>
handleWriteOperation(command.getWriteOperation)
case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
}
Expand Down Expand Up @@ -74,4 +82,64 @@ class SparkConnectCommandPlanner(session: SparkSession, command: proto.Command)
session.udf.registerPython(cf.getPartsList.asScala.head, udf)
}

/**
* Transforms the write operation and executes it.
*
* The input write operation contains a reference to the input plan and transforms it to the
* corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
* parameters of the WriteOperation into the corresponding methods calls.
*
* @param writeOperation
*/
def handleWriteOperation(writeOperation: WriteOperation): Unit = {
// Transform the input plan into the logical plan.
val planner = new SparkConnectPlanner(writeOperation.getInput, session)
val plan = planner.transform()
// And create a Dataset from the plan.
val dataset = Dataset.ofRows(session, logicalPlan = plan)

val w = dataset.write
if (writeOperation.getMode != proto.WriteOperation.SaveMode.SAVE_MODE_UNSPECIFIED) {
w.mode(DataTypeProtoConverter.toSaveMode(writeOperation.getMode))
}

if (writeOperation.getOptionsCount > 0) {
writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
}

if (writeOperation.getSortColumnNamesCount > 0) {
val names = writeOperation.getSortColumnNamesList.asScala
w.sortBy(names.head, names.tail.toSeq: _*)
}

if (writeOperation.hasBucketBy) {
val op = writeOperation.getBucketBy
val cols = op.getBucketColumnNamesList.asScala
if (op.getNumBuckets <= 0) {
throw InvalidCommandInput(
s"BucketBy must specify a bucket count > 0, received ${op.getNumBuckets} instead.")
}
w.bucketBy(op.getNumBuckets, cols.head, cols.tail.toSeq: _*)
}

if (writeOperation.getPartitioningColumnsCount > 0) {
val names = writeOperation.getPartitioningColumnsList.asScala
w.partitionBy(names.toSeq: _*)
}

if (writeOperation.getSource != null) {
w.format(writeOperation.getSource)
}

writeOperation.getSaveTypeCase match {
case proto.WriteOperation.SaveTypeCase.PATH => w.save(writeOperation.getPath)
case proto.WriteOperation.SaveTypeCase.TABLE_NAME =>
w.saveAsTable(writeOperation.getTableName)
case _ =>
throw new UnsupportedOperationException(
"WriteOperation:SaveTypeCase not supported "
+ s"${writeOperation.getSaveTypeCase.getNumber}")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import scala.language.implicitConversions

import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.connect.planner.DataTypeProtoConverter

/**
* A collection of implicit conversions that create a DSL for constructing connect protos.
Expand All @@ -34,59 +36,106 @@ package object dsl {
val identifier = CatalystSqlParser.parseMultipartIdentifier(s)

def protoAttr: proto.Expression =
proto.Expression.newBuilder()
proto.Expression
.newBuilder()
.setUnresolvedAttribute(
proto.Expression.UnresolvedAttribute.newBuilder()
proto.Expression.UnresolvedAttribute
.newBuilder()
.addAllParts(identifier.asJava)
.build())
.build()
}

implicit class DslExpression(val expr: proto.Expression) {
def as(alias: String): proto.Expression = proto.Expression.newBuilder().setAlias(
proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr)).build()

def < (other: proto.Expression): proto.Expression =
proto.Expression.newBuilder().setUnresolvedFunction(
proto.Expression.UnresolvedFunction.newBuilder()
.addParts("<")
.addArguments(expr)
.addArguments(other)
).build()
def as(alias: String): proto.Expression = proto.Expression
.newBuilder()
.setAlias(proto.Expression.Alias.newBuilder().setName(alias).setExpr(expr))
.build()

def <(other: proto.Expression): proto.Expression =
proto.Expression
.newBuilder()
.setUnresolvedFunction(
proto.Expression.UnresolvedFunction
.newBuilder()
.addParts("<")
.addArguments(expr)
.addArguments(other))
.build()
}

implicit def intToLiteral(i: Int): proto.Expression =
proto.Expression.newBuilder().setLiteral(
proto.Expression.Literal.newBuilder().setI32(i)
).build()
proto.Expression
.newBuilder()
.setLiteral(proto.Expression.Literal.newBuilder().setI32(i))
.build()
}

object commands { // scalastyle:ignore
implicit class DslCommands(val logicalPlan: proto.Relation) {
def write(
format: Option[String] = None,
path: Option[String] = None,
tableName: Option[String] = None,
mode: Option[String] = None,
sortByColumns: Seq[String] = Seq.empty,
partitionByCols: Seq[String] = Seq.empty,
bucketByCols: Seq[String] = Seq.empty,
numBuckets: Option[Int] = None): proto.Command = {
val writeOp = proto.WriteOperation.newBuilder()
format.foreach(writeOp.setSource(_))

mode
.map(SaveMode.valueOf(_))
.map(DataTypeProtoConverter.toSaveModeProto(_))
.foreach(writeOp.setMode(_))

if (tableName.nonEmpty) {
tableName.foreach(writeOp.setTableName(_))
} else {
path.foreach(writeOp.setPath(_))
}
sortByColumns.foreach(writeOp.addSortColumnNames(_))
partitionByCols.foreach(writeOp.addPartitioningColumns(_))

if (numBuckets.nonEmpty && bucketByCols.nonEmpty) {
val op = proto.WriteOperation.BucketBy.newBuilder()
numBuckets.foreach(op.setNumBuckets(_))
bucketByCols.foreach(op.addBucketColumnNames(_))
writeOp.setBucketBy(op.build())
}
writeOp.setInput(logicalPlan)
proto.Command.newBuilder().setWriteOperation(writeOp.build()).build()
}
}
}

object plans { // scalastyle:ignore
implicit class DslLogicalPlan(val logicalPlan: proto.Relation) {
def select(exprs: proto.Expression*): proto.Relation = {
proto.Relation.newBuilder().setProject(
proto.Project.newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build()
).build()
proto.Project.newBuilder()
.setInput(logicalPlan)
.addAllExpressions(exprs.toIterable.asJava)
.build())
.build()
}

def where(condition: proto.Expression): proto.Relation = {
proto.Relation.newBuilder()
.setFilter(
proto.Filter.newBuilder().setInput(logicalPlan).setCondition(condition)
).build()
).build()
}


def join(
otherPlan: proto.Relation,
joinType: JoinType = JoinType.JOIN_TYPE_INNER,
condition: Option[proto.Expression] = None): proto.Relation = {
val relation = proto.Relation.newBuilder()
val join = proto.Join.newBuilder()
join.setLeft(logicalPlan)
join
.setLeft(logicalPlan)
.setRight(otherPlan)
.setJoinType(joinType)
if (condition.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connect.planner

import org.apache.spark.connect.proto
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.types.{DataType, IntegerType, StringType}

/**
Expand All @@ -43,4 +44,28 @@ object DataTypeProtoConverter {
throw InvalidPlanInput(s"Does not support convert ${t.typeName} to connect proto types.")
}
}

def toSaveMode(mode: proto.WriteOperation.SaveMode): SaveMode = {
mode match {
case proto.WriteOperation.SaveMode.SAVE_MODE_APPEND => SaveMode.Append
case proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE => SaveMode.Ignore
case proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE => SaveMode.Overwrite
case proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS => SaveMode.ErrorIfExists
case _ =>
throw new IllegalArgumentException(
s"Cannot convert from WriteOperaton.SaveMode to Spark SaveMode: ${mode.getNumber}")
}
}

def toSaveModeProto(mode: SaveMode): proto.WriteOperation.SaveMode = {
mode match {
case SaveMode.Append => proto.WriteOperation.SaveMode.SAVE_MODE_APPEND
case SaveMode.Ignore => proto.WriteOperation.SaveMode.SAVE_MODE_IGNORE
case SaveMode.Overwrite => proto.WriteOperation.SaveMode.SAVE_MODE_OVERWRITE
case SaveMode.ErrorIfExists => proto.WriteOperation.SaveMode.SAVE_MODE_ERROR_IF_EXISTS
case _ =>
throw new IllegalArgumentException(
s"Cannot convert from SaveMode to WriteOperation.SaveMode: ${mode.name()}")
}
}
}
Loading