Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class PartitionerAwareUnionRDD[T: ClassTag](
require(rdds.forall(_.partitioner.isDefined))
require(rdds.flatMap(_.partitioner).toSet.size == 1,
"Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
require(rdds.map(_.partitioner.get.numPartitions).toSet.size == 1,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The partitioners of the rdds might have different numPartitions. It will causes error later.

"Parent RDDs have different number of partitions: " +
rdds.map(_.partitioner.get.numPartitions))

override val partitioner = rdds.head.partitioner

Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.{BoundedPriorityQueue, CheckpointingIterator, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
SamplingUtils}
Expand Down Expand Up @@ -238,11 +238,16 @@ abstract class RDD[T: ClassTag](
* subclasses of RDD.
*/
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
val iter = if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
computeOrReadCheckpoint(split, context)
}
if (checkpointData.isDefined) {
checkpointData.get.getCheckpointIterator(iter, context, split.index)
} else {
iter
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be

checkpointData
  .map(_.getCheckpointIterator(iter, context, split.index))
  .getOrElse(iter)

}

/**
Expand Down
61 changes: 46 additions & 15 deletions core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.rdd

import scala.reflect.ClassTag

import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path

import org.apache.spark._
import org.apache.spark.util.{CheckpointingIterator, SerializableConfiguration}
import org.apache.spark.util.SerializableConfiguration

/**
Expand All @@ -44,6 +46,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])

import CheckpointState._

// Because SparkContext is transient in RDD, so we can't get the id and checkpointDir later.
// So keep a copy of the id and checkpointDir.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I understand this comment. How did it work before then? Even before this patch doCheckpoint directly calls rdd.id

// The id of RDD
val rddId: Int = rdd.id

// The path the checkpoint data will write to.
val checkpointDir = rdd.context.checkpointDir
@transient var checkpointPath: Path = null
@transient var fs: FileSystem = null
if (checkpointDir.isDefined) {
checkpointPath = new Path(checkpointDir.get, "rdd-" + rddId)
fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(checkpointPath)) {
throw new SparkException("Failed to create checkpoint path " + checkpointPath)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these don't need to be vars right? In fact, fs, rddId and checkpointDir don't even need to exist. You can just do

@transient private val checkpointPath: Path = {
  val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id)
  val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
  if (!fs.mkdirs(path)) {
    ...
  }
  path
}


val broadcastedConf = rdd.context.broadcast(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please make all of these private

new SerializableConfiguration(rdd.context.hadoopConfiguration))

// The checkpoint state of the associated RDD.
private var cpState = Initialized

Expand All @@ -66,6 +88,27 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cpFile
}

// Get the iterator used to write checkpoint data to HDFS
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this a real java doc:

/**
 * Wrap the given iterator in a checkpointing iterator, which checkpoints values
 * as the original iterator is consumed. This allows us to checkpoint the RDD
 * without computing it more than once (SPARK-8582).
 */

def getCheckpointIterator(
rddIterator: Iterator[T],
context: TaskContext,
partitionId: Int): Iterator[T] = {
RDDCheckpointData.synchronized {
if (cpState == Initialized) {
// Create the output path for the checkpoint
val path = new Path(checkpointDir.get, "rdd-" + rddId)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this read from RDDCheckpointData.rddCheckpointDataPath?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, the path here should just be checkpointPath. Right now this duplicates some code.

CheckpointingIterator[T](
rddIterator,
path.toString,
broadcastedConf,
partitionId,
context)
} else {
rddIterator
}
}
}

/**
* Materialize this RDD and write its content to a reliable DFS.
* This is called immediately after the first action invoked on this RDD has completed.
Expand All @@ -82,25 +125,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

// Create the output path for the checkpoint
val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException(s"Failed to create checkpoint path $path")
}

// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableConfiguration(rdd.context.hadoopConfiguration))
val path = checkpointPath
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to declare another variable here? Just use checkpointPath

val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
rdd.context.cleaner.foreach { cleaner =>
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rddId)
}
}

// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
"Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " +
Expand All @@ -114,7 +145,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The state transition here is incorrect. At this point the RDD has not been checkpointed yet. It's not safe to truncate the RDD's lineage until we drain the iterator.

}
logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
logInfo(s"Done checkpointing RDD ${rddId} to $path, new parent is RDD ${newRDD.id}")
}

def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
Expand Down
151 changes: 151 additions & 0 deletions core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import java.io.IOException

import scala.reflect.ClassTag

import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.CheckpointRDD
import org.apache.spark.serializer.SerializationStream

/**
* Wrapper around an iterator which writes checkpoint data to HDFS while running action on
* a RDD to support checkpointing RDD.
*/
private[spark] class CheckpointingIterator[A: ClassTag](
values: Iterator[A],
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
partitionId: Int,
context: TaskContext,
blockSize: Int = -1) extends Iterator[A] with Logging {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the java doc, could you document what each of these variables represent?


private val env = SparkEnv.get
private var fs: FileSystem = null
private val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
private var serializeStream: SerializationStream = null

private var finalOutputPath: Path = null
private var tempOutputPath: Path = null

/**
* Initialize this iterator by creating temporary output path and serializer instance.
*
*/
def init(): this.type = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove this init method? It doesn't seem necessary since we can just do all of this in the constructor. The advantage of removing it is that we won't have a bunch of vars initialized to null.

val outputDir = new Path(path)
fs = outputDir.getFileSystem(broadcastedConf.value.value)

val finalOutputName = CheckpointRDD.splitIdToFile(partitionId)
finalOutputPath = new Path(outputDir, finalOutputName)
tempOutputPath =
new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber)

if (fs.exists(tempOutputPath)) {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible that more than one iterators for the same split are created and used, e.g., CartesianRDD. We only need one of them to write checkpoint data to disk.

// There are more than one iterator of the RDD is consumed.
// Don't checkpoint data in this iterator.
doCheckpoint = false
return this
}

val fileOutputStream = if (blockSize < 0) {
fs.create(tempOutputPath, false, bufferSize)
} else {
// This is mainly for testing purpose
fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
}
val serializer = env.serializer.newInstance()
serializeStream = serializer.serializeStream(fileOutputStream)
this
}

/**
* Called when this iterator is on the latest element by `hasNext`.
* This method will rename temporary output path to final output path of checkpoint data.
*/
def completion(): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does completion mean? Can you add some java docs here and everywhere?

if (!doCheckpoint) {
return
}

serializeStream.close()

if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ context.attemptNumber + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
fs.delete(tempOutputPath, false)
}
}
}

