Skip to content

Commit 7f487c8

Browse files
Sun Ruishivaram
authored andcommitted
[SPARK-6797] [SPARKR] Add support for YARN cluster mode.
This PR enables SparkR to dynamically ship the SparkR binary package to the AM node in YARN cluster mode, thus it is no longer required that the SparkR package be installed on each worker node. This PR uses the JDK jar tool to package the SparkR package, because jar is thought to be available on both Linux/Windows platforms where JDK has been installed. This PR does not address the R worker involved in RDD API. Will address it in a separate JIRA issue. This PR does not address SBT build. SparkR installation and packaging by SBT will be addressed in a separate JIRA issue. R/install-dev.bat is not tested. shivaram , Could you help to test it? Author: Sun Rui <[email protected]> Closes apache#6743 from sun-rui/SPARK-6797 and squashes the following commits: ca63c86 [Sun Rui] Adjust MimaExcludes after rebase. 7313374 [Sun Rui] Fix unit test errors. 72695fb [Sun Rui] Fix unit test failures. 193882f [Sun Rui] Fix Mima test error. fe25a33 [Sun Rui] Fix Mima test error. 35ecfa3 [Sun Rui] Fix comments. c38a005 [Sun Rui] Unzipped SparkR binary package is still required for standalone and Mesos modes. b05340c [Sun Rui] Fix scala style. 2ca5048 [Sun Rui] Fix comments. 1acefd1 [Sun Rui] Fix scala style. 0aa1e97 [Sun Rui] Fix scala style. 41d4f17 [Sun Rui] Add support for locating SparkR package for R workers required by RDD APIs. 49ff948 [Sun Rui] Invoke jar.exe with full path in install-dev.bat. 7b916c5 [Sun Rui] Use 'rem' consistently. 3bed438 [Sun Rui] Add a comment. 681afb0 [Sun Rui] Fix a bug that RRunner does not handle client deployment modes. cedfbe2 [Sun Rui] [SPARK-6797][SPARKR] Add support for YARN cluster mode.
1 parent a5bc803 commit 7f487c8

File tree

15 files changed

+133
-54
lines changed

15 files changed

+133
-54
lines changed

R/install-dev.bat

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,8 @@ set SPARK_HOME=%~dp0..
2525
MKDIR %SPARK_HOME%\R\lib
2626

2727
R.exe CMD INSTALL --library="%SPARK_HOME%\R\lib" %SPARK_HOME%\R\pkg\
28+
29+
rem Zip the SparkR package so that it can be distributed to worker nodes on YARN
30+
pushd %SPARK_HOME%\R\lib
31+
%JAVA_HOME%\bin\jar.exe cfM "%SPARK_HOME%\R\lib\sparkr.zip" SparkR
32+
popd

R/install-dev.sh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,16 @@ LIB_DIR="$FWDIR/lib"
3434

3535
mkdir -p $LIB_DIR
3636

37-
pushd $FWDIR
37+
pushd $FWDIR > /dev/null
3838

3939
# Generate Rd files if devtools is installed
4040
Rscript -e ' if("devtools" %in% rownames(installed.packages())) { library(devtools); devtools::document(pkg="./pkg", roclets=c("rd")) }'
4141

4242
# Install SparkR to $LIB_DIR
4343
R CMD INSTALL --library=$LIB_DIR $FWDIR/pkg/
4444

45-
popd
45+
# Zip the SparkR package so that it can be distributed to worker nodes on YARN
46+
cd $LIB_DIR
47+
jar cfM "$LIB_DIR/sparkr.zip" SparkR
48+
49+
popd > /dev/null

R/pkg/DESCRIPTION

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,3 @@ Collate:
3232
'serialize.R'
3333
'sparkR.R'
3434
'utils.R'
35-
'zzz.R'

R/pkg/R/RDD.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
165165
serializedFuncArr,
166166
rdd@env$prev_serializedMode,
167167
packageNamesArr,
168-
as.character(.sparkREnv[["libname"]]),
169168
broadcastArr,
170169
callJMethod(prev_jrdd, "classTag"))
171170
} else {
@@ -175,7 +174,6 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"),
175174
rdd@env$prev_serializedMode,
176175
serializedMode,
177176
packageNamesArr,
178-
as.character(.sparkREnv[["libname"]]),
179177
broadcastArr,
180178
callJMethod(prev_jrdd, "classTag"))
181179
}

