Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2650,15 +2650,29 @@ private[spark] object Utils extends Logging {
redact(redactionPattern, kvs)
}

/**
* Redact the sensitive values in the given map. If a map key matches the redaction pattern then
* its value is replaced with a dummy text.
*/
def redact(regex: Option[Regex], kvs: Seq[(String, String)]): Seq[(String, String)] = {
regex match {
case None => kvs
case Some(r) => redact(r, kvs)
}
}

/**
* Redact the sensitive information in the given string.
*/
def redact(conf: SparkConf, text: String): String = {
if (text == null || text.isEmpty || conf == null || !conf.contains(STRING_REDACTION_PATTERN)) {
text
} else {
val regex = conf.get(STRING_REDACTION_PATTERN).get
regex.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT)
def redact(regex: Option[Regex], text: String): String = {
regex match {
case None => text
case Some(r) =>
if (text == null || text.isEmpty) {
text
} else {
r.replaceAllIn(text, REDACTION_REPLACEMENT_TEXT)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicReference

import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.util.matching.Regex

import org.apache.hadoop.fs.Path

Expand Down Expand Up @@ -1035,6 +1036,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val SQL_STRING_REDACTION_PATTERN =
ConfigBuilder("spark.sql.redaction.string.regex")
.doc("Regex to decide which parts of strings produced by Spark contain sensitive " +
"information. When this regex matches a string part, that string part is replaced by a " +
"dummy value. This is currently used to redact the output of SQL explain commands. " +
"When this conf is not set, the value from `spark.redaction.string.regex` is used.")
.fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -1173,6 +1182,8 @@ class SQLConf extends Serializable with Logging {

def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS)

def stringRedationPattern: Option[Regex] = SQL_STRING_REDACTION_PATTERN.readFrom(reader)

/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
* identifiers are equal.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
* Shorthand for calling redactString() without specifying redacting rules
*/
private def redact(text: String): String = {
Utils.redact(SparkSession.getActiveSession.map(_.sparkContext.conf).orNull, text)
Utils.redact(sqlContext.sessionState.conf.stringRedationPattern, text)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,13 +194,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
}
}

def simpleString: String = {
def simpleString: String = withRedaction {
s"""== Physical Plan ==
|${stringOrError(executedPlan.treeString(verbose = false))}
""".stripMargin.trim
}

override def toString: String = {
override def toString: String = withRedaction {
def output = Utils.truncatedString(
analyzed.output.map(o => s"${o.name}: ${o.dataType.simpleString}"), ", ")
val analyzedPlan = Seq(
Expand All @@ -219,7 +219,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
""".stripMargin.trim
}

def stringWithStats: String = {
def stringWithStats: String = withRedaction {
// trigger to compute stats for logical plans
optimizedPlan.stats

Expand All @@ -231,6 +231,13 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
""".stripMargin.trim
}

/**
* Redact the sensitive information in the given string.
*/
private def withRedaction(message: String): String = {
Utils.redact(sparkSession.sessionState.conf.stringRedationPattern, message)
}

/** A special namespace for commands that can be used to debug query execution. */
// scalastyle:off
object debug {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

/**
Expand Down Expand Up @@ -52,4 +53,34 @@ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext {
assert(df.queryExecution.simpleString.contains(replacement))
}
}

private def isIncluded(queryExecution: QueryExecution, msg: String): Boolean = {
queryExecution.toString.contains(msg) ||
queryExecution.simpleString.contains(msg) ||
queryExecution.stringWithStats.contains(msg)
}

test("explain is redacted using SQLConf") {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
val df = spark.read.parquet(basePath)
val replacement = "*********"

// Respect SparkConf and replace file:/
assert(isIncluded(df.queryExecution, replacement))

assert(isIncluded(df.queryExecution, "FileScan"))
assert(!isIncluded(df.queryExecution, "file:/"))

withSQLConf(SQLConf.SQL_STRING_REDACTION_PATTERN.key -> "(?i)FileScan") {
// Respect SQLConf and replace FileScan
assert(isIncluded(df.queryExecution, replacement))

assert(!isIncluded(df.queryExecution, "FileScan"))
assert(isIncluded(df.queryExecution, "file:/"))
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import com.google.common.util.concurrent.UncheckedExecutionException
import org.apache.commons.io.FileUtils
import org.apache.hadoop.conf.Configuration

import org.apache.spark.SparkContext
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.Range
Expand Down Expand Up @@ -418,6 +418,37 @@ class StreamSuite extends StreamTest {
assert(OutputMode.Update === InternalOutputModes.Update)
}

override protected def sparkConf: SparkConf = super.sparkConf
.set("spark.redaction.string.regex", "file:/[\\w_]+")

test("explain - redaction") {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a SS example. I believe we have more such cases.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because you want to use explainInternal, so you write a SS test?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just because I found a potential security issue in SS.

val replacement = "*********"

val inputData = MemoryStream[String]
val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*"))
// Test StreamingQuery.display
val q = df.writeStream.queryName("memory_explain").outputMode("complete").format("memory")
.start()
.asInstanceOf[StreamingQueryWrapper]
.streamingQuery
try {
inputData.addData("abc")
q.processAllAvailable()

val explainWithoutExtended = q.explainInternal(false)
assert(explainWithoutExtended.contains(replacement))
assert(explainWithoutExtended.contains("StateStoreRestore"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add assert(!explainWithoutExtended.contains("file:/")) to verify it does replace the correct content.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

assert(!explainWithoutExtended.contains("file:/"))

val explainWithExtended = q.explainInternal(true)
assert(explainWithExtended.contains(replacement))
assert(explainWithExtended.contains("StateStoreRestore"))
assert(!explainWithoutExtended.contains("file:/"))
} finally {
q.stop()
}
}

test("explain") {
val inputData = MemoryStream[String]
val df = inputData.toDS().map(_ + "foo").groupBy("value").agg(count("*"))
Expand Down