From 6708f2a4035c048c106b220e433b2be528da3813 Mon Sep 17 00:00:00 2001 From: Anton Okolnychyi Date: Tue, 12 Apr 2022 09:58:58 -0700 Subject: [PATCH 1/2] [SPARK-38959][SQL] DataSource V2: Support runtime group filtering in row-level commands --- .../connector/write/RowLevelOperation.java | 14 ++ .../apache/spark/sql/internal/SQLConf.scala | 18 ++ .../InMemoryRowLevelOperationTable.scala | 5 +- .../spark/sql/execution/SparkOptimizer.scala | 5 +- .../PlanAdaptiveDynamicPruningFilters.scala | 2 +- .../PlanDynamicPruningFilters.scala | 2 +- ...wLevelOperationRuntimeGroupFiltering.scala | 98 ++++++++++ .../sql/connector/DeleteFromTableSuite.scala | 167 ++++++++++++++++-- 8 files changed, 295 insertions(+), 16 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java index 7acd27759a1ba..844734ff7ccb7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/RowLevelOperation.java @@ -21,6 +21,7 @@ import org.apache.spark.sql.connector.expressions.NamedReference; import org.apache.spark.sql.connector.read.Scan; import org.apache.spark.sql.connector.read.ScanBuilder; +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** @@ -68,6 +69,19 @@ default String description() { * be returned by the scan, even if a filter can narrow the set of changes to a single file * in the partition. Similarly, a data source that can swap individual files must produce all * rows from files where at least one record must be changed, not just rows that must be changed. + *

+ * Data sources that replace groups of data (e.g. files, partitions) may prune entire groups + * using provided data source filters when building a scan for this row-level operation. + * However, such data skipping is limited as not all expressions can be converted into data source + * filters and some can only be evaluated by Spark (e.g. subqueries). Since rewriting groups is + * expensive, Spark allows group-based data sources to filter groups at runtime. The runtime + * filtering enables data sources to narrow down the scope of rewriting to only groups that must + * be rewritten. If the row-level operation scan implements {@link SupportsRuntimeV2Filtering}, + * Spark will execute a query at runtime to find which records match the row-level condition. + * The runtime group filter subquery will leverage a regular batch scan, which isn't required to + * produce all rows in a group if any are returned. The information about matching records will + * be passed back into the row-level operation scan, allowing data sources to discard groups + * that don't have to be rewritten. */ ScanBuilder newScanBuilder(CaseInsensitiveStringMap options); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 44d34af1e47e5..5478155b1ae12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -412,6 +412,21 @@ object SQLConf { .longConf .createWithDefault(67108864L) + val RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED = + buildConf("spark.sql.optimizer.runtime.rowLevelOperationGroupFilter.enabled") + .doc("Enables runtime group filtering for group-based row-level operations. " + + "Data sources that replace groups of data (e.g. files, partitions) may prune entire " + + "groups using provided data source filters when planning a row-level operation scan. " + + "However, such filtering is limited as not all expressions can be converted into data " + + "source filters and some expressions can only be evaluated by Spark (e.g. subqueries). " + + "Since rewriting groups is expensive, Spark can execute a query at runtime to find what " + + "records match the condition of the row-level operation. The information about matching " + + "records will be passed back to the row-level operation scan, allowing data sources to " + + "discard groups that don't have to be rewritten.") + .version("3.4.0") + .booleanConf + .createWithDefault(true) + val PLANNED_WRITE_ENABLED = buildConf("spark.sql.optimizer.plannedWrite.enabled") .internal() .doc("When set to true, Spark optimizer will add logical sort operators to V1 write commands " + @@ -4084,6 +4099,9 @@ class SQLConf extends Serializable with Logging { def runtimeFilterCreationSideThreshold: Long = getConf(RUNTIME_BLOOM_FILTER_CREATION_SIDE_THRESHOLD) + def runtimeRowLevelOperationGroupFilterEnabled: Boolean = + getConf(RUNTIME_ROW_LEVEL_OPERATION_GROUP_FILTER_ENABLED) + def stateStoreProviderClass: String = getConf(STATE_STORE_PROVIDER_CLASS) def isStateSchemaCheckEnabled: Boolean = getConf(STATE_SCHEMA_CHECK_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index cb061602ec151..5099d73f18380 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -34,6 +34,8 @@ class InMemoryRowLevelOperationTable( properties: util.Map[String, String]) extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations { + var replacedPartitions: Seq[Seq[Any]] = Seq.empty + override def newRowLevelOperationBuilder( info: RowLevelOperationInfo): RowLevelOperationBuilder = { () => PartitionBasedOperation(info.command) @@ -88,8 +90,9 @@ class InMemoryRowLevelOperationTable( override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized { val newData = messages.map(_.asInstanceOf[BufferedRows]) val readRows = scan.data.flatMap(_.asInstanceOf[BufferedRows].rows) - val readPartitions = readRows.map(r => getKey(r, schema)) + val readPartitions = readRows.map(r => getKey(r, schema)).distinct dataMap --= readPartitions + replacedPartitions = readPartitions withData(newData, schema) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 72bdab409a9e6..017d1f937c34c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning, OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering, V2ScanRelationPushDown, V2Writes} -import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} +import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning, RowLevelOperationRuntimeGroupFiltering} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} class SparkOptimizer( @@ -50,7 +50,8 @@ class SparkOptimizer( override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("PartitionPruning", Once, - PartitionPruning) :+ + PartitionPruning, + RowLevelOperationRuntimeGroupFiltering(OptimizeSubqueries)) :+ Batch("InjectRuntimeFilter", FixedPoint(1), InjectRuntimeFilter) :+ Batch("MergeScalarSubqueries", Once, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala index 9a780c11eefab..21bc55110fe80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveDynamicPruningFilters.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelati case class PlanAdaptiveDynamicPruningFilters( rootPlan: AdaptiveSparkPlanExec) extends Rule[SparkPlan] with AdaptiveSparkPlanHelper { def apply(plan: SparkPlan): SparkPlan = { - if (!conf.dynamicPartitionPruningEnabled) { + if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala index c9ff28eb0459f..df5e3ea13652d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/PlanDynamicPruningFilters.scala @@ -45,7 +45,7 @@ case class PlanDynamicPruningFilters(sparkSession: SparkSession) extends Rule[Sp } override def apply(plan: SparkPlan): SparkPlan = { - if (!conf.dynamicPartitionPruningEnabled) { + if (!conf.dynamicPartitionPruningEnabled && !conf.runtimeRowLevelOperationGroupFilterEnabled) { return plan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala new file mode 100644 index 0000000000000..232c320bcd454 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/dynamicpruning/RowLevelOperationRuntimeGroupFiltering.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.dynamicpruning + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruningSubquery, Expression, PredicateHelper, V2ExpressionUtils} +import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral +import org.apache.spark.sql.catalyst.planning.GroupBasedRowLevelOperation +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.connector.read.SupportsRuntimeV2Filtering +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Implicits, DataSourceV2Relation, DataSourceV2ScanRelation} + +/** + * A rule that assigns a subquery to filter groups in row-level operations at runtime. + * + * Data skipping during job planning for row-level operations is limited to expressions that can be + * converted to data source filters. Since not all expressions can be pushed down that way and + * rewriting groups is expensive, Spark allows data sources to filter group at runtime. + * If the primary scan in a group-based row-level operation supports runtime filtering, this rule + * will inject a subquery to find all rows that match the condition so that data sources know + * exactly which groups must be rewritten. + * + * Note this rule only applies to group-based row-level operations. + */ +case class RowLevelOperationRuntimeGroupFiltering(optimizeSubqueries: Rule[LogicalPlan]) + extends Rule[LogicalPlan] with PredicateHelper { + + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): LogicalPlan = plan transformDown { + // apply special dynamic filtering only for group-based row-level operations + case GroupBasedRowLevelOperation(replaceData, cond, + DataSourceV2ScanRelation(_, scan: SupportsRuntimeV2Filtering, _, _, _)) + if conf.runtimeRowLevelOperationGroupFilterEnabled && cond != TrueLiteral => + + // use reference equality on scan to find required scan relations + val newQuery = replaceData.query transformUp { + case r: DataSourceV2ScanRelation if r.scan eq scan => + // use the original table instance that was loaded for this row-level operation + // in order to leverage a regular batch scan in the group filter query + val originalTable = r.relation.table.asRowLevelOperationTable.table + val relation = r.relation.copy(table = originalTable) + val matchingRowsPlan = buildMatchingRowsPlan(relation, cond) + + val filterAttrs = scan.filterAttributes + val buildKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, matchingRowsPlan) + val pruningKeys = V2ExpressionUtils.resolveRefs[Attribute](filterAttrs, r) + val dynamicPruningCond = buildDynamicPruningCond(matchingRowsPlan, buildKeys, pruningKeys) + + Filter(dynamicPruningCond, r) + } + + // optimize subqueries to rewrite them as joins and trigger job planning + replaceData.copy(query = optimizeSubqueries(newQuery)) + } + + private def buildMatchingRowsPlan( + relation: DataSourceV2Relation, + cond: Expression): LogicalPlan = { + + val matchingRowsPlan = Filter(cond, relation) + + // clone the relation and assign new expr IDs to avoid conflicts + matchingRowsPlan transformUpWithNewOutput { + case r: DataSourceV2Relation if r eq relation => + val oldOutput = r.output + val newOutput = oldOutput.map(_.newInstance()) + r.copy(output = newOutput) -> oldOutput.zip(newOutput) + } + } + + private def buildDynamicPruningCond( + matchingRowsPlan: LogicalPlan, + buildKeys: Seq[Attribute], + pruningKeys: Seq[Attribute]): Expression = { + + val buildQuery = Project(buildKeys, matchingRowsPlan) + val dynamicPruningSubqueries = pruningKeys.zipWithIndex.map { case (key, index) => + DynamicPruningSubquery(key, buildQuery, buildKeys, index, onlyInBroadcast = false) + } + dynamicPruningSubqueries.reduce(And) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala index a2cfdde2671f6..905f940816ec7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala @@ -22,14 +22,17 @@ import java.util.Collections import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, QueryTest, Row} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTableCatalog} +import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog} import org.apache.spark.sql.connector.expressions.LogicalExpressions._ -import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec} +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DeleteFromTableExec, ReplaceDataExec} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.unsafe.types.UTF8String abstract class DeleteFromTableSuiteBase extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { @@ -46,15 +49,19 @@ abstract class DeleteFromTableSuiteBase spark.sessionState.conf.unsetConf("spark.sql.catalog.cat") } - private val namespace = Array("ns1") - private val ident = Identifier.of(namespace, "test_table") - private val tableNameAsString = "cat." + ident.toString + protected val namespace: Array[String] = Array("ns1") + protected val ident: Identifier = Identifier.of(namespace, "test_table") + protected val tableNameAsString: String = "cat." + ident.toString - private def catalog: InMemoryRowLevelOperationTableCatalog = { + protected def catalog: InMemoryRowLevelOperationTableCatalog = { val catalog = spark.sessionState.catalogManager.catalog("cat") catalog.asTableCatalog.asInstanceOf[InMemoryRowLevelOperationTableCatalog] } + protected def table: InMemoryRowLevelOperationTable = { + catalog.loadTable(ident).asInstanceOf[InMemoryRowLevelOperationTable] + } + test("EXPLAIN only delete") { createAndInitTable("id INT, dep STRING", """{ "id": 1, "dep": "hr" }""") @@ -553,13 +560,13 @@ abstract class DeleteFromTableSuiteBase } } - private def createTable(schemaString: String): Unit = { + protected def createTable(schemaString: String): Unit = { val schema = StructType.fromDDL(schemaString) val tableProps = Collections.emptyMap[String, String] catalog.createTable(ident, schema, Array(identity(reference(Seq("dep")))), tableProps) } - private def createAndInitTable(schemaString: String, jsonData: String): Unit = { + protected def createAndInitTable(schemaString: String, jsonData: String): Unit = { createTable(schemaString) append(schemaString, jsonData) } @@ -606,7 +613,7 @@ abstract class DeleteFromTableSuiteBase } // executes an operation and keeps the executed plan - private def executeAndKeepPlan(func: => Unit): SparkPlan = { + protected def executeAndKeepPlan(func: => Unit): SparkPlan = { var executedPlan: SparkPlan = null val listener = new QueryExecutionListener { @@ -626,4 +633,142 @@ abstract class DeleteFromTableSuiteBase } } -class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase +class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { + + import testImplicits._ + + test("delete with IN predicate and runtime group filtering") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "salary INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + + test("delete with subqueries and runtime group filtering") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |{ "id": 4, "salary": 150, "dep": 'software' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(2), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq(Some("software"), None).toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + executeDeleteAndCheckScans( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IN (SELECT * FROM deleted_id) + | AND + | dep IN (SELECT * FROM deleted_dep) + |""".stripMargin, + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, 300, "hr") :: Row(3, 120, "hr") :: Row(4, 150, "software") :: Nil) + + checkReplacedPartitions(Seq("software")) + } + } + + test("delete runtime group filtering (DPP enabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (DPP disabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE enabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE disabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + private def checkDeleteRuntimeGroupFiltering(): Unit = { + withTempView("deleted_id") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(1), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + } + + private def executeDeleteAndCheckScans( + query: String, + primaryScanSchema: String, + groupFilterScanSchema: String): Unit = { + + val executedPlan = executeAndKeepPlan { + sql(query) + } + + val primaryScan = collect(executedPlan) { + case s: BatchScanExec => s + }.head + assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema))) + + primaryScan.runtimeFilters match { + case Seq(DynamicPruningExpression(child: InSubqueryExec)) => + val groupFilterScan = collect(child.plan) { + case s: BatchScanExec => s + }.head + assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema))) + + case _ => + fail("could not find group filter scan") + } + } + + private def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = { + val actualPartitions = table.replacedPartitions.map { + case Seq(partValue: UTF8String) => partValue.toString + case Seq(partValue) => partValue + case other => fail(s"expected only one partition value: $other" ) + } + assert(actualPartitions == expectedPartitions, "replaced partitions must match") + } +} From ccccfd40a22509521f36602cd3588d36a925ba3a Mon Sep 17 00:00:00 2001 From: aokolnychyi Date: Mon, 10 Oct 2022 16:33:06 -0700 Subject: [PATCH 2/2] Review feedback --- .../InMemoryRowLevelOperationTable.scala | 1 + ...e.scala => DeleteFromTableSuiteBase.scala} | 147 +--------------- .../GroupBasedDeleteFromTableSuite.scala | 166 ++++++++++++++++++ 3 files changed, 169 insertions(+), 145 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/connector/{DeleteFromTableSuite.scala => DeleteFromTableSuiteBase.scala} (79%) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala index 5099d73f18380..08c22a02b8555 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTable.scala @@ -34,6 +34,7 @@ class InMemoryRowLevelOperationTable( properties: util.Map[String, String]) extends InMemoryTable(name, schema, partitioning, properties) with SupportsRowLevelOperations { + // used in row-level operation tests to verify replaced partitions var replacedPartitions: Seq[Seq[Any]] = Seq.empty override def newRowLevelOperationBuilder( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala similarity index 79% rename from sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala index 905f940816ec7..d9a12b47ec269 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTableSuiteBase.scala @@ -22,17 +22,14 @@ import java.util.Collections import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, DataFrame, Encoders, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryRowLevelOperationTable, InMemoryRowLevelOperationTableCatalog} import org.apache.spark.sql.connector.expressions.LogicalExpressions._ -import org.apache.spark.sql.execution.{InSubqueryExec, QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DeleteFromTableExec, ReplaceDataExec} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.execution.datasources.v2.{DeleteFromTableExec, ReplaceDataExec} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.QueryExecutionListener -import org.apache.spark.unsafe.types.UTF8String abstract class DeleteFromTableSuiteBase extends QueryTest with SharedSparkSession with BeforeAndAfter with AdaptiveSparkPlanHelper { @@ -632,143 +629,3 @@ abstract class DeleteFromTableSuiteBase stripAQEPlan(executedPlan) } } - -class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { - - import testImplicits._ - - test("delete with IN predicate and runtime group filtering") { - createAndInitTable("id INT, salary INT, dep STRING", - """{ "id": 1, "salary": 300, "dep": 'hr' } - |{ "id": 2, "salary": 150, "dep": 'software' } - |{ "id": 3, "salary": 120, "dep": 'hr' } - |""".stripMargin) - - executeDeleteAndCheckScans( - s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)", - primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", - groupFilterScanSchema = "salary INT, dep STRING") - - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) - - checkReplacedPartitions(Seq("hr")) - } - - test("delete with subqueries and runtime group filtering") { - withTempView("deleted_id", "deleted_dep") { - createAndInitTable("id INT, salary INT, dep STRING", - """{ "id": 1, "salary": 300, "dep": 'hr' } - |{ "id": 2, "salary": 150, "dep": 'software' } - |{ "id": 3, "salary": 120, "dep": 'hr' } - |{ "id": 4, "salary": 150, "dep": 'software' } - |""".stripMargin) - - val deletedIdDF = Seq(Some(2), None).toDF() - deletedIdDF.createOrReplaceTempView("deleted_id") - - val deletedDepDF = Seq(Some("software"), None).toDF() - deletedDepDF.createOrReplaceTempView("deleted_dep") - - executeDeleteAndCheckScans( - s"""DELETE FROM $tableNameAsString - |WHERE - | id IN (SELECT * FROM deleted_id) - | AND - | dep IN (SELECT * FROM deleted_dep) - |""".stripMargin, - primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", - groupFilterScanSchema = "id INT, dep STRING") - - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Row(1, 300, "hr") :: Row(3, 120, "hr") :: Row(4, 150, "software") :: Nil) - - checkReplacedPartitions(Seq("software")) - } - } - - test("delete runtime group filtering (DPP enabled)") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { - checkDeleteRuntimeGroupFiltering() - } - } - - test("delete runtime group filtering (DPP disabled)") { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { - checkDeleteRuntimeGroupFiltering() - } - } - - test("delete runtime group filtering (AQE enabled)") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { - checkDeleteRuntimeGroupFiltering() - } - } - - test("delete runtime group filtering (AQE disabled)") { - withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { - checkDeleteRuntimeGroupFiltering() - } - } - - private def checkDeleteRuntimeGroupFiltering(): Unit = { - withTempView("deleted_id") { - createAndInitTable("id INT, salary INT, dep STRING", - """{ "id": 1, "salary": 300, "dep": 'hr' } - |{ "id": 2, "salary": 150, "dep": 'software' } - |{ "id": 3, "salary": 120, "dep": 'hr' } - |""".stripMargin) - - val deletedIdDF = Seq(Some(1), None).toDF() - deletedIdDF.createOrReplaceTempView("deleted_id") - - executeDeleteAndCheckScans( - s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)", - primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", - groupFilterScanSchema = "id INT, dep STRING") - - checkAnswer( - sql(s"SELECT * FROM $tableNameAsString"), - Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) - - checkReplacedPartitions(Seq("hr")) - } - } - - private def executeDeleteAndCheckScans( - query: String, - primaryScanSchema: String, - groupFilterScanSchema: String): Unit = { - - val executedPlan = executeAndKeepPlan { - sql(query) - } - - val primaryScan = collect(executedPlan) { - case s: BatchScanExec => s - }.head - assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema))) - - primaryScan.runtimeFilters match { - case Seq(DynamicPruningExpression(child: InSubqueryExec)) => - val groupFilterScan = collect(child.plan) { - case s: BatchScanExec => s - }.head - assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema))) - - case _ => - fail("could not find group filter scan") - } - } - - private def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = { - val actualPartitions = table.replacedPartitions.map { - case Seq(partValue: UTF8String) => partValue.toString - case Seq(partValue) => partValue - case other => fail(s"expected only one partition value: $other" ) - } - assert(actualPartitions == expectedPartitions, "replaced partitions must match") - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala new file mode 100644 index 0000000000000..36905027cb0cb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.execution.InSubqueryExec +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructType +import org.apache.spark.unsafe.types.UTF8String + +class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { + + import testImplicits._ + + test("delete with IN predicate and runtime group filtering") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE salary IN (300, 400, 500)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "salary INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + + test("delete with subqueries and runtime group filtering") { + withTempView("deleted_id", "deleted_dep") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |{ "id": 4, "salary": 150, "dep": 'software' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(2), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + val deletedDepDF = Seq(Some("software"), None).toDF() + deletedDepDF.createOrReplaceTempView("deleted_dep") + + executeDeleteAndCheckScans( + s"""DELETE FROM $tableNameAsString + |WHERE + | id IN (SELECT * FROM deleted_id) + | AND + | dep IN (SELECT * FROM deleted_dep) + |""".stripMargin, + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(1, 300, "hr") :: Row(3, 120, "hr") :: Row(4, 150, "software") :: Nil) + + checkReplacedPartitions(Seq("software")) + } + } + + test("delete runtime group filtering (DPP enabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (DPP disabled)") { + withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE enabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + checkDeleteRuntimeGroupFiltering() + } + } + + test("delete runtime group filtering (AQE disabled)") { + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + checkDeleteRuntimeGroupFiltering() + } + } + + private def checkDeleteRuntimeGroupFiltering(): Unit = { + withTempView("deleted_id") { + createAndInitTable("id INT, salary INT, dep STRING", + """{ "id": 1, "salary": 300, "dep": 'hr' } + |{ "id": 2, "salary": 150, "dep": 'software' } + |{ "id": 3, "salary": 120, "dep": 'hr' } + |""".stripMargin) + + val deletedIdDF = Seq(Some(1), None).toDF() + deletedIdDF.createOrReplaceTempView("deleted_id") + + executeDeleteAndCheckScans( + s"DELETE FROM $tableNameAsString WHERE id IN (SELECT * FROM deleted_id)", + primaryScanSchema = "id INT, salary INT, dep STRING, _partition STRING", + groupFilterScanSchema = "id INT, dep STRING") + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Row(2, 150, "software") :: Row(3, 120, "hr") :: Nil) + + checkReplacedPartitions(Seq("hr")) + } + } + + private def executeDeleteAndCheckScans( + query: String, + primaryScanSchema: String, + groupFilterScanSchema: String): Unit = { + + val executedPlan = executeAndKeepPlan { + sql(query) + } + + val primaryScan = collect(executedPlan) { + case s: BatchScanExec => s + }.head + assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema))) + + primaryScan.runtimeFilters match { + case Seq(DynamicPruningExpression(child: InSubqueryExec)) => + val groupFilterScan = collect(child.plan) { + case s: BatchScanExec => s + }.head + assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema))) + + case _ => + fail("could not find group filter scan") + } + } + + private def checkReplacedPartitions(expectedPartitions: Seq[Any]): Unit = { + val actualPartitions = table.replacedPartitions.map { + case Seq(partValue: UTF8String) => partValue.toString + case Seq(partValue) => partValue + case other => fail(s"expected only one partition value: $other" ) + } + assert(actualPartitions == expectedPartitions, "replaced partitions must match") + } +}