def checkpointing(item: A): Unit = {
serializeStream.writeObject(item)
}

override def next(): A = {
val item = values.next()
if (doCheckpoint) {
checkpointing(item)
}
// If this the latest item, call hasNext will write to final output early.
hasNext
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes the rdd.iterator will not be consumed to call hasNext until it returns false, e.g., we already know the number of elements and call next() at exact times of that number. If so, we need to write the checkpoint data to final output path early when we return the latest element.

item
}

private[this] var doCheckpoint = true
private[this] var completed = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please declare all variables at the top


override def hasNext: Boolean = {
val r = values.hasNext
if (!r && !completed) {
completed = true
completion()
}
r
}
}

private[spark] object CheckpointingIterator {
def apply[A: ClassTag](
values: Iterator[A],
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
partitionId: Int,
context: TaskContext,
blockSize: Int = -1) : CheckpointingIterator[A] = {
new CheckpointingIterator[A](
values,
path,
broadcastedConf,
partitionId,
context,
blockSize).init()
}
}
2 changes: 1 addition & 1 deletion core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
* have large size.
*/
def generateFatPairRDD(): RDD[(Int, Int)] = {
new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
new FatPairRDD(sc.makeRDD(1 to 100, 2), partitioner).mapValues(x => x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The partitioner in CheckpointSuite is a HashPartitioner with 2 partitions. So to make PartitionerAwareUnionRDD work, FatPairRDD should be 2 partitions too.

}

/**
Expand Down
Loading