Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,9 @@

package org.apache.spark.sql.kafka010

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import org.scalatest.time.SpanSugar._
import scala.collection.mutable
import scala.util.Random

import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
import org.apache.spark.sql.streaming.Trigger

// Run tests in KafkaSourceSuiteBase in continuous execution mode.
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
Expand Down Expand Up @@ -71,7 +60,8 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] =>
r.reader.asInstanceOf[KafkaContinuousReader]
}.exists { r =>
// Ensure the new topic is present and the old topic is gone.
r.knownPartitions.exists(_.topic == topic2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ trait KafkaContinuousTest extends KafkaSourceTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] =>
r.reader.asInstanceOf[KafkaContinuousReader]
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
} ++ (query.get.lastExecution match {
case null => Seq()
case e => e.logical.collect {
case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
case r: DataSourceV2Relation if r.reader.isInstanceOf[KafkaContinuousReader] =>
r.reader.asInstanceOf[KafkaContinuousReader]
}
})
if (sources.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance()
val ds = cls.newInstance().asInstanceOf[DataSourceV2]
val options = new DataSourceOptions((extraOptions ++
DataSourceV2Utils.extractSessionConfigs(
ds = ds.asInstanceOf[DataSourceV2],
conf = sparkSession.sessionState.conf)).asJava)
DataSourceV2Utils.extractSessionConfigs(ds, sparkSession.sessionState.conf)).asJava)

// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
Expand Down Expand Up @@ -221,7 +219,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
if (reader == null) {
loadV1Source(paths: _*)
} else {
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
Dataset.ofRows(sparkSession, DataSourceV2Relation(ds, reader))
}
} else {
loadV1Source(paths: _*)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.datasources.v2

import java.util.Objects

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.util.Utils

/**
* A base class for data source v2 related query plan(both logical and physical). It defines the
* equals/hashCode methods, and provides a string representation of the query plan, according to
* some common information.
*/
trait DataSourceV2QueryPlan {

/**
* The output of the data source reader, w.r.t. column pruning.
*/
def output: Seq[Attribute]

/**
* The instance of this data source implementation. Note that we only consider its class in
* equals/hashCode, not the instance itself.
*/
def source: DataSourceV2

/**
* The created data source reader. Here we use it to get the filters that has been pushed down
* so far, itself doesn't take part in the equals/hashCode.
*/
def reader: DataSourceReader

private lazy val filters = reader match {
case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet
case s: SupportsPushDownFilters => s.pushedFilters().toSet
case _ => Set.empty
}

/**
* The metadata of this data source query plan that can be used for equality check.
*/
private def metadata: Seq[Any] = Seq(output, source.getClass, filters)

def canEqual(other: Any): Boolean

override def equals(other: Any): Boolean = other match {
case other: DataSourceV2QueryPlan => canEqual(other) && metadata == other.metadata
case _ => false
}

override def hashCode(): Int = {
metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b)
}

def metadataString: String = {
val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)]
if (filters.nonEmpty) entries += "PushedFilter" -> filters.mkString("[", ", ", "]")

val outputStr = Utils.truncatedString(output, "[", ", ", "]")

val entriesStr = if (entries.nonEmpty) {
Utils.truncatedString(entries.map {
case (key, value) => key + ": " + StringUtils.abbreviate(redact(value), 100)
}, " (", ", ", ")")
} else {
""
}

s"${source.getClass.getSimpleName}$outputStr$entriesStr"
}

