Skip to content

Commit 1edd283

Browse files
committed
[SPARK-32847][SS] Add DataStreamWriterV2 API
1 parent a22871f commit 1edd283

File tree

4 files changed

+349
-3
lines changed

4 files changed

+349
-3
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog._
3232
import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform}
3333
import org.apache.spark.sql.connector.read._
3434
import org.apache.spark.sql.connector.write._
35+
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
3536
import org.apache.spark.sql.sources.{And, EqualTo, Filter, IsNotNull}
3637
import org.apache.spark.sql.types.{DataType, DateType, StructType, TimestampType}
3738
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -145,6 +146,7 @@ class InMemoryTable(
145146
override def capabilities: util.Set[TableCapability] = Set(
146147
TableCapability.BATCH_READ,
147148
TableCapability.BATCH_WRITE,
149+
TableCapability.STREAMING_WRITE,
148150
TableCapability.OVERWRITE_BY_FILTER,
149151
TableCapability.OVERWRITE_DYNAMIC,
150152
TableCapability.TRUNCATE).asJava
@@ -169,26 +171,32 @@ class InMemoryTable(
169171

170172
new WriteBuilder with SupportsTruncate with SupportsOverwrite with SupportsDynamicOverwrite {
171173
private var writer: BatchWrite = Append
174+
private var streamingWriter: StreamingWrite = StreamingAppend
172175

173176
override def truncate(): WriteBuilder = {
174177
assert(writer == Append)
175178
writer = TruncateAndAppend
179+
streamingWriter = StreamingTruncateAndAppend
176180
this
177181
}
178182

179183
override def overwrite(filters: Array[Filter]): WriteBuilder = {
180184
assert(writer == Append)
181185
writer = new Overwrite(filters)
186+
// streaming writer doesn't have equivalent semantic
182187
this
183188
}
184189

185190
override def overwriteDynamicPartitions(): WriteBuilder = {
186191
assert(writer == Append)
187192
writer = DynamicOverwrite
193+
// streaming writer doesn't have equivalent semantic
188194
this
189195
}
190196

191197
override def buildForBatch(): BatchWrite = writer
198+
199+
override def buildForStreaming(): StreamingWrite = streamingWriter
192200
}
193201
}
194202

@@ -231,6 +239,31 @@ class InMemoryTable(
231239
}
232240
}
233241

242+
private abstract class TestStreamingWrite extends StreamingWrite {
243+
def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
244+
BufferedRowsWriterFactory
245+
}
246+
247+
def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
248+
}
249+
250+
private object StreamingAppend extends TestStreamingWrite {
251+
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
252+
dataMap.synchronized {
253+
withData(messages.map(_.asInstanceOf[BufferedRows]))
254+
}
255+
}
256+
}
257+
258+
private object StreamingTruncateAndAppend extends TestStreamingWrite {
259+
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
260+
dataMap.synchronized {
261+
dataMap.clear
262+
withData(messages.map(_.asInstanceOf[BufferedRows]))
263+
}
264+
}
265+
}
266+
234267
override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
235268
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
236269
dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
@@ -310,10 +343,17 @@ private class BufferedRowsReader(partition: BufferedRows) extends PartitionReade
310343
override def close(): Unit = {}
311344
}
312345

