Skip to content

Commit edb140e

Browse files
HeartSaVioRdongjoon-hyun
authored andcommitted
[SPARK-32896][SS] Add DataStreamWriter.table API
### What changes were proposed in this pull request? This PR proposes to add `DataStreamWriter.table` to specify the output "table" to write from the streaming query. ### Why are the changes needed? For now, there's no way to write to the table (especially catalog table) even the table is capable to handle streaming write, so even with Spark 3, writing to the catalog table via SS should go through the `DataStreamWriter.format(provider)` and wish the provider can handle it as same as we do with catalog table. With the new API, we can directly point to the catalog table which supports streaming write. Some of usages are covered with tests - simply saying, end users can do the following: ```scala // assuming `testcat` is a custom catalog, and `ns` is a namespace in the catalog spark.sql("CREATE TABLE testcat.ns.table1 (id bigint, data string) USING foo") val query = inputDF .writeStream .table("testcat.ns.table1") .option(...) .start() ``` ### Does this PR introduce _any_ user-facing change? Yes, as this adds a new public API in DataStreamWriter. This doesn't bring backward incompatible change. ### How was this patch tested? New unit tests. Closes #29767 from HeartSaVioR/SPARK-32896. Authored-by: Jungtaek Lim (HeartSaVioR) <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent e1909c9 commit edb140e

File tree

3 files changed

+299
-60
lines changed

3 files changed

+299
-60
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog._
3232
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
3333
import org.apache.spark.sql.connector.read._
3434
import org.apache.spark.sql.connector.write._
35+
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
3536
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
3637
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
3738
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -145,6 +146,7 @@ class InMemoryTable(
145146
override def capabilities: util.Set[TableCapability] = Set(
146147
TableCapability.BATCH_READ,
147148
TableCapability.BATCH_WRITE,
149+
TableCapability.STREAMING_WRITE,
148150
TableCapability.OVERWRITE_BY_FILTER,
149151
TableCapability.OVERWRITE_DYNAMIC,
150152
TableCapability.TRUNCATE).asJava
@@ -169,26 +171,35 @@ class InMemoryTable(
169171

170172
new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite {
171173
private var writer: BatchWrite = Append
174+
private var streamingWriter: StreamingWrite = StreamingAppend
172175

173176
override def truncate(): WriteBuilder = {
174177
assert(writer == Append)
175178
writer = TruncateAndAppend
179+
streamingWriter = StreamingTruncateAndAppend
176180
this
177181
}
178182

179183
override def overwrite(filters: Array[Filter]): WriteBuilder = {
180184
assert(writer == Append)
181185
writer = new Overwrite(filters)
186+
streamingWriter = new StreamingNotSupportedOperation(s"overwrite ($filters)")
182187
this
183188
}
184189

185190
override def overwriteDynamicPartitions(): WriteBuilder = {
186191
assert(writer == Append)
187192
writer = DynamicOverwrite
193+
streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
188194
this
189195
}
190196

191197
override def buildForBatch(): BatchWrite = writer
198+
199+
override def buildForStreaming(): StreamingWrite = streamingWriter match {
200+
case exc: StreamingNotSupportedOperation => exc.throwsException()
201+
case s => s
202+
}
192203
}
193204
}
194205

@@ -231,6 +242,45 @@ class InMemoryTable(
231242
}
232243
}
233244

245+
private abstract class TestStreamingWrite extends StreamingWrite {
246+
def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
247+
BufferedRowsWriterFactory
248+
}
249+
250+
def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
251+
}
252+
253+
private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
254+
override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
255+
throwsException()
256+
257+
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
258+
throwsException()
259+
260+
override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit =
261+
throwsException()
262+
263+
def throwsException[T](): T = throw new IllegalStateException("The operation " +
264+
s"${operation} isn't supported for streaming query.")
265+
}
266+
267+
private object StreamingAppend extends TestStreamingWrite {
268+
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
269+
dataMap.synchronized {
270+
withData(messages.map(_.asInstanceOf[BufferedRows]))
271+
}
272+
}
273+
}
274+
275+
private object StreamingTruncateAndAppend extends TestStreamingWrite {
276+
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
277+
dataMap.synchronized {
278+
dataMap.clear
279+
withData(messages.map(_.asInstanceOf[BufferedRows]))
280+
}
281+
}
282+
}
283+
234284
override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
235285
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
236286
dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
@@ -310,10 +360,17 @@ private class BufferedRowsReader(partition: BufferedRows) extends PartitionReade
310360
override def close(): Unit = {}
311361
}
312362

