From 11220db7879798967cda85d2aa4e68fefb8ec646 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 31 Jan 2018 11:12:00 +0800 Subject: [PATCH] make DataSourceV2Relation immutable --- .../kafka010/KafkaContinuousSourceSuite.scala | 19 ++--- .../sql/kafka010/KafkaContinuousTest.scala | 4 +- .../spark/sql/kafka010/KafkaSourceSuite.scala | 4 +- .../SupportsPushDownCatalystFilters.java | 8 --- .../v2/reader/SupportsPushDownFilters.java | 7 -- .../apache/spark/sql/DataFrameReader.scala | 5 +- .../v2/DataSourceReaderHolder.scala | 68 ------------------ .../v2/DataSourceV2QueryPlan.scala | 69 ++++++++++++++++++ .../datasources/v2/DataSourceV2Relation.scala | 64 ++++++++++++++--- .../datasources/v2/DataSourceV2ScanExec.scala | 14 ++-- .../datasources/v2/DataSourceV2Strategy.scala | 8 ++- .../v2/PushDownOperatorsToDataSource.scala | 70 ++++++++++-------- .../continuous/ContinuousExecution.scala | 2 +- .../sources/v2/JavaAdvancedDataSourceV2.java | 12 ++-- .../sql/sources/v2/DataSourceV2Suite.scala | 71 +++++++++++++++++-- .../spark/sql/streaming/StreamTest.scala | 6 +- .../continuous/ContinuousSuite.scala | 12 ++-- 17 files changed, 272 insertions(+), 171 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala index a7083fa4e341..f679e9bfc045 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSourceSuite.scala @@ -17,20 +17,9 @@ package org.apache.spark.sql.kafka010 -import java.util.Properties -import java.util.concurrent.atomic.AtomicInteger - -import org.scalatest.time.SpanSugar._ -import scala.collection.mutable -import scala.util.Random - -import org.apache.spark.SparkContext -import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution -import org.apache.spark.sql.streaming.{StreamTest, Trigger} -import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession} +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation +import org.apache.spark.sql.streaming.Trigger // Run tests in KafkaSourceSuiteBase in continuous execution mode. class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists { r => // Ensure the new topic is present and the old topic is gone. r.knownPartitions.exists(_.topic == topic2) diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index 5a1a14f7a307..48ac3fc1e8f9 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.apache.spark.SparkContext import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.StreamExecution import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.streaming.Trigger @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest { eventually(timeout(streamingTimeout)) { assert( query.lastExecution.logical.collectFirst { - case DataSourceV2Relation(_, r: KafkaContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r }.exists(_.knownPartitions.size == newCount), s"query never reconfigured to $newCount partitions") } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala index 02c87643568b..d26beca800bc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext import org.apache.spark.sql.{Dataset, ForeachWriter} -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} @@ -117,7 +117,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext { } ++ (query.get.lastExecution match { case null => Seq() case e => e.logical.collect { - case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader + case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader } }) if (sources.isEmpty) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 98224102374a..9359c46341e6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -37,12 +37,4 @@ public interface SupportsPushDownCatalystFilters extends DataSourceReader { * Pushes down filters, and returns unsupported filters. */ Expression[] pushCatalystFilters(Expression[] filters); - - /** - * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. - * It's possible that there is no filters in the query and - * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for - * this case. - */ - Expression[] pushedCatalystFilters(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index f35c711b0387..8aa5794881e4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -35,11 +35,4 @@ public interface SupportsPushDownFilters extends DataSourceReader { * Pushes down filters, and returns unsupported filters. */ Filter[] pushFilters(Filter[] filters); - - /** - * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. - * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} - * is never called, empty array should be returned for this case. - */ - Filter[] pushedFilters(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 46b5f54a33f7..7c15e342252a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -185,7 +185,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val ds = cls.newInstance() + val ds = cls.newInstance().asInstanceOf[DataSourceV2] val options = new DataSourceOptions((extraOptions ++ DataSourceV2Utils.extractSessionConfigs( ds = ds.asInstanceOf[DataSourceV2], @@ -217,7 +217,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (reader == null) { loadV1Source(paths: _*) } else { - Dataset.ofRows(sparkSession, DataSourceV2Relation(reader)) + Dataset.ofRows(sparkSession, + DataSourceV2Relation(ds, reader.readSchema(), options, userSpecifiedSchema)) } } else { loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala deleted file mode 100644 index 6460c97abe34..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import java.util.Objects - -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.sources.v2.reader._ - -/** - * A base class for data source reader holder with customized equals/hashCode methods. - */ -trait DataSourceReaderHolder { - - /** - * The full output of the data source reader, without column pruning. - */ - def fullOutput: Seq[AttributeReference] - - /** - * The held data source reader. - */ - def reader: DataSourceReader - - /** - * The metadata of this data source reader that can be used for equality test. - */ - private def metadata: Seq[Any] = { - val filters: Any = reader match { - case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet - case s: SupportsPushDownFilters => s.pushedFilters().toSet - case _ => Nil - } - Seq(fullOutput, reader.getClass, reader.readSchema(), filters) - } - - def canEqual(other: Any): Boolean - - override def equals(other: Any): Boolean = other match { - case other: DataSourceReaderHolder => - canEqual(other) && metadata.length == other.metadata.length && - metadata.zip(other.metadata).forall { case (l, r) => l == r } - case _ => false - } - - override def hashCode(): Int = { - metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) - } - - lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => - fullOutput.find(_.name == name).get - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala new file mode 100644 index 000000000000..1adb7c0c161b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2QueryPlan.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.commons.lang3.StringUtils + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.DataSourceV2 +import org.apache.spark.util.Utils + +/** + * A base class for data source v2 related query plan(both logical and physical). It defines the + * equals/hashCode methods according to some common information. + */ +trait DataSourceV2QueryPlan { + + def output: Seq[Attribute] + def sourceClass: Class[_ <: DataSourceV2] + def filters: Set[Expression] + + // The metadata of this data source relation that can be used for equality test. + private def metadata: Seq[Any] = Seq(output, sourceClass, filters) + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceV2QueryPlan => + canEqual(other) && metadata == other.metadata + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + def metadataString: String = { + val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] + if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]") + + val outputStr = Utils.truncatedString(output, "[", ", ", "]") + val entriesStr = Utils.truncatedString(entries.map { + case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100) + }, " (", ", ", ")") + + s"${sourceClass.getSimpleName}$outputStr$entriesStr" + } + + private def redact(text: String): String = { + Utils.redact(SQLConf.get.stringRedationPattern, text) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3d4c64981373..310945ffc00a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,36 +17,84 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema} import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.types.StructType +/** + * A logical plan representing a data source relation, which will be planned to a data scan + * operator finally. + * + * @param output The output of this relation. + * @param source The instance of a data source v2 implementation. + * @param options The options specified for this scan, used to create the `DataSourceReader`. + * @param userSpecifiedSchema The user specified schema, used to create the `DataSourceReader`. + * @param filters The predicates which are pushed and handled by this data source. + * @param existingReader A mutable reader carrying some temporary stats during optimization and + * planning. It's always None before optimization, and does not take part in + * the equality of this plan, which means this plan is still immutable. + */ case class DataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends LeafNode with DataSourceReaderHolder { + output: Seq[AttributeReference], + source: DataSourceV2, + options: DataSourceOptions, + userSpecifiedSchema: Option[StructType], + filters: Set[Expression], + existingReader: Option[DataSourceReader]) extends LeafNode with DataSourceV2QueryPlan { + + override def references: AttributeSet = AttributeSet.empty + + override def sourceClass: Class[_ <: DataSourceV2] = source.getClass override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] + def reader: DataSourceReader = existingReader.getOrElse { + (source, userSpecifiedSchema) match { + case (ds: ReadSupportWithSchema, Some(schema)) => + ds.createReader(schema, options) + + case (ds: ReadSupport, None) => + ds.createReader(options) + + case (ds: ReadSupport, Some(schema)) => + val reader = ds.createReader(options) + // Sanity check, this should be guaranteed by `DataFrameReader.load` + assert(reader.readSchema() == schema) + reader + + case _ => throw new IllegalStateException() + } + } + override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } + + override def simpleString: String = s"Relation $metadataString" } /** * A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical * to the non-streaming relation. */ -class StreamingDataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { +case class StreamingDataSourceV2Relation( + output: Seq[AttributeReference], + reader: DataSourceReader) extends LeafNode { override def isStreaming: Boolean = true } object DataSourceV2Relation { - def apply(reader: DataSourceReader): DataSourceV2Relation = { - new DataSourceV2Relation(reader.readSchema().toAttributes, reader) + def apply( + source: DataSourceV2, + schema: StructType, + options: DataSourceOptions, + userSpecifiedSchema: Option[StructType]): DataSourceV2Relation = { + DataSourceV2Relation( + schema.toAttributes, source, options, userSpecifiedSchema, Set.empty, None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index ee085820b077..6f89f53811ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.language.existentials import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -27,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.sources.v2.DataSourceV2 import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.streaming.reader.ContinuousReader import org.apache.spark.sql.types.StructType @@ -35,13 +37,17 @@ import org.apache.spark.sql.types.StructType * Physical plan node for scanning data from a data source. */ case class DataSourceV2ScanExec( - fullOutput: Seq[AttributeReference], - @transient reader: DataSourceReader) - extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { + output: Seq[AttributeReference], + @transient reader: DataSourceReader, + @transient sourceClass: Class[_ <: DataSourceV2], + @transient filters: Set[Expression]) + extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan { + + override def references: AttributeSet = AttributeSet.empty override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def producedAttributes: AttributeSet = AttributeSet(fullOutput) + override def simpleString: String = s"Scan $metadataString" override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index df5b524485f5..fb23c3751ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,8 +23,12 @@ import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case DataSourceV2Relation(output, reader) => - DataSourceV2ScanExec(output, reader) :: Nil + case relation: DataSourceV2Relation => + DataSourceV2ScanExec( + relation.output, + relation.reader, + relation.sourceClass, + relation.filters) :: Nil case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index df034adf1e7d..20823f32af32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, AttributeSet, Expression, NamedExpression, PredicateHelper} import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule @@ -39,10 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: Ideally column pruning should be implemented via a plan property that is propagated // top-down, then we can simplify the logic here and only collect target operators. val filterPushed = plan transformUp { - case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + case FilterAndProject(fields, condition, relation: DataSourceV2Relation) => val (candidates, nonDeterministic) = splitConjunctivePredicates(condition).partition(_.deterministic) + val reader = relation.reader val stayUpFilters: Seq[Expression] = reader match { case r: SupportsPushDownCatalystFilters => r.pushCatalystFilters(candidates.toArray) @@ -70,8 +71,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel case _ => candidates } + val newRelation = relation.copy( + filters = candidates.toSet -- stayUpFilters, + existingReader = Some(reader)) val filterCondition = (stayUpFilters ++ nonDeterministic).reduceLeftOption(And) - val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) + val withFilter = filterCondition.map(Filter(_, newRelation)).getOrElse(newRelation) if (withFilter.output == fields) { withFilter } else { @@ -81,35 +85,45 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - // TODO: nested fields pruning - def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { - plan match { - case Project(projectList, child) => - val required = projectList.filter(requiredByParent.contains).flatMap(_.references) - pushDownRequiredColumns(child, required) - - case Filter(condition, child) => - val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) - - case DataSourceV2Relation(fullOutput, reader) => reader match { - case r: SupportsPushDownRequiredColumns => - // Match original case of attributes. - val attrMap = AttributeMap(fullOutput.zip(fullOutput)) - val requiredColumns = requiredByParent.map(attrMap) - r.pruneColumns(requiredColumns.toStructType) - case _ => + val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(columnPruned) + } + + // TODO: nested fields pruning + private def pushDownRequiredColumns( + plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = plan match { + case p @ Project(projectList, child) => + val required = projectList.flatMap(_.references) + p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) + + case f @ Filter(condition, child) => + val required = requiredByParent ++ condition.references + f.copy(child = pushDownRequiredColumns(child, required)) + + case relation: DataSourceV2Relation => relation.reader match { + case reader: SupportsPushDownRequiredColumns => + if (requiredByParent == relation.outputSet) { + relation + } else { + assert(relation.output.toStructType == reader.readSchema(), + "Schema of data source reader does not match the relation plan.") + + // Match original case of attributes. + val requiredColumns = relation.output.filter(requiredByParent.contains) + reader.pruneColumns(requiredColumns.toStructType) + + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + val newOutput = reader.readSchema().map(_.name).map(nameToAttr) + relation.copy(output = newOutput, existingReader = Some(reader)) } - // TODO: there may be more operators can be used to calculate required columns, we can add - // more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) - } + case _ => relation } - pushDownRequiredColumns(filterPushed, filterPushed.output) - // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + // TODO: there may be more operators can be used to calculate required columns, we can add + // more and more in the future. + case other => other.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 9402d7c1dcef..ce69b65cb6e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -202,7 +202,7 @@ class ContinuousExecution( val withSink = WriteToDataSourceV2(writer, triggerLogicalPlan) val reader = withSink.collect { - case DataSourceV2Relation(_, r: ContinuousReader) => r + case StreamingDataSourceV2Relation(_, r: ContinuousReader) => r }.head reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index d421f7d19563..0e8b46751fd2 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -32,11 +32,12 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, + public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, SupportsPushDownFilters { - private StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); - private Filter[] filters = new Filter[0]; + // Exposed for testing + public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); + public Filter[] filters = new Filter[0]; @Override public StructType readSchema() { @@ -54,11 +55,6 @@ public Filter[] pushFilters(Filter[] filters) { return new Filter[0]; } - @Override - public Filter[] pushedFilters() { - return filters; - } - @Override public List> createDataReaderFactories() { List> res = new ArrayList<>(); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index ee50e8a92270..18e76afc4ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -51,10 +52,70 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j), (0 until 10).map(i => Row(-i))) - checkAnswer(df.filter('i > 3), (4 until 10).map(i => Row(i, -i))) - checkAnswer(df.select('j).filter('i > 6), (7 until 10).map(i => Row(-i))) - checkAnswer(df.select('i).filter('i > 10), Nil) + + val q1 = df.select('j) + checkAnswer(q1, (0 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = q1.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.map(_.name) == Seq("j")) + } else { + val reader = q1.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.isEmpty) + assert(reader.requiredSchema.map(_.name) == Seq("j")) + } + + val q2 = df.filter('i > 3) + checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = q2.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.map(_.name) == Seq("i", "j")) + } else { + val reader = q2.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.map(_.name) == Seq("i", "j")) + } + + val q3 = df.select('j).filter('i > 6) + checkAnswer(q3, (7 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = q3.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.map(_.name) == Seq("j")) + } else { + val reader = q3.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.map(_.name) == Seq("j")) + } + + val q4 = df.select('i).filter('i > 10) + checkAnswer(q4, Nil) + if (cls == classOf[AdvancedDataSourceV2]) { + val reader = q4.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.map(_.name) == Seq("i")) + } else { + val reader = q4.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] + }.head + assert(reader.filters.flatMap(_.references).toSet == Set("i")) + assert(reader.requiredSchema.map(_.name) == Seq("i")) + } } } } @@ -248,8 +309,6 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { Array.empty } - override def pushedFilters(): Array[Filter] = filters - override def readSchema(): StructType = { requiredSchema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index d6433562fb29..06235790331d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -38,9 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch} +import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ @@ -605,7 +605,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case StreamingExecutionRelation(s, _) => s - case DataSourceV2Relation(_, r) => r + case StreamingDataSourceV2Relation(_, r) => r } .zipWithIndex .find(_._1 == source) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 4b4ed82dc652..9f5d08e6eace 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming.continuous -import java.util.UUID - -import org.apache.spark.{SparkContext, SparkEnv, SparkException} -import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskStart} +import org.apache.spark.{SparkContext, SparkException} +import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -43,7 +40,8 @@ class ContinuousSuiteBase extends StreamTest { case s: ContinuousExecution => assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure query is initialized") val reader = s.lastExecution.executedPlan.collectFirst { - case DataSourceV2ScanExec(_, r: RateStreamContinuousReader) => r + case d: DataSourceV2ScanExec if d.reader.isInstanceOf[RateStreamContinuousReader] => + d.reader.asInstanceOf[RateStreamContinuousReader] }.get val deltaMs = numTriggers * 1000 + 300