Skip to content

Commit aadf953

Browse files
rdbluecloud-fan
authored andcommitted
[SPARK-23203][SQL] DataSourceV2: Use immutable logical plans.
## What changes were proposed in this pull request? SPARK-23203: DataSourceV2 should use immutable catalyst trees instead of wrapping a mutable DataSourceV2Reader. This commit updates DataSourceV2Relation and consolidates much of the DataSourceV2 API requirements for the read path in it. Instead of wrapping a reader that changes, the relation lazily produces a reader from its configuration. This commit also updates the predicate and projection push-down. Instead of the implementation from SPARK-22197, this reuses the rule matching from the Hive and DataSource read paths (using `PhysicalOperation`) and copies most of the implementation of `SparkPlanner.pruneFilterProject`, with updates for DataSourceV2. By reusing the implementation from other read paths, this should have fewer regressions from other read paths and is less code to maintain. The new push-down rules also supports the following edge cases: * The output of DataSourceV2Relation should be what is returned by the reader, in case the reader can only partially satisfy the requested schema projection * The requested projection passed to the DataSourceV2Reader should include filter columns * The push-down rule may be run more than once if filters are not pushed through projections ## How was this patch tested? Existing push-down and read tests. Author: Ryan Blue <[email protected]> Closes #20387 from rdblue/SPARK-22386-push-down-immutable-trees.
1 parent 651b027 commit aadf953

File tree

10 files changed

+269
-187
lines changed

10 files changed

+269
-187
lines changed

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,9 @@
1717

1818
package org.apache.spark.sql.kafka010
1919

20-
import java.util.Properties
21-
import java.util.concurrent.atomic.AtomicInteger
22-
23-
import org.scalatest.time.SpanSugar._
24-
import scala.collection.mutable
25-
import scala.util.Random
26-
27-
import org.apache.spark.SparkContext
28-
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
29-
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
30-
import org.apache.spark.sql.execution.streaming.StreamExecution
31-
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
32-
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
33-
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
20+
import org.apache.spark.sql.Dataset
21+
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
22+
import org.apache.spark.sql.streaming.Trigger
3423

3524
// Run tests in KafkaSourceSuiteBase in continuous execution mode.
3625
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
@@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
7160
eventually(timeout(streamingTimeout)) {
7261
assert(
7362
query.lastExecution.logical.collectFirst {
74-
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
63+
case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r
7564
}.exists { r =>
7665
// Ensure the new topic is present and the old topic is gone.
7766
r.knownPartitions.exists(_.topic == topic2)

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger
2121

2222
import org.apache.spark.SparkContext
2323
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
24-
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
24+
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
2525
import org.apache.spark.sql.execution.streaming.StreamExecution
2626
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
2727
import org.apache.spark.sql.streaming.Trigger
@@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest {
4747
eventually(timeout(streamingTimeout)) {
4848
assert(
4949
query.lastExecution.logical.collectFirst {
50-
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
50+
case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r
5151
}.exists(_.knownPartitions.size == newCount),
5252
s"query never reconfigured to $newCount partitions")
5353
}

external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.scalatest.time.SpanSugar._
3535

3636
import org.apache.spark.SparkContext
3737
import org.apache.spark.sql.{Dataset, ForeachWriter}
38-
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
38+
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
3939
import org.apache.spark.sql.execution.streaming._
4040
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
4141
import org.apache.spark.sql.functions.{count, window}
@@ -119,7 +119,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
119119
} ++ (query.get.lastExecution match {
120120
case null => Seq()
121121
case e => e.logical.collect {
122-
case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
122+
case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
123123
}
124124
})
125125
}.distinct

sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._
3434
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
3535
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
3636
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
37-
import org.apache.spark.sql.sources.v2._
37+
import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema}
3838
import org.apache.spark.sql.types.{StringType, StructType}
3939
import org.apache.spark.unsafe.types.UTF8String
4040

@@ -189,39 +189,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
189189

190190
val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
191191
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
192-
val ds = cls.newInstance()
193-
val options = new DataSourceOptions((extraOptions ++
194-
DataSourceV2Utils.extractSessionConfigs(
195-
ds = ds.asInstanceOf[DataSourceV2],
196-
conf = sparkSession.sessionState.conf)).asJava)
197-
198-
// Streaming also uses the data source V2 API. So it may be that the data source implements
199-
// v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
200-
// the dataframe as a v1 source.
201-
val reader = (ds, userSpecifiedSchema) match {
202-
case (ds: ReadSupportWithSchema, Some(schema)) =>
203-
ds.createReader(schema, options)
204-
205-
case (ds: ReadSupport, None) =>
206-
ds.createReader(options)
207-
208-
case (ds: ReadSupportWithSchema, None) =>
209-
throw new AnalysisException(s"A schema needs to be specified when using $ds.")
210-
211-
case (ds: ReadSupport, Some(schema)) =>
212-
val reader = ds.createReader(options)
213-
if (reader.readSchema() != schema) {
214-
throw new AnalysisException(s"$ds does not allow user-specified schemas.")
215-
}
216-
reader
217-
218-
case _ => null // fall back to v1
219-
}
192+
val ds = cls.newInstance().asInstanceOf[DataSourceV2]
193+
if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) {
194+
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
195+
ds = ds, conf = sparkSession.sessionState.conf)
196+
Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
197+
ds, extraOptions.toMap ++ sessionOptions,
198+
userSpecifiedSchema = userSpecifiedSchema))
220199

