Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/spark-submit.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@ rem
rem This is the entry point for running Spark submit. To avoid polluting the
rem environment, it just launches a new cmd to do the real work.

cmd /V /E /C "%~dp0spark-submit2.cmd" %*
cmd /V /E /C ""%~dp0spark-submit2.cmd" %*"
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,11 @@ private[deploy] class DriverRunner(
Files.append(header, stderr, StandardCharsets.UTF_8)
CommandUtils.redirectStream(process.getErrorStream, stderr)
}

if (Utils.isWindows) {
Utils.shortenClasspath(builder)
}

runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@ private[deploy] class ExecutorRunner(
val builder = CommandUtils.buildProcessBuilder(appDesc.command, new SecurityManager(conf),
memory, sparkHome.getAbsolutePath, substituteVariables)
val command = builder.command()

if (Utils.isWindows) {
Utils.shortenClasspath(builder)
}

val formattedCommand = command.asScala.mkString("\"", "\" \"", "\"")
logInfo(s"Launch command: $formattedCommand")

Expand Down
52 changes: 51 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import java.nio.ByteBuffer
import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.util
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
import javax.net.ssl.HttpsURLConnection
Expand Down Expand Up @@ -109,7 +110,7 @@ private[spark] object Utils extends Logging {
}

/** Deserialize a Long value (used for [[org.apache.spark.api.python.PythonPartitioner]]) */
def deserializeLongValue(bytes: Array[Byte]) : Long = {
def deserializeLongValue(bytes: Array[Byte]): Long = {
// Note: we assume that we are given a Long value encoded in network (big-endian) byte order
var result = bytes(7) & 0xFFL
result = result + ((bytes(6) & 0xFFL) << 8)
Expand Down Expand Up @@ -1090,6 +1091,50 @@ private[spark] object Utils extends Logging {
bytesToString(megabytes * 1024L * 1024L)
}


/**
* Create a jar file at the given path, containing a manifest with a classpath
* that references all specified entries.
*/
def createShortClassPath(tempDir: File, classPath: String): String = {
if (isWindows) {
val env = new util.HashMap[String, String](System.getenv())
val javaCps = FileUtil
.createJarWithClassPath(classPath, new Path(tempDir.getAbsolutePath), env)
val javaCpStr = javaCps(0) + javaCps(1)
logInfo("Shorten the class path to: " + javaCpStr)
javaCpStr
} else {
classPath
}
}

def createShortClassPath(classPath: String): String = {
val tempDir = createTempDir("classpaths")
createShortClassPath(tempDir, classPath)
}

/**
* Create a jar file at the given path, containing a manifest with a classpath
* that references all specified entries.
*/
def shortenClasspath(builder: ProcessBuilder): Unit = {
if (builder.command.asScala.mkString("\"", "\" \"", "\"").length > 8190) {
logWarning("Cmd too long, try to shorten the classpath")
// look for the class path
// note that environment set in teh ProcessBuilder is process-local. So it
// won't pollute the environment
val command = builder.command()
val idxCp = command.indexOf("-cp")
if (idxCp > 0 && idxCp + 1 < command.size()) {
val classPath = command.get(idxCp + 1)
val shortPath = createShortClassPath(classPath)
command.set(idxCp + 1, shortPath)
}
}
}


/**
* Execute a command and return the process running the command.
*/
Expand All @@ -1103,6 +1148,11 @@ private[spark] object Utils extends Logging {
for ((key, value) <- extraEnvironment) {
environment.put(key, value)
}

if (Utils.isWindows) {
Utils.shortenClasspath(builder)
}

val process = builder.start()
if (redirectStderr) {
val threadName = "redirect stderr for command " + command(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.launcher;

import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.lang.SystemUtils;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.slf4j.bridge.SLF4JBridgeHandler;
import static org.junit.Assert.*;

import org.apache.spark.util.Utils;

/**
* These tests require the Spark assembly to be built before they can be run.
*/
Expand Down Expand Up @@ -99,15 +103,28 @@ public void testChildProcLauncher() throws Exception {
String.format("%s=-Dfoo=ShouldBeOverriddenBelow", SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS))
.setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS,
"-Dfoo=bar -Dtest.appender=childproc")
.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"))
.addSparkArg(opts.CLASS, "ShouldBeOverriddenBelow")
.setMainClass(SparkLauncherTestApp.class.getName())
.addAppArgs("proc");
final Process app = launcher.launch();

new OutputRedirector(app.getInputStream(), TF);
new OutputRedirector(app.getErrorStream(), TF);
assertEquals(0, app.waitFor());
File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
try {
if (SystemUtils.IS_OS_WINDOWS) {
launcher.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH,
Utils.createShortClassPath(tempDir, System.getProperty("java.class.path")));
} else {
launcher.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"));
}
final Process app = launcher.launch();

new OutputRedirector(app.getInputStream(), TF);
new OutputRedirector(app.getErrorStream(), TF);
assertEquals(0, app.waitFor());
} finally {
if(tempDir.exists()) {
Utils.deleteRecursively(tempDir);
}
}
}

public static class SparkLauncherTestApp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually._

import org.apache.spark._
import org.apache.spark.util.Utils

class LauncherBackendSuite extends SparkFunSuite with Matchers {

private val tests = Seq(
"local" -> "local",
"standalone/client" -> "local-cluster[1,1,1024]")

val tempDir = Utils.createTempDir()

tests.foreach { case (name, master) =>
test(s"$name: launcher handle") {
testWithMaster(master)
Expand All @@ -42,16 +45,22 @@ class LauncherBackendSuite extends SparkFunSuite with Matchers {
private def testWithMaster(master: String): Unit = {
val env = new java.util.HashMap[String, String]()
env.put("SPARK_PRINT_LAUNCH_COMMAND", "1")
val handle = new SparkLauncher(env)

val launcher = new SparkLauncher(env)
.setSparkHome(sys.props("spark.test.home"))
.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"))
.setConf("spark.ui.enabled", "false")
.setConf(SparkLauncher.DRIVER_EXTRA_JAVA_OPTIONS, s"-Dtest.appender=console")
.setMaster(master)
.setAppResource("spark-internal")
.setMainClass(TestApp.getClass.getName().stripSuffix("$"))
.startApplication()
if (Utils.isWindows) {
launcher.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH,
Utils.createShortClassPath(tempDir, System.getProperty("java.class.path")))
} else {
launcher.setConf(SparkLauncher.DRIVER_EXTRA_CLASSPATH, System.getProperty("java.class.path"))
}

val handle = launcher.startApplication()
try {
eventually(timeout(30 seconds), interval(100 millis)) {
handle.getAppId() should not be (null)
Expand Down
Loading