Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {

def render(request: HttpServletRequest): Seq[Node] = {
// stripXSS is called first to remove suspicious characters used in XSS attacks
val requestedIncomplete =
Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean

val allAppsSize = parent.getApplicationList().count(_.completed != requestedIncomplete)
val eventLogsUnderProcessCount = parent.getEventLogsUnderProcess()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val appId = UIUtils.stripXSS(request.getParameter("appId"))
val state = master.askSync[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId)
.getOrElse(state.completedApps.find(_.id == appId).orNull)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val id = Option(request.getParameter("id"))
// stripXSS is called first to remove suspicious characters used in XSS attacks
val killFlag =
Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
Expand Down
30 changes: 18 additions & 12 deletions core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
private val supportedLogTypes = Set("stderr", "stdout")
private val defaultBytes = 100 * 1024

// stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength =
Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
.getOrElse(defaultBytes)

val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand All @@ -55,13 +58,16 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength =
Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
.getOrElse(defaultBytes)

val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/ui/UIUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.util.control.NonFatal
import scala.xml._
import scala.xml.transform.{RewriteRule, RuleTransformer}

import org.apache.commons.lang3.StringEscapeUtils

import org.apache.spark.internal.Logging
import org.apache.spark.ui.scope.RDDOperationGraph

Expand All @@ -34,6 +36,8 @@ private[spark] object UIUtils extends Logging {
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable"

private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r

// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat =
Expand Down Expand Up @@ -527,4 +531,21 @@ private[spark] object UIUtils extends Logging {
origHref
}
}

/**
* Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks
*
* For more information about XSS testing:
* https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
* https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
*/
def stripXSS(requestParameter: String): String = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you would, add a couple brief tests of this to UIUtilsSuite

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll work on some.

if (requestParameter == null) {
null
} else {
// Remove new lines and single quotes, followed by escaping HTML version 4.0
StringEscapeUtils.escapeHtml4(
NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, ""))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage

private val sc = parent.sc

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val executorId = Option(request.getParameter("executorId")).map { executorId =>
val executorId =
Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
Expand Down
14 changes: 8 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,20 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
jobTag: String,
jobs: Seq[JobUIData],
killEnabled: Boolean): Seq[Node] = {
val allParameters = request.getParameterMap.asScala.toMap
// stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(jobTag))
.map(para => para._1 + "=" + para._2(0))

val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
val jobIdTitle = if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"

val parameterJobPage = request.getParameter(jobTag + ".page")
val parameterJobSortColumn = request.getParameter(jobTag + ".sort")
val parameterJobSortDesc = request.getParameter(jobTag + ".desc")
val parameterJobPageSize = request.getParameter(jobTag + ".pageSize")
val parameterJobPrevPageSize = request.getParameter(jobTag + ".prevPageSize")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterJobPage = UIUtils.stripXSS(request.getParameter(jobTag + ".page"))
val parameterJobSortColumn = UIUtils.stripXSS(request.getParameter(jobTag + ".sort"))
val parameterJobSortDesc = UIUtils.stripXSS(request.getParameter(jobTag + ".desc"))
val parameterJobPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".pageSize"))
val parameterJobPrevPageSize = UIUtils.stripXSS(request.getParameter(jobTag + ".prevPageSize"))

val jobPage = Option(parameterJobPage).map(_.toInt).getOrElse(1)
val jobSortColumn = Option(parameterJobSortColumn).map { sortColumn =>
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener

listener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val jobId = parameterId.toInt
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all jobs in the given SparkContext. */
private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
Expand All @@ -40,7 +40,8 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val jobId = Option(request.getParameter("id")).map(_.toInt)
// stripXSS is called first to remove suspicious characters used in XSS attacks
val jobId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
jobId.foreach { id =>
if (jobProgresslistener.activeJobs.contains(id)) {
sc.foreach(_.cancelJob(id))
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {

def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val poolName = Option(request.getParameter("poolname")).map { poolname =>
// stripXSS is called first to remove suspicious characters used in XSS attacks
val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,18 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {

def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterAttempt = request.getParameter("attempt")
val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")

val parameterTaskPage = request.getParameter("task.page")
val parameterTaskSortColumn = request.getParameter("task.sort")
val parameterTaskSortDesc = request.getParameter("task.desc")
val parameterTaskPageSize = request.getParameter("task.pageSize")
val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize")
val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))
val parameterTaskPrevPageSize = UIUtils.stripXSS(request.getParameter("task.prevPageSize"))

val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ private[ui] class StageTableBase(
isFairScheduler: Boolean,
killEnabled: Boolean,
isFailedStage: Boolean) {
val allParameters = request.getParameterMap().asScala.toMap
// stripXSS is called to remove suspicious characters used in XSS attacks
val allParameters = request.getParameterMap.asScala.toMap.mapValues(_.map(UIUtils.stripXSS))
val parameterOtherTable = allParameters.filterNot(_._1.startsWith(stageTag))
.map(para => para._1 + "=" + para._2(0))

val parameterStagePage = request.getParameter(stageTag + ".page")
val parameterStageSortColumn = request.getParameter(stageTag + ".sort")
val parameterStageSortDesc = request.getParameter(stageTag + ".desc")
val parameterStagePageSize = request.getParameter(stageTag + ".pageSize")
val parameterStagePrevPageSize = request.getParameter(stageTag + ".prevPageSize")
val parameterStagePage = UIUtils.stripXSS(request.getParameter(stageTag + ".page"))
val parameterStageSortColumn = UIUtils.stripXSS(request.getParameter(stageTag + ".sort"))
val parameterStageSortDesc = UIUtils.stripXSS(request.getParameter(stageTag + ".desc"))
val parameterStagePageSize = UIUtils.stripXSS(request.getParameter(stageTag + ".pageSize"))
val parameterStagePrevPageSize =
UIUtils.stripXSS(request.getParameter(stageTag + ".prevPageSize"))

val stagePage = Option(parameterStagePage).map(_.toInt).getOrElse(1)
val stageSortColumn = Option(parameterStageSortColumn).map { sortColumn =>
Expand Down Expand Up @@ -512,4 +514,3 @@ private[ui] class StageDataSource(
}
}
}

5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
Expand All @@ -39,7 +39,8 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val stageId = Option(request.getParameter("id")).map(_.toInt)
// stripXSS is called first to remove suspicious characters used in XSS attacks
val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).map(_.toInt)
stageId.foreach { id =>
if (progressListener.activeStages.contains(id)) {
sc.foreach(_.cancelStage(id, "killed via the Web UI"))
Expand Down
13 changes: 7 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener

def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterBlockPage = request.getParameter("block.page")
val parameterBlockSortColumn = request.getParameter("block.sort")
val parameterBlockSortDesc = request.getParameter("block.desc")
val parameterBlockPageSize = request.getParameter("block.pageSize")
val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize")
val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))
val parameterBlockPrevPageSize = UIUtils.stripXSS(request.getParameter("block.prevPageSize"))

val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
Expand Down
39 changes: 39 additions & 0 deletions core/src/test/scala/org/apache/spark/ui/UIUtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,45 @@ class UIUtilsSuite extends SparkFunSuite {
assert(decoded2 === decodeURLParameter(decoded2))
}

test("SPARK-20393: Prevent newline characters in parameters.") {
val encoding = "Encoding:base64%0d%0a%0d%0aPGh0bWw%2bjcmlwdD48L2h0bWw%2b"
val stripEncoding = "Encoding:base64PGh0bWw%2bjcmlwdD48L2h0bWw%2b"

assert(stripEncoding === stripXSS(encoding))
}

test("SPARK-20393: Prevent script from parameters running on page.") {
val scriptAlert = """>"'><script>alert(401)<%2Fscript>"""
val stripScriptAlert = "&gt;&quot;&gt;&lt;script&gt;alert(401)&lt;%2Fscript&gt;"

assert(stripScriptAlert === stripXSS(scriptAlert))
}

test("SPARK-20393: Prevent javascript from parameters running on page.") {
val javascriptAlert =
"""app-20161208133404-0002<iframe+src%3Djavascript%3Aalert(1705)>"""
val stripJavascriptAlert =
"app-20161208133404-0002&lt;iframe+src%3Djavascript%3Aalert(1705)&gt;"

assert(stripJavascriptAlert === stripXSS(javascriptAlert))
}

test("SPARK-20393: Prevent links from parameters on page.") {
val link =
"""stdout'"><iframe+id%3D1131+src%3Dhttp%3A%2F%2Fdemo.test.net%2Fphishing.html>"""
val stripLink =
"stdout&quot;&gt;&lt;iframe+id%3D1131+src%3Dhttp%3A%2F%2Fdemo.test.net%2Fphishing.html&gt;"

assert(stripLink === stripXSS(link))
}

test("SPARK-20393: Prevent popups from parameters on page.") {
val popup = """stdout'%2Balert(60)%2B'"""
val stripPopup = "stdout%2Balert(60)%2B"

assert(stripPopup === stripXSS(popup))
}

private def verify(
desc: String,
expected: Node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") {

override def render(request: HttpServletRequest): Seq[Node] = {
val driverId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val driverId = UIUtils.stripXSS(request.getParameter("id"))
require(driverId != null && driverId.nonEmpty, "Missing id parameter")

val state = parent.scheduler.getDriverState(driverId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging
private val listener = parent.listener

override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized {
val parameterExecutionId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterExecutionId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterExecutionId != null && parameterExecutionId.nonEmpty,
"Missing execution id parameter")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)

/** Render the page */
def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val content =
Expand Down Expand Up @@ -197,4 +198,3 @@ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab)
UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true)
}
}

Loading