R/pkg/R/pairRDD.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,6 @@ setMethod("partitionBy",
215215
serializedHashFuncBytes,
216216
getSerializedMode(x),
217217
packageNamesArr,
218-
as.character(.sparkREnv$libname),
219218
broadcastArr,
220219
callJMethod(jrdd, "classTag"))
221220

R/pkg/R/sparkR.R

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@
1717

1818
.sparkREnv <- new.env()
1919

20-
sparkR.onLoad <- function(libname, pkgname) {
21-
.sparkREnv$libname <- libname
22-
}
23-
2420
# Utility function that returns TRUE if we have an active connection to the
2521
# backend and FALSE otherwise
2622
connExists <- function(env) {
@@ -80,7 +76,6 @@ sparkR.stop <- function() {
8076
#' @param sparkEnvir Named list of environment variables to set on worker nodes.
8177
#' @param sparkExecutorEnv Named list of environment variables to be used when launching executors.
8278
#' @param sparkJars Character string vector of jar files to pass to the worker nodes.
83-
#' @param sparkRLibDir The path where R is installed on the worker nodes.
8479
#' @param sparkPackages Character string vector of packages from spark-packages.org
8580
#' @export
8681
#' @examples
@@ -101,7 +96,6 @@ sparkR.init <- function(
10196
sparkEnvir = list(),
10297
sparkExecutorEnv = list(),
10398
sparkJars = "",
104-
sparkRLibDir = "",
10599
sparkPackages = "") {
106100

107101
if (exists(".sparkRjsc", envir = .sparkREnv)) {
@@ -170,10 +164,6 @@ sparkR.init <- function(
170164
sparkHome <- normalizePath(sparkHome)
171165
}
172166

173-
if (nchar(sparkRLibDir) != 0) {
174-
.sparkREnv$libname <- sparkRLibDir
175-
}
176-
177167
sparkEnvirMap <- new.env()
178168
for (varname in names(sparkEnvir)) {
179169
sparkEnvirMap[[varname]] <- sparkEnvir[[varname]]

R/pkg/R/zzz.R

Lines changed: 0 additions & 20 deletions
This file was deleted.

R/pkg/inst/profile/general.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717

1818
.First <- function() {
19-
home <- Sys.getenv("SPARK_HOME")
20-
.libPaths(c(file.path(home, "R", "lib"), .libPaths()))
19+
packageDir <- Sys.getenv("SPARKR_PACKAGE_DIR")
20+
.libPaths(c(packageDir, .libPaths()))
2121
Sys.setenv(NOAWT=1)
2222
}

core/src/main/scala/org/apache/spark/api/r/RRDD.scala

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
3939
deserializer: String,
4040
serializer: String,
4141
packageNames: Array[Byte],
42-
rLibDir: String,
4342
broadcastVars: Array[Broadcast[Object]])
4443
extends RDD[U](parent) with Logging {
4544
protected var dataStream: DataInputStream = _
@@ -60,7 +59,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
6059

6160
// The stdout/stderr is shared by multiple tasks, because we use one daemon
6261
// to launch child process as worker.
63-
val errThread = RRDD.createRWorker(rLibDir, listenPort)
62+
val errThread = RRDD.createRWorker(listenPort)
6463

6564
// We use two sockets to separate input and output, then it's easy to manage
6665
// the lifecycle of them to avoid deadlock.
@@ -235,11 +234,10 @@ private class PairwiseRRDD[T: ClassTag](
235234
hashFunc: Array[Byte],
236235
deserializer: String,
237236
packageNames: Array[Byte],
238-
rLibDir: String,
239237
broadcastVars: Array[Object])
240238
extends BaseRRDD[T, (Int, Array[Byte])](
241239
parent, numPartitions, hashFunc, deserializer,
242-
SerializationFormats.BYTE, packageNames, rLibDir,
240+
SerializationFormats.BYTE, packageNames,
243241
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
244242

245243
override protected def readData(length: Int): (Int, Array[Byte]) = {
@@ -266,10 +264,9 @@ private class RRDD[T: ClassTag](
266264
deserializer: String,
267265
serializer: String,
268266
packageNames: Array[Byte],
269-
rLibDir: String,
270267
broadcastVars: Array[Object])
271268
extends BaseRRDD[T, Array[Byte]](
272-
parent, -1, func, deserializer, serializer, packageNames, rLibDir,
269+
parent, -1, func, deserializer, serializer, packageNames,
273270
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
274271

275272
override protected def readData(length: Int): Array[Byte] = {
@@ -293,10 +290,9 @@ private class StringRRDD[T: ClassTag](
293290
func: Array[Byte],
294291
deserializer: String,
295292
packageNames: Array[Byte],
296-
rLibDir: String,
297293
broadcastVars: Array[Object])
298294
extends BaseRRDD[T, String](
299-
parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
295+
parent, -1, func, deserializer, SerializationFormats.STRING, packageNames,
300296
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
301297

302298
override protected def readData(length: Int): String = {
@@ -392,9 +388,10 @@ private[r] object RRDD {
392388
thread
393389
}
394390

395-
private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = {
391+
private def createRProcess(port: Int, script: String): BufferedStreamThread = {
396392
val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript")
397393
val rOptions = "--vanilla"
394+
val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
398395
val rExecScript = rLibDir + "/SparkR/worker/" + script
399396
val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript))
400397
// Unset the R_TESTS environment variable for workers.
@@ -413,15 +410,15 @@ private[r] object RRDD {
413410
/**
414411
* ProcessBuilder used to launch worker R processes.
415412
*/
416-
def createRWorker(rLibDir: String, port: Int): BufferedStreamThread = {
413+
def createRWorker(port: Int): BufferedStreamThread = {
417414
val useDaemon = SparkEnv.get.conf.getBoolean("spark.sparkr.use.daemon", true)
418415
if (!Utils.isWindows && useDaemon) {
419416
synchronized {
420417
if (daemonChannel == null) {
421418
// we expect one connections
422419
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
423420
val daemonPort = serverSocket.getLocalPort
424-
errThread = createRProcess(rLibDir, daemonPort, "daemon.R")
421+
errThread = createRProcess(daemonPort, "daemon.R")
425422
// the socket used to send out the input of task
426423
serverSocket.setSoTimeout(10000)
427424
val sock = serverSocket.accept()
@@ -443,7 +440,7 @@ private[r] object RRDD {
443440
errThread
444441
}
445442
} else {
446-
createRProcess(rLibDir, port, "worker.R")
443+
createRProcess(port, "worker.R")
447444
}
448445
}
449446

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.api.r
19+
20+
import java.io.File
21+
22+
import org.apache.spark.{SparkEnv, SparkException}
23+
24+
private[spark] object RUtils {
25+
/**
26+
* Get the SparkR package path in the local spark distribution.
27+
*/
28+
def localSparkRPackagePath: Option[String] = {
29+
val sparkHome = sys.env.get("SPARK_HOME")
30+
sparkHome.map(
31+
Seq(_, "R", "lib").mkString(File.separator)
32+
)
33+
}
34+
35+
/**
36+
* Get the SparkR package path in various deployment modes.
37+
* This assumes that Spark properties `spark.master` and `spark.submit.deployMode`
38+
* and environment variable `SPARK_HOME` are set.
39+
*/
40+
def sparkRPackagePath(isDriver: Boolean): String = {
41+
val (master, deployMode) =
42+
if (isDriver) {
43+
(sys.props("spark.master"), sys.props("spark.submit.deployMode"))
44+
} else {
45+
val sparkConf = SparkEnv.get.conf
46+
(sparkConf.get("spark.master"), sparkConf.get("spark.submit.deployMode"))
47+
}
48+
49+
val isYarnCluster = master.contains("yarn") && deployMode == "cluster"
50+
val isYarnClient = master.contains("yarn") && deployMode == "client"
51+
52+
// In YARN mode, the SparkR package is distributed as an archive symbolically
53+
// linked to the "sparkr" file in the current directory. Note that this does not apply
54+
// to the driver in client mode because it is run outside of the cluster.
55+
if (isYarnCluster || (isYarnClient && !isDriver)) {
56+
new File("sparkr").getAbsolutePath
57+
} else {
58+
// Otherwise, assume the package is local
59+
// TODO: support this for Mesos
60+
localSparkRPackagePath.getOrElse {
61+
throw new SparkException("SPARK_HOME not set. Can't locate SparkR package.")
62+
}
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)