diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala index 7d81f4ead785..0344152be86e 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/ConnectConversions.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.connect import scala.language.implicitConversions import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.connect.proto import org.apache.spark.sql._ +import org.apache.spark.sql.internal.ProtoColumnNode /** * Conversions from sql interfaces to the Connect specific implementation. * - * This class is mainly used by the implementation. In the case of connect it should be extremely - * rare that a developer needs these classes. + * This class is mainly used by the implementation. It is also meant to be used by extension + * developers. * * We provide both a trait and an object. The trait is useful in situations where an extension * developer needs to use these conversions in a project covering multiple Spark versions. They @@ -46,6 +48,40 @@ trait ConnectConversions { implicit def castToImpl[K, V]( kvds: api.KeyValueGroupedDataset[K, V]): KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + + /** + * Create a [[Column]] from a [[proto.Expression]] + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(expr: proto.Expression): Column = { + Column(ProtoColumnNode(expr)) + } + + /** + * Create a [[Column]] using a function that manipulates an [[proto.Expression.Builder]]. + * + * This method is meant to be used by Connect plugins. We do not guarantee any compatibility + * between (minor) versions. + */ + @DeveloperApi + def column(f: proto.Expression.Builder => Unit): Column = { + val builder = proto.Expression.newBuilder() + f(builder) + column(builder.build()) + } + + /** + * Implicit helper that makes it easy to construct a Column from an Expression or an Expression + * builder. This allows developers to create a Column in the same way as in earlier versions of + * Spark (before 4.0). + */ + @DeveloperApi + implicit class ColumnConstructorExt(val c: Column.type) { + def apply(e: proto.Expression): Column = column(e) + } } object ConnectConversions extends ConnectConversions diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala index 154f2b0405fc..556b472283a3 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/package.scala @@ -17,10 +17,7 @@ package org.apache.spark -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.connect.proto import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.internal.ProtoColumnNode package object sql { type DataFrame = Dataset[Row] @@ -28,28 +25,4 @@ package object sql { private[sql] def encoderFor[E: Encoder]: AgnosticEncoder[E] = { implicitly[Encoder[E]].asInstanceOf[AgnosticEncoder[E]] } - - /** - * Create a [[Column]] from a [[proto.Expression]] - * - * This method is meant to be used by Connect plugins. We do not guarantee any compatility - * between (minor) versions. - */ - @DeveloperApi - def column(expr: proto.Expression): Column = { - Column(ProtoColumnNode(expr)) - } - - /** - * Creat a [[Column]] using a function that manipulates an [[proto.Expression.Builder]]. - * - * This method is meant to be used by Connect plugins. We do not guarantee any compatility - * between (minor) versions. - */ - @DeveloperApi - def column(f: proto.Expression.Builder => Unit): Column = { - val builder = proto.Expression.newBuilder() - f(builder) - column(builder.build()) - } } diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 315f80e13eff..c557b5473279 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.avro.{functions => avroFn} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.connect.ConnectConversions._ import org.apache.spark.sql.connect.client.SparkConnectClient import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index b356751083fc..53e12f58edd6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ -import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.types._ /** @@ -122,7 +121,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) (attr.dataType.isInstanceOf[NumericType] && targetColumnType == DoubleType))) { replaceCol(attr, replacementMap) } else { - column(attr) + Column(attr) } } df.select(projections : _*) @@ -131,7 +130,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) protected def fillMap(values: Seq[(String, Any)]): DataFrame = { // Error handling val attrToValue = AttributeMap(values.map { case (colName, replaceValue) => - // Check column name exists + // Check Column name exists val attr = df.resolve(colName) match { case a: Attribute => a case _ => throw QueryExecutionErrors.nestedFieldUnsupportedError(colName) @@ -155,7 +154,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) case v: jl.Integer => fillCol[Integer](attr, v) case v: jl.Boolean => fillCol[Boolean](attr, v.booleanValue()) case v: String => fillCol[String](attr, v) - }.getOrElse(column(attr)) + }.getOrElse(Column(attr)) } df.select(projections : _*) } @@ -165,7 +164,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) * with `replacement`. */ private def fillCol[T](attr: Attribute, replacement: T): Column = { - fillCol(attr.dataType, attr.name, column(attr), replacement) + fillCol(attr.dataType, attr.name, Column(attr), replacement) } /** @@ -192,7 +191,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) val branches = replacementMap.flatMap { case (source, target) => Seq(Literal(source), buildExpr(target)) }.toSeq - column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) + Column(CaseKeyWhen(attr, branches :+ attr)).as(attr.name) } private def convertToDouble(v: Any): Double = v match { @@ -219,7 +218,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) // Filtering condition: // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. val predicate = AtLeastNNonNulls(minNonNulls.getOrElse(cols.size), cols) - df.filter(column(predicate)) + df.filter(Column(predicate)) } private[sql] def fillValue(value: Any, cols: Option[Seq[String]]): DataFrame = { @@ -255,9 +254,9 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) } // Only fill if the column is part of the cols list. if (typeMatches && cols.exists(_.semanticEquals(col))) { - fillCol(col.dataType, col.name, column(col), value) + fillCol(col.dataType, col.name, Column(col), value) } else { - column(col) + Column(col) } } df.select(projections : _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 80ec70a7864c..18fc5787a158 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -63,7 +63,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, Data import org.apache.spark.sql.execution.python.EvaluatePython import org.apache.spark.sql.execution.stat.StatFunctions import org.apache.spark.sql.internal.{DataFrameWriterImpl, DataFrameWriterV2Impl, MergeIntoWriterImpl, SQLConf} -import org.apache.spark.sql.internal.ExpressionUtils.column import org.apache.spark.sql.internal.TypedAggUtils.withInputType import org.apache.spark.sql.streaming.DataStreamWriter import org.apache.spark.sql.types._ @@ -303,7 +302,7 @@ class Dataset[T] private[sql]( truncate: Int): Seq[Seq[String]] = { val newDf = commandResultOptimized.toDF() val castCols = newDf.logicalPlan.output.map { col => - column(ToPrettyString(col)) + Column(ToPrettyString(col)) } val data = newDf.select(castCols: _*).take(numRows + 1) @@ -505,7 +504,7 @@ class Dataset[T] private[sql]( s"New column names (${colNames.size}): " + colNames.mkString(", ")) val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, newName) => - column(oldAttribute).as(newName) + Column(oldAttribute).as(newName) } select(newCols : _*) } @@ -760,18 +759,18 @@ class Dataset[T] private[sql]( /** @inheritdoc */ def col(colName: String): Column = colName match { case "*" => - column(ResolvedStar(queryExecution.analyzed.output)) + Column(ResolvedStar(queryExecution.analyzed.output)) case _ => if (sparkSession.sessionState.conf.supportQuotedRegexColumnName) { colRegex(colName) } else { - column(addDataFrameIdToCol(resolve(colName))) + Column(addDataFrameIdToCol(resolve(colName))) } } /** @inheritdoc */ def metadataColumn(colName: String): Column = - column(queryExecution.analyzed.getMetadataAttributeByName(colName)) + Column(queryExecution.analyzed.getMetadataAttributeByName(colName)) // Attach the dataset id and column position to the column reference, so that we can detect // ambiguous self-join correctly. See the rule `DetectAmbiguousSelfJoin`. @@ -797,11 +796,11 @@ class Dataset[T] private[sql]( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis colName match { case ParserUtils.escapedIdentifier(columnNameRegex) => - column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) + Column(UnresolvedRegex(columnNameRegex, None, caseSensitive)) case ParserUtils.qualifiedEscapedIdentifier(nameParts, columnNameRegex) => - column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) + Column(UnresolvedRegex(columnNameRegex, Some(nameParts), caseSensitive)) case _ => - column(addDataFrameIdToCol(resolve(colName))) + Column(addDataFrameIdToCol(resolve(colName))) } } @@ -1194,7 +1193,7 @@ class Dataset[T] private[sql]( resolver(field.name, colName) } match { case Some((colName: String, col: Column)) => col.as(colName) - case _ => column(field) + case _ => Column(field) } } @@ -1264,7 +1263,7 @@ class Dataset[T] private[sql]( val allColumns = queryExecution.analyzed.output val remainingCols = allColumns.filter { attribute => colNames.forall(n => !resolver(attribute.name, n)) - }.map(attribute => column(attribute)) + }.map(attribute => Column(attribute)) if (remainingCols.size == allColumns.size) { toDF() } else { @@ -1975,6 +1974,14 @@ class Dataset[T] private[sql]( // For Python API //////////////////////////////////////////////////////////////////////////// + /** + * It adds a new long column with the name `name` that increases one by one. + * This is for 'distributed-sequence' default index in pandas API on Spark. + */ + private[sql] def withSequenceColumn(name: String) = { + select(Column(DistributedSequenceID()).alias(name), col("*")) + } + /** * Converts a JavaRDD to a PythonRDD. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala index af91b57a6848..8c3223fa72f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/ClassicConversions.scala @@ -20,11 +20,13 @@ import scala.language.implicitConversions import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.ExpressionUtils /** * Conversions from sql interfaces to the Classic specific implementation. * - * This class is mainly used by the implementation, but is also meant to be used by extension + * This class is mainly used by the implementation. It is also meant to be used by extension * developers. * * We provide both a trait and an object. The trait is useful in situations where an extension @@ -45,6 +47,13 @@ trait ClassicConversions { implicit def castToImpl[K, V](kvds: api.KeyValueGroupedDataset[K, V]) : KeyValueGroupedDataset[K, V] = kvds.asInstanceOf[KeyValueGroupedDataset[K, V]] + + /** + * Helper that makes it easy to construct a Column from an Expression. + */ + implicit class ColumnConstructorExt(val c: Column.type) { + def apply(e: Expression): Column = ExpressionUtils.column(e) + } } object ClassicConversions extends ClassicConversions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala index ca439cdb8995..f25ca387db29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/RuntimeConfigImpl.scala @@ -84,7 +84,7 @@ class RuntimeConfigImpl private[sql](val sqlConf: SQLConf = new SQLConf) extends sqlConf.contains(key) } - private def requireNonStaticConf(key: String): Unit = { + private[sql] def requireNonStaticConf(key: String): Unit = { if (SQLConf.isStaticConfigKey(key)) { throw QueryCompilationErrors.cannotModifyValueOfStaticConfigError(key) }