From b2f7cc02aecd572f7f2f1ea594beeebecc75768a Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Sat, 17 Feb 2024 00:03:28 -0800 Subject: [PATCH 01/14] [SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when they are not equal ### What changes were proposed in this pull request? -- Allow SPJ between 'compatible' bucket funtions -- Add a mechanism to define 'reducible' functions, one function whose output can be 'reduced' to another for all inputs. ### Why are the changes needed? -- SPJ currently applies only if the partition transform expressions on both sides are identifical. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new tests in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No --- .../connector/catalog/functions/Reducer.java | 16 + .../catalog/functions/ReducibleFunction.java | 25 ++ .../expressions/TransformExpression.scala | 27 +- .../plans/physical/partitioning.scala | 71 +++- .../apache/spark/sql/internal/SQLConf.scala | 15 + .../datasources/v2/BatchScanExec.scala | 21 +- .../exchange/EnsureRequirements.scala | 49 ++- .../KeyGroupedPartitioningSuite.scala | 312 ++++++++++++++++++ .../functions/transformFunctions.scala | 19 +- 9 files changed, 533 insertions(+), 22 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java new file mode 100644 index 000000000000..0b9ed7fb681e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -0,0 +1,16 @@ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; + +/** + * A 'reducer' for output of user-defined functions. + * + * A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x), + * if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x. + * @param function output type + * @since 4.0.0 + */ +@Evolving +public interface Reducer { + T reduce(T arg1); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java new file mode 100644 index 000000000000..39103d063f35 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -0,0 +1,25 @@ +package org.apache.spark.sql.connector.catalog.functions; + +import org.apache.spark.annotation.Evolving; +import scala.Option; + +/** + * Base class for user-defined functions that can be 'reduced' on another function. + * + * A function f_source(x) is 'reducible' on another function f_target(x) if + * there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. + * + * @since 4.0.0 + */ +@Evolving +public interface ReducibleFunction extends ScalarFunction { + + /** + * If this function is 'reducible' on another function, return the {@link Reducer} function. + * @param other other function + * @param thisArgument argument for this function instance + * @param otherArgument argument for other function instance + * @return a reduction function if it is reducible, none if not + */ + Option> reducer(ReducibleFunction other, Option thisArgument, Option otherArgument); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 8412de554b71..cc5810993a9f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.BoundFunction +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ReducibleFunction} import org.apache.spark.sql.types.DataType /** @@ -54,6 +54,31 @@ case class TransformExpression( false } + /** + * Whether this [[TransformExpression]]'s function is compatible with the `other` + * [[TransformExpression]]'s function. + * + * This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x) + * such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x. + * + * @param other the transform expression to compare to + * @return true if compatible, false if not + */ + def isCompatible(other: TransformExpression): Boolean = { + if (isSameFunction(other)) { + true + } else { + (function, other.function) match { + case (f: ReducibleFunction[Any, Any] @unchecked, + o: ReducibleFunction[Any, Any] @unchecked) => + val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt) + val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt) + reducer.isDefined || otherReducer.isDefined + case _ => false + } + } + } + override def dataType: DataType = function.resultType() override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index c98a2a92a3ab..3c1cc5e2e9e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.{Reducer, ReducibleFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType} @@ -635,6 +636,22 @@ trait ShuffleSpec { */ def createPartitioning(clustering: Seq[Expression]): Partitioning = throw SparkUnsupportedOperationException() + + /** + * Return a set of [[Reducer]] for the partition expressions of this shuffle spec, + * on the partition expressions of another shuffle spec. + *

+ * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is + * 'reducible' on the corresponding partition expression function of the other shuffle spec. + *

+ * If a value is returned, there must be one Option[[Reducer]] per partition expression. + * A None value in the set indicates that the particular partition expression is not reducible + * on the corresponding expression on the other shuffle spec. + *

+ * Returning none also indicates that none of the partition expressions can be reduced on the + * corresponding expression on the other shuffle spec. + */ + def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = None } case object SinglePartitionShuffleSpec extends ShuffleSpec { @@ -829,20 +846,60 @@ case class KeyGroupedShuffleSpec( } } + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && + // Only support partition expressions are AttributeReference for now + partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + } + + override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = { + other match { + case otherSpec: KeyGroupedShuffleSpec => + val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) + if e1.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] + && e2.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] => + e1.function.asInstanceOf[ReducibleFunction[Any, Any]].reducer( + e2.function.asInstanceOf[ReducibleFunction[Any, Any]], + e1.numBucketsOpt.map(a => a.asInstanceOf[Any]), + e2.numBucketsOpt.map(a => a.asInstanceOf[Any])) + case (_, _) => None + } + + // optimize to not return a value, if none of the partition expressions need reducing + if (results.forall(p => p.isEmpty)) None else Some(results) + case _ => None + } + } + private def isExpressionCompatible(left: Expression, right: Expression): Boolean = (left, right) match { case (_: LeafExpression, _: LeafExpression) => true case (left: TransformExpression, right: TransformExpression) => - left.isSameFunction(right) + if (SQLConf.get.v2BucketingPushPartValuesEnabled && + !SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled && + SQLConf.get.v2BucketingAllowCompatibleTransforms) { + left.isCompatible(right) + } else { + left.isSameFunction(right) + } case _ => false } +} - override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && - // Only support partition expressions are AttributeReference for now - partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) - - override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) +object KeyGroupedShuffleSpec { + def reducePartitionValue(row: InternalRow, + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[Any]]]): + InternalRowComparableWrapper = { + val partitionVals = row.toSeq(expressions.map(_.dataType)) + val reducedRow = partitionVals.zip(reducers).map{ + case (v, Some(reducer)) => reducer.reduce(v) + case (v, _) => v + }.toArray + InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 761b0ea72f3d..fd1fd8dce52a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1541,6 +1541,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS = + buildConf("spark.sql.sources.v2.bucketing.allow.enabled") + .doc("Whether to allow storage-partition join in the case where the partition transforms" + + "are compatible but not identical. This config requires both " + + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + + s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + + "to be disabled." + ) + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets") .doc("The maximum number of buckets allowed.") .version("2.4.0") @@ -5233,6 +5245,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean = getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS) + def v2BucketingAllowCompatibleTransforms: Boolean = + getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS) + def dataFrameSelfJoinAutoResolveAmbiguity: Boolean = getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 7cce59904018..3772d1f9f884 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition} +import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper} import org.apache.spark.sql.connector.catalog.Table +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.connector.read._ import org.apache.spark.util.ArrayImplicits._ @@ -164,6 +165,18 @@ case class BatchScanExec( (groupedParts, expressions) } + // Also re-group the partitions if we are reducing compatible partition expressions + val finalGroupedPartitions = spjParams.reducers match { + case Some(reducers) => + val result = groupedPartitions.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) + }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq + val rowOrdering = RowOrdering.createNaturalAscendingOrdering( + expressions.map(_.dataType)) + result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + case _ => groupedPartitions + } + // When partially clustered, the input partitions are not grouped by partition // values. Here we'll need to check `commonPartitionValues` and decide how to group // and replicate splits within a partition. @@ -174,7 +187,7 @@ case class BatchScanExec( .get .map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2)) .toMap - val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) => + val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) => // `commonPartValuesMap` should contain the part value since it's the super set. val numSplits = commonPartValuesMap .get(InternalRowComparableWrapper(partValue, partExpressions)) @@ -207,7 +220,7 @@ case class BatchScanExec( } else { // either `commonPartitionValues` is not defined, or it is defined but // `applyPartialClustering` is false. - val partitionMapping = groupedPartitions.map { case (partValue, splits) => + val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) => InternalRowComparableWrapper(partValue, partExpressions) -> splits }.toMap @@ -224,7 +237,6 @@ case class BatchScanExec( case _ => filteredPartitions } - new DataSourceRDD( sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) } @@ -259,6 +271,7 @@ case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + reducers: Option[Seq[Option[Reducer[Any]]]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 2a7c1206bb41..1c4c797861b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} @@ -505,11 +506,28 @@ case class EnsureRequirements( } } - // Now we need to push-down the common partition key to the scan in each child - newLeft = populatePartitionValues(left, mergedPartValues, leftSpec.joinKeyPositions, - applyPartialClustering, replicateLeftSide) - newRight = populatePartitionValues(right, mergedPartValues, rightSpec.joinKeyPositions, - applyPartialClustering, replicateRightSide) + // in case of compatible but not identical partition expressions, we apply 'reduce' + // transforms to group one side's partitions as well as the common partition values + val leftReducers = leftSpec.reducers(rightSpec) + val rightReducers = rightSpec.reducers(leftSpec) + + if (leftReducers.isDefined || rightReducers.isDefined) { + mergedPartValues = reduceCommonPartValues(mergedPartValues, + leftSpec.partitioning.expressions, + leftReducers) + mergedPartValues = reduceCommonPartValues(mergedPartValues, + rightSpec.partitioning.expressions, + rightReducers) + val rowOrdering = RowOrdering + .createNaturalAscendingOrdering(partitionExprs.map(_.dataType)) + mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) + } + + // Now we need to push-down the common partition information to the scan in each child + newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions, + leftReducers, applyPartialClustering, replicateLeftSide) + newRight = populateCommonPartitionInfo(right, mergedPartValues, rightSpec.joinKeyPositions, + rightReducers, applyPartialClustering, replicateRightSide) } } @@ -527,11 +545,12 @@ case class EnsureRequirements( joinType == LeftAnti || joinType == LeftOuter } - // Populate the common partition values down to the scan nodes - private def populatePartitionValues( + // Populate the common partition information down to the scan nodes + private def populateCommonPartitionInfo( plan: SparkPlan, values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], + reducers: Option[Seq[Option[Reducer[Any]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => @@ -539,13 +558,25 @@ case class EnsureRequirements( spjParams = scan.spjParams.copy( commonPartitionValues = Some(values), joinKeyPositions = joinKeyPositions, + reducers = reducers, applyPartialClustering = applyPartialClustering, replicatePartitions = replicatePartitions ) ) case node => - node.mapChildren(child => populatePartitionValues( - child, values, joinKeyPositions, applyPartialClustering, replicatePartitions)) + node.mapChildren(child => populateCommonPartitionInfo( + child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) + } + + private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[Any]]]]) = { + reducers match { + case Some(reducers) => commonPartValues.groupBy { case (row, _) => + KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) + }.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq + case _ => commonPartValues + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 7fdc703007c2..638247306280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -63,11 +63,17 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Collections.emptyMap[String, String] } private val table: String = "tbl" + private val columns: Array[Column] = Array( Column.create("id", IntegerType), Column.create("data", StringType), Column.create("ts", TimestampType)) + private val columns2: Array[Column] = Array( + Column.create("store_id", IntegerType), + Column.create("dept_id", IntegerType), + Column.create("data", StringType)) + test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) @@ -1309,6 +1315,312 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-47094: Support compatible buckets") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((2, 4), (4, 2)), + ((4, 2), (2, 4)), + ((2, 2), (4, 6)), + ((6, 2), (2, 2))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, columns2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + val expectedBuckets = Math.min(table1buckets1, table2buckets1) * + Math.min(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp"), + )) + } + } + } + } + + test("SPARK-47094: Support compatible buckets with less join keys than partition keys") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq((2, 4), (4, 2), (2, 6), (6, 2)).foreach { + case (table1buckets, table2buckets) => + catalog.clearTables() + + val partition1 = Array(bucket(3, "store_id"), + bucket(table1buckets, "dept_id")) + val partition2 = Array(bucket(3, "store_id"), + bucket(table2buckets, "dept_id")) + + createTable(table1, columns2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 0, 'ab'), " + + "(2, 1, 'ac'), " + + "(3, 2, 'ad'), " + + "(4, 3, 'ae'), " + + "(5, 4, 'af'), " + + "(6, 5, 'ag'), " + + + // value without other side match + "(6, 6, 'xx')" + ) + + createTable(table2, columns2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(6, 0, '01'), " + + "(5, 1, '02'), " + // duplicate partition key + "(5, 1, '03'), " + + "(4, 2, '04'), " + + "(3, 3, '05'), " + + "(2, 4, '06'), " + + "(1, 5, '07'), " + + + // value without other side match + "(7, 7, '99')" + ) + + + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t2.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + val expectedBuckets = Math.min(table1buckets, table2buckets) + + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 6, 0, 0, "aa", "01"), + Row(1, 6, 0, 0, "ab", "01"), + Row(2, 5, 1, 1, "ac", "02"), + Row(2, 5, 1, 1, "ac", "03"), + Row(3, 4, 2, 2, "ad", "04"), + Row(4, 3, 3, 3, "ae", "05"), + Row(5, 2, 4, 4, "af", "06"), + Row(6, 1, 5, 5, "ag", "07"), + )) + } + } + } + + test("SPARK-47094: Compatible buckets does not support SPJ with " + + "push-down values or partially-clustered") { + val table1 = "tab1e1" + val table2 = "table2" + + val partition1 = Array(bucket(4, "store_id"), + bucket(2, "dept_id")) + val partition2 = Array(bucket(2, "store_id"), + bucket(2, "dept_id")) + + createTable(table1, columns2, partition1) + sql(s"INSERT INTO testcat.ns.$table1 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + createTable(table2, columns2, partition2) + sql(s"INSERT INTO testcat.ns.$table2 VALUES " + + "(0, 0, 'aa'), " + + "(1, 1, 'bb'), " + + "(2, 2, 'cc')" + ) + + Seq(true, false).foreach{ allowPushDown => + Seq(true, false).foreach{ partiallyClustered => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> allowPushDown.toString, + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> + partiallyClustered.toString, + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> "true", + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.store_id, t1.dept_id, t2.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + (allowPushDown, partiallyClustered) match { + case (true, false) => + assert(shuffles.isEmpty, "SPJ should be triggered") + assert(scans == Seq(2, 2)) + case (_, _) => + assert(shuffles.nonEmpty, "SPJ should not be triggered") + assert(scans == Seq(3, 2)) + } + + checkAnswer(df, Seq( + Row(0, 0, 0, 0, "aa", "aa"), + Row(1, 1, 1, 1, "bb", "bb"), + Row(2, 2, 2, 2, "cc", "cc") + )) + } + } + } + } + test("SPARK-44647: test join key is the second cluster key") { val table1 = "tab1e1" val table2 = "table2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 61895d49c4a2..67da85480ef9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -76,7 +76,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -object BucketFunction extends ScalarFunction[Int] { +object BucketFunction extends ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType override def name(): String = "bucket" @@ -85,6 +85,23 @@ object BucketFunction extends ScalarFunction[Int] { override def produceResult(input: InternalRow): Int = { (input.getLong(1) % input.getInt(0)).toInt } + + override def reducer(func: ReducibleFunction[_, _], + thisNumBuckets: Option[_], + otherNumBuckets: Option[_]): Option[Reducer[Int]] = { + (thisNumBuckets, otherNumBuckets) match { + case (Some(thisNumBucketsVal: Int), Some(otherNumBucketsVal: Int)) + if func.isInstanceOf[ReducibleFunction[_, _]] && + ((thisNumBucketsVal > otherNumBucketsVal) && + (thisNumBucketsVal % otherNumBucketsVal == 0)) => + Some(BucketReducer(thisNumBucketsVal, otherNumBucketsVal)) + case _ => None + } + } +} + +case class BucketReducer(thisNumBuckets: Int, otherNumBuckets: Int) extends Reducer[Int] { + override def reduce(bucket: Int): Int = bucket % otherNumBuckets } object UnboundStringSelfFunction extends UnboundFunction { From cda8b26e7205fd1a55d3cde96fd2b2d036916ae4 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 26 Feb 2024 17:58:11 -0800 Subject: [PATCH 02/14] Fix scalastyle and address review comment --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../spark/sql/execution/datasources/v2/BatchScanExec.scala | 1 + .../spark/sql/connector/KeyGroupedPartitioningSuite.scala | 4 ++-- .../sql/connector/catalog/functions/transformFunctions.scala | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index fd1fd8dce52a..0f8c537cfd29 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1542,7 +1542,7 @@ object SQLConf { .createWithDefault(false) val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS = - buildConf("spark.sql.sources.v2.bucketing.allow.enabled") + buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled") .doc("Whether to allow storage-partition join in the case where the partition transforms" + "are compatible but not identical. This config requires both " + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 3772d1f9f884..6bf2f542d53d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -237,6 +237,7 @@ case class BatchScanExec( case _ => filteredPartitions } + new DataSourceRDD( sparkContext, finalPartitions, readerFactory, supportsColumnar, customMetrics) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 638247306280..9639540d8549 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1468,7 +1468,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(5, 2, "bm", "bm"), Row(5, 3, "bn", "bn"), Row(5, 4, "bo", "bo"), - Row(5, 5, "bp", "bp"), + Row(5, 5, "bp", "bp") )) } } @@ -1550,7 +1550,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Row(3, 4, 2, 2, "ad", "04"), Row(4, 3, 3, 3, "ae", "05"), Row(5, 2, 4, 4, "af", "06"), - Row(6, 1, 5, 5, "ag", "07"), + Row(6, 1, 5, 5, "ag", "07") )) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 67da85480ef9..823177cf466a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -91,7 +91,7 @@ object BucketFunction extends ReducibleFunction[Int, Int] { otherNumBuckets: Option[_]): Option[Reducer[Int]] = { (thisNumBuckets, otherNumBuckets) match { case (Some(thisNumBucketsVal: Int), Some(otherNumBucketsVal: Int)) - if func.isInstanceOf[ReducibleFunction[_, _]] && + if func == BucketFunction && ((thisNumBucketsVal > otherNumBucketsVal) && (thisNumBucketsVal % otherNumBucketsVal == 0)) => Some(BucketReducer(thisNumBucketsVal, otherNumBucketsVal)) From eee919813b59f09f2634e278db7e3ab46a6fa6b3 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 27 Feb 2024 10:09:09 -0800 Subject: [PATCH 03/14] Fix bug of using un-projected expression in sorting --- .../execution/datasources/v2/BatchScanExec.scala | 2 +- .../sql/connector/KeyGroupedPartitioningSuite.scala | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 6bf2f542d53d..955d13c8f9be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -172,7 +172,7 @@ case class BatchScanExec( KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers) }.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq val rowOrdering = RowOrdering.createNaturalAscendingOrdering( - expressions.map(_.dataType)) + partExpressions.map(_.dataType)) result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1)) case _ => groupedPartitions } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 9639540d8549..532ea875d595 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -63,6 +63,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Collections.emptyMap[String, String] } private val table: String = "tbl" +<<<<<<< HEAD private val columns: Array[Column] = Array( Column.create("id", IntegerType), @@ -73,6 +74,16 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Column.create("store_id", IntegerType), Column.create("dept_id", IntegerType), Column.create("data", StringType)) +======= + private val schema = new StructType() + .add("id", IntegerType) + .add("data", StringType) + .add("ts", TimestampType) + private val schema2 = new StructType() + .add("store_id", LongType) + .add("dept_id", IntegerType) + .add("data", StringType) +>>>>>>> 0016169c60a (Fix bug of using un-projected expression in sorting) test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) @@ -1483,7 +1494,7 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { case (table1buckets, table2buckets) => catalog.clearTables() - val partition1 = Array(bucket(3, "store_id"), + val partition1 = Array(identity("data"), bucket(table1buckets, "dept_id")) val partition2 = Array(bucket(3, "store_id"), bucket(table2buckets, "dept_id")) From 7bf16105e5663f8104b334eb059c911cb8f9dc31 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 27 Feb 2024 23:32:22 -0800 Subject: [PATCH 04/14] Add licenses --- .../sql/connector/catalog/functions/Reducer.java | 16 ++++++++++++++++ .../catalog/functions/ReducibleFunction.java | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java index 0b9ed7fb681e..f0f9f4ac7fb7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 39103d063f35..3d3f0edeadd0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -1,3 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; From 18db83a565049e5f6c32da31337226e2857408b0 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Wed, 28 Feb 2024 01:17:09 -0800 Subject: [PATCH 05/14] Fix compiler warning --- .../catalog/functions/ReducibleFunction.java | 3 ++- .../expressions/TransformExpression.scala | 3 +-- .../plans/physical/partitioning.scala | 19 +++++++++---------- .../datasources/v2/BatchScanExec.scala | 2 +- .../exchange/EnsureRequirements.scala | 4 ++-- 5 files changed, 15 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 3d3f0edeadd0..240abdd468c0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -37,5 +37,6 @@ public interface ReducibleFunction extends ScalarFunction { * @param otherArgument argument for other function instance * @return a reduction function if it is reducible, none if not */ - Option> reducer(ReducibleFunction other, Option thisArgument, Option otherArgument); + Option> reducer(ReducibleFunction other, Option thisArgument, + Option otherArgument); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index cc5810993a9f..9b53ce54b456 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -69,8 +69,7 @@ case class TransformExpression( true } else { (function, other.function) match { - case (f: ReducibleFunction[Any, Any] @unchecked, - o: ReducibleFunction[Any, Any] @unchecked) => + case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt) val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt) reducer.isDefined || otherReducer.isDefined diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 3c1cc5e2e9e9..cb505779ece8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -651,7 +651,7 @@ trait ShuffleSpec { * Returning none also indicates that none of the partition expressions can be reduced on the * corresponding expression on the other shuffle spec. */ - def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = None + def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[_]]]] = None } case object SinglePartitionShuffleSpec extends ShuffleSpec { @@ -854,17 +854,16 @@ case class KeyGroupedShuffleSpec( KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) } - override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[Any]]]] = { + override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_]]]] = { other match { case otherSpec: KeyGroupedShuffleSpec => val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { case (e1: TransformExpression, e2: TransformExpression) - if e1.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] - && e2.function.isInstanceOf[ReducibleFunction[Any, Any]@unchecked] => - e1.function.asInstanceOf[ReducibleFunction[Any, Any]].reducer( - e2.function.asInstanceOf[ReducibleFunction[Any, Any]], - e1.numBucketsOpt.map(a => a.asInstanceOf[Any]), - e2.numBucketsOpt.map(a => a.asInstanceOf[Any])) + if e1.function.isInstanceOf[ReducibleFunction[_, _]] + && e2.function.isInstanceOf[ReducibleFunction[_, _]] => + e1.function.asInstanceOf[ReducibleFunction[_, _]].reducer( + e2.function.asInstanceOf[ReducibleFunction[_, _]], + e1.numBucketsOpt, e2.numBucketsOpt) case (_, _) => None } @@ -892,11 +891,11 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue(row: InternalRow, expressions: Seq[Expression], - reducers: Seq[Option[Reducer[Any]]]): + reducers: Seq[Option[Reducer[_]]]): InternalRowComparableWrapper = { val partitionVals = row.toSeq(expressions.map(_.dataType)) val reducedRow = partitionVals.zip(reducers).map{ - case (v, Some(reducer)) => reducer.reduce(v) + case (v, Some(reducer: Reducer[Any])) => reducer.reduce(v) case (v, _) => v }.toArray InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 955d13c8f9be..43c2d299ad30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -272,7 +272,7 @@ case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - reducers: Option[Seq[Option[Reducer[Any]]]] = None, + reducers: Option[Seq[Option[Reducer[_]]]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 1c4c797861b2..2cbc8143f565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -550,7 +550,7 @@ case class EnsureRequirements( plan: SparkPlan, values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], - reducers: Option[Seq[Option[Reducer[Any]]]], + reducers: Option[Seq[Option[Reducer[_]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => @@ -570,7 +570,7 @@ case class EnsureRequirements( private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[Any]]]]) = { + reducers: Option[Seq[Option[Reducer[_]]]]) = { reducers match { case Some(reducers) => commonPartValues.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) From 3fdb0d74d27f74248743131a0857e30d866d6ae6 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 12 Mar 2024 22:22:49 -0400 Subject: [PATCH 06/14] Review comments --- .../connector/catalog/functions/Reducer.java | 7 +- .../catalog/functions/ReducibleFunction.java | 42 +++++++-- .../expressions/TransformExpression.scala | 19 ++++- .../plans/physical/partitioning.scala | 85 +++++++++---------- .../datasources/v2/BatchScanExec.scala | 2 +- .../exchange/EnsureRequirements.scala | 4 +- .../KeyGroupedPartitioningSuite.scala | 11 --- .../functions/transformFunctions.scala | 6 +- 8 files changed, 103 insertions(+), 73 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java index f0f9f4ac7fb7..5f3c3385a943 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -23,10 +23,11 @@ * * A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x), * if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x. - * @param function output type + * @param reducer input type + * @param reducer output type * @since 4.0.0 */ @Evolving -public interface Reducer { - T reduce(T arg1); +public interface Reducer { + O reduce(I arg1); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 240abdd468c0..5df2d4c9f719 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -25,18 +25,48 @@ * A function f_source(x) is 'reducible' on another function f_target(x) if * there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. * + *