313-
private object BufferedRowsWriterFactory extends DataWriterFactory {
363+
private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory {
314364
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
315365
new BufferWriter
316366
}
367+
368+
override def createWriter(
369+
partitionId: Int,
370+
taskId: Long,
371+
epochId: Long): DataWriter[InternalRow] = {
372+
new BufferWriter
373+
}
317374
}
318375

319376
private class BufferWriter extends DataWriter[InternalRow] {

sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala

Lines changed: 89 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.api.java.function.VoidFunction2
2727
import org.apache.spark.sql._
2828
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
2929
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
30-
import org.apache.spark.sql.connector.catalog.{SupportsWrite, TableProvider}
30+
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableProvider}
3131
import org.apache.spark.sql.connector.catalog.TableCapability._
3232
import org.apache.spark.sql.execution.command.DDLUtils
3333
import org.apache.spark.sql.execution.datasources.DataSource
@@ -45,6 +45,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
4545
*/
4646
@Evolving
4747
final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
48+
import DataStreamWriter._
4849

4950
private val df = ds.toDF()
5051

@@ -294,60 +295,75 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
294295
@throws[TimeoutException]
295296
def start(): StreamingQuery = startInternal(None)
296297

298+
/**
299+
* Starts the execution of the streaming query, which will continually output results to the given
300+
* table as new data arrives. The returned [[StreamingQuery]] object can be used to interact with
301+
* the stream.
302+
*
303+
* @since 3.1.0
304+
*/
305+
@throws[TimeoutException]
306+
def saveAsTable(tableName: String): StreamingQuery = {
307+
this.source = SOURCE_NAME_TABLE
308+
this.tableName = tableName
309+
startInternal(None)
310+
}
311+
297312
private def startInternal(path: Option[String]): StreamingQuery = {
298313
if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
299314
throw new AnalysisException("Hive data source can only be used with tables, you can not " +
300315
"write files of Hive data source directly.")
301316
}
302317

303-
if (source == "memory") {
304-
assertNotPartitioned("memory")
318+
if (source == SOURCE_NAME_TABLE) {
319+
assertNotPartitioned(SOURCE_NAME_TABLE)
320+
321+
import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier
322+
323+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
324+
val originalMultipartIdentifier = df.sparkSession.sessionState.sqlParser
325+
.parseMultipartIdentifier(tableName)
326+
val CatalogAndIdentifier(catalog, identifier) = originalMultipartIdentifier
327+
328+
// Currently we don't create a logical streaming writer node in logical plan, so cannot rely
329+
// on analyzer to resolve it. Directly lookup only for temp view to provide clearer message.
330+
// TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
331+
if (df.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) {
332+
throw new AnalysisException(s"Temporary view $tableName doesn't support streaming write")
333+
}
334+
335+
val tableInstance = catalog.asTableCatalog.loadTable(identifier)
336+
337+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
338+
val sink = tableInstance match {
339+
case t: SupportsWrite if t.supports(STREAMING_WRITE) => t
340+
case t => throw new AnalysisException(s"Table $tableName doesn't support streaming " +
341+
s"write - $t")
342+
}
343+
344+
startQuery(sink, extraOptions)
345+
} else if (source == SOURCE_NAME_MEMORY) {
346+
assertNotPartitioned(SOURCE_NAME_MEMORY)
305347
if (extraOptions.get("queryName").isEmpty) {
306348
throw new AnalysisException("queryName must be specified for memory sink")
307349
}
308350
val sink = new MemorySink()
309351
val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes))
310-
val chkpointLoc = extraOptions.get("checkpointLocation")
311352
val recoverFromChkpoint = outputMode == OutputMode.Complete()
312-
val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
313-
extraOptions.get("queryName"),
314-
chkpointLoc,
315-
df,
316-
extraOptions.toMap,
317-
sink,
318-
outputMode,
319-
useTempCheckpointLocation = true,
320-
recoverFromCheckpointLocation = recoverFromChkpoint,
321-
trigger = trigger)
353+
val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromChkpoint)
322354
resultDf.createOrReplaceTempView(query.name)
323355
query
324-
} else if (source == "foreach") {
325-
assertNotPartitioned("foreach")
356+
} else if (source == SOURCE_NAME_FOREACH) {
357+
assertNotPartitioned(SOURCE_NAME_FOREACH)
326358
val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc)
327-
df.sparkSession.sessionState.streamingQueryManager.startQuery(
328-
extraOptions.get("queryName"),
329-
extraOptions.get("checkpointLocation"),
330-
df,
331-
extraOptions.toMap,
332-
sink,
333-
outputMode,
334-
useTempCheckpointLocation = true,
335-
trigger = trigger)
336-
} else if (source == "foreachBatch") {
337-
assertNotPartitioned("foreachBatch")
359+
startQuery(sink, extraOptions)
360+
} else if (source == SOURCE_NAME_FOREACH_BATCH) {
361+
assertNotPartitioned(SOURCE_NAME_FOREACH_BATCH)
338362
if (trigger.isInstanceOf[ContinuousTrigger]) {
339-
throw new AnalysisException("'foreachBatch' is not supported with continuous trigger")
363+
throw new AnalysisException(s"'$source' is not supported with continuous trigger")
340364
}
341365
val sink = new ForeachBatchSink[T](foreachBatchWriter, ds.exprEnc)
342-
df.sparkSession.sessionState.streamingQueryManager.startQuery(
343-
extraOptions.get("queryName"),
344-
extraOptions.get("checkpointLocation"),
345-
df,
346-
extraOptions.toMap,
347-
sink,
348-
outputMode,
349-
useTempCheckpointLocation = true,
350-
trigger = trigger)
366+
startQuery(sink, extraOptions)
351367
} else {
352368
val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
353369
val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
@@ -380,19 +396,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
380396
createV1Sink(optionsWithPath)
381397
}
382398

