Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package org.apache.spark.sql.connect
import scala.language.implicitConversions

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.sql._
import org.apache.spark.sql.internal.ProtoColumnNode

/**
* Conversions from sql interfaces to the Connect specific implementation.
*
* This class is mainly used by the implementation. In the case of connect it should be extremely
* rare that a developer needs these classes.
* This class is mainly used by the implementation. It is also meant to be used by extension
* developers.
*
* We provide both a trait and an object. The trait is useful in situations where an extension
* developer needs to use these conversions in a project covering multiple Spark versions. They
Expand All @@ -46,6 +48,40 @@ trait ConnectConversions {
implicit def castToImpl[K, V](
kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] =
kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]

/**
* Create a [[Column]] from a [[proto.Expression]]
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatibility
* between (minor) versions.
*/
@DeveloperApi
def column(expr: proto.Expression): Column = {
Column(ProtoColumnNode(expr))
}

/**
* Create a [[Column]] using a function that manipulates an [[proto.Expression.Builder]].
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatibility
* between (minor) versions.
*/
@DeveloperApi
def column(f: proto.Expression.Builder => Unit): Column = {
val builder = proto.Expression.newBuilder()
f(builder)
column(builder.build())
}

/**
* Implicit helper that makes it easy to construct a Column from an Expression or an Expression
* builder. This allows developers to create a Column in the same way as in earlier versions of
* Spark (before 4.0).
*/
@DeveloperApi
implicit class ColumnConstructorExt(val c: Column.type) {
def apply(e: proto.Expression): Column = column(e)
}
}

object ConnectConversions extends ConnectConversions
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,12 @@

package org.apache.spark

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.connect.proto
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.internal.ProtoColumnNode

package object sql {
type DataFrame = Dataset[Row]

private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = {
implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]]
}

/**
* Create a [[Column]] from a [[proto.Expression]]
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatility
* between (minor) versions.
*/
@DeveloperApi
def column(expr: proto.Expression): Column = {
Column(ProtoColumnNode(expr))
}

/**
* Creat a [[Column]] using a function that manipulates an [[proto.Expression.Builder]].
*
* This method is meant to be used by Connect plugins. We do not guarantee any compatility
* between (minor) versions.
*/
@DeveloperApi
def column(f: proto.Expression.Builder => Unit): Column = {
val builder = proto.Expression.newBuilder()
f(builder)
column(builder.build())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.avro.{functions => avroFn}
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.connect.ConnectConversions._
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.lit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -122,7 +121,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
(attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) {
replaceCol(attr, replacementMap)
} else {
column(attr)
Column(attr)
}
}
df.select(projections : _*)
Expand All @@ -131,7 +130,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
protected def fillMap(values: Seq[(String, Any)]): DataFrame = {
// Error handling
val attrToValue = AttributeMap(values.map { case (colName, replaceValue) =>
// Check column name exists
// Check Column name exists
val attr = df.resolve(colName) match {
case a: Attribute => a
case _ => throw QueryExecutionErrors.nestedFieldUnsupportedError(colName)
Expand All @@ -155,7 +154,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
case v: jl.Integer => fillCol[Integer](attr, v)
case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue())
case v: String => fillCol[String](attr, v)
}.getOrElse(column(attr))
}.getOrElse(Column(attr))
}
df.select(projections : _*)
}
Expand All @@ -165,7 +164,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
* with `replacement`.
*/
private def fillCol[T](attr: Attribute, replacement: T): Column = {
fillCol(attr.dataType, attr.name, column(attr), replacement)
fillCol(attr.dataType, attr.name, Column(attr), replacement)
}

