Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.DataFormat
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, ContinuousReader, Offset, PartitionOffset}
import org.apache.spark.sql.types.StructType
Expand All @@ -53,7 +54,7 @@ class KafkaContinuousReader(
metadataPath: String,
initialOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends ContinuousReader with SupportsScanUnsafeRow with Logging {
extends ContinuousReader with Logging {

private lazy val session = SparkSession.getActiveSession.get
private lazy val sc = session.sparkContext
Expand Down Expand Up @@ -86,7 +87,7 @@ class KafkaContinuousReader(
KafkaSourceOffset(JsonUtils.partitionOffsets(json))
}

override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
override def createDataReaderFactories(): ju.List[DataReaderFactory] = {
import scala.collection.JavaConverters._

val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(offset)
Expand All @@ -108,7 +109,7 @@ class KafkaContinuousReader(
case (topicPartition, start) =>
KafkaContinuousDataReaderFactory(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
.asInstanceOf[DataReaderFactory[UnsafeRow]]
.asInstanceOf[DataReaderFactory]
}.asJava
}

Expand Down Expand Up @@ -161,17 +162,11 @@ case class KafkaContinuousDataReaderFactory(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousDataReaderFactory[UnsafeRow] {
failOnDataLoss: Boolean) extends ContinuousDataReaderFactory {

override def createDataReaderWithOffset(offset: PartitionOffset): DataReader[UnsafeRow] = {
val kafkaOffset = offset.asInstanceOf[KafkaSourcePartitionOffset]
require(kafkaOffset.topicPartition == topicPartition,
s"Expected topicPartition: $topicPartition, but got: ${kafkaOffset.topicPartition}")
new KafkaContinuousDataReader(
topicPartition, kafkaOffset.partitionOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
}
override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW

override def createDataReader(): KafkaContinuousDataReader = {
override def createUnsafeRowDataReader(): KafkaContinuousDataReader = {
new KafkaContinuousDataReader(
topicPartition, startOffset, kafkaParams, pollTimeoutMs, failOnDataLoss)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.{DataFormat, DataSourceOptions}
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.UninterruptibleThread
Expand Down Expand Up @@ -61,7 +61,7 @@ private[kafka010] class KafkaMicroBatchReader(
metadataPath: String,
startingOffsets: KafkaOffsetRangeLimit,
failOnDataLoss: Boolean)
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
extends MicroBatchReader with Logging {

private var startPartitionOffsets: PartitionOffsetMap = _
private var endPartitionOffsets: PartitionOffsetMap = _
Expand Down Expand Up @@ -101,7 +101,7 @@ private[kafka010] class KafkaMicroBatchReader(
}
}

override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
override def createDataReaderFactories(): ju.List[DataReaderFactory] = {
// Find the new partitions, and get their earliest offsets
val newPartitions = endPartitionOffsets.keySet.diff(startPartitionOffsets.keySet)
val newPartitionInitialOffsets = kafkaOffsetReader.fetchEarliestOffsets(newPartitions.toSeq)
Expand Down Expand Up @@ -146,7 +146,7 @@ private[kafka010] class KafkaMicroBatchReader(
new KafkaMicroBatchDataReaderFactory(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}
factories.map(_.asInstanceOf[DataReaderFactory[UnsafeRow]]).asJava
factories.map(_.asInstanceOf[DataReaderFactory]).asJava
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this cast necessary?

}

override def getStartOffset: Offset = {
Expand Down Expand Up @@ -305,11 +305,13 @@ private[kafka010] case class KafkaMicroBatchDataReaderFactory(
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
reuseKafkaConsumer: Boolean) extends DataReaderFactory[UnsafeRow] {
reuseKafkaConsumer: Boolean) extends DataReaderFactory {

override def preferredLocations(): Array[String] = offsetRange.preferredLoc.toArray

override def createDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader(
override def dataFormat(): DataFormat = DataFormat.UNSAFE_ROW

override def createUnsafeRowDataReader(): DataReader[UnsafeRow] = new KafkaMicroBatchDataReader(
offsetRange, executorKafkaParams, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, ForeachWriter, SparkSession}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
Expand Down Expand Up @@ -673,7 +672,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase {
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 0L))),
Optional.of[OffsetV2](KafkaSourceOffset(Map(tp -> 100L)))
)
val factories = reader.createUnsafeRowReaderFactories().asScala
val factories = reader.createDataReaderFactories().asScala
.map(_.asInstanceOf[KafkaMicroBatchDataReaderFactory])
withClue(s"minPartitions = $minPartitions generated factories $factories\n\t") {
assert(factories.size == numPartitionsGenerated)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.sources.v2;

import org.apache.spark.sql.sources.v2.reader.DataReaderFactory;

/**
* An enum returned by {@link DataReaderFactory#dataFormat()}, representing the output data format of
* a data source scan.
*
* <ul>
* <li>{@link #ROW}</li>
* <li>{@link #UNSAFE_ROW}</li>
* <li>{@link #COLUMNAR_BATCH}</li>
* </ul>
*
* TODO: add INTERNAL_ROW
*/
public enum DataFormat {
/**
* Refers to {@link org.apache.spark.sql.Row}, which is very stable and guaranteed to be backward
* compatible. Spark needs to convert data of row format to the internal format, data source developers
* should consider using other formats for better performance.
*/
ROW,

/**
* Refers to {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow}, which is an unstable and
* internal API. It's already the internal format in Spark, so there is no extra conversion needed.
*/
UNSAFE_ROW,

/**
* Refers to {@link org.apache.spark.sql.vectorized.ColumnarBatch}, which is a public but experimental
* API. It's already the internal format in Spark, so there is no extra conversion needed. This format
* is recommended over others as columnar format has other advantages like vectorization, to further
* speed up the data processing.
*/
COLUMNAR_BATCH
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,48 @@
package org.apache.spark.sql.sources.v2.reader;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader;
import org.apache.spark.sql.vectorized.ColumnarBatch;

/**
* A mix-in interface for {@link DataReaderFactory}. Continuous data reader factories can
* implement this interface to provide creating {@link DataReader} with particular offset.
*/
@InterfaceStability.Evolving
public interface ContinuousDataReaderFactory<T> extends DataReaderFactory<T> {
public interface ContinuousDataReaderFactory extends DataReaderFactory {

/**
* Returns a row-formatted continuous data reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
default ContinuousDataReader<Row> createRowDataReader() {
throw new IllegalStateException(
"createRowDataReader must be implemented if the data format is ROW.");
}

/**
* Returns a unsafe-row-formatted continuous data reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
default ContinuousDataReader<UnsafeRow> createUnsafeRowDataReader() {
throw new IllegalStateException(
"createUnsafeRowDataReader must be implemented if the data format is UNSAFE_ROW.");
}

/**
* Create a DataReader with particular offset as its startOffset.
* Returns a columnar-batch-formatted continuous data reader to do the actual reading work.
*
* @param offset offset want to set as the DataReader's startOffset.
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
DataReader<T> createDataReaderWithOffset(PartitionOffset offset);
default ContinuousDataReader<ColumnarBatch> createColumnarBatchDataReader() {
throw new IllegalStateException(
"createColumnarBatchDataReader must be implemented if the data format is COLUMNAR_BATCH.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
import org.apache.spark.annotation.InterfaceStability;

/**
* A data reader returned by {@link DataReaderFactory#createDataReader()} and is responsible for
* outputting data for a RDD partition.
* A data reader returned by the create data reader method in {@link DataReaderFactory} and is
* responsible for outputting data for a RDD partition.
*
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data
* source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source
* readers that mix in {@link SupportsScanUnsafeRow}.
* Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row},
* or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow}, or
* {@link org.apache.spark.sql.vectorized.ColumnarBatch}, depending on the return type of
* {@link DataReaderFactory#dataFormat()}.
*/
@InterfaceStability.Evolving
public interface DataReader<T> extends Closeable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
import java.io.Serializable;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.sources.v2.DataFormat;
import org.apache.spark.sql.vectorized.ColumnarBatch;

/**
* A reader factory returned by {@link DataSourceReader#createDataReaderFactories()} and is
Expand All @@ -32,7 +36,7 @@
* serializable and {@link DataReader} doesn't need to be.
*/
@InterfaceStability.Evolving
public interface DataReaderFactory<T> extends Serializable {
public interface DataReaderFactory extends Serializable {

/**
* The preferred locations where the data reader returned by this reader factory can run faster,
Expand All @@ -52,10 +56,46 @@ default String[] preferredLocations() {
}

/**
* Returns a data reader to do the actual reading work.
* The output data format of this factory's data reader. Spark will invoke the corresponding
* create data reader method w.r.t. the return value of this method:
* <ul>
* <li>{@link DataFormat#ROW}: {@link #createRowDataReader()}</li>
* <li>{@link DataFormat#UNSAFE_ROW}: {@link #createUnsafeRowDataReader()}</li>
* <li>@{@link DataFormat#COLUMNAR_BATCH}: {@link #createColumnarBatchDataReader()}</li>
* </ul>
*/
DataFormat dataFormat();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the data format is determined when the factory is created, then I don't see why it is necessary to change the API. This just makes it more confusing.


/**
* Returns a row-formatted data reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
default DataReader<Row> createRowDataReader() {
throw new IllegalStateException(
"createRowDataReader must be implemented if the data format is ROW.");
}

/**
* Returns a unsafe-row-formatted data reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
DataReader<T> createDataReader();
default DataReader<UnsafeRow> createUnsafeRowDataReader() {
throw new IllegalStateException(
"createUnsafeRowDataReader must be implemented if the data format is UNSAFE_ROW.");
}

/**
* Returns a columnar-batch-formatted data reader to do the actual reading work.
*
* If this method fails (by throwing an exception), the corresponding Spark task would fail and
* get retried until hitting the maximum retry times.
*/
default DataReader<ColumnarBatch> createColumnarBatchDataReader() {
throw new IllegalStateException(
"createColumnarBatchDataReader must be implemented if the data format is COLUMNAR_BATCH.");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import java.util.List;

import org.apache.spark.annotation.InterfaceStability;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.sources.v2.DataSourceOptions;
import org.apache.spark.sql.sources.v2.ReadSupport;
import org.apache.spark.sql.sources.v2.ReadSupportWithSchema;
Expand All @@ -34,23 +33,18 @@
* logic is delegated to {@link DataReaderFactory}s that are returned by
* {@link #createDataReaderFactories()}.
*
* There are mainly 3 kinds of query optimizations:
* There are mainly 2 kinds of query optimizations:
* 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column
* pruning), etc. Names of these interfaces start with `SupportsPushDown`.
* 2. Information Reporting. E.g., statistics reporting, ordering reporting, etc.
* Names of these interfaces start with `SupportsReporting`.
* 3. Special scans. E.g, columnar scan, unsafe row scan, etc.
* Names of these interfaces start with `SupportsScan`. Note that a reader should only
* implement at most one of the special scans, if more than one special scans are implemented,
* only one of them would be respected, according to the priority list from high to low:
* {@link SupportsScanColumnarBatch}, {@link SupportsScanUnsafeRow}.
*
* If an exception was throw when applying any of these query optimizations, the action would fail
* and no Spark job was submitted.
*
* Spark first applies all operator push-down optimizations that this data source supports. Then
* Spark collects information this data source reported for further optimizations. Finally Spark
* issues the scan request and does the actual data reading.
* issues the scan request, create the {@link DataReaderFactory} and does the actual data reading.
*/
@InterfaceStability.Evolving
public interface DataSourceReader {
Expand All @@ -76,5 +70,5 @@ public interface DataSourceReader {
* If this method fails (by throwing an exception), the action would fail and no Spark job was
* submitted.
*/
List<DataReaderFactory<Row>> createDataReaderFactories();
List<DataReaderFactory> createDataReaderFactories();
}
Loading