private def redact(text: String): String = {
Utils.redact(SQLConf.get.stringRedationPattern, text)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,23 @@ package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._

case class DataSourceV2Relation(
output: Seq[AttributeReference],
reader: DataSourceReader)
extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {
source: DataSourceV2,
reader: DataSourceReader,
override val isStreaming: Boolean)
extends LeafNode with MultiInstanceRelation with DataSourceV2QueryPlan {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation]

override def simpleString: String = {
val streamingHeader = if (isStreaming) "Streaming " else ""
s"${streamingHeader}Relation $metadataString"
}

override def computeStats(): Statistics = reader match {
case r: SupportsReportStatistics =>
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
Expand All @@ -41,18 +49,8 @@ case class DataSourceV2Relation(
}
}

/**
* A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical
* to the non-streaming relation.
*/
class StreamingDataSourceV2Relation(
output: Seq[AttributeReference],
reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
override def isStreaming: Boolean = true
}

object DataSourceV2Relation {
def apply(reader: DataSourceReader): DataSourceV2Relation = {
new DataSourceV2Relation(reader.readSchema().toAttributes, reader)
def apply(source: DataSourceV2, reader: DataSourceReader): DataSourceV2Relation = {
new DataSourceV2Relation(reader.readSchema().toAttributes, source, reader, isStreaming = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
import org.apache.spark.sql.types.StructType
Expand All @@ -36,11 +37,14 @@ import org.apache.spark.sql.types.StructType
*/
case class DataSourceV2ScanExec(
output: Seq[AttributeReference],
@transient source: DataSourceV2,
@transient reader: DataSourceReader)
extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan {
extends LeafExecNode with DataSourceV2QueryPlan with ColumnarBatchScan {

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec]

override def simpleString: String = s"Scan $metadataString"

override def outputPartitioning: physical.Partitioning = reader match {
case s: SupportsReportPartitioning =>
new DataSourcePartitioning(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import org.apache.spark.sql.execution.SparkPlan

object DataSourceV2Strategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case DataSourceV2Relation(output, reader) =>
DataSourceV2ScanExec(output, reader) :: Nil
case r: DataSourceV2Relation =>
DataSourceV2ScanExec(r.output, r.source, r.reader) :: Nil

case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel
// TODO: Ideally column pruning should be implemented via a plan property that is propagated
// top-down, then we can simplify the logic here and only collect target operators.
val filterPushed = plan transformUp {
case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) =>
case FilterAndProject(fields, condition, r: DataSourceV2Relation) =>
val (candidates, nonDeterministic) =
splitConjunctivePredicates(condition).partition(_.deterministic)

val stayUpFilters: Seq[Expression] = reader match {
val stayUpFilters: Seq[Expression] = r.reader match {
case r: SupportsPushDownCatalystFilters =>
r.pushCatalystFilters(candidates.toArray)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, CurrentBatchTimestamp, CurrentDate, CurrentTimestamp}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, WriteToDataSourceV2}
import org.apache.spark.sql.execution.streaming.sources.{InternalRowMicroBatchWriter, MicroBatchWriter}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, MicroBatchReadSupport, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow
import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
Expand All @@ -52,6 +52,8 @@ class MicroBatchExecution(

@volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty

private val readerToDataSourceMap = MutableMap.empty[MicroBatchReader, DataSourceV2]

private val triggerExecutor = trigger match {
case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock)
case OneTimeTrigger => OneTimeExecutor()
Expand Down Expand Up @@ -90,6 +92,7 @@ class MicroBatchExecution(
metadataPath,
new DataSourceOptions(options.asJava))
nextSourceId += 1
readerToDataSourceMap(reader) = source
StreamingExecutionRelation(reader, output)(sparkSession)
})
case s @ StreamingRelationV2(_, sourceName, _, output, v1Relation) =>
Expand Down Expand Up @@ -405,12 +408,15 @@ class MicroBatchExecution(
case v1: SerializedOffset => reader.deserializeOffset(v1.json)
case v2: OffsetV2 => v2
}
reader.setOffsetRange(
toJava(current),
Optional.of(availableV2))
reader.setOffsetRange(toJava(current), Optional.of(availableV2))
logDebug(s"Retrieving data from $reader: $current -> $availableV2")
Some(reader ->
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
Some(reader -> new DataSourceV2Relation(
reader.readSchema().toAttributes,
// Provide a fake value here just in case something went wrong, e.g. the reader gives
// a wrong `equals` implementation.
readerToDataSourceMap.getOrElse(reader, FakeDataSourceV2),
reader,
isStreaming = true))
case _ => None
}
}
Expand Down Expand Up @@ -500,3 +506,5 @@ class MicroBatchExecution(
Optional.ofNullable(scalaOption.orNull)
}
}

object FakeDataSourceV2 extends DataSourceV2
Loading