Skip to content

Commit 0c48e4d

Browse files
committed
Merge remote-tracking branch 'origin/master' into enable-more-mima-checks
Conflicts: project/MimaExcludes.scala
2 parents e276cee + 874a2ca commit 0c48e4d

File tree

33 files changed

+515
-228
lines changed

33 files changed

+515
-228
lines changed

R/pkg/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ exportMethods(
7171
"unpersist",
7272
"value",
7373
"values",
74+
"zipPartitions",
7475
"zipRDD",
7576
"zipWithIndex",
7677
"zipWithUniqueId"

R/pkg/R/RDD.R

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ setMethod("initialize", "RDD", function(.Object, jrdd, serializedMode,
6666
.Object
6767
})
6868

69+
setMethod("show", "RDD",
70+
function(.Object) {
71+
cat(paste(callJMethod(.Object@jrdd, "toString"), "\n", sep=""))
72+
})
73+
6974
setMethod("initialize", "PipelinedRDD", function(.Object, prev, func, jrdd_val) {
7075
.Object@env <- new.env()
7176
.Object@env$isCached <- FALSE
@@ -1590,3 +1595,49 @@ setMethod("intersection",
15901595

15911596
keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction))
15921597
})
1598+
1599+
#' Zips an RDD's partitions with one (or more) RDD(s).
1600+
#' Same as zipPartitions in Spark.
1601+
#'
1602+
#' @param ... RDDs to be zipped.
1603+
#' @param func A function to transform zipped partitions.
1604+
#' @return A new RDD by applying a function to the zipped partitions.
1605+
#' Assumes that all the RDDs have the *same number of partitions*, but
1606+
#' does *not* require them to have the same number of elements in each partition.
1607+
#' @examples
1608+
#'\dontrun{
1609+
#' sc <- sparkR.init()
1610+
#' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2
1611+
#' rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4
1612+
#' rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6
1613+
#' collect(zipPartitions(rdd1, rdd2, rdd3,
1614+
#' func = function(x, y, z) { list(list(x, y, z))} ))
1615+
#' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6)))
1616+
#'}
1617+
#' @rdname zipRDD
1618+
#' @aliases zipPartitions,RDD
1619+
setMethod("zipPartitions",
1620+
"RDD",
1621+
function(..., func) {
1622+
rrdds <- list(...)
1623+
if (length(rrdds) == 1) {
1624+
return(rrdds[[1]])
1625+
}
1626+
nPart <- sapply(rrdds, numPartitions)
1627+
if (length(unique(nPart)) != 1) {
1628+
stop("Can only zipPartitions RDDs which have the same number of partitions.")
1629+
}
1630+
1631+
rrdds <- lapply(rrdds, function(rdd) {
1632+
mapPartitionsWithIndex(rdd, function(partIndex, part) {
1633+
print(length(part))
1634+
list(list(partIndex, part))
1635+
})
1636+
})
1637+
union.rdd <- Reduce(unionRDD, rrdds)
1638+
zipped.rdd <- values(groupByKey(union.rdd, numPartitions = nPart[1]))
1639+
res <- mapPartitions(zipped.rdd, function(plist) {
1640+
do.call(func, plist[[1]])
1641+
})
1642+
res
1643+
})

R/pkg/R/generics.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ setGeneric("unpersist", function(x, ...) { standardGeneric("unpersist") })
217217
#' @export
218218
setGeneric("zipRDD", function(x, other) { standardGeneric("zipRDD") })
219219

220+
#' @rdname zipRDD
221+
#' @export
222+
setGeneric("zipPartitions", function(..., func) { standardGeneric("zipPartitions") },
223+
signature = "...")
224+
220225
#' @rdname zipWithIndex
221226
#' @seealso zipWithUniqueId
222227
#' @export

