From 9f86b703536f2a2cd8c49215d9ddf0d96aa6d432 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Tue, 26 Apr 2022 22:39:44 -0700 Subject: [PATCH 1/2] [SPARK-38918][SQL] Nested column pruning should filter out attributes that do not belong to the current relation This PR updates `ProjectionOverSchema` to use the outputs of the data source relation to filter the attributes in the nested schema pruning. This is needed because the attributes in the schema do not necessarily belong to the current data source relation. For example, if a filter contains a correlated subquery, then the subquery's children can contain attributes from both the inner query and the outer query. Since the `RewriteSubquery` batch happens after early scan pushdown rules, nested schema pruning can wrongly use the inner query's attributes to prune the outer query data schema, thus causing wrong results and unexpected exceptions. To fix a bug in `SchemaPruning`. No Unit test Closes #36216 from allisonwang-db/spark-38918-nested-column-pruning. Authored-by: allisonwang-db Signed-off-by: Liang-Chi Hsieh (cherry picked from commit 150434b5d7909dcf8248ffa5ec3d937ea3da09fd) Signed-off-by: Liang-Chi Hsieh (cherry picked from commit 793ba608181b3eba8f1f57fcdd12dcd3fe035362) Signed-off-by: allisonwang-db --- .../expressions/ProjectionOverSchema.scala | 8 +++- .../sql/catalyst/optimizer/Optimizer.scala | 1 + .../sql/catalyst/optimizer/objects.scala | 2 +- .../execution/datasources/SchemaPruning.scala | 2 +- .../v2/V2ScanRelationPushDown.scala | 5 ++- .../datasources/SchemaPruningSuite.scala | 45 ++++++++++++++++++- 6 files changed, 56 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 5920ec749469..7a62bb294e40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -24,15 +24,19 @@ import org.apache.spark.sql.types._ * field indexes and field counts of complex type extractors and attributes * are adjusted to fit the schema. All other expressions are left as-is. This * class is motivated by columnar nested schema pruning. + * + * @param schema nested column schema + * @param output output attributes of the data source relation. They are used to filter out + * attributes in the schema that do not belong to the current relation. */ -case class ProjectionOverSchema(schema: StructType) { +case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { private val fieldNames = schema.fieldNames.toSet def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a: AttributeReference if fieldNames.contains(a.name) => + case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal) => getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 95985d4ae485..54309edd4851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -56,6 +56,7 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val blacklistedOnceBatches: Set[String] = Set( "PartitionPruning", + "RewriteSubquery", "Extract Python UDFs") protected def fixedPoint = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 9643f5827b91..abc6e3d16681 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -225,7 +225,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } // Builds new projection. - val projectionOverSchema = ProjectionOverSchema(prunedSchema) + val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output)) val newProjects = p.projectList.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index e3e89f6ca451..e605cd79aca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -82,7 +82,7 @@ object SchemaPruning extends Rule[LogicalPlan] { // in dataSchema. if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { val prunedRelation = leafNodeBuilder(prunedDataSchema) - val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + val projectionOverSchema = ProjectionOverSchema(prunedDataSchema, AttributeSet(output)) Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, prunedRelation, projectionOverSchema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 2864529cfc6b..2c6866a9d4c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -67,7 +67,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { val scanRelation = DataSourceV2ScanRelation(relation.table, wrappedScan, output) - val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionOverSchema = + ProjectionOverSchema(output.toStructType, AttributeSet(output)) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index a481fe571400..0773c72b9c14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -50,11 +50,15 @@ abstract class SchemaPruningSuite employer: Employer = null, relations: Map[FullName, String] = Map.empty) + case class Employee(id: Int, name: FullName, employer: Company) + val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") - val employer = Employer(0, Company("abc", "123 Business Street")) + val company = Company("abc", "123 Business Street") + + val employer = Employer(0, company) val employerWithNullCompany = Employer(1, null) val contacts = @@ -64,6 +68,8 @@ abstract class SchemaPruningSuite Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe), employer = employerWithNullCompany, relations = Map(janeDoe -> "sister")) :: Nil + val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil + case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) @@ -313,6 +319,26 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") { + withContacts { + withEmployees { + val query = sql( + """ + |select count(*) + |from contacts c + |where not exists (select null from employees e where e.name.first = c.name.first + | and e.employer.name = c.employer.company.name) + |""".stripMargin) + checkScan(query, + "struct," + + "employer:struct>>", + "struct," + + "employer:struct>") + checkAnswer(query, Row(3)) + } + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { @@ -381,6 +407,23 @@ abstract class SchemaPruningSuite } } + private def withEmployees(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeDataSourceFile(employees, new File(path + "/employees")) + + // Providing user specified schema. Inferred schema from different data sources might + // be different. + val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " + + "`employer` STRUCT<`name`: STRING, `address`: STRING>" + spark.read.format(dataSourceName).schema(schema).load(path + "/employees") + .createOrReplaceTempView("employees") + + testThunk + } + } + case class MixedCaseColumn(a: String, B: Int) case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn) From 9e799b769c097576deb397e83f6b6930a9042019 Mon Sep 17 00:00:00 2001 From: allisonwang-db Date: Wed, 27 Apr 2022 21:46:18 -0700 Subject: [PATCH 2/2] fix test --- .../spark/sql/execution/datasources/SchemaPruningSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 0773c72b9c14..89c2c016009e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -330,8 +330,7 @@ abstract class SchemaPruningSuite | and e.employer.name = c.employer.company.name) |""".stripMargin) checkScan(query, - "struct," + - "employer:struct>>", + "struct,employer:struct>>", "struct," + "employer:struct>") checkAnswer(query, Row(3))