Skip to content

Commit a0426c8

Browse files
committed
port RDD API to use commit protocol.
1 parent 9dc9f9a commit a0426c8

File tree

4 files changed

+160
-37
lines changed

4 files changed

+160
-37
lines changed

core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,14 @@ class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable {
162162
private[spark]
163163
object SparkHadoopWriter {
164164
def createJobID(time: Date, id: Int): JobID = {
165-
val formatter = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
166-
val jobtrackerID = formatter.format(time)
165+
val jobtrackerID = createJobTrackerID(time)
167166
new JobID(jobtrackerID, id)
168167
}
169168

169+
def createJobTrackerID(time: Date): String = {
170+
new SimpleDateFormat("yyyyMMddHHmmss", Locale.US).format(time)
171+
}
172+
170173
def createPathFromString(path: String, conf: JobConf): Path = {
171174
if (path == null) {
172175
throw new IllegalArgumentException("Output path is null")
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
import java.util.Date
21+
22+
import org.apache.hadoop.conf.Configuration
23+
import org.apache.hadoop.mapreduce._
24+
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
25+
import org.apache.spark.internal.Logging
26+
import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
27+
import org.apache.spark.util.SerializableConfiguration
28+
29+
/**
30+
* Internal helper class that saves an RDD using a Hadoop OutputFormat
31+
* (from the newer mapreduce API, not the old mapred API).
32+
*
33+
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
34+
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
35+
*
36+
* Use a [[HadoopMapReduceCommitProtocol]] to handle output commit, which, unlike Hadoop's
37+
* OutputCommitter, is serializable.
38+
*/
39+
private[spark]
40+
class SparkNewHadoopWriter(
41+
jobConf: Configuration,
42+
committer: HadoopMapReduceCommitProtocol) extends Logging with Serializable {
43+
44+
private val now = new Date()
45+
private val conf = new SerializableConfiguration(jobConf)
46+
47+
private val jobtrackerID = SparkHadoopWriter.createJobTrackerID(new Date())
48+
private var jobId = 0
49+
private var splitId = 0
50+
private var attemptId = 0
51+
52+
@transient private var writer: RecordWriter[AnyRef, AnyRef] = null
53+
@transient private var jobContext: JobContext = null
54+
@transient private var taskContext: TaskAttemptContext = null
55+
56+
def setupJob(): Unit = {
57+
// Committer setup a job
58+
committer.setupJob(getJobContext)
59+
}
60+
61+
def setupTask(context: TaskContext): Unit = {
62+
// Set jobID/taskID
63+
jobId = context.stageId
64+
splitId = context.partitionId
65+
attemptId = (context.taskAttemptId % Int.MaxValue).toInt
66+
// Committer setup a task
67+
committer.setupTask(getTaskContext(context))
68+
}
69+
70+
def write(context: TaskContext, key: AnyRef, value: AnyRef): Unit = {
71+
getWriter(context).write(key, value)
72+
}
73+
74+
def abortTask(context: TaskContext): Unit = {
75+
// Close writer
76+
getWriter(context).close(getTaskContext(context))
77+
// Committer abort a task
78+
committer.abortTask(getTaskContext(context))
79+
}
80+
81+
def commitTask(context: TaskContext): Unit = {
82+
// Close writer
83+
getWriter(context).close(getTaskContext(context))
84+
// Committer commit a task
85+
committer.commitTask(getTaskContext(context))
86+
}
87+
88+
def abortJob(): Unit = {
89+
committer.abortJob(getJobContext)
90+
}
91+
92+
def commitJob() {
93+
committer.commitJob(getJobContext, Seq.empty)
94+
}
95+
96+
// ********* Private Functions *********
97+
98+
/*
99+
* Generate jobContext. Since jobContext is transient, it may be null after serialization.
100+
*/
101+
private def getJobContext(): JobContext = {
102+
if (jobContext == null) {
103+
val jobAttemptId = new TaskAttemptID(jobtrackerID, jobId, TaskType.MAP, 0, 0)
104+
jobContext = new TaskAttemptContextImpl(conf.value, jobAttemptId)
105+
}
106+
jobContext
107+
}
108+
109+
/*
110+
* Generate taskContext. Since taskContext is transient, it may be null after serialization.
111+
*/
112+
private def getTaskContext(context: TaskContext): TaskAttemptContext = {
113+
if (taskContext == null) {
114+
val attemptId = new TaskAttemptID(jobtrackerID, jobId, TaskType.REDUCE, splitId,
115+
context.attemptNumber)
116+
taskContext = new TaskAttemptContextImpl(conf.value, attemptId)
117+
}
118+
taskContext
119+
}
120+
121+
/*
122+
* Generate writer. Since writer is transient, it may be null after serialization.
123+
*/
124+
private def getWriter(context: TaskContext): RecordWriter[AnyRef, AnyRef] = {
125+
if (writer == null) {
126+
val format = getJobContext.getOutputFormatClass.newInstance
127+
writer = format.getRecordWriter(getTaskContext(context))
128+
.asInstanceOf[RecordWriter[AnyRef, AnyRef]]
129+
}
130+
writer
131+
}
132+
}

core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,21 @@ import java.nio.ByteBuffer
2121
import java.text.SimpleDateFormat
2222
import java.util.{Date, HashMap => JHashMap, Locale}
2323

24+
import org.apache.spark.internal.io.{HadoopMapReduceCommitProtocol, FileCommitProtocol}
25+
2426
import scala.collection.{mutable, Map}
2527
import scala.collection.JavaConverters._
2628
import scala.collection.mutable.ArrayBuffer
2729
import scala.reflect.ClassTag
2830
import scala.util.DynamicVariable
2931

3032
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
31-
import org.apache.hadoop.conf.{Configurable, Configuration}
33+
import org.apache.hadoop.conf.Configuration
3234
import org.apache.hadoop.fs.FileSystem
3335
import org.apache.hadoop.io.SequenceFile.CompressionType
3436
import org.apache.hadoop.io.compress.CompressionCodec
3537
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
36-
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType}
38+
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, TaskAttemptID, TaskType}
3739
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
3840

3941
import org.apache.spark._
@@ -1092,37 +1094,38 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
10921094
jobFormat.checkOutputSpecs(job)
10931095
}
10941096

1097+
// Instantiate writer
1098+
val committer = FileCommitProtocol.instantiate(
1099+
className = classOf[HadoopMapReduceCommitProtocol].getName,
1100+
jobId = stageId.toString,
1101+
outputPath = jobConfiguration.get("mapred.output.dir"),
1102+
isAppend = false
1103+
).asInstanceOf[HadoopMapReduceCommitProtocol]
1104+
val writer = new SparkNewHadoopWriter(hadoopConf, committer)
1105+
10951106
val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => {
1096-
val config = wrappedConf.value
1097-
/* "reduce task" <split #> <attempt # = spark task #> */
1098-
val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId,
1099-
context.attemptNumber)
1100-
val hadoopContext = new TaskAttemptContextImpl(config, attemptId)
1101-
val format = outfmt.newInstance
1102-
format match {
1103-
case c: Configurable => c.setConf(config)
1104-
case _ => ()
1105-
}
1106-
val committer = format.getOutputCommitter(hadoopContext)
1107-
committer.setupTask(hadoopContext)
1107+
writer.setupTask(context)
11081108

11091109
val outputMetricsAndBytesWrittenCallback: Option[(OutputMetrics, () => Long)] =
11101110
initHadoopOutputMetrics(context)
11111111

1112-
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K, V]]
11131112
require(writer != null, "Unable to obtain RecordWriter")
11141113
var recordsWritten = 0L
11151114
Utils.tryWithSafeFinallyAndFailureCallbacks {
11161115
while (iter.hasNext) {
11171116
val pair = iter.next()
1118-
writer.write(pair._1, pair._2)
1117+
writer.write(context, pair._1.asInstanceOf[AnyRef], pair._2.asInstanceOf[AnyRef])
11191118

11201119
// Update bytes written metric every few records
11211120
maybeUpdateOutputMetrics(outputMetricsAndBytesWrittenCallback, recordsWritten)
11221121
recordsWritten += 1
11231122
}
1124-
}(finallyBlock = writer.close(hadoopContext))
1125-
committer.commitTask(hadoopContext)
1123+
1124+
writer.commitTask(context)
1125+
}(catchBlock = {
1126+
writer.abortTask(context)
1127+
writer.abortJob()
1128+
})
11261129
outputMetricsAndBytesWrittenCallback.foreach { case (om, callback) =>
11271130
om.setBytesWritten(callback())
11281131
om.setRecordsWritten(recordsWritten)
@@ -1147,9 +1150,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
11471150
logWarning(warningMessage)
11481151
}
11491152