313-
private object BufferedRowsWriterFactory extends DataWriterFactory {
346+
private object BufferedRowsWriterFactory extends DataWriterFactory with StreamingDataWriterFactory {
314347
override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = {
315348
new BufferWriter
316349
}
350+
351+
override def createWriter(
352+
partitionId: Int,
353+
taskId: Long,
354+
epochId: Long): DataWriter[InternalRow] = {
355+
new BufferWriter
356+
}
317357
}
318358

319359
private class BufferWriter extends DataWriter[InternalRow] {

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation,
5858
import org.apache.spark.sql.execution.python.EvaluatePython
5959
import org.apache.spark.sql.execution.stat.StatFunctions
6060
import org.apache.spark.sql.internal.SQLConf
61-
import org.apache.spark.sql.streaming.DataStreamWriter
61+
import org.apache.spark.sql.streaming.{DataStreamWriter, DataStreamWriterV2}
6262
import org.apache.spark.sql.types._
6363
import org.apache.spark.sql.util.SchemaUtils
6464
import org.apache.spark.storage.StorageLevel
@@ -3380,14 +3380,21 @@ class Dataset[T] private[sql](
33803380
* @since 3.0.0
33813381
*/
33823382
def writeTo(table: String): DataFrameWriterV2[T] = {
3383-
// TODO: streaming could be adapted to use this interface
33843383
if (isStreaming) {
33853384
logicalPlan.failAnalysis(
33863385
"'writeTo' can not be called on streaming Dataset/DataFrame")
33873386
}
33883387
new DataFrameWriterV2[T](table, this)
33893388
}
33903389

3390+
def writeStreamTo(table: String): DataStreamWriterV2[T] = {
3391+
if (!isStreaming) {
3392+
logicalPlan.failAnalysis(
3393+
"'writeStreamTo' can be called only on streaming Dataset/DataFrame")
3394+
}
3395+
new DataStreamWriterV2[T](table, this)
3396+
}
3397+
33913398
/**
33923399
* Interface for saving the content of the streaming Dataset out into external storage.
33933400
*
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.streaming
19+
20+
import java.util.concurrent.TimeoutException
21+
22+
import scala.collection.JavaConverters._
23+
import scala.collection.mutable
24+
25+
import org.apache.spark.annotation.Experimental
26+
import org.apache.spark.sql.{DataFrame, Dataset}
27+
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
28+
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
29+
import org.apache.spark.sql.connector.catalog.TableCapability.{STREAMING_WRITE, TRUNCATE}
30+
31+
@Experimental
32+
final class DataStreamWriterV2[T] private[sql](table: String, ds: Dataset[T]) {
33+
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
34+
import org.apache.spark.sql.connector.catalog.CatalogV2Util._
35+
import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier
36+
37+
private val df: DataFrame = ds.toDF()
38+
39+
private val sparkSession = ds.sparkSession
40+
41+
private var trigger: Trigger = Trigger.ProcessingTime(0L)
42+
43+
private var extraOptions = new mutable.HashMap[String, String]()
44+
45+
private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table)
46+
47+
private val (catalog, identifier) = {
48+
val CatalogAndIdentifier(catalog, identifier) = tableName
49+
(catalog.asTableCatalog, identifier)
50+
}
51+
52+
def trigger(trigger: Trigger): DataStreamWriterV2[T] = {
53+
this.trigger = trigger
54+
this
55+
}
56+
57+
def queryName(queryName: String): DataStreamWriterV2[T] = {
58+
this.extraOptions += ("queryName" -> queryName)
59+
this
60+
}
61+
62+
def option(key: String, value: String): DataStreamWriterV2[T] = {
63+
this.extraOptions += (key -> value)
64+
this
65+
}
66+
67+
def option(key: String, value: Boolean): DataStreamWriterV2[T] = option(key, value.toString)
68+
69+
def option(key: String, value: Long): DataStreamWriterV2[T] = option(key, value.toString)
70+
71+
def option(key: String, value: Double): DataStreamWriterV2[T] = option(key, value.toString)
72+
73+
def options(options: scala.collection.Map[String, String]): DataStreamWriterV2[T] = {
74+
this.extraOptions ++= options
75+
this
76+
}
77+
78+
def options(options: java.util.Map[String, String]): DataStreamWriterV2[T] = {
79+
this.options(options.asScala)
80+
this
81+
}
82+
83+
def checkpointLocation(location: String): DataStreamWriterV2[T] = {
84+
this.extraOptions += "checkpointLocation" -> location
85+
this
86+
}
87+
88+
@throws[NoSuchTableException]
89+
@throws[TimeoutException]
90+
def append(): StreamingQuery = {
91+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
92+
loadTable(catalog, identifier) match {
93+
case Some(t: SupportsWrite) if t.supports(STREAMING_WRITE) =>
94+
start(t, OutputMode.Append())
95+
96+
case Some(t) =>
97+
throw new IllegalArgumentException(s"Table ${t.name()} doesn't support streaming" +
98+
" write!")
99+
100+
case _ =>
101+
throw new NoSuchTableException(identifier)
102+
}
103+
}
104+
105+
@throws[NoSuchTableException]
106+
@throws[TimeoutException]
107+
def truncateAndAppend(): StreamingQuery = {
108+
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._
109+
loadTable(catalog, identifier) match {
110+
case Some(t: SupportsWrite) if t.supports(STREAMING_WRITE) && t.supports(TRUNCATE) =>
111+
start(t, OutputMode.Complete())
112+
113+
case Some(t) =>
114+
throw new IllegalArgumentException(s"Table ${t.name()} doesn't support streaming" +
115+
" write with truncate!")
116+
117+
case _ =>
118+
throw new NoSuchTableException(identifier)
119+
}
120+
}
121+
122+
private def start(table: Table, outputMode: OutputMode): StreamingQuery = {
123+
df.sparkSession.sessionState.streamingQueryManager.startQuery(
124+
extraOptions.get("queryName"),
125+
extraOptions.get("checkpointLocation"),
126+
df,
127+
extraOptions.toMap,
128+
table,
129+
outputMode,
130+
// Here we simply use default values of `useTempCheckpointLocation` and
131+
// `recoverFromCheckpointLocation`, which is required to be changed for some special built-in
132+
// data sources. They're not available in catalog, hence it's safe as of now, but once the
133+
// condition is broken we should take care of that.
134+
trigger = trigger)
135+
}
136+
}

0 commit comments

Comments
 (0)