221-
if (reader == null) {
222-
loadV1Source(paths: _*)
223200
} else {
224-
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
201+
loadV1Source(paths: _*)
225202
}
226203
} else {
227204
loadV1Source(paths: _*)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala

Lines changed: 200 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,80 @@
1717

1818
package org.apache.spark.sql.execution.datasources.v2
1919

20+
import scala.collection.JavaConverters._
21+
22+
import org.apache.spark.sql.AnalysisException
2023
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
21-
import org.apache.spark.sql.catalyst.expressions.AttributeReference
22-
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
23-
import org.apache.spark.sql.sources.v2.reader._
24+
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
25+
import org.apache.spark.sql.catalyst.plans.QueryPlan
26+
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
27+
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
28+
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
29+
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema}
30+
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics}
31+
import org.apache.spark.sql.types.StructType
2432

2533
case class DataSourceV2Relation(
26-
output: Seq[AttributeReference],
27-
reader: DataSourceReader)
28-
extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {
34+
source: DataSourceV2,
35+
options: Map[String, String],
36+
projection: Seq[AttributeReference],
37+
filters: Option[Seq[Expression]] = None,
38+
userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation {
39+
40+
import DataSourceV2Relation._
41+
42+
override def simpleString: String = {
43+
s"DataSourceV2Relation(source=${source.name}, " +
44+
s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " +
45+
s"filters=[${pushedFilters.mkString(", ")}], options=$options)"
46+
}
47+
48+
override lazy val schema: StructType = reader.readSchema()
49+
50+
override lazy val output: Seq[AttributeReference] = {
51+
// use the projection attributes to avoid assigning new ids. fields that are not projected
52+
// will be assigned new ids, which is okay because they are not projected.
53+
val attrMap = projection.map(a => a.name -> a).toMap
54+
schema.map(f => attrMap.getOrElse(f.name,
55+
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()))
56+
}
57+
58+
private lazy val v2Options: DataSourceOptions = makeV2Options(options)
59+
60+
lazy val (
61+
reader: DataSourceReader,
62+
unsupportedFilters: Seq[Expression],
63+
pushedFilters: Seq[Expression]) = {
64+
val newReader = userSpecifiedSchema match {
65+
case Some(s) =>
66+
source.asReadSupportWithSchema.createReader(s, v2Options)
67+
case _ =>
68+
source.asReadSupport.createReader(v2Options)
69+
}
70+
71+
DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType)
2972

30-
override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation]
73+
val (remainingFilters, pushedFilters) = filters match {
74+
case Some(filterSeq) =>
75+
DataSourceV2Relation.pushFilters(newReader, filterSeq)
76+
case _ =>
77+
(Nil, Nil)
78+
}
79+
80+
(newReader, remainingFilters, pushedFilters)
81+
}
82+
83+
override def doCanonicalize(): LogicalPlan = {
84+
val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation]
85+
86+
// override output with canonicalized output to avoid attempting to configure a reader
87+
val canonicalOutput: Seq[AttributeReference] = this.output
88+
.map(a => QueryPlan.normalizeExprId(a, projection))
89+
90+
new DataSourceV2Relation(c.source, c.options, c.projection) {
91+
override lazy val output: Seq[AttributeReference] = canonicalOutput
92+
}
93+
}
3194

3295
override def computeStats(): Statistics = reader match {
3396
case r: SupportsReportStatistics =>
@@ -37,22 +100,147 @@ case class DataSourceV2Relation(
37100
}
38101

39102
override def newInstance(): DataSourceV2Relation = {
40-
copy(output = output.map(_.newInstance()))
103+
// projection is used to maintain id assignment.
104+
// if projection is not set, use output so the copy is not equal to the original
105+
copy(projection = projection.map(_.newInstance()))
41106
}
42107
}
43108

44109
/**
45110
* A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical
46111
* to the non-streaming relation.
47112
*/
48-
class StreamingDataSourceV2Relation(
113+
case class StreamingDataSourceV2Relation(
49114
output: Seq[AttributeReference],
50-
reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
115+
reader: DataSourceReader)
116+
extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation {
51117
override def isStreaming: Boolean = true
118+
119+
override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation]
120+
121+
override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))
122+
123+
override def computeStats(): Statistics = reader match {
124+
case r: SupportsReportStatistics =>
125+
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
126+
case _ =>
127+
Statistics(sizeInBytes = conf.defaultSizeInBytes)
128+
}
52129
}
53130