1150-
jobCommitter.setupJob(jobTaskContext)
1153+
writer.setupJob()
11511154
self.context.runJob(self, writeShard)
1152-
jobCommitter.commitJob(jobTaskContext)
1155+
writer.commitJob()
11531156
}
11541157

11551158
/**

core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -509,21 +509,6 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
509509
(2, ArrayBuffer(1))))
510510
}
511511

512-
test("saveNewAPIHadoopFile should call setConf if format is configurable") {
513-
val pairs = sc.parallelize(Array((new Integer(1), new Integer(1))))
514-
515-
// No error, non-configurable formats still work
516-
pairs.saveAsNewAPIHadoopFile[NewFakeFormat]("ignored")
517-
518-
/*
519-
Check that configurable formats get configured:
520-
ConfigTestFormat throws an exception if we try to write
521-
to it when setConf hasn't been called first.
522-
Assertion is in ConfigTestFormat.getRecordWriter.
523-
*/
524-
pairs.saveAsNewAPIHadoopFile[ConfigTestFormat]("ignored")
525-
}
526-
527512
test("saveAsHadoopFile should respect configured output committers") {
528513
val pairs = sc.parallelize(Array((new Integer(1), new Integer(1))))
529514
val conf = new JobConf()

0 commit comments

Comments
 (0)