Skip to content

Commit 11abc64

Browse files
szehon-hosunchao
authored andcommitted
[SPARK-47094][SQL] SPJ : Dynamically rebalance number of buckets when they are not equal
### What changes were proposed in this pull request? -- Allow SPJ between 'compatible' bucket funtions -- Add a mechanism to define 'reducible' functions, one function whose output can be 'reduced' to another for all inputs. ### Why are the changes needed? -- SPJ currently applies only if the partition transform expressions on both sides are identifical. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new tests in KeyGroupedPartitioningSuite ### Was this patch authored or co-authored using generative AI tooling? No Closes #45267 from szehon-ho/spj-uneven-buckets. Authored-by: Szehon Ho <[email protected]> Signed-off-by: Chao Sun <[email protected]>
1 parent d1ace24 commit 11abc64

File tree

9 files changed

+821
-15
lines changed

9 files changed

+821
-15
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connector.catalog.functions;
18+
19+
import org.apache.spark.annotation.Evolving;
20+
21+
/**
22+
* A 'reducer' for output of user-defined functions.
23+
*
24+
* @see ReducibleFunction
25+
*
26+
* A user defined function f_source(x) is 'reducible' on another user_defined function
27+
* f_target(x) if
28+
* <ul>
29+
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x) for
30+
* all input x, or </li>
31+
* <li> More generally, there exists reducer functions r1(x) and r2(x) such that
32+
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
33+
* </ul>
34+
*
35+
* @param <I> reducer input type
36+
* @param <O> reducer output type
37+
* @since 4.0.0
38+
*/
39+
@Evolving
40+
public interface Reducer<I, O> {
41+
O reduce(I arg);
42+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.connector.catalog.functions;
18+
19+
import org.apache.spark.annotation.Evolving;
20+
21+
/**
22+
* Base class for user-defined functions that can be 'reduced' on another function.
23+
*
24+
* A function f_source(x) is 'reducible' on another function f_target(x) if
25+
* <ul>
26+
* <li> There exists a reducer function r(x) such that r(f_source(x)) = f_target(x)
27+
* for all input x, or </li>
28+
* <li> More generally, there exists reducer functions r1(x) and r2(x) such that
29+
* r1(f_source(x)) = r2(f_target(x)) for all input x. </li>
30+
* </ul>
31+
* <p>
32+
* Examples:
33+
* <ul>
34+
* <li>Bucket functions where one side has reducer
35+
* <ul>
36+
* <li>f_source(x) = bucket(4, x)</li>
37+
* <li>f_target(x) = bucket(2, x)</li>
38+
* <li>r(x) = x % 2</li>
39+
* </ul>
40+
*
41+
* <li>Bucket functions where both sides have reducer
42+
* <ul>
43+
* <li>f_source(x) = bucket(16, x)</li>
44+
* <li>f_target(x) = bucket(12, x)</li>
45+
* <li>r1(x) = x % 4</li>
46+
* <li>r2(x) = x % 4</li>
47+
* </ul>
48+
*
49+
* <li>Date functions
50+
* <ul>
51+
* <li>f_source(x) = days(x)</li>
52+
* <li>f_target(x) = hours(x)</li>
53+
* <li>r(x) = x / 24</li>
54+
* </ul>
55+
* </ul>
56+
* @param <I> reducer function input type
57+
* @param <O> reducer function output type
58+
* @since 4.0.0
59+
*/
60+
@Evolving
61+
public interface ReducibleFunction<I, O> {
62+
63+
/**
64+
* This method is for the bucket function.
65+
*
66+
* If this bucket function is 'reducible' on another bucket function,
67+
* return the {@link Reducer} function.
68+
* <p>
69+
* For example, to return reducer for reducing f_source = bucket(4, x) on f_target = bucket(2, x)
70+
* <ul>
71+
* <li>thisBucketFunction = bucket</li>
72+
* <li>thisNumBuckets = 4</li>
73+
* <li>otherBucketFunction = bucket</li>
74+
* <li>otherNumBuckets = 2</li>
75+
* </ul>
76+
*
77+
* @param thisNumBuckets parameter for this function
78+
* @param otherBucketFunction the other parameterized function
79+
* @param otherNumBuckets parameter for the other function
80+
* @return a reduction function if it is reducible, null if not
81+
*/
82+
default Reducer<I, O> reducer(
83+
int thisNumBuckets,
84+
ReducibleFunction<?, ?> otherBucketFunction,
85+
int otherNumBuckets) {
86+
throw new UnsupportedOperationException();
87+
}
88+
89+
/**
90+
* This method is for all other functions.
91+
*
92+
* If this function is 'reducible' on another function, return the {@link Reducer} function.
93+
* <p>
94+
* Example of reducing f_source = days(x) on f_target = hours(x)
95+
* <ul>
96+
* <li>thisFunction = days</li>
97+
* <li>otherFunction = hours</li>
98+
* </ul>
99+
*
100+
* @param otherFunction the other function
101+
* @return a reduction function if it is reducible, null if not.
102+
*/
103+
default Reducer<I, O> reducer(ReducibleFunction<?, ?> otherFunction) {
104+
throw new UnsupportedOperationException();
105+
}
106+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.connector.catalog.functions.BoundFunction
20+
import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, Reducer, ReducibleFunction}
2121
import org.apache.spark.sql.types.DataType
2222

2323
/**
@@ -54,6 +54,61 @@ case class TransformExpression(
5454
false
5555
}
5656

57+
/**
58+
* Whether this [[TransformExpression]]'s function is compatible with the `other`
59+
* [[TransformExpression]]'s function.
60+
*
61+
* This is true if both are instances of [[ReducibleFunction]] and there exists a [[Reducer]] r(x)
62+
* such that r(t1(x)) = t2(x), or r(t2(x)) = t1(x), for all input x.
63+
*
64+
* @param other the transform expression to compare to
65+
* @return true if compatible, false if not
66+
*/
67+
def isCompatible(other: TransformExpression): Boolean = {
68+
if (isSameFunction(other)) {
69+
true
70+
} else {
71+
(function, other.function) match {
72+
case (f: ReducibleFunction[_, _], o: ReducibleFunction[_, _]) =>
73+
val thisReducer = reducer(f, numBucketsOpt, o, other.numBucketsOpt)
74+
val otherReducer = reducer(o, other.numBucketsOpt, f, numBucketsOpt)
75+
thisReducer.isDefined || otherReducer.isDefined
76+
case _ => false
77+
}
78+
}
79+
}
80+
81+
/**
82+
* Return a [[Reducer]] for this transform expression on another
83+
* on the transform expression.
84+
* <p>
85+
* A [[Reducer]] exists for a transform expression function if it is
86+
* 'reducible' on the other expression function.
87+
* <p>
88+
* @return reducer function or None if not reducible on the other transform expression
89+
*/
90+
def reducers(other: TransformExpression): Option[Reducer[_, _]] = {
91+
(function, other.function) match {
92+
case(e1: ReducibleFunction[_, _], e2: ReducibleFunction[_, _]) =>
93+
reducer(e1, numBucketsOpt, e2, other.numBucketsOpt)
94+
case _ => None
95+
}
96+
}
97+
98+
// Return a Reducer for a reducible function on another reducible function
99+
private def reducer(
100+
thisFunction: ReducibleFunction[_, _],
101+
thisNumBucketsOpt: Option[Int],
102+
otherFunction: ReducibleFunction[_, _],
103+
otherNumBucketsOpt: Option[Int]): Option[Reducer[_, _]] = {
104+
val res = (thisNumBucketsOpt, otherNumBucketsOpt) match {
105+
case (Some(numBuckets), Some(otherNumBuckets)) =>
106+
thisFunction.reducer(numBuckets, otherFunction, otherNumBuckets)
107+
case _ => thisFunction.reducer(otherFunction)
108+
}
109+
Option(res)
110+
}
111+
57112
override def dataType: DataType = function.resultType()
58113

59114
override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.util.InternalRowComparableWrapper
27+
import org.apache.spark.sql.connector.catalog.functions.Reducer
2728
import org.apache.spark.sql.internal.SQLConf
2829
import org.apache.spark.sql.types.{DataType, IntegerType}
2930

@@ -833,10 +834,42 @@ case class KeyGroupedShuffleSpec(
833834
(left, right) match {
834835
case (_: LeafExpression, _: LeafExpression) => true
835836
case (left: TransformExpression, right: TransformExpression) =>
836-
left.isSameFunction(right)
837+
if (SQLConf.get.v2BucketingPushPartValuesEnabled &&
838+
!SQLConf.get.v2BucketingPartiallyClusteredDistributionEnabled &&
839+
SQLConf.get.v2BucketingAllowCompatibleTransforms) {
840+
left.isCompatible(right)
841+
} else {
842+
left.isSameFunction(right)
843+
}
837844
case _ => false
838845
}
839846

847+
/**
848+
* Return a set of [[Reducer]] for the partition expressions of this shuffle spec,
849+
* on the partition expressions of another shuffle spec.
850+
* <p>
851+
* A [[Reducer]] exists for a partition expression function of this shuffle spec if it is
852+
* 'reducible' on the corresponding partition expression function of the other shuffle spec.
853+
* <p>
854+
* If a value is returned, there must be one [[Reducer]] per partition expression.
855+
* A None value in the set indicates that the particular partition expression is not reducible
856+
* on the corresponding expression on the other shuffle spec.
857+
* <p>
858+
* Returning none also indicates that none of the partition expressions can be reduced on the
859+
* corresponding expression on the other shuffle spec.
860+
*
861+
* @param other other key-grouped shuffle spec
862+
*/
863+
def reducers(other: KeyGroupedShuffleSpec): Option[Seq[Option[Reducer[_, _]]]] = {
864+
val results = partitioning.expressions.zip(other.partitioning.expressions).map {
865+
case (e1: TransformExpression, e2: TransformExpression) => e1.reducers(e2)
866+
case (_, _) => None
867+
}
868+
869+
// optimize to not return a value, if none of the partition expressions are reducible
870+
if (results.forall(p => p.isEmpty)) None else Some(results)
871+
}
872+
840873
override def canCreatePartitioning: Boolean = SQLConf.get.v2BucketingShuffleEnabled &&
841874
// Only support partition expressions are AttributeReference for now
842875
partitioning.expressions.forall(_.isInstanceOf[AttributeReference])
@@ -846,6 +879,21 @@ case class KeyGroupedShuffleSpec(
846879
}
847880
}
848881

882+
object KeyGroupedShuffleSpec {
883+
def reducePartitionValue(
884+
row: InternalRow,
885+
expressions: Seq[Expression],
886+
reducers: Seq[Option[Reducer[_, _]]]):
887+
InternalRowComparableWrapper = {
888+
val partitionVals = row.toSeq(expressions.map(_.dataType))
889+
val reducedRow = partitionVals.zip(reducers).map{
890+
case (v, Some(reducer: Reducer[Any, Any])) => reducer.reduce(v)
891+
case (v, _) => v
892+
}.toArray
893+
InternalRowComparableWrapper(new GenericInternalRow(reducedRow), expressions)
894+
}
895+
}
896+
849897
case class ShuffleSpecCollection(specs: Seq[ShuffleSpec]) extends ShuffleSpec {
850898
override def isCompatibleWith(other: ShuffleSpec): Boolean = {
851899
specs.exists(_.isCompatibleWith(other))

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,18 @@ object SQLConf {
15581558
.booleanConf
15591559
.createWithDefault(false)
15601560

1561+
val V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS =
1562+
buildConf("spark.sql.sources.v2.bucketing.allowCompatibleTransforms.enabled")
1563+
.doc("Whether to allow storage-partition join in the case where the partition transforms " +
1564+
"are compatible but not identical. This config requires both " +
1565+
s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
1566+
s"enabled and ${V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
1567+
"to be disabled."
1568+
)
1569+
.version("4.0.0")
1570+
.booleanConf
1571+
.createWithDefault(false)
1572+
15611573
val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
15621574
.doc("The maximum number of buckets allowed.")
15631575
.version("2.4.0")
@@ -5323,6 +5335,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
53235335
def v2BucketingAllowJoinKeysSubsetOfPartitionKeys: Boolean =
53245336
getConf(SQLConf.V2_BUCKETING_ALLOW_JOIN_KEYS_SUBSET_OF_PARTITION_KEYS)
53255337

5338+
def v2BucketingAllowCompatibleTransforms: Boolean =
5339+
getConf(SQLConf.V2_BUCKETING_ALLOW_COMPATIBLE_TRANSFORMS)
5340+
53265341
def dataFrameSelfJoinAutoResolveAmbiguity: Boolean =
53275342
getConf(DATAFRAME_SELF_JOIN_AUTO_RESOLVE_AMBIGUITY)
53285343

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ import org.apache.spark.rdd.RDD
2424
import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.plans.QueryPlan
27-
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, Partitioning, SinglePartition}
27+
import org.apache.spark.sql.catalyst.plans.physical.{KeyGroupedPartitioning, KeyGroupedShuffleSpec, Partitioning, SinglePartition}
2828
import org.apache.spark.sql.catalyst.util.{truncatedString, InternalRowComparableWrapper}
2929
import org.apache.spark.sql.connector.catalog.Table
30+
import org.apache.spark.sql.connector.catalog.functions.Reducer
3031
import org.apache.spark.sql.connector.read._
3132
import org.apache.spark.util.ArrayImplicits._
3233

@@ -164,6 +165,18 @@ case class BatchScanExec(
164165
(groupedParts, expressions)
165166
}
166167

168+
// Also re-group the partitions if we are reducing compatible partition expressions
169+
val finalGroupedPartitions = spjParams.reducers match {
170+
case Some(reducers) =>
171+
val result = groupedPartitions.groupBy { case (row, _) =>
172+
KeyGroupedShuffleSpec.reducePartitionValue(row, partExpressions, reducers)
173+
}.map { case (wrapper, splits) => (wrapper.row, splits.flatMap(_._2)) }.toSeq
174+
val rowOrdering = RowOrdering.createNaturalAscendingOrdering(
175+
partExpressions.map(_.dataType))
176+
result.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
177+
case _ => groupedPartitions
178+
}
179+
167180
// When partially clustered, the input partitions are not grouped by partition
168181
// values. Here we'll need to check `commonPartitionValues` and decide how to group
169182
// and replicate splits within a partition.
@@ -174,7 +187,7 @@ case class BatchScanExec(
174187
.get
175188
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
176189
.toMap
177-
val nestGroupedPartitions = groupedPartitions.map { case (partValue, splits) =>
190+
val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) =>
178191
// `commonPartValuesMap` should contain the part value since it's the super set.
179192
val numSplits = commonPartValuesMap
180193
.get(InternalRowComparableWrapper(partValue, partExpressions))
@@ -207,7 +220,7 @@ case class BatchScanExec(
207220
} else {
208221
// either `commonPartitionValues` is not defined, or it is defined but
209222
// `applyPartialClustering` is false.
210-
val partitionMapping = groupedPartitions.map { case (partValue, splits) =>
223+
val partitionMapping = finalGroupedPartitions.map { case (partValue, splits) =>
211224
InternalRowComparableWrapper(partValue, partExpressions) -> splits
212225
}.toMap
213226

@@ -259,6 +272,7 @@ case class StoragePartitionJoinParams(
259272
keyGroupedPartitioning: Option[Seq[Expression]] = None,
260273
joinKeyPositions: Option[Seq[Int]] = None,
261274
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
275+
reducers: Option[Seq[Option[Reducer[_, _]]]] = None,
262276
applyPartialClustering: Boolean = false,
263277
replicatePartitions: Boolean = false) {
264278
override def equals(other: Any): Boolean = other match {

0 commit comments

Comments
 (0)