Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,19 @@ import org.apache.spark.sql.types._
* field indexes and field counts of complex type extractors and attributes
* are adjusted to fit the schema. All other expressions are left as-is. This
* class is motivated by columnar nested schema pruning.
*
* @param schema nested column schema
* @param output output attributes of the data source relation. They are used to filter out
* attributes in the schema that do not belong to the current relation.
*/
case class ProjectionOverSchema(schema: StructType) {
case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
private val fieldNames = schema.fieldNames.toSet

def unapply(expr: Expression): Option[Expression] = getProjection(expr)

private def getProjection(expr: Expression): Option[Expression] =
expr match {
case a: AttributeReference if fieldNames.contains(a.name) =>
case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) =>
Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
getProjection(child).map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
override protected val excludedOnceBatches: Set[String] =
Set(
"PartitionPruning",
"RewriteSubquery",
"Extract Python UDFs")

protected def fixedPoint =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
}

// Builds new projection.
val projectionOverSchema = ProjectionOverSchema(prunedSchema)
val projectionOverSchema = ProjectionOverSchema(prunedSchema, AttributeSet(s.output))
val newProjects = p.projectList.map(_.transformDown {
case projectionOverSchema(expr) => expr
}).map { case expr: NamedExpression => expr }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
// in dataSchema.
if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
val prunedRelation = leafNodeBuilder(prunedDataSchema)
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema, AttributeSet(output))

Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
prunedRelation, projectionOverSchema))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.planning.ScanOperation
Expand Down Expand Up @@ -199,7 +199,8 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {

val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)

val projectionOverSchema = ProjectionOverSchema(output.toStructType)
val projectionOverSchema =
ProjectionOverSchema(output.toStructType, AttributeSet(output))
val projectionFunc = (expr: Expression) => expr transformDown {
case projectionOverSchema(newExpr) => newExpr
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,15 @@ abstract class SchemaPruningSuite
contactId: Int,
employer: Employer)

case class Employee(id: Int, name: FullName, employer: Company)

val janeDoe = FullName("Jane", "X.", "Doe")
val johnDoe = FullName("John", "Y.", "Doe")
val susanSmith = FullName("Susan", "Z.", "Smith")

val employer = Employer(0, Company("abc", "123 Business Street"))
val company = Company("abc", "123 Business Street")

val employer = Employer(0, company)
val employerWithNullCompany = Employer(1, null)
val employerWithNullCompany2 = Employer(2, null)

Expand All @@ -77,6 +81,8 @@ abstract class SchemaPruningSuite
Department(1, "Marketing", 1, employerWithNullCompany) ::
Department(2, "Operation", 4, employerWithNullCompany2) :: Nil

val employees = Employee(0, janeDoe, company) :: Employee(1, johnDoe, company) :: Nil

case class Name(first: String, last: String)
case class BriefContact(id: Int, name: Name, address: String)

Expand Down Expand Up @@ -617,6 +623,26 @@ abstract class SchemaPruningSuite
}
}

testSchemaPruning("SPARK-38918: nested schema pruning with correlated subqueries") {
withContacts {
withEmployees {
val query = sql(
"""
|select count(*)
|from contacts c
|where not exists (select null from employees e where e.name.first = c.name.first
| and e.employer.name = c.employer.company.name)
|""".stripMargin)
checkScan(query,
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<id:int,company:struct<name:string,address:string>>>",
"struct<name:struct<first:string,middle:string,last:string>," +
"employer:struct<name:string,address:string>>")
checkAnswer(query, Row(3))
}
}
}

protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
test(s"Spark vectorized reader - without partition data column - $testName") {
withSQLConf(vectorizedReaderEnabledKey -> "true") {
Expand Down Expand Up @@ -697,6 +723,23 @@ abstract class SchemaPruningSuite
}
}

private def withEmployees(testThunk: => Unit): Unit = {
withTempPath { dir =>
val path = dir.getCanonicalPath

makeDataSourceFile(employees, new File(path + "/employees"))

// Providing user specified schema. Inferred schema from different data sources might
// be different.
val schema = "`id` INT,`name` STRUCT<`first`: STRING, `middle`: STRING, `last`: STRING>, " +
"`employer` STRUCT<`name`: STRING, `address`: STRING>"
spark.read.format(dataSourceName).schema(schema).load(path + "/employees")
.createOrReplaceTempView("employees")

testThunk
}
}

case class MixedCaseColumn(a: String, B: Int)
case class MixedCase(id: Int, CoL1: String, coL2: MixedCaseColumn)

Expand Down