54131
object DataSourceV2Relation {
55-
def apply(reader: DataSourceReader): DataSourceV2Relation = {
56-
new DataSourceV2Relation(reader.readSchema().toAttributes, reader)
132+
private implicit class SourceHelpers(source: DataSourceV2) {
133+
def asReadSupport: ReadSupport = {
134+
source match {
135+
case support: ReadSupport =>
136+
support
137+
case _: ReadSupportWithSchema =>
138+
// this method is only called if there is no user-supplied schema. if there is no
139+
// user-supplied schema and ReadSupport was not implemented, throw a helpful exception.
140+
throw new AnalysisException(s"Data source requires a user-supplied schema: $name")
141+
case _ =>
142+
throw new AnalysisException(s"Data source is not readable: $name")
143+
}
144+
}
145+
146+
def asReadSupportWithSchema: ReadSupportWithSchema = {
147+
source match {
148+
case support: ReadSupportWithSchema =>
149+
support
150+
case _: ReadSupport =>
151+
throw new AnalysisException(
152+
s"Data source does not support user-supplied schema: $name")
153+
case _ =>
154+
throw new AnalysisException(s"Data source is not readable: $name")
155+
}
156+
}
157+
158+
def name: String = {
159+
source match {
160+
case registered: DataSourceRegister =>
161+
registered.shortName()
162+
case _ =>
163+
source.getClass.getSimpleName
164+
}
165+
}
166+
}
167+
168+
private def makeV2Options(options: Map[String, String]): DataSourceOptions = {
169+
new DataSourceOptions(options.asJava)
170+
}
171+
172+
private def schema(
173+
source: DataSourceV2,
174+
v2Options: DataSourceOptions,
175+
userSchema: Option[StructType]): StructType = {
176+
val reader = userSchema match {
177+
// TODO: remove this case because it is confusing for users
178+
case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] =>
179+
val reader = source.asReadSupport.createReader(v2Options)
180+
if (reader.readSchema() != s) {
181+
throw new AnalysisException(s"${source.name} does not allow user-specified schemas.")
182+
}
183+
reader
184+
case Some(s) =>
185+
source.asReadSupportWithSchema.createReader(s, v2Options)
186+
case _ =>
187+
source.asReadSupport.createReader(v2Options)
188+
}
189+
reader.readSchema()
190+
}
191+
192+
def create(
193+
source: DataSourceV2,
194+
options: Map[String, String],
195+
filters: Option[Seq[Expression]] = None,
196+
userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = {
197+
val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes
198+
DataSourceV2Relation(source, options, projection, filters,
199+
// if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must
200+
// be equal to the reader's schema. the schema method enforces this. because the user schema
201+
// and the reader's schema are identical, drop the user schema.
202+
if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None)
203+
}
204+
205+
private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = {
206+
reader match {
207+
case projectionSupport: SupportsPushDownRequiredColumns =>
208+
projectionSupport.pruneColumns(struct)
209+
case _ =>
210+
}
211+
}
212+
213+
private def pushFilters(
214+
reader: DataSourceReader,
215+
filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
216+
reader match {
217+
case catalystFilterSupport: SupportsPushDownCatalystFilters =>
218+
(
219+
catalystFilterSupport.pushCatalystFilters(filters.toArray),
220+
catalystFilterSupport.pushedCatalystFilters()
221+
)
222+
223+
case filterSupport: SupportsPushDownFilters =>
224+
// A map from original Catalyst expressions to corresponding translated data source
225+
// filters. If a predicate is not in this map, it means it cannot be pushed down.
226+
val translatedMap: Map[Expression, Filter] = filters.flatMap { p =>
227+
DataSourceStrategy.translateFilter(p).map(f => p -> f)
228+
}.toMap
229+
230+
// Catalyst predicate expressions that cannot be converted to data source filters.
231+
val nonConvertiblePredicates = filters.filterNot(translatedMap.contains)
232+
233+
// Data source filters that cannot be pushed down. An unhandled filter means
234+
// the data source cannot guarantee the rows returned can pass the filter.
235+
// As a result we must return it so Spark can plan an extra filter operator.
236+
val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet
237+
val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) =>
238+
unhandledFilters.contains(f)
239+
}
240+
241+
(nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq)
242+
243+
case _ => (filters, Nil)
244+
}
57245
}
58246
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan
2323

2424
object DataSourceV2Strategy extends Strategy {
2525
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
26-
case DataSourceV2Relation(output, reader) =>
27-
DataSourceV2ScanExec(output, reader) :: Nil
26+
case relation: DataSourceV2Relation =>
27+
DataSourceV2ScanExec(relation.output, relation.reader) :: Nil
28+
29+
case relation: StreamingDataSourceV2Relation =>
30+
DataSourceV2ScanExec(relation.output, relation.reader) :: Nil
2831

2932
case WriteToDataSourceV2(writer, query) =>
3033
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil

0 commit comments

Comments
 (0)