R/pkg/inst/tests/test_binary_function.R

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,36 @@ test_that("cogroup on two RDDs", {
6666
expect_equal(sortKeyValueList(actual),
6767
sortKeyValueList(expected))
6868
})
69+
70+
test_that("zipPartitions() on RDDs", {
71+
rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2
72+
rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4
73+
rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6
74+
actual <- collect(zipPartitions(rdd1, rdd2, rdd3,
75+
func = function(x, y, z) { list(list(x, y, z))} ))
76+
expect_equal(actual,
77+
list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))))
78+
79+
mockFile = c("Spark is pretty.", "Spark is awesome.")
80+
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
81+
writeLines(mockFile, fileName)
82+
83+
rdd <- textFile(sc, fileName, 1)
84+
actual <- collect(zipPartitions(rdd, rdd,
85+
func = function(x, y) { list(paste(x, y, sep = "\n")) }))
86+
expected <- list(paste(mockFile, mockFile, sep = "\n"))
87+
expect_equal(actual, expected)
88+
89+
rdd1 <- parallelize(sc, 0:1, 1)
90+
actual <- collect(zipPartitions(rdd1, rdd,
91+
func = function(x, y) { list(x + nchar(y)) }))
92+
expected <- list(0:1 + nchar(mockFile))
93+
expect_equal(actual, expected)
94+
95+
rdd <- map(rdd, function(x) { x })
96+
actual <- collect(zipPartitions(rdd, rdd1,
97+
func = function(x, y) { list(y + nchar(x)) }))
98+
expect_equal(actual, expected)
99+
100+
unlink(fileName)
101+
})

R/pkg/inst/tests/test_rdd.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,11 @@ test_that("collectAsMap() on a pairwise RDD", {
759759
expect_equal(vals, list(`1` = "a", `2` = "b"))
760760
})
761761

