diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala index 5920ec749469..7a62bb294e40 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala @@ -24,15 +24,19 @@ import org.apache.spark.sql.types._ * field indexes and field counts of complex type extractors and attributes * are adjusted to fit the schema. All other expressions are left as-is. This * class is motivated by columnar nested schema pruning. + * + * @param schema nested column schema + * @param output output attributes of the data source relation. They are used to filter out + * attributes in the schema that do not belong to the current relation. */ -case class ProjectionOverSchema(schema: StructType) { +case class ProjectionOverSchema(schema: StructType, output: AttributeSet) { private val fieldNames = schema.fieldNames.toSet def unapply(expr: Expression): Option[Expression] = getProjection(expr) private def getProjection(expr: Expression): Option[Expression] = expr match { - case a: AttributeReference if fieldNames.contains(a.name) => + case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) => Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier)) case GetArrayItem(child, arrayItemOrdinal) => getProjection(child).map { projection => GetArrayItem(projection, arrayItemOrdinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 95985d4ae485..54309edd4851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -56,6 +56,7 @@ abstract class Optimizer(catalogManager: CatalogManager) override protected val blacklistedOnceBatches: Set[String] = Set( "PartitionPruning", + "RewriteSubquery", "Extract Python UDFs") protected def fixedPoint = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 9643f5827b91..abc6e3d16681 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -225,7 +225,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { } // Builds new projection. - val projectionOverSchema = ProjectionOverSchema(prunedSchema) + val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output)) val newProjects = p.projectList.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index e3e89f6ca451..e605cd79aca9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -82,7 +82,7 @@ object SchemaPruning extends Rule[LogicalPlan] { // in dataSchema. if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { val prunedRelation = leafNodeBuilder(prunedDataSchema) - val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + val projectionOverSchema = ProjectionOverSchema(prunedDataSchema, AttributeSet(output)) Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, prunedRelation, projectionOverSchema)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 2864529cfc6b..2c6866a9d4c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeSet, Expression, NamedExpression, ProjectionOverSchema, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -67,7 +67,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] { val scanRelation = DataSourceV2ScanRelation(relation.table, wrappedScan, output) - val projectionOverSchema = ProjectionOverSchema(output.toStructType) + val projectionOverSchema = + ProjectionOverSchema(output.toStructType, AttributeSet(output)) val projectionFunc = (expr: Expression) => expr transformDown { case projectionOverSchema(newExpr) => newExpr } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index a481fe571400..89c2c016009e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -50,11 +50,15 @@ abstract class SchemaPruningSuite employer: Employer = null, relations: Map[FullName, String] = Map.empty) + case class Employee(id: Int, name: FullName, employer: Company) + val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") val susanSmith = FullName("Susan", "Z.", "Smith") - val employer = Employer(0, Company("abc", "123 Business Street")) + val company = Company("abc", "123 Business Street") + + val employer = Employer(0, company) val employerWithNullCompany = Employer(1, null) val contacts = @@ -64,6 +68,8 @@ abstract class SchemaPruningSuite Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe), employer = employerWithNullCompany, relations = Map(janeDoe -> "sister")) :: Nil + val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil + case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) @@ -313,6 +319,25 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") { + withContacts { + withEmployees { + val query = sql( + """ + |select count(*) + |from contacts c + |where not exists (select null from employees e where e.name.first = c.name.first + | and e.employer.name = c.employer.company.name) + |""".stripMargin) + checkScan(query, + "struct,employer:struct>>", + "struct," + + "employer:struct>") + checkAnswer(query, Row(3)) + } + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") { @@ -381,6 +406,23 @@ abstract class SchemaPruningSuite } } + private def withEmployees(testThunk: => Unit): Unit = { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeDataSourceFile(employees, new File(path + "/employees")) + + // Providing user specified schema. Inferred schema from different data sources might + // be different. + val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " + + "`employer` STRUCT<`name`: STRING, `address`: STRING>" + spark.read.format(dataSourceName).schema(schema).load(path + "/employees") + .createOrReplaceTempView("employees") + + testThunk + } + } + case class MixedCaseColumn(a: String, B: Int) case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)