/**
Expand All @@ -192,7 +191,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
val branches = replacementMap.flatMap { case (source, target) =>
Seq(Literal(source), buildExpr(target))
}.toSeq
column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name)
}

private def convertToDouble(v: Any): Double = v match {
Expand All @@ -219,7 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
// Filtering condition:
// only keep the row if it has at least `minNonNulls` non-null and non-NaN values.
val predicate = AtLeastNNonNulls(minNonNulls.getOrElse(cols.size), cols)
df.filter(column(predicate))
df.filter(Column(predicate))
}

private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame = {
Expand Down Expand Up @@ -255,9 +254,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame)
}
// Only fill if the column is part of the cols list.
if (typeMatches && cols.exists(_.semanticEquals(col))) {
fillCol(col.dataType, col.name, column(col), value)
fillCol(col.dataType, col.name, Column(col), value)
} else {
column(col)
Column(col)
}
}
df.select(projections : _*)
Expand Down
29 changes: 18 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Data
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.execution.stat.StatFunctions
import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -303,7 +302,7 @@ class Dataset[T] private[sql](
truncate: Int): Seq[Seq[String]] = {
val newDf = commandResultOptimized.toDF()
val castCols = newDf.logicalPlan.output.map { col =>
column(ToPrettyString(col))
Column(ToPrettyString(col))
}
val data = newDf.select(castCols: _*).take(numRows + 1)

Expand Down Expand Up @@ -505,7 +504,7 @@ class Dataset[T] private[sql](
s"New column names (${colNames.size}): " + colNames.mkString(", "))

val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) =>
column(oldAttribute).as(newName)
Column(oldAttribute).as(newName)
}
select(newCols : _*)
}
Expand Down Expand Up @@ -760,18 +759,18 @@ class Dataset[T] private[sql](
/** @inheritdoc */
def col(colName: String): Column = colName match {
case "*" =>
column(ResolvedStar(queryExecution.analyzed.output))
Column(ResolvedStar(queryExecution.analyzed.output))
case _ =>
if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) {
colRegex(colName)
} else {
column(addDataFrameIdToCol(resolve(colName)))
Column(addDataFrameIdToCol(resolve(colName)))
}
}

/** @inheritdoc */
def metadataColumn(colName: String): Column =
column(queryExecution.analyzed.getMetadataAttributeByName(colName))
Column(queryExecution.analyzed.getMetadataAttributeByName(colName))

// Attach the dataset id and column position to the column reference, so that we can detect
// ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`.
Expand All @@ -797,11 +796,11 @@ class Dataset[T] private[sql](
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
colName match {
case ParserUtils.escapedIdentifier(columnNameRegex) =>
column(UnresolvedRegex(columnNameRegex, None, caseSensitive))
Column(UnresolvedRegex(columnNameRegex, None, caseSensitive))
case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) =>
column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive))
Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive))
case _ =>
column(addDataFrameIdToCol(resolve(colName)))
Column(addDataFrameIdToCol(resolve(colName)))
}
}

Expand Down Expand Up @@ -1194,7 +1193,7 @@ class Dataset[T] private[sql](
resolver(field.name, colName)
} match {
case Some((colName: String, col: Column)) => col.as(colName)
case _ => column(field)
case _ => Column(field)
}
}

Expand Down Expand Up @@ -1264,7 +1263,7 @@ class Dataset[T] private[sql](
val allColumns = queryExecution.analyzed.output
val remainingCols = allColumns.filter { attribute =>
colNames.forall(n => !resolver(attribute.name, n))
}.map(attribute => column(attribute))
}.map(attribute => Column(attribute))
if (remainingCols.size == allColumns.size) {
toDF()
} else {
Expand Down Expand Up @@ -1975,6 +1974,14 @@ class Dataset[T] private[sql](
// For Python API
////////////////////////////////////////////////////////////////////////////

/**
* It adds a new long column with the name `name` that increases one by one.
* This is for 'distributed-sequence' default index in pandas API on Spark.
*/
private[sql] def withSequenceColumn(name: String) = {
select(Column(DistributedSequenceID()).alias(name), col("*"))
}

/**
* Converts a JavaRDD to a PythonRDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import scala.language.implicitConversions

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.internal.ExpressionUtils

/**
* Conversions from sql interfaces to the Classic specific implementation.
*
* This class is mainly used by the implementation, but is also meant to be used by extension
* This class is mainly used by the implementation. It is also meant to be used by extension
* developers.
*
* We provide both a trait and an object. The trait is useful in situations where an extension
Expand All @@ -45,6 +47,13 @@ trait ClassicConversions {

implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V])
: KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]]

/**
* Helper that makes it easy to construct a Column from an Expression.
*/
implicit class ColumnConstructorExt(val c: Column.type) {
def apply(e: Expression): Column = ExpressionUtils.column(e)
}
}

object ClassicConversions extends ClassicConversions
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends
sqlConf.contains(key)
}

private def requireNonStaticConf(key: String): Unit = {
private[sql] def requireNonStaticConf(key: String): Unit = {
if (SQLConf.isStaticConfigKey(key)) {
throw QueryCompilationErrors.cannotModifyValueOfStaticConfigError(key)
}
Expand Down