762+
test_that("show()", {
763+
rdd <- parallelize(sc, list(1:10))
764+
expect_output(show(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+")
765+
})
766+
762767
test_that("sampleByKey() on pairwise RDDs", {
763768
rdd <- parallelize(sc, 1:2000)
764769
pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) })

core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
7676

7777
private var timeoutCheckingTask: ScheduledFuture[_] = null
7878

79-
private val timeoutCheckingThread =
80-
ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-timeout-checking-thread")
79+
// "eventLoopThread" is used to run some pretty fast actions. The actions running in it should not
80+
// block the thread for a long time.
81+
private val eventLoopThread =
82+
ThreadUtils.newDaemonSingleThreadScheduledExecutor("heartbeat-receiver-event-loop-thread")
8183

8284
private val killExecutorThread = ThreadUtils.newDaemonSingleThreadExecutor("kill-executor-thread")
8385

8486
override def onStart(): Unit = {
85-
timeoutCheckingTask = timeoutCheckingThread.scheduleAtFixedRate(new Runnable {
87+
timeoutCheckingTask = eventLoopThread.scheduleAtFixedRate(new Runnable {
8688
override def run(): Unit = Utils.tryLogNonFatalError {
8789
Option(self).foreach(_.send(ExpireDeadHosts))
8890
}
@@ -99,11 +101,15 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
99101
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
100102
case heartbeat @ Heartbeat(executorId, taskMetrics, blockManagerId) =>
101103
if (scheduler != null) {
102-
val unknownExecutor = !scheduler.executorHeartbeatReceived(
103-
executorId, taskMetrics, blockManagerId)
104-
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
105104
executorLastSeen(executorId) = System.currentTimeMillis()
106-
context.reply(response)
105+
eventLoopThread.submit(new Runnable {
106+
override def run(): Unit = Utils.tryLogNonFatalError {
107+
val unknownExecutor = !scheduler.executorHeartbeatReceived(
108+
executorId, taskMetrics, blockManagerId)
109+
val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
110+
context.reply(response)
111+
}
112+
})
107113
} else {
108114
// Because Executor will sleep several seconds before sending the first "Heartbeat", this
109115
// case rarely happens. However, if it really happens, log it and ask the executor to
@@ -125,7 +131,9 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
125131
if (sc.supportDynamicAllocation) {
126132
// Asynchronously kill the executor to avoid blocking the current thread
127133
killExecutorThread.submit(new Runnable {
128-
override def run(): Unit = sc.killExecutor(executorId)
134+
override def run(): Unit = Utils.tryLogNonFatalError {
135+
sc.killExecutor(executorId)
136+
}
129137
})
130138
}
131139
executorLastSeen.remove(executorId)
@@ -137,7 +145,7 @@ private[spark] class HeartbeatReceiver(sc: SparkContext)
137145
if (timeoutCheckingTask != null) {
138146
timeoutCheckingTask.cancel(true)
139147
}
140-
timeoutCheckingThread.shutdownNow()
148+
eventLoopThread.shutdownNow()
141149
killExecutorThread.shutdownNow()
142150
}
143151
}

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
223223
private var _listenerBusStarted: Boolean = false
224224
private var _jars: Seq[String] = _
225225
private var _files: Seq[String] = _
226+
private var _shutdownHookRef: AnyRef = _
226227

227228
/* ------------------------------------------------------------------------------------- *
228229
| Accessors and public fields. These provide access to the internal state of the |
@@ -517,6 +518,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
517518
_taskScheduler.postStartHook()
518519
_env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler))
519520
_env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager))
521+
522+
// Make sure the context is stopped if the user forgets about it. This avoids leaving
523+
// unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM
524+
// is killed, though.
525+
_shutdownHookRef = Utils.addShutdownHook(Utils.SPARK_CONTEXT_SHUTDOWN_PRIORITY) { () =>
526+
logInfo("Invoking stop() from shutdown hook")
527+
stop()
528+
}
520529
} catch {
521530
case NonFatal(e) =>
522531
logError("Error initializing SparkContext.", e)
@@ -1055,7 +1064,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
10551064
/** Build the union of a list of RDDs. */
10561065
def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = {
10571066
val partitioners = rdds.flatMap(_.partitioner).toSet
1058-
if (partitioners.size == 1) {
1067+
if (rdds.forall(_.partitioner.isDefined) && partitioners.size == 1) {
10591068
new PartitionerAwareUnionRDD(this, rdds)
10601069
} else {
10611070
new UnionRDD(this, rdds)
@@ -1481,6 +1490,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
14811490
logInfo("SparkContext already stopped.")
14821491
return
14831492
}
1493+
if (_shutdownHookRef != null) {
1494+
Utils.removeShutdownHook(_shutdownHookRef)
1495+
}
14841496

14851497
postApplicationEnd()
14861498
_ui.foreach(_.stop())
@@ -1891,7 +1903,7 @@ object SparkContext extends Logging {
18911903
*
18921904
* Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK.
18931905
*/
1894-
private val activeContext: AtomicReference[SparkContext] =
1906+
private val activeContext: AtomicReference[SparkContext] =
18951907
new AtomicReference[SparkContext](null)
18961908

18971909
/**
@@ -1944,11 +1956,11 @@ object SparkContext extends Logging {
19441956
}
19451957

19461958
/**
1947-
* This function may be used to get or instantiate a SparkContext and register it as a
1948-
* singleton object. Because we can only have one active SparkContext per JVM,
1949-
* this is useful when applications may wish to share a SparkContext.
1959+
* This function may be used to get or instantiate a SparkContext and register it as a
1960+
* singleton object. Because we can only have one active SparkContext per JVM,
1961+
* this is useful when applications may wish to share a SparkContext.
19501962
*
1951-
* Note: This function cannot be used to create multiple SparkContext instances
1963+
* Note: This function cannot be used to create multiple SparkContext instances
19521964
* even if multiple contexts are allowed.
19531965
*/
19541966
def getOrCreate(config: SparkConf): SparkContext = {
@@ -1961,17 +1973,17 @@ object SparkContext extends Logging {
19611973
activeContext.get()
19621974
}
19631975
}
1964-
1976+
19651977
/**
1966-
* This function may be used to get or instantiate a SparkContext and register it as a
1967-
* singleton object. Because we can only have one active SparkContext per JVM,
1978+
* This function may be used to get or instantiate a SparkContext and register it as a
1979+
* singleton object. Because we can only have one active SparkContext per JVM,
19681980
* this is useful when applications may wish to share a SparkContext.
1969-
*
1981+
*
19701982
* This method allows not passing a SparkConf (useful if just retrieving).
1971-
*
1972-
* Note: This function cannot be used to create multiple SparkContext instances
1973-
* even if multiple contexts are allowed.
1974-
*/
1983+
*
1984+
* Note: This function cannot be used to create multiple SparkContext instances
1985+
* even if multiple contexts are allowed.
1986+
*/
19751987
def getOrCreate(): SparkContext = {
19761988
getOrCreate(new SparkConf())
19771989
}

core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class PartitionerAwareUnionRDD[T: ClassTag](
6060
var rdds: Seq[RDD[T]]
6161
) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) {
6262
require(rdds.length > 0)
63+
require(rdds.forall(_.partitioner.isDefined))
6364
require(rdds.flatMap(_.partitioner).toSet.size == 1,
6465
"Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
6566

core/src/main/scala/org/apache/spark/util/SizeEstimator.scala

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ private[spark] object SizeEstimator extends Logging {
179179
}
180180

181181
// Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling.
182-
private val ARRAY_SIZE_FOR_SAMPLING = 200
182+
private val ARRAY_SIZE_FOR_SAMPLING = 400
183183
private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
184184

185185
private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) {
@@ -204,25 +204,40 @@ private[spark] object SizeEstimator extends Logging {
204204
}
205205
} else {
206206
// Estimate the size of a large array by sampling elements without replacement.
207-
var size = 0.0
207+
// To exclude the shared objects that the array elements may link, sample twice
208+
// and use the min one to caculate array size.
208209
val rand = new Random(42)
209-
val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE)
210-
var numElementsDrawn = 0
211-
while (numElementsDrawn < ARRAY_SAMPLE_SIZE) {
212-
var index = 0
213-
do {
214-
index = rand.nextInt(length)
215-
} while (drawn.contains(index))
216-
drawn.add(index)
217-
val elem = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef]
218-
size += SizeEstimator.estimate(elem, state.visited)
219-
numElementsDrawn += 1
220-
}
221-
state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong
210+
val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE)
211+
val s1 = sampleArray(array, state, rand, drawn, length)
212+
val s2 = sampleArray(array, state, rand, drawn, length)
213+
val size = math.min(s1, s2)
214+
state.size += math.max(s1, s2) +
215+
(size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong
222216
}
223217
}
224218
}
225219

220+
private def sampleArray(
221+
array: AnyRef,
222+
state: SearchState,
223+
rand: Random,
224+
drawn: OpenHashSet[Int],
225+
length: Int): Long = {
226+
var size = 0L
227+
for (i <- 0 until ARRAY_SAMPLE_SIZE) {
228+
var index = 0
229+
do {
230+
index = rand.nextInt(length)
231+
} while (drawn.contains(index))
232+
drawn.add(index)
233+
val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef]
234+
if (obj != null) {
235+
size += SizeEstimator.estimate(obj, state.visited).toLong
236+
}
237+
}
238+
size
239+
}
240+
226241
private def primitiveSize(cls: Class[_]): Long = {
227242
if (cls == classOf[Byte]) {
228243
BYTE_SIZE

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ private[spark] object Utils extends Logging {
6767

6868
val DEFAULT_SHUTDOWN_PRIORITY = 100
6969

70+
/**
71+
* The shutdown priority of the SparkContext instance. This is lower than the default
72+
* priority, so that by default hooks are run before the context is shut down.
73+
*/
74+
val SPARK_CONTEXT_SHUTDOWN_PRIORITY = 50
75+
7076
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
7177
@volatile private var localRootDirs: Array[String] = null
7278

@@ -2116,7 +2122,7 @@ private[spark] object Utils extends Logging {
21162122
* @return A handle that can be used to unregister the shutdown hook.
21172123
*/
21182124
def addShutdownHook(hook: () => Unit): AnyRef = {
2119-
addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY, hook)
2125+
addShutdownHook(DEFAULT_SHUTDOWN_PRIORITY)(hook)
21202126
}
21212127

21222128
/**
@@ -2126,7 +2132,7 @@ private[spark] object Utils extends Logging {
21262132
* @param hook The code to run during shutdown.
21272133
* @return A handle that can be used to unregister the shutdown hook.
21282134
*/
2129-
def addShutdownHook(priority: Int, hook: () => Unit): AnyRef = {
2135+
def addShutdownHook(priority: Int)(hook: () => Unit): AnyRef = {
21302136
shutdownHooks.add(priority, hook)
21312137
}
21322138

0 commit comments

Comments
 (0)