Skip to content

Commit f320766

Browse files
committed
Adds prepareForWrite() hook, refactored writer containers
1 parent 422ff4a commit f320766

File tree

3 files changed

+66
-26
lines changed

3 files changed

+66
-26
lines changed

sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ private[sql] case class InsertIntoFSBasedRelation(
9090
needsConversion = false)
9191

9292
if (partitionColumns.isEmpty) {
93-
insert(new WriterContainer(relation, jobConf), df)
93+
insert(new DefaultWriterContainer(relation, jobConf), df)
9494
} else {
9595
val writerContainer = new DynamicPartitionWriterContainer(
9696
relation, jobConf, partitionColumns, "__HIVE_DEFAULT_PARTITION__")
@@ -101,7 +101,7 @@ private[sql] case class InsertIntoFSBasedRelation(
101101
Seq.empty[Row]
102102
}
103103

104-
private def insert(writerContainer: WriterContainer, df: DataFrame): Unit = {
104+
private def insert(writerContainer: BaseWriterContainer, df: DataFrame): Unit = {
105105
try {
106106
writerContainer.driverSideSetup()
107107
df.sqlContext.sparkContext.runJob(df.rdd, writeRows _)
@@ -128,7 +128,7 @@ private[sql] case class InsertIntoFSBasedRelation(
128128
}
129129

130130
private def insertWithDynamicPartitions(
131-
writerContainer: WriterContainer,
131+
writerContainer: BaseWriterContainer,
132132
df: DataFrame,
133133
partitionColumns: Array[String]): Unit = {
134134

@@ -191,7 +191,7 @@ private[sql] case class InsertIntoFSBasedRelation(
191191
}
192192
}
193193

194-
private[sql] class WriterContainer(
194+
private[sql] abstract class BaseWriterContainer(
195195
@transient val relation: FSBasedRelation,
196196
@transient jobConf: JobConf)
197197
extends SparkHadoopMapRedUtil
@@ -223,11 +223,9 @@ private[sql] class WriterContainer(
223223

224224
protected val outputWriterClass: Class[_ <: OutputWriter] = relation.outputWriterClass
225225

226-
// All output writers are created on executor side.
227-
@transient protected var outputWriters: mutable.Map[String, OutputWriter] = _
228-
229226
def driverSideSetup(): Unit = {
230227
setupIDs(0, 0, 0)
228+
relation.prepareForWrite(serializableJobConf.value)
231229
setupJobConf()
232230
jobContext = newJobContext(jobConf, jobId)
233231
outputCommitter = jobConf.getOutputCommitter
@@ -240,7 +238,7 @@ private[sql] class WriterContainer(
240238
taskAttemptContext = newTaskAttemptContext(serializableJobConf.value, taskAttemptId)
241239
outputCommitter = serializableJobConf.value.getOutputCommitter
242240
outputCommitter.setupTask(taskAttemptContext)
243-
outputWriters = initWriters()
241+
initWriters()
244242
}
245243

246244
private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = {
@@ -258,22 +256,20 @@ private[sql] class WriterContainer(
258256
}
259257

260258
// Called on executor side when writing rows
261-
def outputWriterForRow(row: Row): OutputWriter = outputWriters.values.head
259+
def outputWriterForRow(row: Row): OutputWriter
262260

263-
protected def initWriters(): mutable.Map[String, OutputWriter] = {
261+
protected def initWriters(): Unit = {
264262
val writer = outputWriterClass.newInstance()
265263
writer.init(outputPath, dataSchema, serializableJobConf.value)
266264
mutable.Map(outputPath -> writer)
267265
}
268266

269267
def commitTask(): Unit = {
270-
outputWriters.values.foreach(_.close())
271268
SparkHadoopMapRedUtil.commitTask(
272269
outputCommitter, taskAttemptContext, jobId.getId, taskId.getId, taskAttemptId.getId)
273270
}
274271

275272
def abortTask(): Unit = {
276-
outputWriters.values.foreach(_.close())
277273
outputCommitter.abortTask(taskAttemptContext)
278274
logError(s"Task attempt $taskAttemptId aborted.")
279275
}
@@ -289,22 +285,52 @@ private[sql] class WriterContainer(
289285
}
290286
}
291287

288+
private[sql] class DefaultWriterContainer(
289+
@transient relation: FSBasedRelation,
290+
@transient conf: JobConf)
291+
extends BaseWriterContainer(relation, conf) {
292+
293+
@transient private var writer: OutputWriter = _
294+
295+
override protected def initWriters(): Unit = {
296+
writer = relation.outputWriterClass.newInstance()
297+
writer.init(outputPath, dataSchema, serializableJobConf.value)
298+
}
299+
300+
override def outputWriterForRow(row: Row): OutputWriter = writer
301+
302+
override def commitTask(): Unit = {
303+
writer.close()
304+
super.commitTask()
305+
}
306+
307+
override def abortTask(): Unit = {
308+
writer.close()
309+
super.abortTask()
310+
}
311+
}
312+
292313
private[sql] class DynamicPartitionWriterContainer(
293314
@transient relation: FSBasedRelation,
294315
@transient conf: JobConf,
295316
partitionColumns: Array[String],
296317
defaultPartitionName: String)
297-
extends WriterContainer(relation, conf) {
318+
extends BaseWriterContainer(relation, conf) {
319+
320+
// All output writers are created on executor side.
321+
@transient protected var outputWriters: mutable.Map[String, OutputWriter] = _
298322

299-
override protected def initWriters() = mutable.Map.empty[String, OutputWriter]
323+
override protected def initWriters(): Unit = {
324+
outputWriters = mutable.Map.empty[String, OutputWriter]
325+
}
300326

301327
override def outputWriterForRow(row: Row): OutputWriter = {
302328
val partitionPath = partitionColumns.zip(row.toSeq).map { case (col, rawValue) =>
303329
val string = if (rawValue == null) null else String.valueOf(rawValue)
304330
val valueString = if (string == null || string.isEmpty) {
305331
defaultPartitionName
306332
} else {
307-
escapePathName(string)
333+
DynamicPartitionWriterContainer.escapePathName(string)
308334
}
309335
s"/$col=$valueString"
310336
}.mkString
@@ -317,18 +343,14 @@ private[sql] class DynamicPartitionWriterContainer(
317343
})
318344
}
319345

320-
private def escapePathName(path: String): String = {
321-
val builder = new StringBuilder()
322-
path.foreach { c =>
323-
if (DynamicPartitionWriterContainer.needsEscaping(c)) {
324-
builder.append('%')
325-
builder.append(f"${c.asInstanceOf[Int]}%02x")
326-
} else {
327-
builder.append(c)
328-
}
329-
}
346+
override def commitTask(): Unit = {
347+
outputWriters.values.foreach(_.close())
348+
super.commitTask()
349+
}
330350

331-
builder.toString()
351+
override def abortTask(): Unit = {
352+
outputWriters.values.foreach(_.close())
353+
super.abortTask()
332354
}
333355
}
334356

@@ -359,4 +381,18 @@ private[sql] object DynamicPartitionWriterContainer {
359381
def needsEscaping(c: Char): Boolean = {
360382
c >= 0 && c < charToEscape.size() && charToEscape.get(c);
361383
}
384+
385+
def escapePathName(path: String): String = {
386+
val builder = new StringBuilder()
387+
path.foreach { c =>
388+
if (DynamicPartitionWriterContainer.needsEscaping(c)) {
389+
builder.append('%')
390+
builder.append(f"${c.asInstanceOf[Int]}%02x")
391+
} else {
392+
builder.append(c)
393+
}
394+
}
395+
396+
builder.toString()
397+
}
362398
}

sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,8 @@ abstract class FSBasedRelation private[sql](
409409
buildScan(requiredColumns, inputPaths)
410410
}
411411

412+
def prepareForWrite(conf: Configuration): Unit
413+
412414
/**
413415
* This method is responsible for producing a new [[OutputWriter]] for each newly opened output
414416
* file on the executor side.

sql/hive/src/test/scala/org/apache/spark/sql/sources/FSBasedRelationSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class SimpleFSBasedRelation
107107
}
108108

109109
override def outputWriterClass: Class[_ <: OutputWriter] = classOf[SimpleOutputWriter]
110+
111+
override def prepareForWrite(conf: Configuration): Unit = ()
110112
}
111113

112114
object TestResult {

0 commit comments

Comments
 (0)