383-
df.sparkSession.sessionState.streamingQueryManager.startQuery(
384-
extraOptions.get("queryName"),
385-
extraOptions.get("checkpointLocation"),
386-
df,
387-
optionsWithPath.originalMap,
388-
sink,
389-
outputMode,
390-
useTempCheckpointLocation = source == "console" || source == "noop",
391-
recoverFromCheckpointLocation = true,
392-
trigger = trigger)
399+
startQuery(sink, optionsWithPath)
393400
}
394401
}
395402

403+
private def startQuery(
404+
sink: Table,
405+
newOptions: CaseInsensitiveMap[String],
406+
recoverFromCheckpoint: Boolean = true): StreamingQuery = {
407+
val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY.contains(source)
408+
409+
df.sparkSession.sessionState.streamingQueryManager.startQuery(
410+
newOptions.get("queryName"),
411+
newOptions.get("checkpointLocation"),
412+
df,
413+
newOptions.originalMap,
414+
sink,
415+
outputMode,
416+
useTempCheckpointLocation = useTempCheckpointLocation,
417+
recoverFromCheckpointLocation = recoverFromCheckpoint,
418+
trigger = trigger)
419+
}
420+
396421
private def createV1Sink(optionsWithPath: CaseInsensitiveMap[String]): Sink = {
397422
val ds = DataSource(
398423
df.sparkSession,
@@ -409,7 +434,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
409434
* @since 2.0.0
410435
*/
411436
def foreach(writer: ForeachWriter[T]): DataStreamWriter[T] = {
412-
this.source = "foreach"
437+
this.source = SOURCE_NAME_FOREACH
413438
this.foreachWriter = if (writer != null) {
414439
ds.sparkSession.sparkContext.clean(writer)
415440
} else {
@@ -433,7 +458,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
433458
*/
434459
@Evolving
435460
def foreachBatch(function: (Dataset[T], Long) => Unit): DataStreamWriter[T] = {
436-
this.source = "foreachBatch"
461+
this.source = SOURCE_NAME_FOREACH_BATCH
437462
if (function == null) throw new IllegalArgumentException("foreachBatch function cannot be null")
438463
this.foreachBatchWriter = function
439464
this
@@ -485,6 +510,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
485510

486511
private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName
487512

513+
private var tableName: String = null
514+
488515
private var outputMode: OutputMode = OutputMode.Append
489516

490517
private var trigger: Trigger = Trigger.ProcessingTime(0L)
@@ -497,3 +524,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
497524

498525
private var partitioningColumns: Option[Seq[String]] = None
499526
}
527+
528+
object DataStreamWriter {
529+
val SOURCE_NAME_MEMORY = "memory"
530+
val SOURCE_NAME_FOREACH = "foreach"
531+
val SOURCE_NAME_FOREACH_BATCH = "foreachBatch"
532+
val SOURCE_NAME_CONSOLE = "console"
533+
val SOURCE_NAME_TABLE = "table"
534+
val SOURCE_NAME_NOOP = "noop"
535+
536+
// these writer sources are also used for one-time query, hence allow temp checkpoint location
537+
val SOURCES_ALLOW_ONE_TIME_QUERY = Seq(SOURCE_NAME_MEMORY, SOURCE_NAME_FOREACH,
538+
SOURCE_NAME_FOREACH_BATCH, SOURCE_NAME_CONSOLE, SOURCE_NAME_NOOP)
539+
}

0 commit comments

Comments
 (0)