Skip to content

Commit 0867964

Browse files
aokolnychyicloud-fan
authored andcommitted
[SPARK-34026][SQL] Inject repartition and sort nodes to satisfy required distribution and ordering
### What changes were proposed in this pull request? This PR adds repartition and sort nodes to satisfy the required distribution and ordering introduced in SPARK-33779. Note: This PR contains the final part of changes discussed in PR #29066. ### Why are the changes needed? These changes are the next step as discussed in the [design doc](https://docs.google.com/document/d/1X0NsQSryvNmXBY9kcvfINeYyKC-AahZarUqg3nS1GQs/edit#) for SPARK-23889. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? This PR comes with a new test suite. Closes #31083 from aokolnychyi/spark-34026. Authored-by: Anton Okolnychyi <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 8999e88 commit 0867964

File tree

5 files changed

+708
-8
lines changed

5 files changed

+708
-8
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow}
3131
import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils}
3232
import org.apache.spark.sql.connector.catalog._
33-
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
33+
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
34+
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, SortOrder, Transform, YearsTransform}
3435
import org.apache.spark.sql.connector.read._
3536
import org.apache.spark.sql.connector.write._
3637
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
@@ -46,7 +47,9 @@ class InMemoryTable(
4647
val name: String,
4748
val schema: StructType,
4849
override val partitioning: Array[Transform],
49-
override val properties: util.Map[String, String])
50+
override val properties: util.Map[String, String],
51+
val distribution: Distribution = Distributions.unspecified(),
52+
val ordering: Array[SortOrder] = Array.empty)
5053
extends Table with SupportsRead with SupportsWrite with SupportsDelete
5154
with SupportsMetadataColumns {
5255

@@ -284,7 +287,11 @@ class InMemoryTable(
284287
this
285288
}
286289

287-
override def build(): Write = new Write {
290+
override def build(): Write = new Write with RequiresDistributionAndOrdering {
291+
override def requiredDistribution: Distribution = distribution
292+
293+
override def requiredOrdering: Array[SortOrder] = ordering
294+
288295
override def toBatch: BatchWrite = writer
289296

290297
override def toStreaming: StreamingWrite = streamingWriter match {

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTableCatalog.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ import scala.collection.JavaConverters._
2424

2525
import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
2626
import org.apache.spark.sql.connector.catalog._
27-
import org.apache.spark.sql.connector.expressions.Transform
27+
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
28+
import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
2829
import org.apache.spark.sql.types.StructType
2930
import org.apache.spark.sql.util.CaseInsensitiveStringMap
3031

@@ -69,13 +70,24 @@ class BasicInMemoryTableCatalog extends TableCatalog {
6970
schema: StructType,
7071
partitions: Array[Transform],
7172
properties: util.Map[String, String]): Table = {
73+
createTable(ident, schema, partitions, properties, Distributions.unspecified(), Array.empty)
74+
}
75+
76+
def createTable(
77+
ident: Identifier,
78+
schema: StructType,
79+
partitions: Array[Transform],
80+
properties: util.Map[String, String],
81+
distribution: Distribution,
82+
ordering: Array[SortOrder]): Table = {
7283
if (tables.containsKey(ident)) {
7384
throw new TableAlreadyExistsException(ident)
7485
}
7586

7687
InMemoryTableCatalog.maybeSimulateFailedTableCreation(properties)
7788

78-
val table = new InMemoryTable(s"$name.${ident.quoted}", schema, partitions, properties)
89+
val tableName = s"$name.${ident.quoted}"
90+
val table = new InMemoryTable(tableName, schema, partitions, properties, distribution, ordering)
7991
tables.put(ident, table)
8092
namespaces.putIfAbsent(ident.namespace.toList, Map())
8193
table
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.datasources.v2
19+
20+
import org.apache.spark.sql.AnalysisException
21+
import org.apache.spark.sql.catalyst.analysis.Resolver
22+
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, NamedExpression, NullOrdering, NullsFirst, NullsLast, SortDirection, SortOrder}
23+
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, RepartitionByExpression, Sort}
24+
import org.apache.spark.sql.connector.distributions.{ClusteredDistribution, OrderedDistribution, UnspecifiedDistribution}
25+
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, IdentityTransform, NullOrdering => V2NullOrdering, SortDirection => V2SortDirection, SortValue}
26+
import org.apache.spark.sql.connector.write.{RequiresDistributionAndOrdering, Write}
27+
import org.apache.spark.sql.internal.SQLConf
28+
29+
object DistributionAndOrderingUtils {
30+
31+
def prepareQuery(write: Write, query: LogicalPlan, conf: SQLConf): LogicalPlan = write match {
32+
case write: RequiresDistributionAndOrdering =>
33+
val resolver = conf.resolver
34+
35+
val distribution = write.requiredDistribution match {
36+
case d: OrderedDistribution =>
37+
d.ordering.map(e => toCatalyst(e, query, resolver))
38+
case d: ClusteredDistribution =>
39+
d.clustering.map(e => toCatalyst(e, query, resolver))
40+
case _: UnspecifiedDistribution =>
41+
Array.empty[Expression]
42+
}
43+
44+
val queryWithDistribution = if (distribution.nonEmpty) {
45+
val numShufflePartitions = conf.numShufflePartitions
46+
// the conversion to catalyst expressions above produces SortOrder expressions
47+
// for OrderedDistribution and generic expressions for ClusteredDistribution
48+
// this allows RepartitionByExpression to pick either range or hash partitioning
49+
RepartitionByExpression(distribution, query, numShufflePartitions)
50+
} else {
51+
query
52+
}
53+
54+
val ordering = write.requiredOrdering.toSeq
55+
.map(e => toCatalyst(e, query, resolver))
56+
.asInstanceOf[Seq[SortOrder]]
57+
58+
val queryWithDistributionAndOrdering = if (ordering.nonEmpty) {
59+
Sort(ordering, global = false, queryWithDistribution)
60+
} else {
61+
queryWithDistribution
62+
}
63+
64+
queryWithDistributionAndOrdering
65+
66+
case _ =>
67+
query
68+
}
69+
70+
private def toCatalyst(
71+
expr: V2Expression,
72+
query: LogicalPlan,
73+
resolver: Resolver): Expression = {
74+
75+
// we cannot perform the resolution in the analyzer since we need to optimize expressions
76+
// in nodes like OverwriteByExpression before constructing a logical write
77+
def resolve(ref: FieldReference): NamedExpression = {
78+
query.resolve(ref.parts, resolver) match {
79+
case Some(attr) => attr
80+
case None => throw new AnalysisException(s"Cannot resolve '$ref' using ${query.output}")
81+
}
82+
}
83+
84+
expr match {
85+
case SortValue(child, direction, nullOrdering) =>
86+
val catalystChild = toCatalyst(child, query, resolver)
87+
SortOrder(catalystChild, toCatalyst(direction), toCatalyst(nullOrdering), Seq.empty)
88+
case IdentityTransform(ref) =>
89+
resolve(ref)
90+
case ref: FieldReference =>
91+
resolve(ref)
92+
case _ =>
93+
throw new AnalysisException(s"$expr is not currently supported")
94+
}
95+
}
96+
97+
private def toCatalyst(direction: V2SortDirection): SortDirection = direction match {
98+
case V2SortDirection.ASCENDING => Ascending
99+
case V2SortDirection.DESCENDING => Descending
100+
}
101+
102+
private def toCatalyst(nullOrdering: V2NullOrdering): NullOrdering = nullOrdering match {
103+
case V2NullOrdering.NULLS_FIRST => NullsFirst
104+
case V2NullOrdering.NULLS_LAST => NullsLast
105+
}
106+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
4040
case a @ AppendData(r: DataSourceV2Relation, query, options, _, None) =>
4141
val writeBuilder = newWriteBuilder(r.table, query, options)
4242
val write = writeBuilder.build()
43-
a.copy(write = Some(write))
43+
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf)
44+
a.copy(write = Some(write), query = newQuery)
4445

4546
case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) =>
4647
// fail if any filter cannot be converted. correctness depends on removing all matching data.
@@ -63,7 +64,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
6364
throw new SparkException(s"Table does not support overwrite by expression: $table")
6465
}
6566

66-
o.copy(write = Some(write))
67+
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf)
68+
o.copy(write = Some(write), query = newQuery)
6769

6870
case o @ OverwritePartitionsDynamic(r: DataSourceV2Relation, query, options, _, None) =>
6971
val table = r.table
@@ -74,7 +76,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
7476
case _ =>
7577
throw new SparkException(s"Table does not support dynamic partition overwrite: $table")
7678
}
77-
o.copy(write = Some(write))
79+
val newQuery = DistributionAndOrderingUtils.prepareQuery(write, query, conf)
80+
o.copy(write = Some(write), query = newQuery)
7881
}
7982

8083
private def isTruncate(filters: Array[Filter]): Boolean = {

0 commit comments

Comments
 (0)