@@ -27,7 +27,7 @@ import org.apache.spark.api.java.function.VoidFunction2
2727import org .apache .spark .sql ._
2828import org .apache .spark .sql .catalyst .streaming .InternalOutputModes
2929import org .apache .spark .sql .catalyst .util .CaseInsensitiveMap
30- import org .apache .spark .sql .connector .catalog .{SupportsWrite , TableProvider }
30+ import org .apache .spark .sql .connector .catalog .{SupportsWrite , Table , TableProvider }
3131import org .apache .spark .sql .connector .catalog .TableCapability ._
3232import org .apache .spark .sql .execution .command .DDLUtils
3333import org .apache .spark .sql .execution .datasources .DataSource
@@ -45,6 +45,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
4545 */
4646@ Evolving
4747final class DataStreamWriter [T ] private [sql](ds : Dataset [T ]) {
48+ import DataStreamWriter ._
4849
4950 private val df = ds.toDF()
5051
@@ -294,60 +295,75 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
294295 @ throws[TimeoutException ]
295296 def start (): StreamingQuery = startInternal(None )
296297
298+ /**
299+ * Starts the execution of the streaming query, which will continually output results to the given
300+ * table as new data arrives. The returned [[StreamingQuery ]] object can be used to interact with
301+ * the stream.
302+ *
303+ * @since 3.1.0
304+ */
305+ @ throws[TimeoutException ]
306+ def saveAsTable (tableName : String ): StreamingQuery = {
307+ this .source = SOURCE_NAME_TABLE
308+ this .tableName = tableName
309+ startInternal(None )
310+ }
311+
297312 private def startInternal (path : Option [String ]): StreamingQuery = {
298313 if (source.toLowerCase(Locale .ROOT ) == DDLUtils .HIVE_PROVIDER ) {
299314 throw new AnalysisException (" Hive data source can only be used with tables, you can not " +
300315 " write files of Hive data source directly." )
301316 }
302317
303- if (source == " memory" ) {
304- assertNotPartitioned(" memory" )
318+ if (source == SOURCE_NAME_TABLE ) {
319+ assertNotPartitioned(SOURCE_NAME_TABLE )
320+
321+ import df .sparkSession .sessionState .analyzer .CatalogAndIdentifier
322+
323+ import org .apache .spark .sql .connector .catalog .CatalogV2Implicits ._
324+ val originalMultipartIdentifier = df.sparkSession.sessionState.sqlParser
325+ .parseMultipartIdentifier(tableName)
326+ val CatalogAndIdentifier (catalog, identifier) = originalMultipartIdentifier
327+
328+ // Currently we don't create a logical streaming writer node in logical plan, so cannot rely
329+ // on analyzer to resolve it. Directly lookup only for temp view to provide clearer message.
330+ // TODO (SPARK-27484): we should add the writing node before the plan is analyzed.
331+ if (df.sparkSession.sessionState.catalog.isTempView(originalMultipartIdentifier)) {
332+ throw new AnalysisException (s " Temporary view $tableName doesn't support streaming write " )
333+ }
334+
335+ val tableInstance = catalog.asTableCatalog.loadTable(identifier)
336+
337+ import org .apache .spark .sql .execution .datasources .v2 .DataSourceV2Implicits ._
338+ val sink = tableInstance match {
339+ case t : SupportsWrite if t.supports(STREAMING_WRITE ) => t
340+ case t => throw new AnalysisException (s " Table $tableName doesn't support streaming " +
341+ s " write - $t" )
342+ }
343+
344+ startQuery(sink, extraOptions)
345+ } else if (source == SOURCE_NAME_MEMORY ) {
346+ assertNotPartitioned(SOURCE_NAME_MEMORY )
305347 if (extraOptions.get(" queryName" ).isEmpty) {
306348 throw new AnalysisException (" queryName must be specified for memory sink" )
307349 }
308350 val sink = new MemorySink ()
309351 val resultDf = Dataset .ofRows(df.sparkSession, new MemoryPlan (sink, df.schema.toAttributes))
310- val chkpointLoc = extraOptions.get(" checkpointLocation" )
311352 val recoverFromChkpoint = outputMode == OutputMode .Complete ()
312- val query = df.sparkSession.sessionState.streamingQueryManager.startQuery(
313- extraOptions.get(" queryName" ),
314- chkpointLoc,
315- df,
316- extraOptions.toMap,
317- sink,
318- outputMode,
319- useTempCheckpointLocation = true ,
320- recoverFromCheckpointLocation = recoverFromChkpoint,
321- trigger = trigger)
353+ val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromChkpoint)
322354 resultDf.createOrReplaceTempView(query.name)
323355 query
324- } else if (source == " foreach " ) {
325- assertNotPartitioned(" foreach " )
356+ } else if (source == SOURCE_NAME_FOREACH ) {
357+ assertNotPartitioned(SOURCE_NAME_FOREACH )
326358 val sink = ForeachWriterTable [T ](foreachWriter, ds.exprEnc)
327- df.sparkSession.sessionState.streamingQueryManager.startQuery(
328- extraOptions.get(" queryName" ),
329- extraOptions.get(" checkpointLocation" ),
330- df,
331- extraOptions.toMap,
332- sink,
333- outputMode,
334- useTempCheckpointLocation = true ,
335- trigger = trigger)
336- } else if (source == " foreachBatch" ) {
337- assertNotPartitioned(" foreachBatch" )
359+ startQuery(sink, extraOptions)
360+ } else if (source == SOURCE_NAME_FOREACH_BATCH ) {
361+ assertNotPartitioned(SOURCE_NAME_FOREACH_BATCH )
338362 if (trigger.isInstanceOf [ContinuousTrigger ]) {
339- throw new AnalysisException (" 'foreachBatch ' is not supported with continuous trigger" )
363+ throw new AnalysisException (s " ' $source ' is not supported with continuous trigger " )
340364 }
341365 val sink = new ForeachBatchSink [T ](foreachBatchWriter, ds.exprEnc)
342- df.sparkSession.sessionState.streamingQueryManager.startQuery(
343- extraOptions.get(" queryName" ),
344- extraOptions.get(" checkpointLocation" ),
345- df,
346- extraOptions.toMap,
347- sink,
348- outputMode,
349- useTempCheckpointLocation = true ,
350- trigger = trigger)
366+ startQuery(sink, extraOptions)
351367 } else {
352368 val cls = DataSource .lookupDataSource(source, df.sparkSession.sessionState.conf)
353369 val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(" ," )
@@ -380,19 +396,28 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
380396 createV1Sink(optionsWithPath)
381397 }
382398
383- df.sparkSession.sessionState.streamingQueryManager.startQuery(
384- extraOptions.get(" queryName" ),
385- extraOptions.get(" checkpointLocation" ),
386- df,
387- optionsWithPath.originalMap,
388- sink,
389- outputMode,
390- useTempCheckpointLocation = source == " console" || source == " noop" ,
391- recoverFromCheckpointLocation = true ,
392- trigger = trigger)
399+ startQuery(sink, optionsWithPath)
393400 }
394401 }
395402
403+ private def startQuery (
404+ sink : Table ,
405+ newOptions : CaseInsensitiveMap [String ],
406+ recoverFromCheckpoint : Boolean = true ): StreamingQuery = {
407+ val useTempCheckpointLocation = SOURCES_ALLOW_ONE_TIME_QUERY .contains(source)
408+
409+ df.sparkSession.sessionState.streamingQueryManager.startQuery(
410+ newOptions.get(" queryName" ),
411+ newOptions.get(" checkpointLocation" ),
412+ df,
413+ newOptions.originalMap,
414+ sink,
415+ outputMode,
416+ useTempCheckpointLocation = useTempCheckpointLocation,
417+ recoverFromCheckpointLocation = recoverFromCheckpoint,
418+ trigger = trigger)
419+ }
420+
396421 private def createV1Sink (optionsWithPath : CaseInsensitiveMap [String ]): Sink = {
397422 val ds = DataSource (
398423 df.sparkSession,
@@ -409,7 +434,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
409434 * @since 2.0.0
410435 */
411436 def foreach (writer : ForeachWriter [T ]): DataStreamWriter [T ] = {
412- this .source = " foreach "
437+ this .source = SOURCE_NAME_FOREACH
413438 this .foreachWriter = if (writer != null ) {
414439 ds.sparkSession.sparkContext.clean(writer)
415440 } else {
@@ -433,7 +458,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
433458 */
434459 @ Evolving
435460 def foreachBatch (function : (Dataset [T ], Long ) => Unit ): DataStreamWriter [T ] = {
436- this .source = " foreachBatch "
461+ this .source = SOURCE_NAME_FOREACH_BATCH
437462 if (function == null ) throw new IllegalArgumentException (" foreachBatch function cannot be null" )
438463 this .foreachBatchWriter = function
439464 this
@@ -485,6 +510,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
485510
486511 private var source : String = df.sparkSession.sessionState.conf.defaultDataSourceName
487512
513+ private var tableName : String = null
514+
488515 private var outputMode : OutputMode = OutputMode .Append
489516
490517 private var trigger : Trigger = Trigger .ProcessingTime (0L )
@@ -497,3 +524,16 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
497524
498525 private var partitioningColumns : Option [Seq [String ]] = None
499526}
527+
528+ object DataStreamWriter {
529+ val SOURCE_NAME_MEMORY = " memory"
530+ val SOURCE_NAME_FOREACH = " foreach"
531+ val SOURCE_NAME_FOREACH_BATCH = " foreachBatch"
532+ val SOURCE_NAME_CONSOLE = " console"
533+ val SOURCE_NAME_TABLE = " table"
534+ val SOURCE_NAME_NOOP = " noop"
535+
536+ // these writer sources are also used for one-time query, hence allow temp checkpoint location
537+ val SOURCES_ALLOW_ONE_TIME_QUERY = Seq (SOURCE_NAME_MEMORY , SOURCE_NAME_FOREACH ,
538+ SOURCE_NAME_FOREACH_BATCH , SOURCE_NAME_CONSOLE , SOURCE_NAME_NOOP )
539+ }
0 commit comments