|
-
+
App ID
|
-
+
App Name
|
-
+
Attempt ID
|
-
+
Started
|
-
-
+ |
+
Completed
|
-
+
Duration
|
-
+
Spark User
|
-
+
Last Updated
|
-
+
Event Log
|
@@ -73,11 +73,11 @@
{{#attempts}}
{{attemptId}} |
{{startTime}} |
- {{endTime}} |
+ {{endTime}} |
{{duration}} |
{{sparkUser}} |
{{lastUpdated}} |
- Download |
+ Download |
{{/attempts}}
{{/applications}}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
index 54810edaf146..5ec1ce15a212 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js
@@ -120,6 +120,9 @@ $(document).ready(function() {
attempt["startTime"] = formatDate(attempt["startTime"]);
attempt["endTime"] = formatDate(attempt["endTime"]);
attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]);
+ attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" +
+ (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs";
+
var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]};
array.push(app_clone);
}
@@ -174,6 +177,13 @@ $(document).ready(function() {
}
}
+ if (requestedIncomplete) {
+ var completedCells = document.getElementsByClassName("completedColumn");
+ for (i = 0; i < completedCells.length; i++) {
+ completedCells[i].style.display='none';
+ }
+ }
+
var durationCells = document.getElementsByClassName("durationClass");
for (i = 0; i < durationCells.length; i++) {
var timeInMilliseconds = parseInt(durationCells[i].title);
@@ -185,7 +195,7 @@ $(document).ready(function() {
}
$(selector).DataTable(conf);
- $('#hisotry-summary [data-toggle="tooltip"]').tooltip();
+ $('#history-summary [data-toggle="tooltip"]').tooltip();
});
});
});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
index ff241470f32d..9960d5c34d1f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/sorttable.js
@@ -207,8 +207,8 @@ sorttable = {
hasInputs = (typeof node.getElementsByTagName == 'function') &&
node.getElementsByTagName('input').length;
-
- if (node.getAttribute("sorttable_customkey") != null) {
+
+ if (node.nodeType == 1 && node.getAttribute("sorttable_customkey") != null) {
return node.getAttribute("sorttable_customkey");
}
else if (typeof node.textContent != 'undefined' && !hasInputs) {
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 261b3329a7b9..fcc72ff49276 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager(
val delta = addExecutors(maxNeeded)
logDebug(s"Starting timer to add more executors (to " +
s"expire in $sustainedSchedulerBacklogTimeoutS seconds)")
- addTime += sustainedSchedulerBacklogTimeoutS * 1000
+ addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000)
delta
} else {
0
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index a50600f1488c..089969398801 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -261,7 +261,7 @@ class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S
private def getImpl(timeout: Duration): T = {
// This will throw TimeoutException on timeout:
- Await.ready(futureAction, timeout)
+ ThreadUtils.awaitReady(futureAction, timeout)
futureAction.value.get match {
case scala.util.Success(value) => converter(value)
case scala.util.Failure(exception) =>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 99efc4893fda..1a2443f7ee78 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1350,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging {
@deprecated("use AccumulatorV2", "2.0.0")
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T])
: Accumulator[T] = {
- val acc = new Accumulator(initialValue, param, Some(name))
+ val acc = new Accumulator(initialValue, param, Option(name))
cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging {
@deprecated("use AccumulatorV2", "2.0.0")
def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T])
: Accumulable[R, T] = {
- val acc = new Accumulable(initialValue, param, Some(name))
+ val acc = new Accumulable(initialValue, param, Option(name))
cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc))
acc
}
@@ -1414,7 +1414,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @note Accumulators must be registered before use, or it will throw exception.
*/
def register(acc: AccumulatorV2[_, _], name: String): Unit = {
- acc.register(this, name = Some(name))
+ acc.register(this, name = Option(name))
}
/**
@@ -1734,6 +1734,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Return information about blocks stored in all of the slaves
*/
@DeveloperApi
+ @deprecated("This method may change or be removed in a future release.", "2.2.0")
def getExecutorStorageStatus: Array[StorageStatus] = {
assertNotStopped()
env.blockManager.master.getStorageStatus
@@ -1800,40 +1801,39 @@ class SparkContext(config: SparkConf) extends Logging {
* an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
+ def addJarFile(file: File): String = {
+ try {
+ if (!file.exists()) {
+ throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found")
+ }
+ if (file.isDirectory) {
+ throw new IllegalArgumentException(
+ s"Directory ${file.getAbsoluteFile} is not allowed for addJar")
+ }
+ env.rpcEnv.fileServer.addJar(file)
+ } catch {
+ case NonFatal(e) =>
+ logError(s"Failed to add $path to Spark environment", e)
+ null
+ }
+ }
+
if (path == null) {
logWarning("null specified as parameter to addJar")
} else {
- var key = ""
- if (path.contains("\\")) {
+ val key = if (path.contains("\\")) {
// For local paths with backslashes on Windows, URI throws an exception
- key = env.rpcEnv.fileServer.addJar(new File(path))
+ addJarFile(new File(path))
} else {
val uri = new URI(path)
// SPARK-17650: Make sure this is a valid URL before adding it to the list of dependencies
Utils.validateURL(uri)
- key = uri.getScheme match {
+ uri.getScheme match {
// A JAR file which exists only on the driver node
- case null | "file" =>
- try {
- val file = new File(uri.getPath)
- if (!file.exists()) {
- throw new FileNotFoundException(s"Jar ${file.getAbsolutePath} not found")
- }
- if (file.isDirectory) {
- throw new IllegalArgumentException(
- s"Directory ${file.getAbsoluteFile} is not allowed for addJar")
- }
- env.rpcEnv.fileServer.addJar(new File(uri.getPath))
- } catch {
- case NonFatal(e) =>
- logError(s"Failed to add $path to Spark environment", e)
- null
- }
+ case null | "file" => addJarFile(new File(uri.getPath))
// A JAR file which exists locally on every worker node
- case "local" =>
- "file:" + uri.getPath
- case _ =>
- path
+ case "local" => "file:" + uri.getPath
+ case _ => path
}
}
if (key != null) {
@@ -1938,6 +1938,9 @@ class SparkContext(config: SparkConf) extends Logging {
}
SparkEnv.set(null)
}
+ // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this
+ // `SparkContext` is stopped.
+ localProperties.remove()
// Unset YARN mode system env variable, to allow switching between cluster types.
System.clearProperty("SPARK_YARN_MODE")
SparkContext.clearActiveContext()
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index 8cd1d1c96aa0..01d8973e1bb0 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -110,10 +110,10 @@ private[spark] class TaskContextImpl(
/** Marks the task as completed and triggers the completion listeners. */
@GuardedBy("this")
- private[spark] def markTaskCompleted(): Unit = synchronized {
+ private[spark] def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
if (completed) return
completed = true
- invokeListeners(onCompleteCallbacks, "TaskCompletionListener", None) {
+ invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
_.onTaskCompletion(this)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index ac09c6c497f8..fa35e4568819 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -24,7 +24,7 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
-import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef}
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
@@ -34,6 +34,16 @@ private[deploy] object DeployMessages {
// Worker to Master
+ /**
+ * @param id the worker id
+ * @param host the worker host
+ * @param port the worker post
+ * @param worker the worker endpoint ref
+ * @param cores the core number of worker
+ * @param memory the memory size of worker
+ * @param workerWebUiUrl the worker Web UI address
+ * @param masterAddress the master address used by the worker to connect
+ */
case class RegisterWorker(
id: String,
host: String,
@@ -41,7 +51,8 @@ private[deploy] object DeployMessages {
worker: RpcEndpointRef,
cores: Int,
memory: Int,
- workerWebUiUrl: String)
+ workerWebUiUrl: String,
+ masterAddress: RpcAddress)
extends DeployMessage {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
@@ -80,8 +91,16 @@ private[deploy] object DeployMessages {
sealed trait RegisterWorkerResponse
- case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage
- with RegisterWorkerResponse
+ /**
+ * @param master the master ref
+ * @param masterWebUiUrl the master Web UI address
+ * @param masterAddress the master address used by the worker to connect. It should be
+ * [[RegisterWorker.masterAddress]].
+ */
+ case class RegisteredWorker(
+ master: RpcEndpointRef,
+ masterWebUiUrl: String,
+ masterAddress: RpcAddress) extends DeployMessage with RegisterWorkerResponse
case class RegisterWorkerFailed(message: String) extends DeployMessage with RegisterWorkerResponse
diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
index 050778a895c0..7d356e8fc1c0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala
@@ -92,6 +92,9 @@ private[deploy] object RPackageUtils extends Logging {
* Exposed for testing.
*/
private[deploy] def checkManifestForR(jar: JarFile): Boolean = {
+ if (jar.getManifest == null) {
+ return false
+ }
val manifest = jar.getManifest.getMainAttributes
manifest.getValue(hasRPackage) != null && manifest.getValue(hasRPackage).trim == "true"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index bae7a3f307f5..6afe58bff522 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -23,11 +23,13 @@ import java.text.DateFormat
import java.util.{Arrays, Comparator, Date, Locale}
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.control.NonFatal
import com.google.common.primitives.Longs
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
+import org.apache.hadoop.fs.permission.FsAction
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.{Credentials, UserGroupInformation}
import org.apache.hadoop.security.token.{Token, TokenIdentifier}
@@ -142,14 +144,29 @@ class SparkHadoopUtil extends Logging {
* Returns a function that can be called to find Hadoop FileSystem bytes read. If
* getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
* return the bytes read on r since t.
- *
- * @return None if the required method can't be found.
*/
private[spark] def getFSBytesReadOnThreadCallback(): () => Long = {
- val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics)
- val f = () => threadStats.map(_.getBytesRead).sum
- val baselineBytesRead = f()
- () => f() - baselineBytesRead
+ val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum
+ val baseline = (Thread.currentThread().getId, f())
+
+ /**
+ * This function may be called in both spawned child threads and parent task thread (in
+ * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics.
+ * So we need a map to track the bytes read from the child threads and parent thread,
+ * summing them together to get the bytes read of this task.
+ */
+ new Function0[Long] {
+ private val bytesReadMap = new mutable.HashMap[Long, Long]()
+
+ override def apply(): Long = {
+ bytesReadMap.synchronized {
+ bytesReadMap.put(Thread.currentThread().getId, f())
+ bytesReadMap.map { case (k, v) =>
+ v - (if (k == baseline._1) baseline._2 else 0)
+ }.sum
+ }
+ }
+ }
}
/**
@@ -353,6 +370,28 @@ class SparkHadoopUtil extends Logging {
}
buffer.toString
}
+
+ private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = {
+ val perm = status.getPermission
+ val ugi = UserGroupInformation.getCurrentUser
+
+ if (ugi.getShortUserName == status.getOwner) {
+ if (perm.getUserAction.implies(mode)) {
+ return true
+ }
+ } else if (ugi.getGroupNames.contains(status.getGroup)) {
+ if (perm.getGroupAction.implies(mode)) {
+ return true
+ }
+ } else if (perm.getOtherAction.implies(mode)) {
+ return true
+ }
+
+ logDebug(s"Permission denied: user=${ugi.getShortUserName}, " +
+ s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" +
+ s"${if (status.isDirectory) "d" else "-"}$perm")
+ false
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 77005aa9040b..c60a2a1706d5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -20,6 +20,7 @@ package org.apache.spark.deploy
import java.io.{File, IOException}
import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
+import java.nio.file.Files
import java.security.PrivilegedExceptionAction
import java.text.ParseException
@@ -28,7 +29,8 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.Properties
import org.apache.commons.lang3.StringUtils
-import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.{Configuration => HadoopConfiguration}
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.security.UserGroupInformation
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
@@ -308,6 +310,15 @@ object SparkSubmit extends CommandLineUtils {
RPackageUtils.checkAndBuildRPackage(args.jars, printStream, args.verbose)
}
+ // In client mode, download remote files.
+ if (deployMode == CLIENT) {
+ val hadoopConf = new HadoopConfiguration()
+ args.primaryResource = Option(args.primaryResource).map(downloadFile(_, hadoopConf)).orNull
+ args.jars = Option(args.jars).map(downloadFileList(_, hadoopConf)).orNull
+ args.pyFiles = Option(args.pyFiles).map(downloadFileList(_, hadoopConf)).orNull
+ args.files = Option(args.files).map(downloadFileList(_, hadoopConf)).orNull
+ }
+
// Require all python files to be local, so we can add them to the PYTHONPATH
// In YARN cluster mode, python files are distributed as regular files, which can be non-local.
// In Mesos cluster mode, non-local python files are automatically downloaded by Mesos.
@@ -825,6 +836,41 @@ object SparkSubmit extends CommandLineUtils {
.mkString(",")
if (merged == "") null else merged
}
+
+ /**
+ * Download a list of remote files to temp local files. If the file is local, the original file
+ * will be returned.
+ * @param fileList A comma separated file list.
+ * @return A comma separated local files list.
+ */
+ private[deploy] def downloadFileList(
+ fileList: String,
+ hadoopConf: HadoopConfiguration): String = {
+ require(fileList != null, "fileList cannot be null.")
+ fileList.split(",").map(downloadFile(_, hadoopConf)).mkString(",")
+ }
+
+ /**
+ * Download a file from the remote to a local temporary directory. If the input path points to
+ * a local path, returns it with no operation.
+ */
+ private[deploy] def downloadFile(path: String, hadoopConf: HadoopConfiguration): String = {
+ require(path != null, "path cannot be null.")
+ val uri = Utils.resolveURI(path)
+ uri.getScheme match {
+ case "file" | "local" =>
+ path
+
+ case _ =>
+ val fs = FileSystem.get(uri, hadoopConf)
+ val tmpFile = new File(Files.createTempDirectory("tmp").toFile, uri.getPath)
+ // scalastyle:off println
+ printStream.println(s"Downloading ${uri.toString} to ${tmpFile.getAbsolutePath}.")
+ // scalastyle:on println
+ fs.copyToLocalFile(new Path(uri), new Path(tmpFile.getAbsolutePath))
+ Utils.resolveURI(tmpFile.getAbsolutePath).toString
+ }
+ }
}
/** Provides utility functions to be used inside SparkSubmit. */
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index d7d82800b8b5..6d8758a3d3b1 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider {
* @return Count of application event logs that are currently under process
*/
def getEventLogsUnderProcess(): Int = {
- return 0;
+ 0
}
/**
@@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider {
* @return 0 if this is undefined or unsupported, otherwise the last updated time in millis
*/
def getLastUpdatedTime(): Long = {
- return 0;
+ 0
}
/**
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 9012736bc274..f4235df24512 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -27,7 +27,8 @@ import scala.xml.Node
import com.google.common.io.ByteStreams
import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder}
-import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
+import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.permission.FsAction
import org.apache.hadoop.hdfs.DistributedFileSystem
import org.apache.hadoop.hdfs.protocol.HdfsConstants
import org.apache.hadoop.security.AccessControlException
@@ -318,21 +319,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
// scan for modified applications, replay and merge them
val logInfos: Seq[FileStatus] = statusList
.filter { entry =>
- try {
- val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L)
- !entry.isDirectory() &&
- // FsHistoryProvider generates a hidden file which can't be read. Accidentally
- // reading a garbage file is safe, but we would log an error which can be scary to
- // the end-user.
- !entry.getPath().getName().startsWith(".") &&
- prevFileSize < entry.getLen()
- } catch {
- case e: AccessControlException =>
- // Do not use "logInfo" since these messages can get pretty noisy if printed on
- // every poll.
- logDebug(s"No permission to read $entry, ignoring.")
- false
- }
+ val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L)
+ !entry.isDirectory() &&
+ // FsHistoryProvider generates a hidden file which can't be read. Accidentally
+ // reading a garbage file is safe, but we would log an error which can be scary to
+ // the end-user.
+ !entry.getPath().getName().startsWith(".") &&
+ prevFileSize < entry.getLen() &&
+ SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ)
}
.flatMap { entry => Some(entry) }
.sortWith { case (entry1, entry2) =>
@@ -445,7 +439,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock)
/**
* Replay the log files in the list and merge the list of old applications with new ones
*/
- private def mergeApplicationListing(fileStatus: FileStatus): Unit = {
+ protected def mergeApplicationListing(fileStatus: FileStatus): Unit = {
val newAttempts = try {
val eventsFilter: ReplayEventsFilter = { eventString =>
eventString.startsWith(APPL_START_EVENT_PREFIX) ||
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index 0e7a6c24d4fa..af1471763340 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
def render(request: HttpServletRequest): Seq[Node] = {
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
val requestedIncomplete =
- Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
+ Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean
val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete)
val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess()
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 54f39f7620e5..d9c8fda99ef9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -301,6 +301,14 @@ object HistoryServer extends Logging {
logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}")
config.set(SecurityManager.SPARK_AUTH_CONF, "false")
}
+
+ if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) {
+ logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " +
+ "only using spark.history.ui.acl.enable")
+ config.set("spark.acls.enable", "false")
+ config.set("spark.ui.acls.enable", "false")
+ }
+
new SecurityManager(config)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 816bf37e39fe..96b53c624232 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -231,7 +231,8 @@ private[deploy] class Master(
logError("Leadership has been revoked -- master shutting down.")
System.exit(0)
- case RegisterWorker(id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) =>
+ case RegisterWorker(
+ id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl, masterAddress) =>
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
@@ -243,7 +244,7 @@ private[deploy] class Master(
workerRef, workerWebUiUrl)
if (registerWorker(worker)) {
persistenceEngine.addWorker(worker)
- workerRef.send(RegisteredWorker(self, masterWebUiUrl))
+ workerRef.send(RegisteredWorker(self, masterWebUiUrl, masterAddress))
schedule()
} else {
val workerAddress = worker.endpoint.address
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 946a92882141..94ff81c1a68e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
- val appId = request.getParameter("appId")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val appId = UIUtils.stripXSS(request.getParameter("appId"))
val state = master.askSync[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId)
.getOrElse(state.completedApps.find(_.id == appId).orNull)
@@ -83,7 +84,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")
Executor Memory:
{Utils.megabytesToString(app.desc.memoryPerExecutorMB)}
- Submit Date: {app.submitDate}
+ Submit Date: {UIUtils.formatDate(app.submitDate)}
State: {app.state}
{
if (!app.isFinished) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index e722a24d4a89..ce71300e9097 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
- val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
- val id = Option(request.getParameter("id"))
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val killFlag =
+ Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
+ val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
@@ -252,7 +254,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
| {driver.id} {killLink} |
- {driver.submitDate} |
+ {UIUtils.formatDate(driver.submitDate)} |
{driver.worker.map(w =>
if (w.isAlive()) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index e878c10183f6..58a181128eb4 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -57,7 +57,8 @@ private[deploy] class DriverRunner(
@volatile private[worker] var finalException: Option[Exception] = None
// Timeout to wait for when trying to terminate a driver.
- private val DRIVER_TERMINATE_TIMEOUT_MS = 10 * 1000
+ private val DRIVER_TERMINATE_TIMEOUT_MS =
+ conf.getTimeAsMs("spark.worker.driverTerminateTimeout", "10s")
// Decoupled for testing
def setClock(_clock: Clock): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 00b9d1af373d..ca9243e39c0a 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -99,6 +99,20 @@ private[deploy] class Worker(
private val testing: Boolean = sys.props.contains("spark.testing")
private var master: Option[RpcEndpointRef] = None
+
+ /**
+ * Whether to use the master address in `masterRpcAddresses` if possible. If it's disabled, Worker
+ * will just use the address received from Master.
+ */
+ private val preferConfiguredMasterAddress =
+ conf.getBoolean("spark.worker.preferConfiguredMasterAddress", false)
+ /**
+ * The master address to connect in case of failure. When the connection is broken, worker will
+ * use this address to connect. This is usually just one of `masterRpcAddresses`. However, when
+ * a master is restarted or takes over leadership, it will be an address sent from master, which
+ * may not be in `masterRpcAddresses`.
+ */
+ private var masterAddressToConnect: Option[RpcAddress] = None
private var activeMasterUrl: String = ""
private[worker] var activeMasterWebUiUrl : String = ""
private var workerWebUiUrl: String = ""
@@ -196,10 +210,19 @@ private[deploy] class Worker(
metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}
- private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) {
+ /**
+ * Change to use the new master.
+ *
+ * @param masterRef the new master ref
+ * @param uiUrl the new master Web UI address
+ * @param masterAddress the new master address which the worker should use to connect in case of
+ * failure
+ */
+ private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String, masterAddress: RpcAddress) {
// activeMasterUrl it's a valid Spark url since we receive it from master.
activeMasterUrl = masterRef.address.toSparkURL
activeMasterWebUiUrl = uiUrl
+ masterAddressToConnect = Some(masterAddress)
master = Some(masterRef)
connected = true
if (conf.getBoolean("spark.ui.reverseProxy", false)) {
@@ -266,7 +289,8 @@ private[deploy] class Worker(
if (registerMasterFutures != null) {
registerMasterFutures.foreach(_.cancel(true))
}
- val masterAddress = masterRef.address
+ val masterAddress =
+ if (preferConfiguredMasterAddress) masterAddressToConnect.get else masterRef.address
registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable {
override def run(): Unit = {
try {
@@ -342,15 +366,27 @@ private[deploy] class Worker(
}
private def sendRegisterMessageToMaster(masterEndpoint: RpcEndpointRef): Unit = {
- masterEndpoint.send(RegisterWorker(workerId, host, port, self, cores, memory, workerWebUiUrl))
+ masterEndpoint.send(RegisterWorker(
+ workerId,
+ host,
+ port,
+ self,
+ cores,
+ memory,
+ workerWebUiUrl,
+ masterEndpoint.address))
}
private def handleRegisterResponse(msg: RegisterWorkerResponse): Unit = synchronized {
msg match {
- case RegisteredWorker(masterRef, masterWebUiUrl) =>
- logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
+ case RegisteredWorker(masterRef, masterWebUiUrl, masterAddress) =>
+ if (preferConfiguredMasterAddress) {
+ logInfo("Successfully registered with master " + masterAddress.toSparkURL)
+ } else {
+ logInfo("Successfully registered with master " + masterRef.address.toSparkURL)
+ }
registered = true
- changeMaster(masterRef, masterWebUiUrl)
+ changeMaster(masterRef, masterWebUiUrl, masterAddress)
forwordMessageScheduler.scheduleAtFixedRate(new Runnable {
override def run(): Unit = Utils.tryLogNonFatalError {
self.send(SendHeartbeat)
@@ -419,7 +455,7 @@ private[deploy] class Worker(
case MasterChanged(masterRef, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL)
- changeMaster(masterRef, masterWebUiUrl)
+ changeMaster(masterRef, masterWebUiUrl, masterRef.address)
val execs = executors.values.
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
@@ -561,7 +597,8 @@ private[deploy] class Worker(
}
override def onDisconnected(remoteAddress: RpcAddress): Unit = {
- if (master.exists(_.address == remoteAddress)) {
+ if (master.exists(_.address == remoteAddress) ||
+ masterAddressToConnect.exists(_ == remoteAddress)) {
logInfo(s"$remoteAddress Disassociated !")
masterDisconnected()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 80dc9bf8779d..2f5a5642d3ca 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -33,13 +33,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
private val supportedLogTypes = Set("stderr", "stdout")
private val defaultBytes = 100 * 1024
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
- val appId = Option(request.getParameter("appId"))
- val executorId = Option(request.getParameter("executorId"))
- val driverId = Option(request.getParameter("driverId"))
- val logType = request.getParameter("logType")
- val offset = Option(request.getParameter("offset")).map(_.toLong)
- val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
+ val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
+ val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
+ val logType = UIUtils.stripXSS(request.getParameter("logType"))
+ val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
+ val byteLength =
+ Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ .getOrElse(defaultBytes)
val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
@@ -55,13 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
- val appId = Option(request.getParameter("appId"))
- val executorId = Option(request.getParameter("executorId"))
- val driverId = Option(request.getParameter("driverId"))
- val logType = request.getParameter("logType")
- val offset = Option(request.getParameter("offset")).map(_.toLong)
- val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
+ val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
+ val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
+ val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
+ val logType = UIUtils.stripXSS(request.getParameter("logType"))
+ val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
+ val byteLength =
+ Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
+ .getOrElse(defaultBytes)
val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 83469c5ff060..d54dd2d46482 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory
import java.net.{URI, URL}
import java.nio.ByteBuffer
import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+import java.util.concurrent._
import javax.annotation.concurrent.GuardedBy
import scala.collection.JavaConverters._
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.control.NonFatal
+import com.google.common.util.concurrent.ThreadFactoryBuilder
+
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
@@ -84,7 +86,20 @@ private[spark] class Executor(
}
// Start worker thread pool
- private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker")
+ private val threadPool = {
+ val threadFactory = new ThreadFactoryBuilder()
+ .setDaemon(true)
+ .setNameFormat("Executor task launch worker-%d")
+ .setThreadFactory(new ThreadFactory {
+ override def newThread(r: Runnable): Thread =
+ // Use UninterruptibleThread to run tasks so that we can allow running codes without being
+ // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622,
+ // will hang forever if some methods are interrupted.
+ new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder
+ })
+ .build()
+ Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor]
+ }
private val executorSource = new ExecutorSource(threadPool, executorId)
// Pool used for threads that supervise task killing / cancellation
private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper")
@@ -410,6 +425,7 @@ private[spark] class Executor(
}
}
+ setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
@@ -432,7 +448,8 @@ private[spark] class Executor(
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))
- case NonFatal(_) if task != null && task.reasonIfKilled.isDefined =>
+ case _: InterruptedException | NonFatal(_) if
+ task != null && task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
setTaskFinishedAndClearInterruptStatus()
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index dfd2f818acda..a3ce3d1ccc5e 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable {
private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums
- /**
- * Looks for a registered accumulator by accumulator name.
- */
- private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
- accumulators.find { acc =>
- acc.name.isDefined && acc.name.get == name
- }
+ private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = {
+ // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
+ // value will be updated at driver side.
+ internalAccums.filter(a => !a.isZero || a == _resultSize)
}
}
@@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging {
*/
def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = {
val tm = new TaskMetrics
- val (internalAccums, externalAccums) =
- accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get))
-
- internalAccums.foreach { acc =>
- val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]]
- tmAcc.metadata = acc.metadata
- tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
+ for (acc <- accums) {
+ val name = acc.name
+ if (name.isDefined && tm.nameToAccums.contains(name.get)) {
+ val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]]
+ tmAcc.metadata = acc.metadata
+ tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]])
+ } else {
+ tm.externalAccums += acc
+ }
}
-
- tm.externalAccums ++= externalAccums
tm
}
}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 89aeea493908..f8139b706a7c 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -244,8 +244,8 @@ package object config {
ConfigBuilder("spark.redaction.regex")
.doc("Regex to decide which Spark configuration properties and environment variables in " +
"driver and executor environments contain sensitive information. When this regex matches " +
- "a property, its value is redacted from the environment UI and various logs like YARN " +
- "and event logs.")
+ "a property key or value, the value is redacted from the environment UI and various logs " +
+ "like YARN and event logs.")
.regexConf
.createWithDefault("(?i)secret|password".r)
@@ -272,4 +272,25 @@ package object config {
.booleanConf
.createWithDefault(false)
+ private[spark] val CHECKPOINT_COMPRESS =
+ ConfigBuilder("spark.checkpoint.compress")
+ .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " +
+ "spark.io.compression.codec.")
+ .booleanConf
+ .createWithDefault(false)
+
+ private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD =
+ ConfigBuilder("spark.shuffle.accurateBlockThreshold")
+ .doc("When we compress the size of shuffle blocks in HighlyCompressedMapStatus, we will " +
+ "record the size accurately if it's above this config. This helps to prevent OOM by " +
+ "avoiding underestimating shuffle block size when fetch shuffle blocks.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefault(100 * 1024 * 1024)
+
+ private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM =
+ ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem")
+ .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " +
+ "above this threshold. This is to avoid a giant request takes too much memory.")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("200m")
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index 8f83668d7902..b3f8bfe8b1d4 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -46,5 +46,5 @@ trait BlockDataManager {
/**
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
*/
- def releaseLock(blockId: BlockId): Unit
+ def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index cb9d389dd7ea..6860214c7fe3 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -17,7 +17,7 @@
package org.apache.spark.network
-import java.io.Closeable
+import java.io.{Closeable, File}
import java.nio.ByteBuffer
import scala.concurrent.{Future, Promise}
@@ -67,7 +67,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
port: Int,
execId: String,
blockIds: Array[String],
- listener: BlockFetchingListener): Unit
+ listener: BlockFetchingListener,
+ shuffleFiles: Array[File]): Unit
/**
* Upload a single block to a remote node, available only after [[init]] is invoked.
@@ -100,7 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
ret.flip()
result.success(new NioManagedBuffer(ret))
}
- })
+ }, shuffleFiles = null)
ThreadUtils.awaitResult(result.future, Duration.Inf)
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 2ed8a00df702..305fd9a6de10 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -56,11 +56,12 @@ class NettyBlockRpcServer(
message match {
case openBlocks: OpenBlocks =>
- val blocks: Seq[ManagedBuffer] =
- openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
+ val blocksNum = openBlocks.blockIds.length
+ val blocks = for (i <- (0 until blocksNum).view)
+ yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i)))
val streamId = streamManager.registerStream(appId, blocks.iterator.asJava)
- logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
- responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer)
+ logTrace(s"Registered streamId $streamId with $blocksNum buffers")
+ responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer)
case uploadBlock: UploadBlock =>
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index b75e91b66096..b13a9c681e54 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,6 +17,7 @@
package org.apache.spark.network.netty
+import java.io.File
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
@@ -88,13 +89,15 @@ private[spark] class NettyBlockTransferService(
port: Int,
execId: String,
blockIds: Array[String],
- listener: BlockFetchingListener): Unit = {
+ listener: BlockFetchingListener,
+ shuffleFiles: Array[File]): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
val client = clientFactory.createClient(host, port)
- new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
+ new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener,
+ transportConf, shuffleFiles).start()
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 4bf8ecc38354..76ea8b86c53d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -251,7 +251,13 @@ class HadoopRDD[K, V](
null
}
// Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener{ context => closeIfNeeded() }
+ context.addTaskCompletionListener { context =>
+ // Update the bytes read before closing is to make sure lingering bytesRead statistics in
+ // this thread get correctly added.
+ updateBytesRead()
+ closeIfNeeded()
+ }
+
private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey()
private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue()
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index ce3a9a2a1e2a..482875e6c1ac 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -191,7 +191,13 @@ class NewHadoopRDD[K, V](
}
// Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener(context => close())
+ context.addTaskCompletionListener { context =>
+ // Update the bytesRead before closing is to make sure lingering bytesRead statistics in
+ // this thread get correctly added.
+ updateBytesRead()
+ close()
+ }
+
private var havePair = false
private var recordsSinceMetricsUpdate = 0
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index e524675332d1..63a87e7f09d8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
-import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}
@@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag](
val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
- queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
+ queue ++= collectionUtils.takeOrdered(items, num)(ord)
Iterator.single(queue)
}
if (mapRDDs.partitions.length == 0) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
index e0a29b48314f..37c67cee55f9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala
@@ -18,6 +18,7 @@
package org.apache.spark.rdd
import java.io.{FileNotFoundException, IOException}
+import java.util.concurrent.TimeUnit
import scala.reflect.ClassTag
import scala.util.control.NonFatal
@@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.CHECKPOINT_COMPRESS
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
@@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
originalRDD: RDD[T],
checkpointDir: String,
blockSize: Int = -1): ReliableCheckpointRDD[T] = {
+ val checkpointStartTimeNs = System.nanoTime()
val sc = originalRDD.sparkContext
@@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging {
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
}
+ val checkpointDurationMs =
+ TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs)
+ logInfo(s"Checkpointing took $checkpointDurationMs ms.")
+
val newRDD = new ReliableCheckpointRDD[T](
sc, checkpointDirPath.toString, originalRDD.partitioner)
if (newRDD.partitions.length != originalRDD.partitions.length) {
@@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileOutputStream = if (blockSize < 0) {
- fs.create(tempOutputPath, false, bufferSize)
+ val fileStream = fs.create(tempOutputPath, false, bufferSize)
+ if (env.conf.get(CHECKPOINT_COMPRESS)) {
+ CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream)
+ } else {
+ fileStream
+ }
} else {
// This is mainly for testing purpose
fs.create(tempOutputPath, false, bufferSize,
@@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val env = SparkEnv.get
val fs = path.getFileSystem(broadcastedConf.value.value)
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
- val fileInputStream = fs.open(path, bufferSize)
+ val fileInputStream = {
+ val fileStream = fs.open(path, bufferSize)
+ if (env.conf.get(CHECKPOINT_COMPRESS)) {
+ CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream)
+ } else {
+ fileStream
+ }
+ }
val serializer = env.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
similarity index 97%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
index 145dc22b7428..ab72addb2466 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala
@@ -15,11 +15,12 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.rdd.util
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.PeriodicCheckpointer
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index aab177f257a8..35f6b365eca8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -618,12 +618,7 @@ class DAGScheduler(
properties: Properties): Unit = {
val start = System.nanoTime
val waiter = submitJob(rdd, func, partitions, callSite, resultHandler, properties)
- // Note: Do not call Await.ready(future) because that calls `scala.concurrent.blocking`,
- // which causes concurrent SQL executions to fail if a fork-join pool is used. Note that
- // due to idiosyncrasies in Scala, `awaitPermission` is not actually used anywhere so it's
- // safe to pass in null here. For more detail, see SPARK-13747.
- val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
- waiter.completionFuture.ready(Duration.Inf)(awaitPermission)
+ ThreadUtils.awaitReady(waiter.completionFuture, Duration.Inf)
waiter.completionFuture.value.get match {
case scala.util.Success(_) =>
logInfo("Job %d finished: %s, took %f s".format
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index aecb3a980e7c..a7dbf87915b2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -252,11 +252,17 @@ private[spark] class EventLoggingListener(
private[spark] def redactEvent(
event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = {
- // "Spark Properties" entry will always exist because the map is always populated with it.
- val redactedProps = Utils.redact(sparkConf, event.environmentDetails("Spark Properties"))
- val redactedEnvironmentDetails = event.environmentDetails +
- ("Spark Properties" -> redactedProps)
- SparkListenerEnvironmentUpdate(redactedEnvironmentDetails)
+ // environmentDetails maps a string descriptor to a set of properties
+ // Similar to:
+ // "JVM Information" -> jvmInformation,
+ // "Spark Properties" -> sparkProperties,
+ // ...
+ // where jvmInformation, sparkProperties, etc. are sequence of tuples.
+ // We go through the various of properties and redact sensitive information from them.
+ val redactedProps = event.environmentDetails.map{ case (name, props) =>
+ name -> Utils.redact(sparkConf, props)
+ }
+ SparkListenerEnvironmentUpdate(redactedProps)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index b2e9a97129f0..048e0d018659 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,8 +19,13 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.roaringbitmap.RoaringBitmap
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.config
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
@@ -121,34 +126,41 @@ private[spark] class CompressedMapStatus(
}
/**
- * A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
+ * A [[MapStatus]] implementation that stores the accurate size of huge blocks, which are larger
+ * than spark.shuffle.accurateBlockThreshold. It stores the average size of other non-empty blocks,
* plus a bitmap for tracking which blocks are empty.
*
* @param loc location where the task is being executed
* @param numNonEmptyBlocks the number of non-empty blocks
* @param emptyBlocks a bitmap tracking which blocks are empty
- * @param avgSize average size of the non-empty blocks
+ * @param avgSize average size of the non-empty and non-huge blocks
+ * @param hugeBlockSizes sizes of huge blocks by their reduceId.
*/
private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
private[this] var numNonEmptyBlocks: Int,
private[this] var emptyBlocks: RoaringBitmap,
- private[this] var avgSize: Long)
+ private[this] var avgSize: Long,
+ @transient private var hugeBlockSizes: Map[Int, Byte])
extends MapStatus with Externalizable {
// loc could be null when the default constructor is called during deserialization
- require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0,
+ require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0,
"Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, -1, null, -1) // For deserialization only
+ protected def this() = this(null, -1, null, -1, null) // For deserialization only
override def location: BlockManagerId = loc
override def getSizeForBlock(reduceId: Int): Long = {
+ assert(hugeBlockSizes != null)
if (emptyBlocks.contains(reduceId)) {
0
} else {
- avgSize
+ hugeBlockSizes.get(reduceId) match {
+ case Some(size) => MapStatus.decompressSize(size)
+ case None => avgSize
+ }
}
}
@@ -156,6 +168,11 @@ private[spark] class HighlyCompressedMapStatus private (
loc.writeExternal(out)
emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
+ out.writeInt(hugeBlockSizes.size)
+ hugeBlockSizes.foreach { kv =>
+ out.writeInt(kv._1)
+ out.writeByte(kv._2)
+ }
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
@@ -163,6 +180,14 @@ private[spark] class HighlyCompressedMapStatus private (
emptyBlocks = new RoaringBitmap()
emptyBlocks.readExternal(in)
avgSize = in.readLong()
+ val count = in.readInt()
+ val hugeBlockSizesArray = mutable.ArrayBuffer[Tuple2[Int, Byte]]()
+ (0 until count).foreach { _ =>
+ val block = in.readInt()
+ val size = in.readByte()
+ hugeBlockSizesArray += Tuple2(block, size)
+ }
+ hugeBlockSizes = hugeBlockSizesArray.toMap
}
}
@@ -178,11 +203,21 @@ private[spark] object HighlyCompressedMapStatus {
// we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
val emptyBlocks = new RoaringBitmap()
val totalNumBlocks = uncompressedSizes.length
+ val threshold = Option(SparkEnv.get)
+ .map(_.conf.get(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD))
+ .getOrElse(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.defaultValue.get)
+ val hugeBlockSizesArray = ArrayBuffer[Tuple2[Int, Byte]]()
while (i < totalNumBlocks) {
- var size = uncompressedSizes(i)
+ val size = uncompressedSizes(i)
if (size > 0) {
numNonEmptyBlocks += 1
- totalSize += size
+ // Huge blocks are not included in the calculation for average size, thus size for smaller
+ // blocks is more accurate.
+ if (size < threshold) {
+ totalSize += size
+ } else {
+ hugeBlockSizesArray += Tuple2(i, MapStatus.compressSize(uncompressedSizes(i)))
+ }
} else {
emptyBlocks.add(i)
}
@@ -195,6 +230,7 @@ private[spark] object HighlyCompressedMapStatus {
}
emptyBlocks.trim()
emptyBlocks.runOptimize()
- new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
+ hugeBlockSizesArray.toMap)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 7fd2918960cd..7767ef1803a0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -115,26 +115,33 @@ private[spark] abstract class Task[T](
case t: Throwable =>
e.addSuppressed(t)
}
+ context.markTaskCompleted(Some(e))
throw e
} finally {
- // Call the task completion callbacks.
- context.markTaskCompleted()
try {
- Utils.tryLogNonFatalError {
- // Release memory used by this thread for unrolling blocks
- SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
- SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.OFF_HEAP)
- // Notify any tasks waiting for execution memory to be freed to wake up and try to
- // acquire memory again. This makes impossible the scenario where a task sleeps forever
- // because there are no other tasks left to notify it. Since this is safe to do but may
- // not be strictly necessary, we should revisit whether we can remove this in the future.
- val memoryManager = SparkEnv.get.memoryManager
- memoryManager.synchronized { memoryManager.notifyAll() }
- }
+ // Call the task completion callbacks. If "markTaskCompleted" is called twice, the second
+ // one is no-op.
+ context.markTaskCompleted(None)
} finally {
- // Though we unset the ThreadLocal here, the context member variable itself is still queried
- // directly in the TaskRunner to check for FetchFailedExceptions.
- TaskContext.unset()
+ try {
+ Utils.tryLogNonFatalError {
+ // Release memory used by this thread for unrolling blocks
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP)
+ SparkEnv.get.blockManager.memoryStore.releaseUnrollMemoryForThisTask(
+ MemoryMode.OFF_HEAP)
+ // Notify any tasks waiting for execution memory to be freed to wake up and try to
+ // acquire memory again. This makes impossible the scenario where a task sleeps forever
+ // because there are no other tasks left to notify it. Since this is safe to do but may
+ // not be strictly necessary, we should revisit whether we can remove this in the
+ // future.
+ val memoryManager = SparkEnv.get.memoryManager
+ memoryManager.synchronized { memoryManager.notifyAll() }
+ }
+ } finally {
+ // Though we unset the ThreadLocal here, the context member variable itself is still
+ // queried directly in the TaskRunner to check for FetchFailedExceptions.
+ TaskContext.unset()
+ }
}
}
}
@@ -182,14 +189,11 @@ private[spark] abstract class Task[T](
*/
def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = {
if (context != null) {
- context.taskMetrics.internalAccums.filter { a =>
- // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its
- // value will be updated at driver side.
- // Note: internal accumulators representing task metrics always count failed values
- !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE)
- // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter
- // them out.
- } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
+ // Note: internal accumulators representing task metrics always count failed values
+ context.taskMetrics.nonZeroInternalAccums() ++
+ // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not
+ // filter them out.
+ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues)
} else {
Seq.empty
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 4eedaaea6119..dc82bb770472 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// `CoarseGrainedSchedulerBackend.this`.
private val executorDataMap = new HashMap[String, ExecutorData]
+ // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]]
+ @GuardedBy("CoarseGrainedSchedulerBackend.this")
+ private var requestedTotalExecutors = 0
+
// Number of executors requested from the cluster manager that have not registered yet
@GuardedBy("CoarseGrainedSchedulerBackend.this")
private var numPendingExecutors = 0
@@ -413,6 +417,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
* */
protected def reset(): Unit = {
val executors = synchronized {
+ requestedTotalExecutors = 0
numPendingExecutors = 0
executorsPendingToRemove.clear()
Set() ++ executorDataMap.keys
@@ -487,12 +492,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
val response = synchronized {
+ requestedTotalExecutors += numAdditionalExecutors
numPendingExecutors += numAdditionalExecutors
logDebug(s"Number of pending executors is now $numPendingExecutors")
+ if (requestedTotalExecutors !=
+ (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) {
+ logDebug(
+ s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match:
+ |requestedTotalExecutors = $requestedTotalExecutors
+ |numExistingExecutors = $numExistingExecutors
+ |numPendingExecutors = $numPendingExecutors
+ |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin)
+ }
// Account for executors pending to be added or removed
- doRequestTotalExecutors(
- numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)
+ doRequestTotalExecutors(requestedTotalExecutors)
}
defaultAskTimeout.awaitResult(response)
@@ -524,6 +538,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
}
val response = synchronized {
+ this.requestedTotalExecutors = numExecutors
this.localityAwareTasks = localityAwareTasks
this.hostToLocalTaskCount = hostToLocalTaskCount
@@ -589,8 +604,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
// take into account executors that are pending to be added or removed.
val adjustTotalExecutors =
if (!replace) {
- doRequestTotalExecutors(
- numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)
+ requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0)
+ if (requestedTotalExecutors !=
+ (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) {
+ logDebug(
+ s"""killExecutors($executorIds, $replace, $force): Executor counts do not match:
+ |requestedTotalExecutors = $requestedTotalExecutors
+ |numExistingExecutors = $numExistingExecutors
+ |numPendingExecutors = $numPendingExecutors
+ |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin)
+ }
+ doRequestTotalExecutors(requestedTotalExecutors)
} else {
numPendingExecutors += knownExecutors.size
Future.successful(true)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index ba3e0e395e95..2fbac79a2305 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -18,7 +18,7 @@
package org.apache.spark.shuffle
import org.apache.spark._
-import org.apache.spark.internal.Logging
+import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer.SerializerManager
import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
import org.apache.spark.util.CompletionIterator
@@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
val serializerInstance = dep.serializer.newInstance()
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
index 00f918c09c66..f17b63775482 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
@@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext {
@Path("applications/{appId}/logs")
def getEventLogs(
@PathParam("appId") appId: String): EventLogDownloadResource = {
- new EventLogDownloadResource(uiRoot, appId, None)
+ try {
+ // withSparkUI will throw NotFoundException if attemptId exists for this application.
+ // So we need to try again with attempt id "1".
+ withSparkUI(appId, None) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, None)
+ }
+ } catch {
+ case _: NotFoundException =>
+ withSparkUI(appId, Some("1")) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, None)
+ }
+ }
}
@Path("applications/{appId}/{attemptId}/logs")
def getEventLogs(
@PathParam("appId") appId: String,
@PathParam("attemptId") attemptId: String): EventLogDownloadResource = {
- new EventLogDownloadResource(uiRoot, appId, Some(attemptId))
+ withSparkUI(appId, Some(attemptId)) { _ =>
+ new EventLogDownloadResource(uiRoot, appId, Some(attemptId))
+ }
}
@Path("version")
@@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext {
case None => throw new NotFoundException("no such app: " + appId)
}
}
-
}
private[v1] class ForbiddenException(msg: String) extends WebApplicationException(
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
index d159b9450ef5..56d8e51732ff 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala
@@ -76,10 +76,13 @@ class ExecutorSummary private[spark](
val isBlacklisted: Boolean,
val maxMemory: Long,
val executorLogs: Map[String, String],
- val onHeapMemoryUsed: Option[Long],
- val offHeapMemoryUsed: Option[Long],
- val maxOnHeapMemory: Option[Long],
- val maxOffHeapMemory: Option[Long])
+ val memoryMetrics: Option[MemoryMetrics])
+
+class MemoryMetrics private[spark](
+ val usedOnHeapStorageMemory: Long,
+ val usedOffHeapStorageMemory: Long,
+ val totalOnHeapStorageMemory: Long,
+ val totalOffHeapStorageMemory: Long)
class JobData private[spark](
val jobId: Int,
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
index 3db59837fbeb..7064872ec1c7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
@@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging {
/**
* Release a lock on the given block.
+ * In case a TaskContext is not propagated properly to all child threads for the task, we fail to
+ * get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock.
+ *
+ * See SPARK-18406 for more discussion of this issue.
*/
- def unlock(blockId: BlockId): Unit = synchronized {
- logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId")
+ def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized {
+ val taskId = taskAttemptId.getOrElse(currentTaskAttemptId)
+ logTrace(s"Task $taskId releasing lock for $blockId")
val info = get(blockId).getOrElse {
throw new IllegalStateException(s"Block $blockId not found")
}
if (info.writerTask != BlockInfo.NO_WRITER) {
info.writerTask = BlockInfo.NO_WRITER
- writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
+ writeLocksByTask.removeBinding(taskId, blockId)
} else {
assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
info.readerCount -= 1
- val countsForTask = readLocksByTask(currentTaskAttemptId)
+ val countsForTask = readLocksByTask(taskId)
val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
assert(newPinCountForTask >= 0,
- s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it")
+ s"Task $taskId release lock on block $blockId more times than it acquired it")
}
notifyAll()
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 3219969bcd06..5f067191070e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -23,14 +23,12 @@ import java.nio.channels.Channels
import scala.collection.mutable
import scala.collection.mutable.HashMap
-import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.{ExecutionContext, Future}
import scala.concurrent.duration._
import scala.reflect.ClassTag
import scala.util.Random
import scala.util.control.NonFatal
-import com.google.common.io.ByteStreams
-
import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
import org.apache.spark.internal.Logging
@@ -41,7 +39,6 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage.memory._
@@ -337,7 +334,7 @@ private[spark] class BlockManager(
val task = asyncReregisterTask
if (task != null) {
try {
- Await.ready(task, Duration.Inf)
+ ThreadUtils.awaitReady(task, Duration.Inf)
} catch {
case NonFatal(t) =>
throw new Exception("Error occurred while waiting for async. reregistration", t)
@@ -504,6 +501,7 @@ private[spark] class BlockManager(
case Some(info) =>
val level = info.level
logDebug(s"Level for block $blockId is $level")
+ val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
if (level.useMemory && memoryStore.contains(blockId)) {
val iter: Iterator[Any] = if (level.deserialized) {
memoryStore.getValues(blockId).get
@@ -511,7 +509,12 @@ private[spark] class BlockManager(
serializerManager.dataDeserializeStream(
blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
}
- val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
+ // We need to capture the current taskId in case the iterator completion is triggered
+ // from a different thread which does not have TaskContext set; see SPARK-18406 for
+ // discussion.
+ val ci = CompletionIterator[Any, Iterator[Any]](iter, {
+ releaseLock(blockId, taskAttemptId)
+ })
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
} else if (level.useDisk && diskStore.contains(blockId)) {
val diskData = diskStore.getBytes(blockId)
@@ -528,8 +531,9 @@ private[spark] class BlockManager(
serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
}
}
- val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn,
- releaseLockAndDispose(blockId, diskData))
+ val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
+ releaseLockAndDispose(blockId, diskData, taskAttemptId)
+ })
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
} else {
handleLocalReadFailure(blockId)
@@ -707,10 +711,13 @@ private[spark] class BlockManager(
}
/**
- * Release a lock on the given block.
+ * Release a lock on the given block with explicit TID.
+ * The param `taskAttemptId` should be passed in case we can't get the correct TID from
+ * TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child
+ * thread.
*/
- def releaseLock(blockId: BlockId): Unit = {
- blockInfoManager.unlock(blockId)
+ def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = {
+ blockInfoManager.unlock(blockId, taskAttemptId)
}
/**
@@ -912,7 +919,7 @@ private[spark] class BlockManager(
if (level.replication > 1) {
// Wait for asynchronous replication to finish
try {
- Await.ready(replicationFuture, Duration.Inf)
+ ThreadUtils.awaitReady(replicationFuture, Duration.Inf)
} catch {
case NonFatal(t) =>
throw new Exception("Error occurred while waiting for replication to finish", t)
@@ -1463,8 +1470,11 @@ private[spark] class BlockManager(
}
}
- def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = {
- blockInfoManager.unlock(blockId)
+ def releaseLockAndDispose(
+ blockId: BlockId,
+ data: BlockData,
+ taskAttemptId: Option[Long] = None): Unit = {
+ releaseLock(blockId, taskAttemptId)
data.dispose()
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 467c3e0e6b51..6f85b9e4d6c7 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -497,11 +497,17 @@ private[spark] class BlockManagerInfo(
updateLastSeenMs()
- if (_blocks.containsKey(blockId)) {
+ val blockExists = _blocks.containsKey(blockId)
+ var originalMemSize: Long = 0
+ var originalDiskSize: Long = 0
+ var originalLevel: StorageLevel = StorageLevel.NONE
+
+ if (blockExists) {
// The block exists on the slave already.
val blockStatus: BlockStatus = _blocks.get(blockId)
- val originalLevel: StorageLevel = blockStatus.storageLevel
- val originalMemSize: Long = blockStatus.memSize
+ originalLevel = blockStatus.storageLevel
+ originalMemSize = blockStatus.memSize
+ originalDiskSize = blockStatus.diskSize
if (originalLevel.useMemory) {
_remainingMem += originalMemSize
@@ -520,32 +526,44 @@ private[spark] class BlockManagerInfo(
blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0)
_blocks.put(blockId, blockStatus)
_remainingMem -= memSize
- logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(memSize),
- Utils.bytesToString(_remainingMem)))
+ if (blockExists) {
+ logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" +
+ s" (current size: ${Utils.bytesToString(memSize)}," +
+ s" original size: ${Utils.bytesToString(originalMemSize)}," +
+ s" free: ${Utils.bytesToString(_remainingMem)})")
+ } else {
+ logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" +
+ s" (size: ${Utils.bytesToString(memSize)}," +
+ s" free: ${Utils.bytesToString(_remainingMem)})")
+ }
}
if (storageLevel.useDisk) {
blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize)
_blocks.put(blockId, blockStatus)
- logInfo("Added %s on disk on %s (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize)))
+ if (blockExists) {
+ logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" +
+ s" (current size: ${Utils.bytesToString(diskSize)}," +
+ s" original size: ${Utils.bytesToString(originalDiskSize)})")
+ } else {
+ logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" +
+ s" (size: ${Utils.bytesToString(diskSize)})")
+ }
}
if (!blockId.isBroadcast && blockStatus.isCached) {
_cachedBlocks += blockId
}
- } else if (_blocks.containsKey(blockId)) {
+ } else if (blockExists) {
// If isValid is not true, drop the block.
- val blockStatus: BlockStatus = _blocks.get(blockId)
_blocks.remove(blockId)
_cachedBlocks -= blockId
- if (blockStatus.storageLevel.useMemory) {
- logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize),
- Utils.bytesToString(_remainingMem)))
+ if (originalLevel.useMemory) {
+ logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" +
+ s" (size: ${Utils.bytesToString(originalMemSize)}," +
+ s" free: ${Utils.bytesToString(_remainingMem)})")
}
- if (blockStatus.storageLevel.useDisk) {
- logInfo("Removed %s on %s on disk (size: %s)".format(
- blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize)))
+ if (originalLevel.useDisk) {
+ logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" +
+ s" (size: ${Utils.bytesToString(originalDiskSize)})")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index f8906117638b..bded3a1e4eb5 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{InputStream, IOException}
+import java.io.{File, InputStream, IOException}
import java.nio.ByteBuffer
import java.util.concurrent.LinkedBlockingQueue
import javax.annotation.concurrent.GuardedBy
@@ -52,6 +52,7 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
*/
private[spark]
@@ -63,6 +64,7 @@ final class ShuffleBlockFetcherIterator(
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
+ maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
extends Iterator[(BlockId, InputStream)] with Logging {
@@ -129,6 +131,12 @@ final class ShuffleBlockFetcherIterator(
@GuardedBy("this")
private[this] var isZombie = false
+ /**
+ * A set to store the files used for shuffling remote huge blocks. Files in this set will be
+ * deleted when cleanup. This is a layer of defensiveness against disk file leaks.
+ */
+ val shuffleFilesSet = mutable.HashSet[File]()
+
initialize()
// Decrements the buffer reference count.
@@ -163,6 +171,11 @@ final class ShuffleBlockFetcherIterator(
case _ =>
}
}
+ shuffleFilesSet.foreach { file =>
+ if (!file.delete()) {
+ logInfo("Failed to cleanup shuffle fetch temp file " + file.getAbsolutePath());
+ }
+ }
}
private[this] def sendRequest(req: FetchRequest) {
@@ -175,33 +188,46 @@ final class ShuffleBlockFetcherIterator(
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
val blockIds = req.blocks.map(_._1.toString)
-
val address = req.address
- shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
- new BlockFetchingListener {
- override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
- // Only add the buffer to results queue if the iterator is not zombie,
- // i.e. cleanup() has not been called yet.
- ShuffleBlockFetcherIterator.this.synchronized {
- if (!isZombie) {
- // Increment the ref count because we need to pass this to a different thread.
- // This needs to be released after use.
- buf.retain()
- remainingBlocks -= blockId
- results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
- remainingBlocks.isEmpty))
- logDebug("remainingBlocks: " + remainingBlocks)
- }
+
+ val blockFetchingListener = new BlockFetchingListener {
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ ShuffleBlockFetcherIterator.this.synchronized {
+ if (!isZombie) {
+ // Increment the ref count because we need to pass this to a different thread.
+ // This needs to be released after use.
+ buf.retain()
+ remainingBlocks -= blockId
+ results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
+ remainingBlocks.isEmpty))
+ logDebug("remainingBlocks: " + remainingBlocks)
}
- logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
+ logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ }
- override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
- logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- results.put(new FailureFetchResult(BlockId(blockId), address, e))
- }
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
+ logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
+ results.put(new FailureFetchResult(BlockId(blockId), address, e))
}
- )
+ }
+
+ // Fetch remote shuffle blocks to disk when the request is too large. Since the shuffle data is
+ // already encrypted and compressed over the wire(w.r.t. the related configs), we can just fetch
+ // the data and write it to file directly.
+ if (req.size > maxReqSizeShuffleToMem) {
+ val shuffleFiles = blockIds.map { _ =>
+ blockManager.diskBlockManager.createTempLocalBlock()._2
+ }.toArray
+ shuffleFilesSet ++= shuffleFiles
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+ blockFetchingListener, shuffleFiles)
+ } else {
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
+ blockFetchingListener, null)
+ }
}
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index 1b30d4fa93bc..ac60f795915a 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -30,6 +30,7 @@ import org.apache.spark.scheduler._
* This class is thread-safe (unlike JobProgressListener)
*/
@DeveloperApi
+@deprecated("This class will be removed in a future release.", "2.2.0")
class StorageStatusListener(conf: SparkConf) extends SparkListener {
// This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE)
private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]()
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
index 8f0d181fc8fe..e9694fdbca2d 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala
@@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging
* class cannot mutate the source of the information. Accesses are not thread-safe.
*/
@DeveloperApi
+@deprecated("This class may be removed or made private in a future release.", "2.2.0")
class StorageStatus(
val blockManagerId: BlockManagerId,
val maxMemory: Long,
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index bdbdba578085..edf328b5ae53 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -29,8 +29,8 @@ import org.eclipse.jetty.client.api.Response
import org.eclipse.jetty.proxy.ProxyServlet
import org.eclipse.jetty.server._
import org.eclipse.jetty.server.handler._
+import org.eclipse.jetty.server.handler.gzip.GzipHandler
import org.eclipse.jetty.servlet._
-import org.eclipse.jetty.servlets.gzip.GzipHandler
import org.eclipse.jetty.util.component.LifeCycle
import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler}
import org.json4s.JValue
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 7d31ac54a717..bf4cf79e9faa 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -117,7 +117,7 @@ private[spark] class SparkUI private (
endTime = new Date(-1),
duration = 0,
lastUpdated = new Date(startTime),
- sparkUser = "",
+ sparkUser = getSparkUser,
completed = false
))
))
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index e53d6907bc40..4bc7fb6185e6 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -25,6 +25,8 @@ import scala.util.control.NonFatal
import scala.xml._
import scala.xml.transform.{RewriteRule, RuleTransformer}
+import org.apache.commons.lang3.StringEscapeUtils
+
import org.apache.spark.internal.Logging
import org.apache.spark.ui.scope.RDDOperationGraph
@@ -34,6 +36,8 @@ private[spark] object UIUtils extends Logging {
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable"
+ private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r
+
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat =
@@ -446,7 +450,7 @@ private[spark] object UIUtils extends Logging {
val xml = XML.loadString(s"""$desc""")
// Verify that this has only anchors and span (we are wrapping in span)
- val allowedNodeLabels = Set("a", "span")
+ val allowedNodeLabels = Set("a", "span", "br")
val illegalNodes = xml \\ "_" filterNot { case node: Node =>
allowedNodeLabels.contains(node.label)
}
@@ -527,4 +531,21 @@ private[spark] object UIUtils extends Logging {
origHref
}
}
+
+ /**
+ * Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks
+ *
+ * For more information about XSS testing:
+ * https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
+ * https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
+ */
+ def stripXSS(requestParameter: String): String = {
+ if (requestParameter == null) {
+ null
+ } else {
+ // Remove new lines and single quotes, followed by escaping HTML version 4.0
+ StringEscapeUtils.escapeHtml4(
+ NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, ""))
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
index 70b3ffd95e60..8c18464e6477 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
@@ -32,6 +32,7 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en
* A SparkListener that prepares information to be displayed on the EnvironmentTab
*/
@DeveloperApi
+@deprecated("This class will be removed in a future release.", "2.2.0")
class EnvironmentListener extends SparkListener {
var jvmInformation = Seq[(String, String)]()
var sparkProperties = Seq[(String, String)]()
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index 6ce3f511e89c..7b211ea5199c 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -28,8 +28,10 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage
private val sc = parent.sc
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
- val executorId = Option(request.getParameter("executorId")).map { executorId =>
+ val executorId =
+ Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
index 0a3c63d14ca8..b7cbed468517 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala
@@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import org.apache.spark.status.api.v1.ExecutorSummary
+import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics}
import org.apache.spark.ui.{UIUtils, WebUIPage}
// This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive
@@ -114,10 +114,16 @@ private[spark] object ExecutorsPage {
val rddBlocks = status.numBlocks
val memUsed = status.memUsed
val maxMem = status.maxMem
- val onHeapMemUsed = status.onHeapMemUsed
- val offHeapMemUsed = status.offHeapMemUsed
- val maxOnHeapMem = status.maxOnHeapMem
- val maxOffHeapMem = status.maxOffHeapMem
+ val memoryMetrics = for {
+ onHeapUsed <- status.onHeapMemUsed
+ offHeapUsed <- status.offHeapMemUsed
+ maxOnHeap <- status.maxOnHeapMem
+ maxOffHeap <- status.maxOffHeapMem
+ } yield {
+ new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap)
+ }
+
+
val diskUsed = status.diskUsed
val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId))
@@ -142,10 +148,7 @@ private[spark] object ExecutorsPage {
taskSummary.isBlacklisted,
maxMem,
taskSummary.executorLogs,
- onHeapMemUsed,
- offHeapMemUsed,
- maxOnHeapMem,
- maxOffHeapMem
+ memoryMetrics
)
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 03851293eb2f..aabf6e0c63c0 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -62,6 +62,7 @@ private[ui] case class ExecutorTaskSummary(
* A SparkListener that prepares information to be displayed on the ExecutorsTab
*/
@DeveloperApi
+@deprecated("This class will be removed in a future release.", "2.2.0")
class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf)
extends SparkListener {
val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]()
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 18be0870746e..a0fd29c22ddc 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
jobTag: String,
jobs: Seq[JobUIData],
killEnabled: Boolean): Seq[Node] = {
- val allParameters = request.getParameterMap.asScala.toMap
+ // stripXSS is called to remove suspicious characters used in XSS attacks
+ val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
.map(para => para._1 + "=" + para._2(0))
val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"
- val parameterJobPage = request.getParameter(jobTag + ".page")
- val parameterJobSortColumn = request.getParameter(jobTag + ".sort")
- val parameterJobSortDesc = request.getParameter(jobTag + ".desc")
- val parameterJobPageSize = request.getParameter(jobTag + ".pageSize")
- val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page"))
+ val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort"))
+ val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc"))
+ val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize"))
+ val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize"))
val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1)
val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
index 3131c4a1eb7d..9fb011a049b7 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -187,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener
listener.synchronized {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
val jobId = parameterId.toInt
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index f78db5ab80d1..7370f9feb68c 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -41,6 +41,7 @@ import org.apache.spark.ui.jobs.UIData._
* updating the internal data structures concurrently.
*/
@DeveloperApi
+@deprecated("This class will be removed in a future release.", "2.2.0")
class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
// Define a handful of type aliases so that data structures' types can serve as documentation.
@@ -328,13 +329,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
val taskInfo = taskStart.taskInfo
if (taskInfo != null) {
- val metrics = TaskMetrics.empty
val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), {
logWarning("Task start for unknown stage " + taskStart.stageId)
new StageUIData
})
stageData.numActiveTasks += 1
- stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo, Some(metrics)))
+ stageData.taskData.put(taskInfo.taskId, TaskUIData(taskInfo))
}
for (
activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
@@ -404,7 +404,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
updateAggregateMetrics(stageData, info.executorId, m, oldMetrics)
}
- val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info, None))
+ val taskData = stageData.taskData.getOrElseUpdate(info.taskId, TaskUIData(info))
taskData.updateTaskInfo(info)
taskData.updateTaskMetrics(taskMetrics)
taskData.errorMessage = errorMessage
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
index 620c54c2dc0a..cc173381879a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest
import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.{SparkUI, SparkUITab}
+import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}
/** Web UI showing progress status of all jobs in the given SparkContext. */
private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
@@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
- val jobId = Option(request.getParameter("id")).map(_.toInt)
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
jobId.foreach { id =>
if (jobProgresslistener.activeJobs.contains(id)) {
sc.foreach(_.cancelJob(id))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index 8ee70d27cc09..b164f32b62e9 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
- val poolName = Option(request.getParameter("poolname")).map { poolname =>
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 19325a2dc916..6b3dadc33331 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -87,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
- val parameterAttempt = request.getParameter("attempt")
+ val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")
- val parameterTaskPage = request.getParameter("task.page")
- val parameterTaskSortColumn = request.getParameter("task.sort")
- val parameterTaskSortDesc = request.getParameter("task.desc")
- val parameterTaskPageSize = request.getParameter("task.pageSize")
- val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize")
+ val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
+ val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
+ val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
+ val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))
+ val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize"))
val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
index 256b726fa7ee..a28daf7f9045 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
@@ -42,15 +42,17 @@ private[ui] class StageTableBase(
isFairScheduler: Boolean,
killEnabled: Boolean,
isFailedStage: Boolean) {
- val allParameters = request.getParameterMap().asScala.toMap
+ // stripXSS is called to remove suspicious characters used in XSS attacks
+ val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))
- val parameterStagePage = request.getParameter(stageTag + ".page")
- val parameterStageSortColumn = request.getParameter(stageTag + ".sort")
- val parameterStageSortDesc = request.getParameter(stageTag + ".desc")
- val parameterStagePageSize = request.getParameter(stageTag + ".pageSize")
- val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize")
+ val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page"))
+ val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort"))
+ val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc"))
+ val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize"))
+ val parameterStagePrevPageSize =
+ UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize"))
val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1)
val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn =>
@@ -512,4 +514,3 @@ private[ui] class StageDataSource(
}
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
index 181465bdf960..799d76962639 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest
import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.{SparkUI, SparkUITab}
+import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}
/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
@@ -39,7 +39,8 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"
def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
- val stageId = Option(request.getParameter("id")).map(_.toInt)
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
stageId.foreach { id =>
if (progressListener.activeStages.contains(id)) {
sc.foreach(_.cancelStage(id, "killed via the Web UI"))
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index ac1a74ad8029..8bedd071a2c1 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -20,6 +20,8 @@ package org.apache.spark.ui.jobs
import scala.collection.mutable
import scala.collection.mutable.{HashMap, LinkedHashMap}
+import com.google.common.collect.Interners
+
import org.apache.spark.JobExecutionStatus
import org.apache.spark.executor._
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
@@ -112,9 +114,9 @@ private[spark] object UIData {
/**
* These are kept mutable and reused throughout a task's lifetime to avoid excessive reallocation.
*/
- class TaskUIData private(
- private var _taskInfo: TaskInfo,
- private var _metrics: Option[TaskMetricsUIData]) {
+ class TaskUIData private(private var _taskInfo: TaskInfo) {
+
+ private[this] var _metrics: Option[TaskMetricsUIData] = Some(TaskMetricsUIData.EMPTY)
var errorMessage: Option[String] = None
@@ -127,7 +129,7 @@ private[spark] object UIData {
}
def updateTaskMetrics(metrics: Option[TaskMetrics]): Unit = {
- _metrics = TaskUIData.toTaskMetricsUIData(metrics)
+ _metrics = metrics.map(TaskMetricsUIData.fromTaskMetrics)
}
def taskDuration: Option[Long] = {
@@ -140,28 +142,16 @@ private[spark] object UIData {
}
object TaskUIData {
- def apply(taskInfo: TaskInfo, metrics: Option[TaskMetrics]): TaskUIData = {
- new TaskUIData(dropInternalAndSQLAccumulables(taskInfo), toTaskMetricsUIData(metrics))
+
+ private val stringInterner = Interners.newWeakInterner[String]()
+
+ /** String interning to reduce the memory usage. */
+ private def weakIntern(s: String): String = {
+ stringInterner.intern(s)
}
- private def toTaskMetricsUIData(metrics: Option[TaskMetrics]): Option[TaskMetricsUIData] = {
- metrics.map { m =>
- TaskMetricsUIData(
- executorDeserializeTime = m.executorDeserializeTime,
- executorDeserializeCpuTime = m.executorDeserializeCpuTime,
- executorRunTime = m.executorRunTime,
- executorCpuTime = m.executorCpuTime,
- resultSize = m.resultSize,
- jvmGCTime = m.jvmGCTime,
- resultSerializationTime = m.resultSerializationTime,
- memoryBytesSpilled = m.memoryBytesSpilled,
- diskBytesSpilled = m.diskBytesSpilled,
- peakExecutionMemory = m.peakExecutionMemory,
- inputMetrics = InputMetricsUIData(m.inputMetrics),
- outputMetrics = OutputMetricsUIData(m.outputMetrics),
- shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics),
- shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics))
- }
+ def apply(taskInfo: TaskInfo): TaskUIData = {
+ new TaskUIData(dropInternalAndSQLAccumulables(taskInfo))
}
/**
@@ -174,8 +164,8 @@ private[spark] object UIData {
index = taskInfo.index,
attemptNumber = taskInfo.attemptNumber,
launchTime = taskInfo.launchTime,
- executorId = taskInfo.executorId,
- host = taskInfo.host,
+ executorId = weakIntern(taskInfo.executorId),
+ host = weakIntern(taskInfo.host),
taskLocality = taskInfo.taskLocality,
speculative = taskInfo.speculative
)
@@ -206,6 +196,28 @@ private[spark] object UIData {
shuffleReadMetrics: ShuffleReadMetricsUIData,
shuffleWriteMetrics: ShuffleWriteMetricsUIData)
+ object TaskMetricsUIData {
+ def fromTaskMetrics(m: TaskMetrics): TaskMetricsUIData = {
+ TaskMetricsUIData(
+ executorDeserializeTime = m.executorDeserializeTime,
+ executorDeserializeCpuTime = m.executorDeserializeCpuTime,
+ executorRunTime = m.executorRunTime,
+ executorCpuTime = m.executorCpuTime,
+ resultSize = m.resultSize,
+ jvmGCTime = m.jvmGCTime,
+ resultSerializationTime = m.resultSerializationTime,
+ memoryBytesSpilled = m.memoryBytesSpilled,
+ diskBytesSpilled = m.diskBytesSpilled,
+ peakExecutionMemory = m.peakExecutionMemory,
+ inputMetrics = InputMetricsUIData(m.inputMetrics),
+ outputMetrics = OutputMetricsUIData(m.outputMetrics),
+ shuffleReadMetrics = ShuffleReadMetricsUIData(m.shuffleReadMetrics),
+ shuffleWriteMetrics = ShuffleWriteMetricsUIData(m.shuffleWriteMetrics))
+ }
+
+ val EMPTY: TaskMetricsUIData = fromTaskMetrics(TaskMetrics.empty)
+ }
+
case class InputMetricsUIData(bytesRead: Long, recordsRead: Long)
object InputMetricsUIData {
def apply(metrics: InputMetrics): InputMetricsUIData = {
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index a1a0c729b924..317e0aa5ea25 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
- val parameterId = request.getParameter("id")
+ // stripXSS is called first to remove suspicious characters used in XSS attacks
+ val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")
- val parameterBlockPage = request.getParameter("block.page")
- val parameterBlockSortColumn = request.getParameter("block.sort")
- val parameterBlockSortDesc = request.getParameter("block.desc")
- val parameterBlockPageSize = request.getParameter("block.pageSize")
- val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize")
+ val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
+ val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
+ val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
+ val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))
+ val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize"))
val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index c212362557be..148efb134e14 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -39,6 +39,7 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag
* This class is thread-safe (unlike JobProgressListener)
*/
@DeveloperApi
+@deprecated("This class will be removed in a future release.", "2.2.0")
class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener {
private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index 7479de55140e..603c23abb689 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -68,7 +68,7 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
private def assertMetadataNotNull(): Unit = {
if (metadata == null) {
- throw new IllegalAccessError("The metadata of this accumulator has not been assigned yet.")
+ throw new IllegalStateException("The metadata of this accumulator has not been assigned yet.")
}
}
@@ -85,7 +85,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
*/
final def name: Option[String] = {
assertMetadataNotNull()
- metadata.name
+
+ if (atDriverSide) {
+ metadata.name.orElse(AccumulatorContext.get(id).flatMap(_.metadata.name))
+ } else {
+ metadata.name
+ }
}
/**
@@ -161,7 +166,17 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable {
}
val copyAcc = copyAndReset()
assert(copyAcc.isZero, "copyAndReset must return a zero value copy")
- copyAcc.metadata = metadata
+ val isInternalAcc = name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)
+ if (isInternalAcc) {
+ // Do not serialize the name of internal accumulator and send it to executor.
+ copyAcc.metadata = metadata.copy(name = None)
+ } else {
+ // For non-internal accumulators, we still need to send the name because users may need to
+ // access the accumulator name at executor side, or they may keep the accumulators sent from
+ // executors and access the name when the registered accumulator is already garbage
+ // collected(e.g. SQLMetrics).
+ copyAcc.metadata = metadata
+ }
copyAcc
} else {
this
@@ -250,7 +265,7 @@ private[spark] object AccumulatorContext {
// Since we are storing weak references, we must check whether the underlying data is valid.
val acc = ref.get
if (acc eq null) {
- throw new IllegalAccessError(s"Attempted to access garbage collected accumulator $id")
+ throw new IllegalStateException(s"Attempted to access garbage collected accumulator $id")
}
acc
}
@@ -263,16 +278,6 @@ private[spark] object AccumulatorContext {
originals.clear()
}
- /**
- * Looks for a registered accumulator by accumulator name.
- */
- private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = {
- originals.values().asScala.find { ref =>
- val acc = ref.get
- acc != null && acc.name.isDefined && acc.name.get == name
- }.map(_.get)
- }
-
// Identifier for distinguishing SQL metrics from other accumulators
private[spark] val SQL_ACCUM_IDENTIFIER = "sql"
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
similarity index 95%
rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
index 4dd498cd91b4..ce06e18879a4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala
+++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.mllib.impl
+package org.apache.spark.util
import scala.collection.mutable
@@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
-private[mllib] abstract class PeriodicCheckpointer[T](
+private[spark] abstract class PeriodicCheckpointer[T](
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {
@@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T](
/** Get list of checkpoint files for this given Dataset */
protected def getCheckpointFiles(data: T): Iterable[String]
+ /**
+ * Call this to unpersist the Dataset.
+ */
+ def unpersistDataSet(): Unit = {
+ while (persistedQueue.nonEmpty) {
+ val dataToUnpersist = persistedQueue.dequeue()
+ unpersist(dataToUnpersist)
+ }
+ }
+
/**
* Call this at the end to delete any remaining checkpoint files.
*/
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 1aa4456ed01b..81aaf79db0c1 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -206,4 +206,25 @@ private[spark] object ThreadUtils {
}
}
// scalastyle:on awaitresult
+
+ // scalastyle:off awaitready
+ /**
+ * Preferred alternative to `Await.ready()`.
+ *
+ * @see [[awaitResult]]
+ */
+ @throws(classOf[SparkException])
+ def awaitReady[T](awaitable: Awaitable[T], atMost: Duration): awaitable.type = {
+ try {
+ // `awaitPermission` is not actually used anywhere so it's safe to pass in null here.
+ // See SPARK-13747.
+ val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait]
+ awaitable.ready(atMost)(awaitPermission)
+ } catch {
+ // TimeoutException is thrown in the current thread, so not need to warp the exception.
+ case NonFatal(t) if !t.isInstanceOf[TimeoutException] =>
+ throw new SparkException("Exception thrown in awaitResult: ", t)
+ }
+ }
+ // scalastyle:on awaitready
}
diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
index f0b68f0cb7e2..27922b31949b 100644
--- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
+++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala
@@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy
*
* Note: "runUninterruptibly" should be called only in `this` thread.
*/
-private[spark] class UninterruptibleThread(name: String) extends Thread(name) {
+private[spark] class UninterruptibleThread(
+ target: Runnable,
+ name: String) extends Thread(target, name) {
+
+ def this(name: String) {
+ this(null, name)
+ }
/** A monitor to protect "uninterruptible" and "interrupted" */
private val uninterruptibleLock = new Object
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 943dde072327..67497bbba150 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -22,7 +22,7 @@ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInf
import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
-import java.nio.channels.Channels
+import java.nio.channels.{Channels, FileChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
import java.util.{Locale, Properties, Random, UUID}
@@ -60,7 +60,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
-import org.apache.spark.util.logging.RollingFileAppender
/** CallSite represents a place in user code. It can have a short and a long form. */
private[spark] case class CallSite(shortForm: String, longForm: String)
@@ -319,41 +318,22 @@ private[spark] object Utils extends Logging {
* copying is disabled by default unless explicitly set transferToEnabled as true,
* the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false].
*/
- def copyStream(in: InputStream,
- out: OutputStream,
- closeStreams: Boolean = false,
- transferToEnabled: Boolean = false): Long =
- {
- var count = 0L
+ def copyStream(
+ in: InputStream,
+ out: OutputStream,
+ closeStreams: Boolean = false,
+ transferToEnabled: Boolean = false): Long = {
tryWithSafeFinally {
if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]
&& transferToEnabled) {
// When both streams are File stream, use transferTo to improve copy performance.
val inChannel = in.asInstanceOf[FileInputStream].getChannel()
val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
- val initialPos = outChannel.position()
val size = inChannel.size()
-
- // In case transferTo method transferred less data than we have required.
- while (count < size) {
- count += inChannel.transferTo(count, size - count, outChannel)
- }
-
- // Check the position after transferTo loop to see if it is in the right position and
- // give user information if not.
- // Position will not be increased to the expected length after calling transferTo in
- // kernel version 2.6.32, this issue can be seen in
- // https://bugs.openjdk.java.net/browse/JDK-7052359
- // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
- val finalPos = outChannel.position()
- assert(finalPos == initialPos + size,
- s"""
- |Current position $finalPos do not equal to expected position ${initialPos + size}
- |after transferTo, please check your kernel version to see if it is 2.6.32,
- |this is a kernel bug which will lead to unexpected behavior when using transferTo.
- |You can set spark.file.transferTo = false to disable this NIO feature.
- """.stripMargin)
+ copyFileStreamNIO(inChannel, outChannel, 0, size)
+ size
} else {
+ var count = 0L
val buf = new Array[Byte](8192)
var n = 0
while (n != -1) {
@@ -363,8 +343,8 @@ private[spark] object Utils extends Logging {
count += n
}
}
+ count
}
- count
} {
if (closeStreams) {
try {
@@ -376,6 +356,37 @@ private[spark] object Utils extends Logging {
}
}
+ def copyFileStreamNIO(
+ input: FileChannel,
+ output: FileChannel,
+ startPosition: Long,
+ bytesToCopy: Long): Unit = {
+ val initialPos = output.position()
+ var count = 0L
+ // In case transferTo method transferred less data than we have required.
+ while (count < bytesToCopy) {
+ count += input.transferTo(count + startPosition, bytesToCopy - count, output)
+ }
+ assert(count == bytesToCopy,
+ s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")
+
+ // Check the position after transferTo loop to see if it is in the right position and
+ // give user information if not.
+ // Position will not be increased to the expected length after calling transferTo in
+ // kernel version 2.6.32, this issue can be seen in
+ // https://bugs.openjdk.java.net/browse/JDK-7052359
+ // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
+ val finalPos = output.position()
+ val expectedPos = initialPos + bytesToCopy
+ assert(finalPos == expectedPos,
+ s"""
+ |Current position $finalPos do not equal to expected position $expectedPos
+ |after transferTo, please check your kernel version to see if it is 2.6.32,
+ |this is a kernel bug which will lead to unexpected behavior when using transferTo.
+ |You can set spark.file.transferTo = false to disable this NIO feature.
+ """.stripMargin)
+ }
+
/**
* Construct a URI container information used for authentication.
* This also sets the default authenticator to properly negotiation the
@@ -740,7 +751,11 @@ private[spark] object Utils extends Logging {
* always return a single directory.
*/
def getLocalDir(conf: SparkConf): String = {
- getOrCreateLocalRootDirs(conf)(0)
+ getOrCreateLocalRootDirs(conf).headOption.getOrElse {
+ val configuredLocalDirs = getConfiguredLocalDirs(conf)
+ throw new IOException(
+ s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].")
+ }
}
private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = {
@@ -2606,10 +2621,24 @@ private[spark] object Utils extends Logging {
}
private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = {
- kvs.map { kv =>
- redactionPattern.findFirstIn(kv._1)
- .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) }
- .getOrElse(kv)
+ // If the sensitive information regex matches with either the key or the value, redact the value
+ // While the original intent was to only redact the value if the key matched with the regex,
+ // we've found that especially in verbose mode, the value of the property may contain sensitive
+ // information like so:
+ // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \
+ // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ...
+ //
+ // And, in such cases, simply searching for the sensitive information regex in the key name is
+ // not sufficient. The values themselves have to be searched as well and redacted if matched.
+ // This does mean we may be accounting more false positives - for example, if the value of an
+ // arbitrary property contained the term 'password', we may redact the value from the UI and
+ // logs. In order to work around it, user would have to make the spark.redaction.regex property
+ // more specific.
+ kvs.map { case (key, value) =>
+ redactionPattern.findFirstIn(key)
+ .orElse(redactionPattern.findFirstIn(value))
+ .map { _ => (key, REDACTION_REPLACEMENT_TEXT) }
+ .getOrElse((key, value))
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/taskListeners.scala b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
index 1be31e88ab68..51feccfb8342 100644
--- a/core/src/main/scala/org/apache/spark/util/taskListeners.scala
+++ b/core/src/main/scala/org/apache/spark/util/taskListeners.scala
@@ -55,14 +55,16 @@ class TaskCompletionListenerException(
extends RuntimeException {
override def getMessage: String = {
- if (errorMessages.size == 1) {
- errorMessages.head
- } else {
- errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
- } +
- previousError.map { e =>
+ val listenerErrorMessage =
+ if (errorMessages.size == 1) {
+ errorMessages.head
+ } else {
+ errorMessages.zipWithIndex.map { case (msg, i) => s"Exception $i: $msg" }.mkString("\n")
+ }
+ val previousErrorMessage = previousError.map { e =>
"\n\nPrevious exception in task: " + e.getMessage + "\n" +
e.getStackTrace.mkString("\t", "\n\t", "")
}.getOrElse("")
+ listenerErrorMessage + previousErrorMessage
}
}
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json
index e732af266350..0f94e3b255db 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json
@@ -22,10 +22,12 @@
"stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout",
"stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr"
},
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "driver",
"hostPort" : "172.22.0.167:51475",
@@ -47,10 +49,12 @@
"isBlacklisted" : true,
"maxMemory" : 908381388,
"executorLogs" : { },
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "1",
"hostPort" : "172.22.0.167:51490",
@@ -75,11 +79,12 @@
"stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout",
"stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr"
},
-
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "0",
"hostPort" : "172.22.0.167:51491",
@@ -104,10 +109,12 @@
"stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout",
"stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr"
},
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "3",
"hostPort" : "172.22.0.167:51485",
@@ -132,8 +139,10 @@
"stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout",
"stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr"
},
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
} ]
diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json
index e732af266350..0f94e3b255db 100644
--- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json
+++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json
@@ -22,10 +22,12 @@
"stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout",
"stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr"
},
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "driver",
"hostPort" : "172.22.0.167:51475",
@@ -47,10 +49,12 @@
"isBlacklisted" : true,
"maxMemory" : 908381388,
"executorLogs" : { },
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "1",
"hostPort" : "172.22.0.167:51490",
@@ -75,11 +79,12 @@
"stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout",
"stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr"
},
-
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "0",
"hostPort" : "172.22.0.167:51491",
@@ -104,10 +109,12 @@
"stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout",
"stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr"
},
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
}, {
"id" : "3",
"hostPort" : "172.22.0.167:51485",
@@ -132,8 +139,10 @@
"stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout",
"stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr"
},
- "onHeapMemoryUsed" : 0,
- "offHeapMemoryUsed" : 0,
- "maxOnHeapMemory" : 384093388,
- "maxOffHeapMemory" : 524288000
+ "memoryMetrics": {
+ "usedOnHeapStorageMemory": 0,
+ "usedOffHeapStorageMemory": 0,
+ "totalOnHeapStorageMemory": 384093388,
+ "totalOffHeapStorageMemory": 524288000
+ }
} ]
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index ddbcb2d19dcb..3990ee1ec326 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -210,7 +210,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex
assert(ref.get.isEmpty)
// Getting a garbage collected accum should throw error
- intercept[IllegalAccessError] {
+ intercept[IllegalStateException] {
AccumulatorContext.get(accId)
}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index b117c7709b46..ee70a3399efe 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -21,8 +21,10 @@ import java.io.File
import scala.reflect.ClassTag
+import com.google.common.io.ByteStreams
import org.apache.hadoop.fs.Path
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils
@@ -580,3 +582,42 @@ object CheckpointSuite {
).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
}
}
+
+class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {
+
+ test("checkpoint compression") {
+ val checkpointDir = Utils.createTempDir()
+ try {
+ val conf = new SparkConf()
+ .set("spark.checkpoint.compress", "true")
+ .set("spark.ui.enabled", "false")
+ sc = new SparkContext("local", "test", conf)
+ sc.setCheckpointDir(checkpointDir.toString)
+ val rdd = sc.makeRDD(1 to 20, numSlices = 1)
+ rdd.checkpoint()
+ assert(rdd.collect().toSeq === (1 to 20))
+
+ // Verify that RDD is checkpointed
+ assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]])
+
+ val checkpointPath = new Path(rdd.getCheckpointFile.get)
+ val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
+ val checkpointFile =
+ fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get
+
+ // Verify the checkpoint file is compressed, in other words, can be decompressed
+ val compressedInputStream = CompressionCodec.createCodec(conf)
+ .compressedInputStream(fs.open(checkpointFile))
+ try {
+ ByteStreams.toByteArray(compressedInputStream)
+ } finally {
+ compressedInputStream.close()
+ }
+
+ // Verify that the compressed content can be read back
+ assert(rdd.collect().toSeq === (1 to 20))
+ } finally {
+ Utils.deleteRecursively(checkpointDir)
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index bb24c6ce4d33..71bedda5ac89 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark
import scala.collection.mutable.ArrayBuffer
-import org.mockito.Matchers.{any, isA}
+import org.mockito.Matchers.any
import org.mockito.Mockito._
import org.apache.spark.broadcast.BroadcastManager
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 735f4454e299..979270a527a6 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -23,7 +23,6 @@ import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit
import scala.concurrent.duration._
-import scala.concurrent.Await
import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
@@ -35,7 +34,7 @@ import org.scalatest.concurrent.Eventually
import org.scalatest.Matchers._
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerTaskEnd, SparkListenerTaskStart}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventually {
@@ -301,13 +300,13 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
sc.addJar(tmpJar.getAbsolutePath)
- // Invaid jar path will only print the error log, will not add to file server.
+ // Invalid jar path will only print the error log, will not add to file server.
sc.addJar("dummy.jar")
sc.addJar("")
sc.addJar(tmpDir.getAbsolutePath)
- sc.listJars().size should be (1)
- sc.listJars().head should include (tmpJar.getName)
+ assert(sc.listJars().size == 1)
+ assert(sc.listJars().head.contains(tmpJar.getName))
}
test("Cancelling job group should not cause SparkContext to shutdown (SPARK-6414)") {
@@ -315,7 +314,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
val future = sc.parallelize(Seq(0)).foreachAsync(_ => {Thread.sleep(1000L)})
sc.cancelJobGroup("nonExistGroupId")
- Await.ready(future, Duration(2, TimeUnit.SECONDS))
+ ThreadUtils.awaitReady(future, Duration(2, TimeUnit.SECONDS))
// In SPARK-6414, sc.cancelJobGroup will cause NullPointerException and cause
// SparkContext to shutdown, so the following assertion will fail.
@@ -540,10 +539,24 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
}
}
- // Launches one task that will run forever. Once the SparkListener detects the task has
+ testCancellingTasks("that raise interrupted exception on cancel") {
+ Thread.sleep(9999999)
+ }
+
+ // SPARK-20217 should not fail stage if task throws non-interrupted exception
+ testCancellingTasks("that raise runtime exception on cancel") {
+ try {
+ Thread.sleep(9999999)
+ } catch {
+ case t: Throwable =>
+ throw new RuntimeException("killed")
+ }
+ }
+
+ // Launches one task that will block forever. Once the SparkListener detects the task has
// started, kill and re-schedule it. The second run of the task will complete immediately.
// If this test times out, then the first version of the task wasn't killed successfully.
- test("Killing tasks") {
+ def testCancellingTasks(desc: String)(blockFn: => Unit): Unit = test(s"Killing tasks $desc") {
sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local"))
SparkContextSuite.isTaskStarted = false
@@ -572,13 +585,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu
// first attempt will hang
if (!SparkContextSuite.isTaskStarted) {
SparkContextSuite.isTaskStarted = true
- try {
- Thread.sleep(9999999)
- } catch {
- case t: Throwable =>
- // SPARK-20217 should not fail stage if task throws non-interrupted exception
- throw new RuntimeException("killed")
- }
+ blockFn
}
// second attempt succeeds immediately
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala
index f50cb38311db..42b8cde65039 100644
--- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala
@@ -243,16 +243,22 @@ private[deploy] object IvyTestUtils {
withManifest: Option[Manifest] = None): File = {
val jarFile = new File(dir, artifactName(artifact, useIvyLayout))
val jarFileStream = new FileOutputStream(jarFile)
- val manifest = withManifest.getOrElse {
- val mani = new Manifest()
+ val manifest: Manifest = withManifest.getOrElse {
if (withR) {
+ val mani = new Manifest()
val attr = mani.getMainAttributes
attr.put(Name.MANIFEST_VERSION, "1.0")
attr.put(new Name("Spark-HasRPackage"), "true")
+ mani
+ } else {
+ null
}
- mani
}
- val jarStream = new JarOutputStream(jarFileStream, manifest)
+ val jarStream = if (manifest != null) {
+ new JarOutputStream(jarFileStream, manifest)
+ } else {
+ new JarOutputStream(jarFileStream)
+ }
for (file <- files) {
val jarEntry = new JarEntry(file._1)
diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
index 005587051b6a..5e0bf6d438dc 100644
--- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala
@@ -133,6 +133,16 @@ class RPackageUtilsSuite
}
}
+ test("jars without manifest return false") {
+ IvyTestUtils.withRepository(main, None, None) { repo =>
+ val jar = IvyTestUtils.packJar(new File(new URI(repo)), dep1, Nil,
+ useIvyLayout = false, withR = false, None)
+ val jarFile = new JarFile(jar)
+ assert(jarFile.getManifest == null, "jar file should have null manifest")
+ assert(!RPackageUtils.checkManifestForR(jarFile), "null manifest should return false")
+ }
+ }
+
test("SparkR zipping works properly") {
val tempDir = Files.createTempDir()
Utils.tryWithSafeFinally {
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala
new file mode 100644
index 000000000000..ab24a76e20a3
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import java.security.PrivilegedExceptionAction
+
+import scala.util.Random
+
+import org.apache.hadoop.fs.FileStatus
+import org.apache.hadoop.fs.permission.{FsAction, FsPermission}
+import org.apache.hadoop.security.UserGroupInformation
+import org.scalatest.Matchers
+
+import org.apache.spark.SparkFunSuite
+
+class SparkHadoopUtilSuite extends SparkFunSuite with Matchers {
+ test("check file permission") {
+ import FsAction._
+ val testUser = s"user-${Random.nextInt(100)}"
+ val testGroups = Array(s"group-${Random.nextInt(100)}")
+ val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups)
+
+ testUgi.doAs(new PrivilegedExceptionAction[Void] {
+ override def run(): Void = {
+ val sparkHadoopUtil = new SparkHadoopUtil
+
+ // If file is owned by user and user has access permission
+ var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE)
+ sparkHadoopUtil.checkAccessPermission(status, READ) should be(true)
+ sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true)
+
+ // If file is owned by user but user has no access permission
+ status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE)
+ sparkHadoopUtil.checkAccessPermission(status, READ) should be(false)
+ sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false)
+
+ val otherUser = s"test-${Random.nextInt(100)}"
+ val otherGroup = s"test-${Random.nextInt(100)}"
+
+ // If file is owned by user's group and user's group has access permission
+ status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE)
+ sparkHadoopUtil.checkAccessPermission(status, READ) should be(true)
+ sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true)
+
+ // If file is owned by user's group but user's group has no access permission
+ status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE)
+ sparkHadoopUtil.checkAccessPermission(status, READ) should be(false)
+ sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false)
+
+ // If file is owned by other user and this user has access permission
+ status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE)
+ sparkHadoopUtil.checkAccessPermission(status, READ) should be(true)
+ sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true)
+
+ // If file is owned by other user but this user has no access permission
+ status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE)
+ sparkHadoopUtil.checkAccessPermission(status, READ) should be(false)
+ sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false)
+
+ null
+ }
+ })
+ }
+
+ private def fileStatus(
+ owner: String,
+ group: String,
+ userAction: FsAction,
+ groupAction: FsAction,
+ otherAction: FsAction): FileStatus = {
+ new FileStatus(0L,
+ false,
+ 0,
+ 0L,
+ 0L,
+ 0L,
+ new FsPermission(userAction, groupAction, otherAction),
+ owner,
+ group,
+ null)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 7c2ec01a03d0..6e9721c45931 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -18,11 +18,16 @@
package org.apache.spark.deploy
import java.io._
+import java.net.URI
import java.nio.charset.StandardCharsets
import scala.collection.mutable.ArrayBuffer
+import scala.io.Source
import com.google.common.io.ByteStreams
+import org.apache.commons.io.{FilenameUtils, FileUtils}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
import org.scalatest.{BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
@@ -34,6 +39,7 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate
import org.apache.spark.internal.config._
import org.apache.spark.internal.Logging
import org.apache.spark.TestUtils.JavaSourceFromString
+import org.apache.spark.scheduler.EventLoggingListener
import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils}
@@ -404,6 +410,37 @@ class SparkSubmitSuite
runSparkSubmit(args)
}
+ test("launch simple application with spark-submit with redaction") {
+ val testDir = Utils.createTempDir()
+ testDir.deleteOnExit()
+ val testDirPath = new Path(testDir.getAbsolutePath())
+ val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
+ val fileSystem = Utils.getHadoopFileSystem("/",
+ SparkHadoopUtil.get.newConfiguration(new SparkConf()))
+ try {
+ val args = Seq(
+ "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"),
+ "--name", "testApp",
+ "--master", "local",
+ "--conf", "spark.ui.enabled=false",
+ "--conf", "spark.master.rest.enabled=false",
+ "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password",
+ "--conf", "spark.eventLog.enabled=true",
+ "--conf", "spark.eventLog.testing=true",
+ "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}",
+ "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com",
+ unusedJar.toString)
+ runSparkSubmit(args)
+ val listStatus = fileSystem.listStatus(testDirPath)
+ val logData = EventLoggingListener.openEventLog(listStatus.last.getPath, fileSystem)
+ Source.fromInputStream(logData).getLines().foreach { line =>
+ assert(!line.contains("secret_password"))
+ }
+ } finally {
+ Utils.deleteRecursively(testDir)
+ }
+ }
+
test("includes jars passed in through --jars") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
@@ -501,7 +538,7 @@ class SparkSubmitSuite
test("resolves command line argument paths correctly") {
val jars = "/jar1,/jar2" // --jars
- val files = "hdfs:/file1,file2" // --files
+ val files = "local:/file1,file2" // --files
val archives = "file:/archive1,archive2" // --archives
val pyFiles = "py-file1,py-file2" // --py-files
@@ -553,7 +590,7 @@ class SparkSubmitSuite
test("resolves config paths correctly") {
val jars = "/jar1,/jar2" // spark.jars
- val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files
+ val files = "local:/file1,file2" // spark.files / spark.yarn.dist.files
val archives = "file:/archive1,archive2" // spark.yarn.dist.archives
val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles
@@ -671,6 +708,87 @@ class SparkSubmitSuite
}
// scalastyle:on println
+ private def checkDownloadedFile(sourcePath: String, outputPath: String): Unit = {
+ if (sourcePath == outputPath) {
+ return
+ }
+
+ val sourceUri = new URI(sourcePath)
+ val outputUri = new URI(outputPath)
+ assert(outputUri.getScheme === "file")
+
+ // The path and filename are preserved.
+ assert(outputUri.getPath.endsWith(sourceUri.getPath))
+ assert(FileUtils.readFileToString(new File(outputUri.getPath)) ===
+ FileUtils.readFileToString(new File(sourceUri.getPath)))
+ }
+
+ private def deleteTempOutputFile(outputPath: String): Unit = {
+ val outputFile = new File(new URI(outputPath).getPath)
+ if (outputFile.exists) {
+ outputFile.delete()
+ }
+ }
+
+ test("downloadFile - invalid url") {
+ intercept[IOException] {
+ SparkSubmit.downloadFile("abc:/my/file", new Configuration())
+ }
+ }
+
+ test("downloadFile - file doesn't exist") {
+ val hadoopConf = new Configuration()
+ // Set s3a implementation to local file system for testing.
+ hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
+ // Disable file system impl cache to make sure the test file system is picked up.
+ hadoopConf.set("fs.s3a.impl.disable.cache", "true")
+ intercept[FileNotFoundException] {
+ SparkSubmit.downloadFile("s3a:/no/such/file", hadoopConf)
+ }
+ }
+
+ test("downloadFile does not download local file") {
+ // empty path is considered as local file.
+ assert(SparkSubmit.downloadFile("", new Configuration()) === "")
+ assert(SparkSubmit.downloadFile("/local/file", new Configuration()) === "/local/file")
+ }
+
+ test("download one file to local") {
+ val jarFile = File.createTempFile("test", ".jar")
+ jarFile.deleteOnExit()
+ val content = "hello, world"
+ FileUtils.write(jarFile, content)
+ val hadoopConf = new Configuration()
+ // Set s3a implementation to local file system for testing.
+ hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
+ // Disable file system impl cache to make sure the test file system is picked up.
+ hadoopConf.set("fs.s3a.impl.disable.cache", "true")
+ val sourcePath = s"s3a://${jarFile.getAbsolutePath}"
+ val outputPath = SparkSubmit.downloadFile(sourcePath, hadoopConf)
+ checkDownloadedFile(sourcePath, outputPath)
+ deleteTempOutputFile(outputPath)
+ }
+
+ test("download list of files to local") {
+ val jarFile = File.createTempFile("test", ".jar")
+ jarFile.deleteOnExit()
+ val content = "hello, world"
+ FileUtils.write(jarFile, content)
+ val hadoopConf = new Configuration()
+ // Set s3a implementation to local file system for testing.
+ hadoopConf.set("fs.s3a.impl", "org.apache.spark.deploy.TestFileSystem")
+ // Disable file system impl cache to make sure the test file system is picked up.
+ hadoopConf.set("fs.s3a.impl.disable.cache", "true")
+ val sourcePaths = Seq("/local/file", s"s3a://${jarFile.getAbsolutePath}")
+ val outputPaths = SparkSubmit.downloadFileList(sourcePaths.mkString(","), hadoopConf).split(",")
+
+ assert(outputPaths.length === sourcePaths.length)
+ sourcePaths.zip(outputPaths).foreach { case (sourcePath, outputPath) =>
+ checkDownloadedFile(sourcePath, outputPath)
+ deleteTempOutputFile(outputPath)
+ }
+ }
+
// NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly.
private def runSparkSubmit(args: Seq[String]): Unit = {
val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
@@ -773,3 +891,10 @@ object UserClasspathFirstTest {
}
}
}
+
+class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem {
+ override def copyToLocalFile(src: Path, dst: Path): Unit = {
+ // Ignore the scheme for testing.
+ super.copyToLocalFile(new Path(src.toUri.getPath), dst)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
index 9839dcf8535d..bf7480d79f8a 100644
--- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala
@@ -356,12 +356,13 @@ class StandaloneDynamicAllocationSuite
test("kill the same executor twice (SPARK-9795)") {
sc = new SparkContext(appConf)
val appId = sc.applicationId
+ sc.requestExecutors(2)
eventually(timeout(10.seconds), interval(10.millis)) {
val apps = getApplications()
assert(apps.size === 1)
assert(apps.head.id === appId)
assert(apps.head.executors.size === 2)
- assert(apps.head.getExecutorLimit === Int.MaxValue)
+ assert(apps.head.getExecutorLimit === 2)
}
// sync executors between the Master and the driver, needed because
// the driver refuses to kill executors it does not know about
@@ -380,12 +381,13 @@ class StandaloneDynamicAllocationSuite
test("the pending replacement executors should not be lost (SPARK-10515)") {
sc = new SparkContext(appConf)
val appId = sc.applicationId
+ sc.requestExecutors(2)
eventually(timeout(10.seconds), interval(10.millis)) {
val apps = getApplications()
assert(apps.size === 1)
assert(apps.head.id === appId)
assert(apps.head.executors.size === 2)
- assert(apps.head.getExecutorLimit === Int.MaxValue)
+ assert(apps.head.getExecutorLimit === 2)
}
// sync executors between the Master and the driver, needed because
// the driver refuses to kill executors it does not know about
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index ec580a44b8e7..456158d41b93 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -27,6 +27,7 @@ import scala.concurrent.duration._
import scala.language.postfixOps
import com.google.common.io.{ByteStreams, Files}
+import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.hdfs.DistributedFileSystem
import org.json4s.jackson.JsonMethods._
import org.mockito.Matchers.any
@@ -130,9 +131,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
}
}
- test("SPARK-3697: ignore directories that cannot be read.") {
+ test("SPARK-3697: ignore files that cannot be read.") {
// setReadable(...) does not work on Windows. Please refer JDK-6728842.
assume(!Utils.isWindows)
+
+ class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) {
+ var mergeApplicationListingCall = 0
+ override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = {
+ super.mergeApplicationListing(fileStatus)
+ mergeApplicationListingCall += 1
+ }
+ }
+ val provider = new TestFsHistoryProvider
+
val logFile1 = newLogFile("new1", None, inProgress = false)
writeFile(logFile1, true, None,
SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None),
@@ -145,10 +156,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc
)
logFile2.setReadable(false, false)
- val provider = new FsHistoryProvider(createTestConf())
updateAndCheck(provider) { list =>
list.size should be (1)
}
+
+ provider.mergeApplicationListingCall should be (1)
}
test("history file is renamed from inprogress to completed") {
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
index 764156c3edc4..95acb9a54440 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -565,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
assert(jobcount === getNumJobs("/jobs"))
// no need to retain the test dir now the tests complete
- logDir.deleteOnExit();
-
+ logDir.deleteOnExit()
}
test("ui and api authorization checks") {
- val appId = "app-20161115172038-0000"
- val owner = "jose"
+ val appId = "local-1430917381535"
+ val owner = "irashid"
val admin = "root"
val other = "alice"
@@ -590,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers
val port = server.boundPort
val testUrls = Seq(
- s"http://localhost:$port/api/v1/applications/$appId/jobs",
- s"http://localhost:$port/history/$appId/jobs/")
+ s"http://localhost:$port/api/v1/applications/$appId/1/jobs",
+ s"http://localhost:$port/history/$appId/1/jobs/",
+ s"http://localhost:$port/api/v1/applications/$appId/logs",
+ s"http://localhost:$port/api/v1/applications/$appId/1/logs",
+ s"http://localhost:$port/api/v1/applications/$appId/2/logs")
tests.foreach { case (user, expectedCode) =>
testUrls.foreach { url =>
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 2127da48ece4..539264652d7d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -34,7 +34,7 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy._
import org.apache.spark.deploy.DeployMessages._
-import org.apache.spark.rpc.{RpcEndpoint, RpcEnv}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv}
class MasterSuite extends SparkFunSuite
with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter {
@@ -447,8 +447,15 @@ class MasterSuite extends SparkFunSuite
}
})
- master.self.send(
- RegisterWorker("1", "localhost", 9999, fakeWorker, 10, 1024, "http://localhost:8080"))
+ master.self.send(RegisterWorker(
+ "1",
+ "localhost",
+ 9999,
+ fakeWorker,
+ 10,
+ 1024,
+ "http://localhost:8080",
+ RpcAddress("localhost", 9999)))
val executors = (0 until 3).map { i =>
new ExecutorDescription(appId = i.toString, execId = i, 2, ExecutorState.RUNNING)
}
@@ -459,4 +466,37 @@ class MasterSuite extends SparkFunSuite
assert(killedDrivers.asScala.toList.sorted === List("0", "1", "2"))
}
}
+
+ test("SPARK-20529: Master should reply the address received from worker") {
+ val master = makeMaster()
+ master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master)
+ eventually(timeout(10.seconds)) {
+ val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
+ assert(masterState.status === RecoveryState.ALIVE, "Master is not alive")
+ }
+
+ @volatile var receivedMasterAddress: RpcAddress = null
+ val fakeWorker = master.rpcEnv.setupEndpoint("worker", new RpcEndpoint {
+ override val rpcEnv: RpcEnv = master.rpcEnv
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case RegisteredWorker(_, _, masterAddress) =>
+ receivedMasterAddress = masterAddress
+ }
+ })
+
+ master.self.send(RegisterWorker(
+ "1",
+ "localhost",
+ 9999,
+ fakeWorker,
+ 10,
+ 1024,
+ "http://localhost:8080",
+ RpcAddress("localhost2", 10000)))
+
+ eventually(timeout(10.seconds)) {
+ assert(receivedMasterAddress === RpcAddress("localhost2", 10000))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
index f47e574b4fc4..efcad140350b 100644
--- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala
@@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.UninterruptibleThread
class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually {
@@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug
assert(failReason.isInstanceOf[FetchFailed])
}
+ test("Executor's worker threads should be UninterruptibleThread") {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("executor thread test")
+ .set("spark.ui.enabled", "false")
+ sc = new SparkContext(conf)
+ val executorThread = sc.parallelize(Seq(1), 1).map { _ =>
+ Thread.currentThread.getClass.getName
+ }.collect().head
+ assert(executorThread === classOf[UninterruptibleThread].getName)
+ }
+
test("SPARK-19276: OOMs correctly handled with a FetchFailure") {
// when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it
// may be a false positive. And we should call the uncaught exception handler.
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 5d522189a0c2..6f4203da1d86 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter
import org.apache.spark.{SharedSparkContext, SparkFunSuite}
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
with BeforeAndAfter {
@@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
}
assert(bytesRead >= tmpFile.length())
}
+
+ test("input metrics with old Hadoop API in different thread") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.textFile(tmpFilePath, 4).mapPartitions { iter =>
+ val buf = new ArrayBuffer[String]()
+ ThreadUtils.runInNewThread("testThread", false) {
+ iter.flatMap(_.split(" ")).foreach(buf.append(_))
+ }
+
+ buf.iterator
+ }.count()
+ }
+ assert(bytesRead >= tmpFile.length())
+ }
+
+ test("input metrics with new Hadoop API in different thread") {
+ val bytesRead = runAndReturnBytesRead {
+ sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable],
+ classOf[Text]).mapPartitions { iter =>
+ val buf = new ArrayBuffer[String]()
+ ThreadUtils.runInNewThread("testThread", false) {
+ iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_))
+ }
+
+ buf.iterator
+ }.count()
+ }
+ assert(bytesRead >= tmpFile.length())
+ }
}
/**
diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
index fe8955840d72..474e30144f62 100644
--- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
+++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala
@@ -22,7 +22,7 @@ import java.nio._
import java.nio.charset.StandardCharsets
import java.util.concurrent.TimeUnit
-import scala.concurrent.{Await, Promise}
+import scala.concurrent.Promise
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}
@@ -36,6 +36,7 @@ import org.apache.spark.network.{BlockDataManager, BlockTransferService}
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, ShuffleBlockId}
+import org.apache.spark.util.ThreadUtils
class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers {
test("security default off") {
@@ -164,9 +165,9 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi
override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
promise.success(data.retain())
}
- })
+ }, null)
- Await.ready(promise.future, FiniteDuration(10, TimeUnit.SECONDS))
+ ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS))
promise.future.value.get
}
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ad56715656c8..8d06f5468f4f 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
import org.apache.spark._
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDDSuiteUtils._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
class RDDSuite extends SparkFunSuite with SharedSparkContext {
var tempDir: File = _
@@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(totalPartitionCount == 10)
}
+ test("SPARK-18406: race between end-of-task and completion iterator read lock release") {
+ val rdd = sc.parallelize(1 to 1000, 10)
+ rdd.cache()
+
+ rdd.mapPartitions { iter =>
+ ThreadUtils.runInNewThread("TestThread") {
+ // Iterate to the end of the input iterator, to cause the CompletionIterator completion to
+ // fire outside of the task's main thread.
+ while (iter.hasNext) {
+ iter.next()
+ }
+ iter
+ }
+ }.collect()
+ }
+
// NOTE
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
// running after them and if they access sc those tests will fail as sc is already closed, because
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
index f9a7f151823a..7f20206202cb 100644
--- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
@@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w
}
test("get a range of elements in an array not partitioned by a range partitioner") {
- val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
+ val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
val pairs = sc.parallelize(pairArr, 10)
val range = pairs.filterByRange(200, 800).collect()
assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index 759d52fca5ce..3ec37f674c77 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -17,11 +17,15 @@
package org.apache.spark.scheduler
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+
import scala.util.Random
+import org.mockito.Mockito._
import org.roaringbitmap.RoaringBitmap
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkEnv, SparkFunSuite}
+import org.apache.spark.internal.config
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.BlockManagerId
@@ -128,4 +132,26 @@ class MapStatusSuite extends SparkFunSuite {
assert(size1 === size2)
assert(!success)
}
+
+ test("Blocks which are bigger than SHUFFLE_ACCURATE_BLOCK_THRESHOLD should not be " +
+ "underestimated.") {
+ val conf = new SparkConf().set(config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "1000")
+ val env = mock(classOf[SparkEnv])
+ doReturn(conf).when(env).conf
+ SparkEnv.set(env)
+ // Value of element in sizes is equal to the corresponding index.
+ val sizes = (0L to 2000L).toArray
+ val status1 = MapStatus(BlockManagerId("exec-0", "host-0", 100), sizes)
+ val arrayStream = new ByteArrayOutputStream(102400)
+ val objectOutputStream = new ObjectOutputStream(arrayStream)
+ assert(status1.isInstanceOf[HighlyCompressedMapStatus])
+ objectOutputStream.writeObject(status1)
+ objectOutputStream.flush()
+ val array = arrayStream.toByteArray
+ val objectInput = new ObjectInputStream(new ByteArrayInputStream(array))
+ val status2 = objectInput.readObject().asInstanceOf[HighlyCompressedMapStatus]
+ (1001 to 2000).foreach {
+ case part => assert(status2.getSizeForBlock(part) >= sizes(part))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
index 8300607ea888..37b08980db87 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala
@@ -21,7 +21,7 @@ import java.util.concurrent.{TimeoutException, TimeUnit}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
-import scala.concurrent.{Await, Future}
+import scala.concurrent.Future
import scala.concurrent.duration.{Duration, SECONDS}
import scala.language.existentials
import scala.reflect.ClassTag
@@ -260,7 +260,7 @@ abstract class SchedulerIntegrationSuite[T <: MockBackend: ClassTag] extends Spa
*/
def awaitJobTermination(jobFuture: Future[_], duration: Duration): Unit = {
try {
- Await.ready(jobFuture, duration)
+ ThreadUtils.awaitReady(jobFuture, duration)
} catch {
case te: TimeoutException if backendException.get() != null =>
val msg = raw"""
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
index 8f576daa77d1..992d3396d203 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
@@ -100,7 +100,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
context.addTaskCompletionListener(_ => throw new Exception("blah"))
intercept[TaskCompletionListenerException] {
- context.markTaskCompleted()
+ context.markTaskCompleted(None)
}
verify(listener, times(1)).onTaskCompletion(any())
@@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
sc = new SparkContext("local", "test")
// Create a dummy task. We won't end up running this; we just want to collect
// accumulator updates from it.
- val taskMetrics = TaskMetrics.empty
+ val taskMetrics = TaskMetrics.registered
val task = new Task[Int](0, 0, 0) {
context = new TaskContextImpl(0, 0, 0L, 0,
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
@@ -231,10 +231,10 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
test("immediately call a completion listener if the context is completed") {
var invocations = 0
val context = TaskContext.empty()
- context.markTaskCompleted()
+ context.markTaskCompleted(None)
context.addTaskCompletionListener(_ => invocations += 1)
assert(invocations == 1)
- context.markTaskCompleted()
+ context.markTaskCompleted(None)
assert(invocations == 1)
}
@@ -254,6 +254,36 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(lastError == error)
assert(invocations == 1)
}
+
+ test("TaskCompletionListenerException.getMessage should include previousError") {
+ val listenerErrorMessage = "exception in listener"
+ val taskErrorMessage = "exception in task"
+ val e = new TaskCompletionListenerException(
+ Seq(listenerErrorMessage),
+ Some(new RuntimeException(taskErrorMessage)))
+ assert(e.getMessage.contains(listenerErrorMessage) && e.getMessage.contains(taskErrorMessage))
+ }
+
+ test("all TaskCompletionListeners should be called even if some fail or a task") {
+ val context = TaskContext.empty()
+ val listener = mock(classOf[TaskCompletionListener])
+ context.addTaskCompletionListener(_ => throw new Exception("exception in listener1"))
+ context.addTaskCompletionListener(listener)
+ context.addTaskCompletionListener(_ => throw new Exception("exception in listener3"))
+
+ val e = intercept[TaskCompletionListenerException] {
+ context.markTaskCompleted(Some(new Exception("exception in task")))
+ }
+
+ // Make sure listener 2 was called.
+ verify(listener, times(1)).onTaskCompletion(any())
+
+ // also need to check failure in TaskCompletionListener does not mask earlier exception
+ assert(e.getMessage.contains("exception in listener1"))
+ assert(e.getMessage.contains("exception in listener3"))
+ assert(e.getMessage.contains("exception in task"))
+ }
+
}
private object TaskContextSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 9ca6b8b0fe63..db14c9acfdce 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -1070,11 +1070,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
sched.dagScheduler = mockDAGScheduler
val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0)
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1))
- when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] {
- override def answer(invocationOnMock: InvocationOnMock): Unit = {
- assert(manager.isZombie === true)
- }
- })
+ when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer(
+ new Answer[Unit] {
+ override def answer(invocationOnMock: InvocationOnMock): Unit = {
+ assert(manager.isZombie)
+ }
+ })
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
assert(taskOption.isDefined)
// this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon
diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala
index 1bfb0c1547ec..82bd7c4ff660 100644
--- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala
+++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala
@@ -31,7 +31,7 @@ class AllStagesResourceSuite extends SparkFunSuite {
val tasks = new LinkedHashMap[Long, TaskUIData]
taskLaunchTimes.zipWithIndex.foreach { case (time, idx) =>
tasks(idx.toLong) = TaskUIData(
- new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false), None)
+ new TaskInfo(idx, idx, 1, time, "", "", TaskLocality.ANY, false))
}
val stageUiData = new StageUIData()
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
index 1b325801e27f..917db766f7f1 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala
@@ -152,7 +152,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
// one should acquire the write lock. The second thread should block until the winner of the
// write race releases its lock.
val winningFuture: Future[Boolean] =
- Await.ready(Future.firstCompletedOf(Seq(lock1Future, lock2Future)), 1.seconds)
+ ThreadUtils.awaitReady(Future.firstCompletedOf(Seq(lock1Future, lock2Future)), 1.seconds)
assert(winningFuture.value.get.get)
val winningTID = blockInfoManager.get("block").get.writerTask
assert(winningTID === 1 || winningTID === 2)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index a8b960489983..9d7a8696818f 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.storage
+import java.io.File
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
@@ -1265,7 +1266,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
port: Int,
execId: String,
blockIds: Array[String],
- listener: BlockFetchingListener): Unit = {
+ listener: BlockFetchingListener,
+ shuffleFiles: Array[File]): Unit = {
listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1)))
}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala
index dfecd04c1b96..4000218e71a8 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.storage
import scala.collection.mutable
+import scala.language.implicitConversions
import scala.util.Random
import org.scalatest.{BeforeAndAfter, Matchers}
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index bbfd6df3b699..7859b0bba2b4 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -19,8 +19,6 @@ package org.apache.spark.storage
import java.io.{File, FileWriter}
-import scala.language.reflectiveCalls
-
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.apache.spark.{SparkConf, SparkFunSuite}
diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
index c7074078d8fd..6883eb211efd 100644
--- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.File
+import java.io.{File, IOException}
import org.scalatest.BeforeAndAfter
@@ -33,22 +33,66 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter {
Utils.clearLocalRootDirs()
}
+ after {
+ Utils.clearLocalRootDirs()
+ }
+
+ private def assumeNonExistentAndNotCreatable(f: File): Unit = {
+ try {
+ assume(!f.exists() && !f.mkdirs())
+ } finally {
+ Utils.deleteRecursively(f)
+ }
+ }
+
test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
// Regression test for SPARK-2974
- assert(!new File("/NONEXISTENT_DIR").exists())
+ val f = new File("/NONEXISTENT_PATH")
+ assumeNonExistentAndNotCreatable(f)
+
val conf = new SparkConf(false)
.set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
assert(new File(Utils.getLocalDir(conf)).exists())
+
+ // This directory should not be created.
+ assert(!f.exists())
}
test("SPARK_LOCAL_DIRS override also affects driver") {
- // Regression test for SPARK-2975
- assert(!new File("/NONEXISTENT_DIR").exists())
+ // Regression test for SPARK-2974
+ val f = new File("/NONEXISTENT_PATH")
+ assumeNonExistentAndNotCreatable(f)
+
// spark.local.dir only contains invalid directories, but that's not a problem since
// SPARK_LOCAL_DIRS will override it on both the driver and workers:
val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir")))
.set("spark.local.dir", "/NONEXISTENT_PATH")
assert(new File(Utils.getLocalDir(conf)).exists())
+
+ // This directory should not be created.
+ assert(!f.exists())
}
+ test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") {
+ val path1 = "/NONEXISTENT_PATH_ONE"
+ val path2 = "/NONEXISTENT_PATH_TWO"
+ val f1 = new File(path1)
+ val f2 = new File(path2)
+ assumeNonExistentAndNotCreatable(f1)
+ assumeNonExistentAndNotCreatable(f2)
+
+ assert(!new File(path1).exists())
+ assert(!new File(path2).exists())
+ val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2")
+ val message = intercept[IOException] {
+ Utils.getLocalDir(conf)
+ }.getMessage
+ // If any temporary directory could not be retrieved under the given paths above, it should
+ // throw an exception with the message that includes the paths.
+ assert(message.contains(s"$path1,$path2"))
+
+ // These directories should not be created.
+ assert(!f1.exists())
+ assert(!f2.exists())
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
index 3050f9a25023..535105379963 100644
--- a/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallySerializedBlockSuite.scala
@@ -145,7 +145,7 @@ class PartiallySerializedBlockSuite
try {
TaskContext.setTaskContext(TaskContext.empty())
val partiallySerializedBlock = partiallyUnroll((1 to 10).iterator, 2)
- TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted()
+ TaskContext.get().asInstanceOf[TaskContextImpl].markTaskCompleted(None)
Mockito.verify(partiallySerializedBlock.getUnrolledChunkedByteBuffer).dispose()
Mockito.verifyNoMoreInteractions(memoryStore)
} finally {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index e56e440380a5..559b3faab8fd 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.storage
import java.io.{File, InputStream, IOException}
+import java.util.UUID
import java.util.concurrent.Semaphore
import scala.concurrent.ExecutionContext.Implicits.global
@@ -35,6 +36,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.network.util.LimitedInputStream
import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.util.Utils
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
@@ -44,7 +46,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
/** Creates a mock [[BlockTransferService]] that returns data from the given map. */
private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
@@ -106,6 +109,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
+ Int.MaxValue,
true)
// 3 local blocks fetched in initialization
@@ -134,7 +138,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// 3 local blocks, and 2 remote blocks
// (but from the same block manager so one call to fetchBlocks)
verify(blockManager, times(3)).getBlockData(any())
- verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any())
+ verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any())
}
test("release current unexhausted buffer in case the task completes early") {
@@ -153,7 +157,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val sem = new Semaphore(0)
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@@ -181,6 +186,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
+ Int.MaxValue,
true)
verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
@@ -192,7 +198,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// Complete the task; then the 2nd block buffer should be exhausted
verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
- taskContext.markTaskCompleted()
+ taskContext.markTaskCompleted(None)
verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()
// The 3rd block should not be retained because the iterator is already in zombie state
@@ -218,7 +224,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val sem = new Semaphore(0)
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@@ -246,6 +253,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => in,
48 * 1024 * 1024,
Int.MaxValue,
+ Int.MaxValue,
true)
// Continue only after the mock calls onBlockFetchFailure
@@ -281,7 +289,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@@ -309,6 +318,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => new LimitedInputStream(in, 100),
48 * 1024 * 1024,
Int.MaxValue,
+ Int.MaxValue,
true)
// Continue only after the mock calls onBlockFetchFailure
@@ -318,7 +328,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
val (id1, _) = iterator.next()
assert(id1 === ShuffleBlockId(0, 0, 0))
- when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@@ -359,7 +370,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
Future {
@@ -387,6 +399,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
(_, in) => new LimitedInputStream(in, 100),
48 * 1024 * 1024,
Int.MaxValue,
+ Int.MaxValue,
false)
// Continue only after the mock calls onBlockFetchFailure
@@ -401,4 +414,65 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
assert(id3 === ShuffleBlockId(0, 2, 0))
}
+ test("Blocks should be shuffled to disk when size of the request is above the" +
+ " threshold(maxReqSizeShuffleToMem).") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ val diskBlockManager = mock(classOf[DiskBlockManager])
+ val tmpDir = Utils.createTempDir()
+ doReturn{
+ val blockId = TempLocalBlockId(UUID.randomUUID())
+ (blockId, new File(tmpDir, blockId.name))
+ }.when(diskBlockManager).createTempLocalBlock()
+ doReturn(diskBlockManager).when(blockManager).diskBlockManager
+
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val remoteBlocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
+ val transfer = mock(classOf[BlockTransferService])
+ var shuffleFiles: Array[File] = null
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+ .thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]]
+ Future {
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0)))
+ }
+ }
+ })
+
+ def fetchShuffleBlock(blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])]): Unit = {
+ // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the
+ // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks
+ // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
+ new ShuffleBlockFetcherIterator(
+ TaskContext.empty(),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ (_, in) => in,
+ maxBytesInFlight = Int.MaxValue,
+ maxReqsInFlight = Int.MaxValue,
+ maxReqSizeShuffleToMem = 200,
+ detectCorrupt = true)
+ }
+
+ val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq))
+ fetchShuffleBlock(blocksByAddress1)
+ // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch
+ // shuffle block to disk.
+ assert(shuffleFiles === null)
+
+ val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq))
+ fetchShuffleBlock(blocksByAddress2)
+ // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch
+ // shuffle block to disk.
+ assert(shuffleFiles != null)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
index c770fd5da76f..423daacc0f5a 100644
--- a/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
@@ -133,6 +133,45 @@ class UIUtilsSuite extends SparkFunSuite {
assert(decoded2 === decodeURLParameter(decoded2))
}
+ test("SPARK-20393: Prevent newline characters in parameters.") {
+ val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b"
+ val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b"
+
+ assert(stripEncoding === stripXSS(encoding))
+ }
+
+ test("SPARK-20393: Prevent script from parameters running on page.") {
+ val scriptAlert = """>"'> |