From 43ca8037c500c547076e2319622b0f80c2be91dd Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 25 Apr 2024 18:36:48 -0700 Subject: [PATCH 1/5] [SPARK-48012][SQL] SPJ: Support Transfrom Expressions for One Side Shuffle ### Why are the changes needed? Support SPJ one-side shuffle if other side has partition transform expression ### How was this patch tested? New unit test in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No. --- .../scala/org/apache/spark/Partitioner.scala | 4 +- .../expressions/TransformExpression.scala | 35 ++++- .../expressions/V2ExpressionUtils.scala | 9 +- .../plans/physical/partitioning.scala | 26 +++- .../KeyGroupedPartitioningSuite.scala | 136 +++++++++++++++--- .../functions/transformFunctions.scala | 12 +- 6 files changed, 195 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ae39e2e183e4a..9950d336074d2 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import scala.collection.immutable.ArraySeq import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.log10 @@ -149,7 +150,8 @@ private[spark] class KeyGroupedPartitioner( override val numPartitions: Int) extends Partitioner { override def getPartition(key: Any): Int = { val keys = key.asInstanceOf[Seq[Any]] - valueMap.getOrElseUpdate(keys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) + val normalizedKeys = ArraySeq.from(keys) + valueMap.getOrElseUpdate(normalizedKeys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index d37c9d9f6452a..98b5c641096fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction, ScalarFunction} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.DataType /** @@ -30,7 +33,7 @@ import org.apache.spark.sql.types.DataType case class TransformExpression( function: BoundFunction, children: Seq[Expression], - numBucketsOpt: Option[Int] = None) extends Expression with Unevaluable { + numBucketsOpt: Option[Int] = None) extends Expression { override def nullable: Boolean = true @@ -113,4 +116,32 @@ case class TransformExpression( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) + + lazy val resolvedFunction: Option[Expression] = this match { + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => + Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, + Seq(Literal(numBuckets)) ++ arguments)) + case TransformExpression(scalarFunc: ScalarFunction[_], arguments, None) => + Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, arguments)) + case _ => None + } + + override def eval(input: InternalRow): Any = { + resolvedFunction match { + case Some(fn) => fn.eval(input) + case None => throw QueryExecutionErrors.cannotEvaluateExpressionError(this) + } + } + + /** + * Returns Java source code that can be compiled to evaluate this expression. + * The default behavior is to call the eval method of the expression. Concrete expression + * implementations should override this to do actual code generation. + * + * @param ctx a [[CodegenContext]] + * @param ev an [[ExprCode]] with unique terms. + * @return an [[ExprCode]] containing the Java source code to generate the given expression + */ + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 220920a5a3198..44a585da5134d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -161,7 +161,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { val declaredInputTypes = scalarFunc.inputTypes().toImmutableArraySeq val argClasses = declaredInputTypes.map(EncoderUtils.dataTypeJavaClass) findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { - case Some(m) if Modifier.isStatic(m.getModifiers) => + case Some(m) if isStatic(m) => StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, propagateNull = false, returnNullable = scalarFunc.isResultNullable, @@ -204,4 +204,11 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { None } } + + private def isStatic(m: Method) = { + val javaStatic = Modifier.isStatic(m.getModifiers) + val scalaObjModule = m.getDeclaringClass.getField("MODULE$") + val scalaStatic = scalaObjModule != null && Modifier.isStatic(scalaObjModule.getModifiers) + javaStatic || scalaStatic + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 43aba478c37be..a4756ac1fde74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -871,12 +871,30 @@ case class KeyGroupedShuffleSpec( if (results.forall(p => p.isEmpty)) None else Some(results) } - override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && - // Only support partition expressions are AttributeReference for now - partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + override def canCreatePartitioning: Boolean = { + // Allow one side shuffle for SPJ for now only if partially-clustered is not enabled + // and for join keys less than partition keys only if transforms are not enabled. + val checkExprType = if (SQLConf.get.v2BucketingAllowJoinKeysSubsetOfPartitionKeys) { + e: Expression => e.isInstanceOf[AttributeReference] + } else { + e: Expression => e.isInstanceOf[AttributeReference] || e.isInstanceOf[TransformExpression] + } + SQLConf.get.v2BucketingShuffleEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + partitioning.expressions.forall(checkExprType) + } + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map{ + case (c, e: TransformExpression) => TransformExpression( + e.function, Seq(c), e.numBucketsOpt) + case (c, _) => c + } + KeyGroupedPartitioning(newExpressions, + partitioning.numPartitions, + partitioning.partitionValues) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 10a32441b6cd9..a5de5bc1913b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1136,7 +1136,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { val df = createJoinTestDF(Seq("arrive_time" -> "time")) val shuffles = collectShuffles(df.queryExecution.executedPlan) if (shuffle) { - assert(shuffles.size == 2, "partitioning with transform not work now") + assert(shuffles.size == 1, "partitioning with transform should trigger SPJ") } else { assert(shuffles.size == 2, "should add two side shuffle when bucketing shuffle one side" + " is not enabled") @@ -1991,22 +1991,19 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { "(6, 50.0, cast('2023-02-01' as timestamp))") Seq(true, false).foreach { pushdownValues => - Seq(true, false).foreach { partiallyClustered => - withSQLConf( - SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", - SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, - SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key - -> partiallyClustered.toString, - SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { - val df = createJoinTestDF(Seq("id" -> "item_id")) - val shuffles = collectShuffles(df.queryExecution.executedPlan) - assert(shuffles.size == 1, "SPJ should be triggered") - checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), - Row(1, "aa", 30.0, 89.0), - Row(1, "aa", 40.0, 42.0), - Row(1, "aa", 40.0, 89.0), - Row(3, "bb", 10.0, 19.5))) - } + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> pushdownValues.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "SPJ should be triggered") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) } } } @@ -2052,4 +2049,109 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } } + + test("SPARK-48012: one-side shuffle with partition transforms") { + val items_partitions = Array(bucket(2, "id"), identity("arrive_time")) + val items_partitions2 = Array(identity("arrive_time"), bucket(2, "id")) + + Seq(items_partitions, items_partitions2).foreach { partition => + catalog.clearTables() + + createTable(items, itemsColumns, partition) + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " + + "(1, 'cc', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'dd', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'ee', 15.5, cast('2020-02-01' as timestamp)), " + + "(5, 'ff', 32.1, cast('2020-03-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 10.7, cast('2020-01-01' as timestamp))," + + "(3, 19.5, cast('2020-02-01' as timestamp))," + + "(4, 56.5, cast('2020-02-01' as timestamp))") + + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle side that does not report partitioning") + + checkAnswer(df, Seq( + Row(1, "bb", 30.0, 42.0), + Row(1, "aa", 40.0, 42.0), + Row(4, "ee", 15.5, 56.5))) + } + } + } + + test("SPARK-48012: one-side shuffle with partition transforms and pushdown values") { + val items_partitions = Array(bucket(2, "id"), identity("arrive_time")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'bb', 30.0, cast('2020-01-01' as timestamp)), " + + "(1, 'cc', 30.0, cast('2020-01-02' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(2, 10.7, cast('2020-01-01' as timestamp))") + + Seq(true, false).foreach { pushDown => { + withSQLConf( + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> + pushDown.toString) { + val df = createJoinTestDF(Seq("id" -> "item_id", "arrive_time" -> "time")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 1, "only shuffle side that does not report partitioning") + + checkAnswer(df, Seq( + Row(1, "bb", 30.0, 42.0), + Row(1, "aa", 40.0, 42.0))) + } + } + } + } + + test("SPARK-48012: one-side shuffle with partition transforms " + + "with fewer join keys than partition kes") { + val items_partitions = Array(bucket(2, "id"), identity("name")) + createTable(items, itemsColumns, items_partitions) + + sql(s"INSERT INTO testcat.ns.$items VALUES " + + "(1, 'aa', 40.0, cast('2020-01-01' as timestamp)), " + + "(1, 'aa', 30.0, cast('2020-01-02' as timestamp)), " + + "(3, 'bb', 10.0, cast('2020-01-01' as timestamp)), " + + "(4, 'cc', 15.5, cast('2020-02-01' as timestamp))") + + createTable(purchases, purchasesColumns, Array.empty) + sql(s"INSERT INTO testcat.ns.$purchases VALUES " + + "(1, 42.0, cast('2020-01-01' as timestamp)), " + + "(1, 89.0, cast('2020-01-03' as timestamp)), " + + "(3, 19.5, cast('2020-02-01' as timestamp)), " + + "(5, 26.0, cast('2023-01-01' as timestamp)), " + + "(6, 50.0, cast('2023-02-01' as timestamp))") + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_SHUFFLE_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true") { + val df = createJoinTestDF(Seq("id" -> "item_id")) + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.size == 2, "SPJ should not be triggered for transform expression with" + + "less join keys than partition keys for now.") + checkAnswer(df, Seq(Row(1, "aa", 30.0, 42.0), + Row(1, "aa", 30.0, 89.0), + Row(1, "aa", 40.0, 42.0), + Row(1, "aa", 40.0, 89.0), + Row(3, "bb", 10.0, 19.5))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 5cdb900901056..5364fc5d62423 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.sql.connector.catalog.functions -import java.sql.Timestamp +import java.time.{Instant, LocalDate, ZoneId} +import java.time.temporal.ChronoUnit import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -44,7 +46,13 @@ object YearsFunction extends ScalarFunction[Long] { override def name(): String = "years" override def canonicalName(): String = name() - def invoke(ts: Long): Long = new Timestamp(ts).getYear + 1900 + val UTC: ZoneId = ZoneId.of("UTC") + val EPOCH_LOCAL_DATE: LocalDate = Instant.EPOCH.atZone(UTC).toLocalDate + + def invoke(ts: Long): Long = { + val localDate = DateTimeUtils.microsToInstant(ts).atZone(UTC).toLocalDate + ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, localDate) + } } object DaysFunction extends BoundFunction { From 4d8a2237470226b52923b2dab4d6f3454af03b4f Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sun, 28 Apr 2024 12:55:57 -0700 Subject: [PATCH 2/5] Fix test cases --- .../scala/org/apache/spark/Partitioner.scala | 3 ++- .../WriteDistributionAndOrderingSuite.scala | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 9950d336074d2..357e71cdf4457 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -151,7 +151,8 @@ private[spark] class KeyGroupedPartitioner( override def getPartition(key: Any): Int = { val keys = key.asInstanceOf[Seq[Any]] val normalizedKeys = ArraySeq.from(keys) - valueMap.getOrElseUpdate(normalizedKeys, Utils.nonNegativeMod(keys.hashCode, numPartitions)) + valueMap.getOrElseUpdate(normalizedKeys, + Utils.nonNegativeMod(normalizedKeys.hashCode, numPartitions)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 12d5f13df01c7..e301cd8e3f359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.connector import java.sql.Date import java.util.Collections - -import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, catalyst} import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.{Column, Identifier} @@ -40,7 +39,7 @@ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} -import org.apache.spark.sql.types.{DateType, IntegerType, LongType, ObjectType, StringType, TimestampType} +import org.apache.spark.sql.types.{DataTypes, DateType, IntegerType, LongType, StringType, TimestampType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.tags.SlowSQLTest @@ -1128,13 +1127,16 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase Seq.empty ), catalyst.expressions.SortOrder( - Invoke( - Literal.create(YearsFunction, ObjectType(YearsFunction.getClass)), + StaticInvoke( + YearsFunction.getClass, + DataTypes.LongType, "invoke", - LongType, Seq(Cast(attr("day"), TimestampType, Some("America/Los_Angeles"))), Seq(TimestampType), - propagateNull = false), + propagateNull = false, + returnNullable = true, + isDeterministic = true, + Some(YearsFunction)), catalyst.expressions.Descending, catalyst.expressions.NullsFirst, Seq.empty From 98be0dfa8490faf6ead8adce9625eee43ac43c88 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sun, 28 Apr 2024 23:49:57 -0700 Subject: [PATCH 3/5] Scalafmt --- .../sql/catalyst/expressions/TransformExpression.scala | 9 --------- .../connector/WriteDistributionAndOrderingSuite.scala | 3 ++- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 98b5c641096fb..c388e7691dd58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -133,15 +133,6 @@ case class TransformExpression( } } - /** - * Returns Java source code that can be compiled to evaluate this expression. - * The default behavior is to call the eval method of the expression. Concrete expression - * implementations should override this to do actual code generation. - * - * @param ctx a [[CodegenContext]] - * @param ev an [[ExprCode]] with unique terms. - * @return an [[ExprCode]] containing the Java source code to generate the given expression - */ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw QueryExecutionErrors.cannotGenerateCodeForExpressionError(this) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index e301cd8e3f359..4d9c001aebb13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.connector import java.sql.Date import java.util.Collections -import org.apache.spark.sql.{AnalysisException, DataFrame, Row, catalyst} + +import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.plans.physical From 36a8a9269494693da5754570964155ecfda2920e Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 29 Apr 2024 00:12:43 -0700 Subject: [PATCH 4/5] Revert unncessary changes --- .../catalyst/expressions/V2ExpressionUtils.scala | 9 +-------- .../WriteDistributionAndOrderingSuite.scala | 15 ++++++--------- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala index 44a585da5134d..220920a5a3198 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/V2ExpressionUtils.scala @@ -161,7 +161,7 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { val declaredInputTypes = scalarFunc.inputTypes().toImmutableArraySeq val argClasses = declaredInputTypes.map(EncoderUtils.dataTypeJavaClass) findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match { - case Some(m) if isStatic(m) => + case Some(m) if Modifier.isStatic(m.getModifiers) => StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(), MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes, propagateNull = false, returnNullable = scalarFunc.isResultNullable, @@ -204,11 +204,4 @@ object V2ExpressionUtils extends SQLConfHelper with Logging { None } } - - private def isStatic(m: Method) = { - val javaStatic = Modifier.isStatic(m.getModifiers) - val scalaObjModule = m.getDeclaringClass.getField("MODULE$") - val scalaStatic = scalaObjModule != null && Modifier.isStatic(scalaObjModule.getModifiers) - javaStatic || scalaStatic - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala index 4d9c001aebb13..12d5f13df01c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/WriteDistributionAndOrderingSuite.scala @@ -23,7 +23,7 @@ import java.util.Collections import org.apache.spark.sql.{catalyst, AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{ApplyFunctionExpression, Cast, Literal} -import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.catalyst.plans.physical.{CoalescedBoundary, CoalescedHashPartitioning, HashPartitioning, RangePartitioning, UnknownPartitioning} import org.apache.spark.sql.connector.catalog.{Column, Identifier} @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{StreamingQueryException, Trigger} -import org.apache.spark.sql.types.{DataTypes, DateType, IntegerType, LongType, StringType, TimestampType} +import org.apache.spark.sql.types.{DateType, IntegerType, LongType, ObjectType, StringType, TimestampType} import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.tags.SlowSQLTest @@ -1128,16 +1128,13 @@ class WriteDistributionAndOrderingSuite extends DistributionAndOrderingSuiteBase Seq.empty ), catalyst.expressions.SortOrder( - StaticInvoke( - YearsFunction.getClass, - DataTypes.LongType, + Invoke( + Literal.create(YearsFunction, ObjectType(YearsFunction.getClass)), "invoke", + LongType, Seq(Cast(attr("day"), TimestampType, Some("America/Los_Angeles"))), Seq(TimestampType), - propagateNull = false, - returnNullable = true, - isDeterministic = true, - Some(YearsFunction)), + propagateNull = false), catalyst.expressions.Descending, catalyst.expressions.NullsFirst, Seq.empty From e7f0ec032e0b3e6c50998512b3f0549f43c7a754 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 7 Jun 2024 18:51:08 -0700 Subject: [PATCH 5/5] Review comments --- .../spark/sql/catalyst/expressions/TransformExpression.scala | 2 +- .../apache/spark/sql/catalyst/plans/physical/partitioning.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index c388e7691dd58..9041ed15fc501 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -117,7 +117,7 @@ case class TransformExpression( override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = copy(children = newChildren) - lazy val resolvedFunction: Option[Expression] = this match { + private lazy val resolvedFunction: Option[Expression] = this match { case TransformExpression(scalarFunc: ScalarFunction[_], arguments, Some(numBuckets)) => Some(V2ExpressionUtils.resolveScalarFunction(scalarFunc, Seq(Literal(numBuckets)) ++ arguments)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index a4756ac1fde74..19595eef10b34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -887,7 +887,7 @@ case class KeyGroupedShuffleSpec( override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map{ + val newExpressions: Seq[Expression] = clustering.zip(partitioning.expressions).map { case (c, e: TransformExpression) => TransformExpression( e.function, Seq(c), e.numBucketsOpt) case (c, _) => c