Skip to content

Commit 3aa0438

Browse files
committed
failed to get output from canonicalized data source v2 related plans
1 parent 19c7c7e commit 3aa0438

File tree

5 files changed

+48
-25
lines changed

5 files changed

+48
-25
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2
1919

2020
import java.util.Objects
2121

22-
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
22+
import org.apache.spark.sql.catalyst.expressions.Attribute
2323
import org.apache.spark.sql.sources.v2.reader._
2424

2525
/**
@@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._
2828
trait DataSourceReaderHolder {
2929

3030
/**
31-
* The full output of the data source reader, without column pruning.
31+
* The output of the data source reader, w.r.t. column pruning.
3232
*/
33-
def fullOutput: Seq[AttributeReference]
33+
def output: Seq[Attribute]
3434

3535
/**
3636
* The held data source reader.
@@ -46,7 +46,7 @@ trait DataSourceReaderHolder {
4646
case s: SupportsPushDownFilters => s.pushedFilters().toSet
4747
case _ => Nil
4848
}
49-
Seq(fullOutput, reader.getClass, reader.readSchema(), filters)
49+
Seq(output, reader.getClass, filters)
5050
}
5151

5252
def canEqual(other: Any): Boolean
@@ -61,8 +61,4 @@ trait DataSourceReaderHolder {
6161
override def hashCode(): Int = {
6262
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
6363
}
64-
65-
lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name =>
66-
fullOutput.find(_.name == name).get
67-
}
6864
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
2323
import org.apache.spark.sql.sources.v2.reader._
2424

2525
case class DataSourceV2Relation(
26-
fullOutput: Seq[AttributeReference],
26+
output: Seq[AttributeReference],
2727
reader: DataSourceReader)
2828
extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {
2929

@@ -37,7 +37,7 @@ case class DataSourceV2Relation(
3737
}
3838

3939
override def newInstance(): DataSourceV2Relation = {
40-
copy(fullOutput = fullOutput.map(_.newInstance()))
40+
copy(output = output.map(_.newInstance()))
4141
}
4242
}
4343

@@ -46,8 +46,8 @@ case class DataSourceV2Relation(
4646
* to the non-streaming relation.
4747
*/
4848
class StreamingDataSourceV2Relation(
49-
fullOutput: Seq[AttributeReference],
50-
reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) {
49+
output: Seq[AttributeReference],
50+
reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
5151
override def isStreaming: Boolean = true
5252
}
5353

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType
3535
* Physical plan node for scanning data from a data source.
3636
*/
3737
case class DataSourceV2ScanExec(
38-
fullOutput: Seq[AttributeReference],
38+
output: Seq[AttributeReference],
3939
@transient reader: DataSourceReader)
4040
extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {
4141

4242
override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]
4343

44-
override def producedAttributes: AttributeSet = AttributeSet(fullOutput)
45-
4644
override def outputPartitioning: physical.Partitioning = reader match {
4745
case s: SupportsReportPartitioning =>
4846
new DataSourcePartitioning(

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel
8181

8282
// TODO: add more push down rules.
8383

84-
pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
84+
val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet)
8585
// After column pruning, we may have redundant PROJECT nodes in the query plan, remove them.
86-
RemoveRedundantProject(filterPushed)
86+
RemoveRedundantProject(columnPruned)
8787
}
8888

8989
// TODO: nested fields pruning
90-
private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = {
90+
private def pushDownRequiredColumns(
91+
plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = {
9192
plan match {
92-
case Project(projectList, child) =>
93+
case p @ Project(projectList, child) =>
9394
val required = projectList.flatMap(_.references)
94-
pushDownRequiredColumns(child, AttributeSet(required))
95+
p.copy(child = pushDownRequiredColumns(child, AttributeSet(required)))
9596

96-
case Filter(condition, child) =>
97+
case f @ Filter(condition, child) =>
9798
val required = requiredByParent ++ condition.references
98-
pushDownRequiredColumns(child, required)
99+
f.copy(child = pushDownRequiredColumns(child, required))
99100

100101
case relation: DataSourceV2Relation => relation.reader match {
101102
case reader: SupportsPushDownRequiredColumns =>
103+
// TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now
104+
// it's possible that the mutable reader being updated by someone else, and we need to
105+
// always call `reader.pruneColumns` here to correct it.
106+
// assert(relation.output.toStructType == reader.readSchema(),
107+
// "Schema of data source reader does not match the relation plan.")
108+
102109
val requiredColumns = relation.output.filter(requiredByParent.contains)
103110
reader.pruneColumns(requiredColumns.toStructType)
104111

105-
case _ =>
112+
val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
113+
val newOutput = reader.readSchema().map(_.name).map(nameToAttr)
114+
relation.copy(output = newOutput)
115+
116+
case _ => relation
106117
}
107118

108119
// TODO: there may be more operators that can be used to calculate the required columns. We
109120
// can add more and more in the future.
110-
case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet))
121+
case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet))
111122
}
112123
}
113124

sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._
2424
import org.apache.spark.SparkException
2525
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
2626
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
27-
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
27+
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec}
2828
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
2929
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
3030
import org.apache.spark.sql.functions._
@@ -316,6 +316,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
316316
val reader4 = getReader(q4)
317317
assert(reader4.requiredSchema.fieldNames === Seq("i"))
318318
}
319+
320+
test("SPARK-23315: get output from canonicalized data source v2 related plans") {
321+
def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = {
322+
val logical = df.queryExecution.optimizedPlan.collect {
323+
case d: DataSourceV2Relation => d
324+
}.head
325+
assert(logical.canonicalized.output.length == numOutput)
326+
327+
val physical = df.queryExecution.executedPlan.collect {
328+
case d: DataSourceV2ScanExec => d
329+
}.head
330+
assert(physical.canonicalized.output.length == numOutput)
331+
}
332+
333+
val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load()
334+
checkCanonicalizedOutput(df, 2)
335+
checkCanonicalizedOutput(df.select('i), 1)
336+
}
319337
}
320338

321339
class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {

0 commit comments

Comments
 (0)