Skip to content

Commit 150434b

Browse files
allisonwang-dbviirya
authored andcommitted
[SPARK-38918][SQL] Nested column pruning should filter out attributes that do not belong to the current relation
### What changes were proposed in this pull request? This PR updates `ProjectionOverSchema` to use the outputs of the data source relation to filter the attributes in the nested schema pruning. This is needed because the attributes in the schema do not necessarily belong to the current data source relation. For example, if a filter contains a correlated subquery, then the subquery's children can contain attributes from both the inner query and the outer query. Since the `RewriteSubquery` batch happens after early scan pushdown rules, nested schema pruning can wrongly use the inner query's attributes to prune the outer query data schema, thus causing wrong results and unexpected exceptions. ### Why are the changes needed? To fix a bug in `SchemaPruning`. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #36216 from allisonwang-db/spark-38918-nested-column-pruning. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Liang-Chi Hsieh <[email protected]>
1 parent 1b7c636 commit 150434b

File tree

6 files changed

+57
-9
lines changed

6 files changed

+57
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,19 @@ import org.apache.spark.sql.types._
2424
* field indexes and field counts of complex type extractors and attributes
2525
* are adjusted to fit the schema. All other expressions are left as-is. This
2626
* class is motivated by columnar nested schema pruning.
27+
*
28+
* @param schema nested column schema
29+
* @param output output attributes of the data source relation. They are used to filter out
30+
* attributes in the schema that do not belong to the current relation.
2731
*/
28-
case class ProjectionOverSchema(schema: StructType) {
32+
case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
2933
private val fieldNames = schema.fieldNames.toSet
3034

3135
def unapply(expr: Expression): Option[Expression] = getProjection(expr)
3236

3337
private def getProjection(expr: Expression): Option[Expression] =
3438
expr match {
35-
case a: AttributeReference if fieldNames.contains(a.name) =>
39+
case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) =>
3640
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
3741
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
3842
getProjection(child).map {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
6060
override protected val excludedOnceBatches: Set[String] =
6161
Set(
6262
"PartitionPruning",
63+
"RewriteSubquery",
6364
"Extract Python UDFs")
6465

6566
protected def fixedPoint =

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
229229
}
230230

231231
// Builds new projection.
232-
val projectionOverSchema = ProjectionOverSchema(prunedSchema)
232+
val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output))
233233
val newProjects = p.projectList.map(_.transformDown {
234234
case projectionOverSchema(expr) => expr
235235
}).map { case expr: NamedExpression => expr }

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,8 @@ object SchemaPruning extends Rule[LogicalPlan] {
9191
if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) ||
9292
countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) {
9393
val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema)
94-
val projectionOverSchema =
95-
ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema))
94+
val projectionOverSchema = ProjectionOverSchema(
95+
prunedDataSchema.merge(prunedMetadataSchema), AttributeSet(relation.output))
9696
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
9797
prunedRelation, projectionOverSchema))
9898
} else {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
1919

2020
import scala.collection.mutable
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
23-
import org.apache.spark.sql.catalyst.expressions.aggregate
22+
import org.apache.spark.sql.catalyst.expressions.{aggregate, Alias, AliasHelper, And, Attribute, AttributeReference, AttributeSet, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
2423
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2524
import org.apache.spark.sql.catalyst.optimizer.CollapseProject
2625
import org.apache.spark.sql.catalyst.planning.ScanOperation
@@ -320,7 +319,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
320319

321320
val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
322321

323-
val projectionOverSchema = ProjectionOverSchema(output.toStructType)
322+
val projectionOverSchema =
323+
ProjectionOverSchema(output.toStructType, AttributeSet(output))
324324
val projectionFunc = (expr: Expression) => expr transformDown {
325325
case projectionOverSchema(newExpr) => newExpr
326326
}

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,15 @@ abstract class SchemaPruningSuite
6161
override protected def sparkConf: SparkConf =
6262
super.sparkConf.set(SQLConf.ANSI_STRICT_INDEX_OPERATOR.key, "false")
6363

64+
case class Employee(id: Int, name: FullName, employer: Company)
65+
6466
val janeDoe = FullName("Jane", "X.", "Doe")
6567
val johnDoe = FullName("John", "Y.", "Doe")
6668
val susanSmith = FullName("Susan", "Z.", "Smith")
6769

68-
val employer = Employer(0, Company("abc", "123 Business Street"))
70+
val company = Company("abc", "123 Business Street")
71+
72+
val employer = Employer(0, company)
6973
val employerWithNullCompany = Employer(1, null)
7074
val employerWithNullCompany2 = Employer(2, null)
7175

@@ -81,6 +85,8 @@ abstract class SchemaPruningSuite
8185
Department(1, "Marketing", 1, employerWithNullCompany) ::
8286
Department(2, "Operation", 4, employerWithNullCompany2) :: Nil
8387

88+
val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil
89+
8490
case class Name(first: String, last: String)
8591
case class BriefContact(id: Int, name: Name, address: String)
8692

@@ -621,6 +627,26 @@ abstract class SchemaPruningSuite
621627
}
622628
}
623629

630+
testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") {
631+
withContacts {
632+
withEmployees {
633+
val query = sql(
634+
"""
635+
|select count(*)
636+
|from contacts c
637+
|where not exists (select null from employees e where e.name.first = c.name.first
638+
| and e.employer.name = c.employer.company.name)
639+
|""".stripMargin)
640+
checkScan(query,
641+
"struct<name:struct<first:string,middle:string,last:string>," +
642+
"employer:struct<id:int,company:struct<name:string,address:string>>>",
643+
"struct<name:struct<first:string,middle:string,last:string>," +
644+
"employer:struct<name:string,address:string>>")
645+
checkAnswer(query, Row(3))
646+
}
647+
}
648+
}
649+
624650
protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
625651
test(s"Spark vectorized reader - without partition data column - $testName") {
626652
withSQLConf(vectorizedReaderEnabledKey -> "true") {
@@ -701,6 +727,23 @@ abstract class SchemaPruningSuite
701727
}
702728
}
703729

730+
private def withEmployees(testThunk: => Unit): Unit = {
731+
withTempPath { dir =>
732+
val path = dir.getCanonicalPath
733+
734+
makeDataSourceFile(employees, new File(path + "/employees"))
735+
736+
// Providing user specified schema. Inferred schema from different data sources might
737+
// be different.
738+
val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " +
739+
"`employer` STRUCT<`name`: STRING, `address`: STRING>"
740+
spark.read.format(dataSourceName).schema(schema).load(path + "/employees")
741+
.createOrReplaceTempView("employees")
742+
743+
testThunk
744+
}
745+
}
746+
704747
case class MixedCaseColumn(a: String, B: Int)
705748
case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)
706749

0 commit comments

Comments
 (0)