+ * Examples: + *

    + *
  • Bucket functions + *
      + *
    • f_source(x) = bucket(4, x)
    • + *
    • f_target(x) = bucket(2, x)
    • + *
    • r(x) = x / 2
    • + *
    + *
  • Date functions
  • + *
      + *
    • f_source(x) = days(x)
    • + *
    • f_target(x) = hours(x)
    • + *
    • r(x) = x / 24
    • + *
    + *
+ * @param reducer function input type + * @param reducer function output type * @since 4.0.0 */ @Evolving -public interface ReducibleFunction extends ScalarFunction { +public interface ReducibleFunction { /** * If this function is 'reducible' on another function, return the {@link Reducer} function. - * @param other other function - * @param thisArgument argument for this function instance - * @param otherArgument argument for other function instance + *

+ * Example: + *

    + *
  • this_function = bucket(4, x) + *
  • other function = bucket(2, x) + *
+ * Invoke with arguments + *
    + *
  • other = bucket
  • + *
  • this param = Int(4)
  • + *
  • other param = Int(2)
  • + *
+ * @param other the other function + * @param thisParam param for this function + * @param otherParam param for the other function * @return a reduction function if it is reducible, none if not */ - Option> reducer(ReducibleFunction other, Option thisArgument, - Option otherArgument); + Option> reducer(ReducibleFunction other, Option thisParam, + Option otherParam); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 9b53ce54b456..5dc048f5dafb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, ReducibleFunction} +import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction} import org.apache.spark.sql.types.DataType /** @@ -78,6 +78,23 @@ case class TransformExpression( } } + /** + * Return a [[Reducer]] for this transform expression on another + * on the transform expression. + *

+ * A [[Reducer]] exists for a transform expression function if it is + * 'reducible' on the other expression function. + *

+ * @return reducer function or None if not reducible on the other transform expression + */ + def reducers(other: TransformExpression): Option[Reducer[_, _]] = { + (function, other.function) match { + case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => + e1.reducer(e2, numBucketsOpt, other.numBucketsOpt) + case _ => None + } + } + override def dataType: DataType = function.resultType() override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index cb505779ece8..0bdb1fde67b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -24,7 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper -import org.apache.spark.sql.connector.catalog.functions.{Reducer, ReducibleFunction} +import org.apache.spark.sql.connector.catalog.functions.Reducer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType} @@ -636,22 +636,6 @@ trait ShuffleSpec { */ def createPartitioning(clustering: Seq[Expression]): Partitioning = throw SparkUnsupportedOperationException() - - /** - * Return a set of [[Reducer]] for the partition expressions of this shuffle spec, - * on the partition expressions of another shuffle spec. - *

- * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is - * 'reducible' on the corresponding partition expression function of the other shuffle spec. - *

- * If a value is returned, there must be one Option[[Reducer]] per partition expression. - * A None value in the set indicates that the particular partition expression is not reducible - * on the corresponding expression on the other shuffle spec. - *

- * Returning none also indicates that none of the partition expressions can be reduced on the - * corresponding expression on the other shuffle spec. - */ - def reducers(spec: ShuffleSpec): Option[Seq[Option[Reducer[_]]]] = None } case object SinglePartitionShuffleSpec extends ShuffleSpec { @@ -846,33 +830,6 @@ case class KeyGroupedShuffleSpec( } } - override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && - // Only support partition expressions are AttributeReference for now - partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) - - override def createPartitioning(clustering: Seq[Expression]): Partitioning = { - KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) - } - - override def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_]]]] = { - other match { - case otherSpec: KeyGroupedShuffleSpec => - val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { - case (e1: TransformExpression, e2: TransformExpression) - if e1.function.isInstanceOf[ReducibleFunction[_, _]] - && e2.function.isInstanceOf[ReducibleFunction[_, _]] => - e1.function.asInstanceOf[ReducibleFunction[_, _]].reducer( - e2.function.asInstanceOf[ReducibleFunction[_, _]], - e1.numBucketsOpt, e2.numBucketsOpt) - case (_, _) => None - } - - // optimize to not return a value, if none of the partition expressions need reducing - if (results.forall(p => p.isEmpty)) None else Some(results) - case _ => None - } - } - private def isExpressionCompatible(left: Expression, right: Expression): Boolean = (left, right) match { case (_: LeafExpression, _: LeafExpression) => true @@ -886,16 +843,52 @@ case class KeyGroupedShuffleSpec( } case _ => false } + + /** + * Return a set of [[Reducer]] for the partition expressions of this shuffle spec, + * on the partition expressions of another shuffle spec. + *

+ * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is + * 'reducible' on the corresponding partition expression function of the other shuffle spec. + *

+ * If a value is returned, there must be one Option[[Reducer]] per partition expression. + * A None value in the set indicates that the particular partition expression is not reducible + * on the corresponding expression on the other shuffle spec. + *

+ * Returning none also indicates that none of the partition expressions can be reduced on the + * corresponding expression on the other shuffle spec. + */ + def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { + other match { + case otherSpec: KeyGroupedShuffleSpec => + val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) + case (_, _) => None + } + + // optimize to not return a value, if none of the partition expressions are reducible + if (results.forall(p => p.isEmpty)) None else Some(results) + case _ => None + } + } + + override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && + // Only support partition expressions are AttributeReference for now + partitioning.expressions.forall(_.isInstanceOf[AttributeReference]) + + override def createPartitioning(clustering: Seq[Expression]): Partitioning = { + KeyGroupedPartitioning(clustering, partitioning.numPartitions, partitioning.partitionValues) + } } object KeyGroupedShuffleSpec { def reducePartitionValue(row: InternalRow, expressions: Seq[Expression], - reducers: Seq[Option[Reducer[_]]]): + reducers: Seq[Option[Reducer[_, _]]]): InternalRowComparableWrapper = { val partitionVals = row.toSeq(expressions.map(_.dataType)) val reducedRow = partitionVals.zip(reducers).map{ - case (v, Some(reducer: Reducer[Any])) => reducer.reduce(v) + case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v) case (v, _) => v }.toArray InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index 43c2d299ad30..f949dbf71a37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -272,7 +272,7 @@ case class StoragePartitionJoinParams( keyGroupedPartitioning: Option[Seq[Expression]] = None, joinKeyPositions: Option[Seq[Int]] = None, commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - reducers: Option[Seq[Option[Reducer[_]]]] = None, + reducers: Option[Seq[Option[Reducer[_, _]]]] = None, applyPartialClustering: Boolean = false, replicatePartitions: Boolean = false) { override def equals(other: Any): Boolean = other match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 2cbc8143f565..b34990e1b716 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -550,7 +550,7 @@ case class EnsureRequirements( plan: SparkPlan, values: Seq[(InternalRow, Int)], joinKeyPositions: Option[Seq[Int]], - reducers: Option[Seq[Option[Reducer[_]]]], + reducers: Option[Seq[Option[Reducer[_, _]]]], applyPartialClustering: Boolean, replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => @@ -570,7 +570,7 @@ case class EnsureRequirements( private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[_]]]]) = { + reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { case Some(reducers) => commonPartValues.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 532ea875d595..403081b66551 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -63,7 +63,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Collections.emptyMap[String, String] } private val table: String = "tbl" -<<<<<<< HEAD private val columns: Array[Column] = Array( Column.create("id", IntegerType), @@ -74,16 +73,6 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { Column.create("store_id", IntegerType), Column.create("dept_id", IntegerType), Column.create("data", StringType)) -======= - private val schema = new StructType() - .add("id", IntegerType) - .add("data", StringType) - .add("ts", TimestampType) - private val schema2 = new StructType() - .add("store_id", LongType) - .add("dept_id", IntegerType) - .add("data", StringType) ->>>>>>> 0016169c60a (Fix bug of using un-projected expression in sorting) test("clustered distribution: output partitioning should be KeyGroupedPartitioning") { val partitions: Array[Transform] = Array(Expressions.years("ts")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 823177cf466a..7a77c00b577f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -76,7 +76,7 @@ object UnboundBucketFunction extends UnboundFunction { override def name(): String = "bucket" } -object BucketFunction extends ReducibleFunction[Int, Int] { +object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, Int] { override def inputTypes(): Array[DataType] = Array(IntegerType, LongType) override def resultType(): DataType = IntegerType override def name(): String = "bucket" @@ -88,7 +88,7 @@ object BucketFunction extends ReducibleFunction[Int, Int] { override def reducer(func: ReducibleFunction[_, _], thisNumBuckets: Option[_], - otherNumBuckets: Option[_]): Option[Reducer[Int]] = { + otherNumBuckets: Option[_]): Option[Reducer[Int, Int]] = { (thisNumBuckets, otherNumBuckets) match { case (Some(thisNumBucketsVal: Int), Some(otherNumBucketsVal: Int)) if func == BucketFunction && @@ -100,7 +100,7 @@ object BucketFunction extends ReducibleFunction[Int, Int] { } } -case class BucketReducer(thisNumBuckets: Int, otherNumBuckets: Int) extends Reducer[Int] { +case class BucketReducer(thisNumBuckets: Int, otherNumBuckets: Int) extends Reducer[Int, Int] { override def reduce(bucket: Int): Int = bucket % otherNumBuckets } From 23c580fa4cca37ed37d001f3eabe15893448d50a Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Thu, 14 Mar 2024 00:29:26 -0400 Subject: [PATCH 07/14] Try to fix doc --- .../sql/connector/catalog/functions/ReducibleFunction.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 5df2d4c9f719..6d57909e1b98 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -33,8 +33,8 @@ *

  • f_source(x) = bucket(4, x)
  • *
  • f_target(x) = bucket(2, x)
  • *
  • r(x) = x / 2
  • - * - *
  • Date functions
  • + * + *
  • Date functions *
      *
    • f_source(x) = days(x)
    • *
    • f_target(x) = hours(x)
    • From 0c6f494aae1f83318fc4023f2463fea768187309 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 18 Mar 2024 18:12:10 -0700 Subject: [PATCH 08/14] Second round review comments --- .../connector/catalog/functions/Reducer.java | 13 +- .../catalog/functions/ReducibleFunction.java | 67 ++++++-- .../expressions/TransformExpression.scala | 12 +- .../plans/physical/partitioning.scala | 26 ++- .../exchange/EnsureRequirements.scala | 4 +- .../KeyGroupedPartitioningSuite.scala | 162 ++++++++++++++++++ .../functions/transformFunctions.scala | 32 ++-- 7 files changed, 263 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java index 5f3c3385a943..af742fe8cb24 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -21,13 +21,20 @@ /** * A 'reducer' for output of user-defined functions. * - * A user_defined function f_source(x) is 'reducible' on another user_defined function f_target(x), - * if there exists a 'reducer' r(x) such that r(f_source(x)) = f_target(x) for all input x. + * @see ReducibleFunction + * + * A user defined function f_source(x) is 'reducible' on another user_defined function f_target(x) if + *
        + *
      • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
      • + *
      • More generally, there exists two reducer functions r1(x) and r2(x) such that + * r1(f_source(x)) = r2(f_target(x)) for all input x.
      • + *
      + * * @param reducer input type * @param reducer output type * @since 4.0.0 */ @Evolving public interface Reducer { - O reduce(I arg1); + O reduce(I arg); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 6d57909e1b98..9d2215c1167c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -17,23 +17,34 @@ package org.apache.spark.sql.connector.catalog.functions; import org.apache.spark.annotation.Evolving; -import scala.Option; /** * Base class for user-defined functions that can be 'reduced' on another function. * * A function f_source(x) is 'reducible' on another function f_target(x) if - * there exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x. - * + *
        + *
      • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
      • + *
      • More generally, there exists two reducer functions r1(x) and r2(x) such that + * r1(f_source(x)) = r2(f_target(x)) for all input x.
      • + *
      *

      * Examples: *

        - *
      • Bucket functions + *
      • Bucket functions where one side has reducer *
          *
        • f_source(x) = bucket(4, x)
        • *
        • f_target(x) = bucket(2, x)
        • - *
        • r(x) = x / 2
        • + *
        • r(x) = x % 2
        • *
        + * + *
      • Bucket functions where both sides have reducer + *
          + *
        • f_source(x) = bucket(16, x)
        • + *
        • f_target(x) = bucket(12, x)
        • + *
        • r1(x) = x % 4
        • + *
        • r2(x) = x % 4
        • + *
        + * *
      • Date functions *
          *
        • f_source(x) = days(x)
        • @@ -49,24 +60,42 @@ public interface ReducibleFunction { /** - * If this function is 'reducible' on another function, return the {@link Reducer} function. + * This method is for bucket functions. + * + * If this bucket function is 'reducible' on another bucket function, return the {@link Reducer} function. *

          - * Example: + * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) *

            - *
          • this_function = bucket(4, x) - *
          • other function = bucket(2, x) + *
          • thisFunction = bucket
          • + *
          • otherFunction = bucket
          • + *
          • thisNumBuckets = Int(4)
          • + *
          • otherNumBuckets = Int(2)
          • *
          - * Invoke with arguments + * + * @param otherFunction the other bucket function + * @param thisNumBuckets number of buckets for this bucket function + * @param otherNumBuckets number of buckets for the other bucket function + * @return a reduction function if it is reducible, null if not + */ + default Reducer reducer(ReducibleFunction otherFunction, int thisNumBuckets, int otherNumBuckets) { + return reducer(otherFunction); + } + + /** + * This method is for all other functions. + * + * If this function is 'reducible' on another function, return the {@link Reducer} function. + *

          + * Example of reducing f_source = days(x) on f_target = hours(x) *

            - *
          • other = bucket
          • - *
          • this param = Int(4)
          • - *
          • other param = Int(2)
          • + *
          • thisFunction = days
          • + *
          • otherFunction = hours
          • *
          - * @param other the other function - * @param thisParam param for this function - * @param otherParam param for the other function - * @return a reduction function if it is reducible, none if not + * + * @param otherFunction the other function + * @return a reduction function if it is reducible, null if not. */ - Option> reducer(ReducibleFunction other, Option thisParam, - Option otherParam); + default Reducer reducer(ReducibleFunction otherFunction) { + return reducer(otherFunction, 0, 0); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 5dc048f5dafb..eff0a0ddfe71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -70,9 +70,10 @@ case class TransformExpression( } else { (function, other.function) match { case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => - val reducer = f.reducer(o, numBucketsOpt, other.numBucketsOpt) - val otherReducer = o.reducer(f, other.numBucketsOpt, numBucketsOpt) - reducer.isDefined || otherReducer.isDefined + val reducer = f.reducer(o, numBucketsOpt.getOrElse(0), other.numBucketsOpt.getOrElse(0)) + val otherReducer = + o.reducer(f, other.numBucketsOpt.getOrElse(0), numBucketsOpt.getOrElse(0)) + reducer != null || otherReducer != null case _ => false } } @@ -90,7 +91,10 @@ case class TransformExpression( def reducers(other: TransformExpression): Option[Reducer[_, _]] = { (function, other.function) match { case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => - e1.reducer(e2, numBucketsOpt, other.numBucketsOpt) + val reducer = e1.reducer(e2, + numBucketsOpt.getOrElse(0), + other.numBucketsOpt.getOrElse(0)) + Option(reducer) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0bdb1fde67b2..33ea5ad52cd5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -851,25 +851,23 @@ case class KeyGroupedShuffleSpec( * A [[Reducer]] exists for a partition expression function of this shuffle spec if it is * 'reducible' on the corresponding partition expression function of the other shuffle spec. *

          - * If a value is returned, there must be one Option[[Reducer]] per partition expression. + * If a value is returned, there must be one [[Reducer]] per partition expression. * A None value in the set indicates that the particular partition expression is not reducible * on the corresponding expression on the other shuffle spec. *

          * Returning none also indicates that none of the partition expressions can be reduced on the * corresponding expression on the other shuffle spec. + * + * @param other other key-grouped shuffle spec */ - def reducers(other: ShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { - other match { - case otherSpec: KeyGroupedShuffleSpec => - val results = partitioning.expressions.zip(otherSpec.partitioning.expressions).map { - case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) - case (_, _) => None - } + def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = { + val results = partitioning.expressions.zip(other.partitioning.expressions).map { + case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2) + case (_, _) => None + } - // optimize to not return a value, if none of the partition expressions are reducible - if (results.forall(p => p.isEmpty)) None else Some(results) - case _ => None - } + // optimize to not return a value, if none of the partition expressions are reducible + if (results.forall(p => p.isEmpty)) None else Some(results) } override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled && @@ -883,8 +881,8 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue(row: InternalRow, - expressions: Seq[Expression], - reducers: Seq[Option[Reducer[_, _]]]): + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[_, _]]]): InternalRowComparableWrapper = { val partitionVals = row.toSeq(expressions.map(_.dataType)) val reducedRow = partitionVals.zip(reducers).map{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index b34990e1b716..7ff682178ad2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -569,8 +569,8 @@ case class EnsureRequirements( } private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], - expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[_, _]]]]) = { + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { case Some(reducers) => commonPartValues.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala index 403081b66551..ec275fe101fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/KeyGroupedPartitioningSuite.scala @@ -1475,6 +1475,168 @@ class KeyGroupedPartitioningSuite extends DistributionAndOrderingSuiteBase { } } + test("SPARK-47094: Support compatible buckets with common divisor") { + val table1 = "tab1e1" + val table2 = "table2" + + Seq( + ((6, 4), (4, 6)), + ((6, 6), (4, 4)), + ((4, 4), (6, 6)), + ((4, 6), (6, 4))).foreach { + case ((table1buckets1, table1buckets2), (table2buckets1, table2buckets2)) => + catalog.clearTables() + + val partition1 = Array(bucket(table1buckets1, "store_id"), + bucket(table1buckets2, "dept_id")) + val partition2 = Array(bucket(table2buckets1, "store_id"), + bucket(table2buckets2, "dept_id")) + + Seq((table1, partition1), (table2, partition2)).foreach { case (tab, part) => + createTable(tab, columns2, part) + val insertStr = s"INSERT INTO testcat.ns.$tab VALUES " + + "(0, 0, 'aa'), " + + "(0, 0, 'ab'), " + // duplicate partition key + "(0, 1, 'ac'), " + + "(0, 2, 'ad'), " + + "(0, 3, 'ae'), " + + "(0, 4, 'af'), " + + "(0, 5, 'ag'), " + + "(1, 0, 'ah'), " + + "(1, 0, 'ai'), " + // duplicate partition key + "(1, 1, 'aj'), " + + "(1, 2, 'ak'), " + + "(1, 3, 'al'), " + + "(1, 4, 'am'), " + + "(1, 5, 'an'), " + + "(2, 0, 'ao'), " + + "(2, 0, 'ap'), " + // duplicate partition key + "(2, 1, 'aq'), " + + "(2, 2, 'ar'), " + + "(2, 3, 'as'), " + + "(2, 4, 'at'), " + + "(2, 5, 'au'), " + + "(3, 0, 'av'), " + + "(3, 0, 'aw'), " + // duplicate partition key + "(3, 1, 'ax'), " + + "(3, 2, 'ay'), " + + "(3, 3, 'az'), " + + "(3, 4, 'ba'), " + + "(3, 5, 'bb'), " + + "(4, 0, 'bc'), " + + "(4, 0, 'bd'), " + // duplicate partition key + "(4, 1, 'be'), " + + "(4, 2, 'bf'), " + + "(4, 3, 'bg'), " + + "(4, 4, 'bh'), " + + "(4, 5, 'bi'), " + + "(5, 0, 'bj'), " + + "(5, 0, 'bk'), " + // duplicate partition key + "(5, 1, 'bl'), " + + "(5, 2, 'bm'), " + + "(5, 3, 'bn'), " + + "(5, 4, 'bo'), " + + "(5, 5, 'bp')" + + // additional unmatched partitions to test push down + val finalStr = if (tab == table1) { + insertStr ++ ", (8, 0, 'xa'), (8, 8, 'xx')" + } else { + insertStr ++ ", (9, 0, 'ya'), (9, 9, 'yy')" + } + + sql(finalStr) + } + + Seq(true, false).foreach { allowJoinKeysSubsetOfPartitionKeys => + withSQLConf( + SQLConf.REQUIRE_ALL_CLUSTER_KEYS_FOR_CO_PARTITION.key -> "false", + SQLConf.V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key -> "true", + SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key -> "false", + SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS.key -> + allowJoinKeysSubsetOfPartitionKeys.toString, + SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS.key -> "true") { + val df = sql( + s""" + |${selectWithMergeJoinHint("t1", "t2")} + |t1.store_id, t1.dept_id, t1.data, t2.data + |FROM testcat.ns.$table1 t1 JOIN testcat.ns.$table2 t2 + |ON t1.store_id = t2.store_id AND t1.dept_id = t2.dept_id + |ORDER BY t1.store_id, t1.dept_id, t1.data, t2.data + |""".stripMargin) + + val shuffles = collectShuffles(df.queryExecution.executedPlan) + assert(shuffles.isEmpty, "SPJ should be triggered") + + val scans = collectScans(df.queryExecution.executedPlan).map(_.inputRDD. + partitions.length) + + def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt + val expectedBuckets = gcd(table1buckets1, table2buckets1) * + gcd(table1buckets2, table2buckets2) + assert(scans == Seq(expectedBuckets, expectedBuckets)) + + checkAnswer(df, Seq( + Row(0, 0, "aa", "aa"), + Row(0, 0, "aa", "ab"), + Row(0, 0, "ab", "aa"), + Row(0, 0, "ab", "ab"), + Row(0, 1, "ac", "ac"), + Row(0, 2, "ad", "ad"), + Row(0, 3, "ae", "ae"), + Row(0, 4, "af", "af"), + Row(0, 5, "ag", "ag"), + Row(1, 0, "ah", "ah"), + Row(1, 0, "ah", "ai"), + Row(1, 0, "ai", "ah"), + Row(1, 0, "ai", "ai"), + Row(1, 1, "aj", "aj"), + Row(1, 2, "ak", "ak"), + Row(1, 3, "al", "al"), + Row(1, 4, "am", "am"), + Row(1, 5, "an", "an"), + Row(2, 0, "ao", "ao"), + Row(2, 0, "ao", "ap"), + Row(2, 0, "ap", "ao"), + Row(2, 0, "ap", "ap"), + Row(2, 1, "aq", "aq"), + Row(2, 2, "ar", "ar"), + Row(2, 3, "as", "as"), + Row(2, 4, "at", "at"), + Row(2, 5, "au", "au"), + Row(3, 0, "av", "av"), + Row(3, 0, "av", "aw"), + Row(3, 0, "aw", "av"), + Row(3, 0, "aw", "aw"), + Row(3, 1, "ax", "ax"), + Row(3, 2, "ay", "ay"), + Row(3, 3, "az", "az"), + Row(3, 4, "ba", "ba"), + Row(3, 5, "bb", "bb"), + Row(4, 0, "bc", "bc"), + Row(4, 0, "bc", "bd"), + Row(4, 0, "bd", "bc"), + Row(4, 0, "bd", "bd"), + Row(4, 1, "be", "be"), + Row(4, 2, "bf", "bf"), + Row(4, 3, "bg", "bg"), + Row(4, 4, "bh", "bh"), + Row(4, 5, "bi", "bi"), + Row(5, 0, "bj", "bj"), + Row(5, 0, "bj", "bk"), + Row(5, 0, "bk", "bj"), + Row(5, 0, "bk", "bk"), + Row(5, 1, "bl", "bl"), + Row(5, 2, "bm", "bm"), + Row(5, 3, "bn", "bn"), + Row(5, 4, "bo", "bo"), + Row(5, 5, "bp", "bp") + )) + } + } + } + } + test("SPARK-47094: Support compatible buckets with less join keys than partition keys") { val table1 = "tab1e1" val table2 = "table2" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 7a77c00b577f..176e597fe44b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -87,21 +87,31 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In } override def reducer(func: ReducibleFunction[_, _], - thisNumBuckets: Option[_], - otherNumBuckets: Option[_]): Option[Reducer[Int, Int]] = { - (thisNumBuckets, otherNumBuckets) match { - case (Some(thisNumBucketsVal: Int), Some(otherNumBucketsVal: Int)) - if func == BucketFunction && - ((thisNumBucketsVal > otherNumBucketsVal) && - (thisNumBucketsVal % otherNumBucketsVal == 0)) => - Some(BucketReducer(thisNumBucketsVal, otherNumBucketsVal)) - case _ => None + thisNumBuckets: Int, + otherNumBuckets: Int): Reducer[Int, Int] = { + + if (func == BucketFunction) { + if ((thisNumBuckets > otherNumBuckets) + && (thisNumBuckets % otherNumBuckets == 0)) { + BucketReducer(thisNumBuckets, otherNumBuckets) + } else { + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) + if (gcd != thisNumBuckets) { + BucketReducer(thisNumBuckets, gcd) + } else { + null + } + } + } else { + null } } + + private def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt } -case class BucketReducer(thisNumBuckets: Int, otherNumBuckets: Int) extends Reducer[Int, Int] { - override def reduce(bucket: Int): Int = bucket % otherNumBuckets +case class BucketReducer(thisNumBuckets: Int, divisor: Int) extends Reducer[Int, Int] { + override def reduce(bucket: Int): Int = bucket % divisor } object UnboundStringSelfFunction extends UnboundFunction { From d2de9c32c5d24391489704ebb8a762acfc5cb640 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 18 Mar 2024 23:48:12 -0700 Subject: [PATCH 09/14] Fix linting --- .../sql/connector/catalog/functions/Reducer.java | 12 +++++++----- .../catalog/functions/ReducibleFunction.java | 16 ++++++++++------ 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java index af742fe8cb24..561d66092d64 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/Reducer.java @@ -23,11 +23,13 @@ * * @see ReducibleFunction * - * A user defined function f_source(x) is 'reducible' on another user_defined function f_target(x) if + * A user defined function f_source(x) is 'reducible' on another user_defined function + * f_target(x) if *

            - *
          • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
          • - *
          • More generally, there exists two reducer functions r1(x) and r2(x) such that - * r1(f_source(x)) = r2(f_target(x)) for all input x.
          • + *
          • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for + * all input x, or
          • + *
          • More generally, there exists reducer functions r1(x) and r2(x) such that + * r1(f_source(x)) = r2(f_target(x)) for all input x.
          • *
          * * @param reducer input type @@ -36,5 +38,5 @@ */ @Evolving public interface Reducer { - O reduce(I arg); + O reduce(I arg); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 9d2215c1167c..87bf5efaaf8c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -23,8 +23,9 @@ * * A function f_source(x) is 'reducible' on another function f_target(x) if *
            - *
          • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for all input x.
          • - *
          • More generally, there exists two reducer functions r1(x) and r2(x) such that + *
          • There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) + * for all input x, or
          • + *
          • More generally, there exists reducer functions r1(x) and r2(x) such that * r1(f_source(x)) = r2(f_target(x)) for all input x.
          • *
          *

          @@ -62,7 +63,8 @@ public interface ReducibleFunction { /** * This method is for bucket functions. * - * If this bucket function is 'reducible' on another bucket function, return the {@link Reducer} function. + * If this bucket function is 'reducible' on another bucket function, + * return the {@link Reducer} function. *

          * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) *

            @@ -77,8 +79,10 @@ public interface ReducibleFunction { * @param otherNumBuckets number of buckets for the other bucket function * @return a reduction function if it is reducible, null if not */ - default Reducer reducer(ReducibleFunction otherFunction, int thisNumBuckets, int otherNumBuckets) { - return reducer(otherFunction); + default Reducer reducer(ReducibleFunction otherFunction, + int thisNumBuckets, + int otherNumBuckets) { + return reducer(otherFunction); } /** @@ -96,6 +100,6 @@ default Reducer reducer(ReducibleFunction otherFunction, int thisNum * @return a reduction function if it is reducible, null if not. */ default Reducer reducer(ReducibleFunction otherFunction) { - return reducer(otherFunction, 0, 0); + return reducer(otherFunction, 0, 0); } } From d3e196c1f158ade217fa20fe37ce31019588adcb Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Tue, 19 Mar 2024 10:48:20 -0700 Subject: [PATCH 10/14] Make reducer take generic arguments --- .../catalog/functions/ReducibleFunction.java | 83 ++++++++++--------- .../expressions/TransformExpression.scala | 25 ++++-- .../functions/transformFunctions.scala | 12 ++- 3 files changed, 67 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 87bf5efaaf8c..9b2e629dedf7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -60,46 +60,47 @@ @Evolving public interface ReducibleFunction { - /** - * This method is for bucket functions. - * - * If this bucket function is 'reducible' on another bucket function, - * return the {@link Reducer} function. - *

            - * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) - *

              - *
            • thisFunction = bucket
            • - *
            • otherFunction = bucket
            • - *
            • thisNumBuckets = Int(4)
            • - *
            • otherNumBuckets = Int(2)
            • - *
            - * - * @param otherFunction the other bucket function - * @param thisNumBuckets number of buckets for this bucket function - * @param otherNumBuckets number of buckets for the other bucket function - * @return a reduction function if it is reducible, null if not - */ - default Reducer reducer(ReducibleFunction otherFunction, - int thisNumBuckets, - int otherNumBuckets) { - return reducer(otherFunction); - } + /** + * This method is for parameterized functions. + * + * If this parameterized function is 'reducible' on another bucket function, + * return the {@link Reducer} function. + *

            + * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) + *

              + *
            • thisFunction = bucket
            • + *
            • thisParam = Int(4)
            • + *
            • otherFunction = bucket
            • + *
            • otherParam = Int(2)
            • + *
            + * + * @param thisParam parameter for this function + * @param otherFunction the other parameterized function + * @param otherParam parameter for the other function + * @return a reduction function if it is reducible, null if not + */ + default Reducer reducer( + Object thisParam, + ReducibleFunction otherFunction, + Object otherParam) { + throw new UnsupportedOperationException(); + } - /** - * This method is for all other functions. - * - * If this function is 'reducible' on another function, return the {@link Reducer} function. - *

            - * Example of reducing f_source = days(x) on f_target = hours(x) - *

              - *
            • thisFunction = days
            • - *
            • otherFunction = hours
            • - *
            - * - * @param otherFunction the other function - * @return a reduction function if it is reducible, null if not. - */ - default Reducer reducer(ReducibleFunction otherFunction) { - return reducer(otherFunction, 0, 0); - } + /** + * This method is for all other functions. + * + * If this function is 'reducible' on another function, return the {@link Reducer} function. + *

            + * Example of reducing f_source = days(x) on f_target = hours(x) + *

              + *
            • thisFunction = days
            • + *
            • otherFunction = hours
            • + *
            + * + * @param otherFunction the other function + * @return a reduction function if it is reducible, null if not. + */ + default Reducer reducer(ReducibleFunction otherFunction) { + throw new UnsupportedOperationException(); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index eff0a0ddfe71..ed44fdd838fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -70,10 +70,9 @@ case class TransformExpression( } else { (function, other.function) match { case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) => - val reducer = f.reducer(o, numBucketsOpt.getOrElse(0), other.numBucketsOpt.getOrElse(0)) - val otherReducer = - o.reducer(f, other.numBucketsOpt.getOrElse(0), numBucketsOpt.getOrElse(0)) - reducer != null || otherReducer != null + val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt) + val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt) + thisReducer.isDefined || otherReducer.isDefined case _ => false } } @@ -91,14 +90,24 @@ case class TransformExpression( def reducers(other: TransformExpression): Option[Reducer[_, _]] = { (function, other.function) match { case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) => - val reducer = e1.reducer(e2, - numBucketsOpt.getOrElse(0), - other.numBucketsOpt.getOrElse(0)) - Option(reducer) + reducer(e1, numBucketsOpt, e2, other.numBucketsOpt) case _ => None } } + // Return a Reducer for a reducible function on another reducible function + private def reducer(thisFunction: ReducibleFunction[_, _], + thisNumBucketsOpt: Option[Int], + otherFunction: ReducibleFunction[_, _], + otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { + val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { + case (Some(numBuckets), Some(otherNumBuckets)) => + thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) + case _ => thisFunction.reducer(otherFunction) + } + Option(res) + } + override def dataType: DataType = function.resultType() override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 176e597fe44b..ae6c9450c576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -86,11 +86,15 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In (input.getLong(1) % input.getInt(0)).toInt } - override def reducer(func: ReducibleFunction[_, _], - thisNumBuckets: Int, - otherNumBuckets: Int): Reducer[Int, Int] = { + override def reducer( + thisNumBucketsArg: Object, + otherFunc: ReducibleFunction[_, _], + otherNumBucketsArg: Object): Reducer[Int, Int] = { - if (func == BucketFunction) { + val thisNumBuckets = thisNumBucketsArg.asInstanceOf[Int] + val otherNumBuckets = otherNumBucketsArg.asInstanceOf[Int] + + if (otherFunc == BucketFunction) { if ((thisNumBuckets > otherNumBuckets) && (thisNumBuckets % otherNumBuckets == 0)) { BucketReducer(thisNumBuckets, otherNumBuckets) From 1fa24a4cbb1deb9fdc05141ed348ce8cf5a24e70 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Fri, 22 Mar 2024 16:11:01 -0700 Subject: [PATCH 11/14] Review comments --- .../spark/sql/catalyst/plans/physical/partitioning.scala | 3 ++- .../src/main/scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 33ea5ad52cd5..86295ec61575 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -880,7 +880,8 @@ case class KeyGroupedShuffleSpec( } object KeyGroupedShuffleSpec { - def reducePartitionValue(row: InternalRow, + def reducePartitionValue( + row: InternalRow, expressions: Seq[Expression], reducers: Seq[Option[Reducer[_, _]]]): InternalRowComparableWrapper = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0f8c537cfd29..1c646eb644ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1543,7 +1543,7 @@ object SQLConf { val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS = buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled") - .doc("Whether to allow storage-partition join in the case where the partition transforms" + + .doc("Whether to allow storage-partition join in the case where the partition transforms " + "are compatible but not identical. This config requires both " + s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " + s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + From 8053e58f6db8f630cc0dd37a0cbb30e75033cdaf Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Mon, 25 Mar 2024 15:14:41 -0700 Subject: [PATCH 12/14] Revert to bucketReducer function --- .../catalog/functions/ReducibleFunction.java | 24 +++++++++---------- .../expressions/TransformExpression.scala | 4 ++-- .../functions/transformFunctions.scala | 9 +++---- 3 files changed, 17 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 9b2e629dedf7..86156c789630 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -68,21 +68,21 @@ public interface ReducibleFunction { *

            * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) *

              - *
            • thisFunction = bucket
            • - *
            • thisParam = Int(4)
            • - *
            • otherFunction = bucket
            • - *
            • otherParam = Int(2)
            • + *
            • thisBucketFunction = bucket
            • + *
            • thisNumBuckets = 4
            • + *
            • otherBucketFunction = bucket
            • + *
            • otherNumBuckets = 2
            • *
            * - * @param thisParam parameter for this function - * @param otherFunction the other parameterized function - * @param otherParam parameter for the other function + * @param thisNumBuckets parameter for this function + * @param otherBucketFunction the other parameterized function + * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not */ - default Reducer reducer( - Object thisParam, - ReducibleFunction otherFunction, - Object otherParam) { + default Reducer bucketReducer( + int thisNumBuckets, + ReducibleFunction otherBucketFunction, + int otherNumBuckets) { throw new UnsupportedOperationException(); } @@ -100,7 +100,7 @@ default Reducer reducer( * @param otherFunction the other function * @return a reduction function if it is reducible, null if not. */ - default Reducer reducer(ReducibleFunction otherFunction) { + default Reducer bucketReducer(ReducibleFunction otherFunction) { throw new UnsupportedOperationException(); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index ed44fdd838fa..371e6622d5a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -102,8 +102,8 @@ case class TransformExpression( otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { case (Some(numBuckets), Some(otherNumBuckets)) => - thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) - case _ => thisFunction.reducer(otherFunction) + thisFunction.bucketReducer(numBuckets, otherFunction, otherNumBuckets) + case _ => thisFunction.bucketReducer(otherFunction) } Option(res) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index ae6c9450c576..68f5e774a385 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -86,13 +86,10 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In (input.getLong(1) % input.getInt(0)).toInt } - override def reducer( - thisNumBucketsArg: Object, + override def bucketReducer( + thisNumBuckets: Int, otherFunc: ReducibleFunction[_, _], - otherNumBucketsArg: Object): Reducer[Int, Int] = { - - val thisNumBuckets = thisNumBucketsArg.asInstanceOf[Int] - val otherNumBuckets = otherNumBucketsArg.asInstanceOf[Int] + otherNumBuckets: Int): Reducer[Int, Int] = { if (otherFunc == BucketFunction) { if ((thisNumBuckets > otherNumBuckets) From ad61b0c61efa1e7e7f9374e96d28732f859d3f59 Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Wed, 3 Apr 2024 17:57:18 +0800 Subject: [PATCH 13/14] Formatting --- .../catalog/functions/ReducibleFunction.java | 16 ++++++++-------- .../expressions/TransformExpression.scala | 13 +++++++------ .../catalyst/plans/physical/partitioning.scala | 6 +++--- .../execution/exchange/EnsureRequirements.scala | 4 ++-- .../catalog/functions/transformFunctions.scala | 8 ++++---- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java index 86156c789630..ef1a14e50cda 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ReducibleFunction.java @@ -61,12 +61,12 @@ public interface ReducibleFunction { /** - * This method is for parameterized functions. + * This method is for the bucket function. * - * If this parameterized function is 'reducible' on another bucket function, + * If this bucket function is 'reducible' on another bucket function, * return the {@link Reducer} function. *

            - * Example to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) + * For example, to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x) *

              *
            • thisBucketFunction = bucket
            • *
            • thisNumBuckets = 4
            • @@ -79,10 +79,10 @@ public interface ReducibleFunction { * @param otherNumBuckets parameter for the other function * @return a reduction function if it is reducible, null if not */ - default Reducer bucketReducer( - int thisNumBuckets, - ReducibleFunction otherBucketFunction, - int otherNumBuckets) { + default Reducer reducer( + int thisNumBuckets, + ReducibleFunction otherBucketFunction, + int otherNumBuckets) { throw new UnsupportedOperationException(); } @@ -100,7 +100,7 @@ default Reducer bucketReducer( * @param otherFunction the other function * @return a reduction function if it is reducible, null if not. */ - default Reducer bucketReducer(ReducibleFunction otherFunction) { + default Reducer reducer(ReducibleFunction otherFunction) { throw new UnsupportedOperationException(); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala index 371e6622d5a6..d37c9d9f6452 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala @@ -96,14 +96,15 @@ case class TransformExpression( } // Return a Reducer for a reducible function on another reducible function - private def reducer(thisFunction: ReducibleFunction[_, _], - thisNumBucketsOpt: Option[Int], - otherFunction: ReducibleFunction[_, _], - otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { + private def reducer( + thisFunction: ReducibleFunction[_, _], + thisNumBucketsOpt: Option[Int], + otherFunction: ReducibleFunction[_, _], + otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = { val res = (thisNumBucketsOpt, otherNumBucketsOpt) match { case (Some(numBuckets), Some(otherNumBuckets)) => - thisFunction.bucketReducer(numBuckets, otherFunction, otherNumBuckets) - case _ => thisFunction.bucketReducer(otherFunction) + thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets) + case _ => thisFunction.reducer(otherFunction) } Option(res) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 86295ec61575..2364130f79e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -881,9 +881,9 @@ case class KeyGroupedShuffleSpec( object KeyGroupedShuffleSpec { def reducePartitionValue( - row: InternalRow, - expressions: Seq[Expression], - reducers: Seq[Option[Reducer[_, _]]]): + row: InternalRow, + expressions: Seq[Expression], + reducers: Seq[Option[Reducer[_, _]]]): InternalRowComparableWrapper = { val partitionVals = row.toSeq(expressions.map(_.dataType)) val reducedRow = partitionVals.zip(reducers).map{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 7ff682178ad2..105bded78549 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -569,8 +569,8 @@ case class EnsureRequirements( } private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], - expressions: Seq[Expression], - reducers: Option[Seq[Option[Reducer[_, _]]]]) = { + expressions: Seq[Expression], + reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { case Some(reducers) => commonPartValues.groupBy { case (row, _) => KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index 68f5e774a385..c4207fd3e092 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -86,10 +86,10 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In (input.getLong(1) % input.getInt(0)).toInt } - override def bucketReducer( - thisNumBuckets: Int, - otherFunc: ReducibleFunction[_, _], - otherNumBuckets: Int): Reducer[Int, Int] = { + override def reducer( + thisNumBuckets: Int, + otherFunc: ReducibleFunction[_, _], + otherNumBuckets: Int): Reducer[Int, Int] = { if (otherFunc == BucketFunction) { if ((thisNumBuckets > otherNumBuckets) From 0356f9e9d1f83b0c379b2267ffd86872b084a3bd Mon Sep 17 00:00:00 2001 From: Szehon Ho Date: Wed, 3 Apr 2024 21:23:45 +0800 Subject: [PATCH 14/14] Simplify bucket reducer logic --- .../execution/exchange/EnsureRequirements.scala | 3 ++- .../catalog/functions/transformFunctions.scala | 16 ++++------------ 2 files changed, 6 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 105bded78549..a0f74ef6c3d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -568,7 +568,8 @@ case class EnsureRequirements( child, values, joinKeyPositions, reducers, applyPartialClustering, replicatePartitions)) } - private def reduceCommonPartValues(commonPartValues: Seq[(InternalRow, Int)], + private def reduceCommonPartValues( + commonPartValues: Seq[(InternalRow, Int)], expressions: Seq[Expression], reducers: Option[Seq[Option[Reducer[_, _]]]]) = { reducers match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala index c4207fd3e092..5cdb90090105 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/catalog/functions/transformFunctions.scala @@ -92,20 +92,12 @@ object BucketFunction extends ScalarFunction[Int] with ReducibleFunction[Int, In otherNumBuckets: Int): Reducer[Int, Int] = { if (otherFunc == BucketFunction) { - if ((thisNumBuckets > otherNumBuckets) - && (thisNumBuckets % otherNumBuckets == 0)) { - BucketReducer(thisNumBuckets, otherNumBuckets) - } else { - val gcd = this.gcd(thisNumBuckets, otherNumBuckets) - if (gcd != thisNumBuckets) { - BucketReducer(thisNumBuckets, gcd) - } else { - null - } + val gcd = this.gcd(thisNumBuckets, otherNumBuckets) + if (gcd != thisNumBuckets) { + return BucketReducer(thisNumBuckets, gcd) } - } else { - null } + null } private def gcd(a: Int, b: Int): Int = BigInt(a).gcd(BigInt(b)).toInt