Skip to content

Commit eff3af2

Browse files
committed
Merge remote-tracking branch 'upstream/master' into SPARK-24781
2 parents 38a935d + 5ad4735 commit eff3af2

File tree

109 files changed

+2910
-528
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

109 files changed

+2910
-528
lines changed

R/pkg/vignettes/sparkr-vignettes.Rmd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,7 @@ summary(model)
590590
Predict values on training data
591591
```{r}
592592
prediction <- predict(model, training)
593+
head(select(prediction, "Class", "Sex", "Age", "Freq", "Survived", "prediction"))
593594
```
594595

595596
#### Logistic Regression
@@ -613,6 +614,7 @@ summary(model)
613614
Predict values on training data
614615
```{r}
615616
fitted <- predict(model, training)
617+
head(select(fitted, "Class", "Sex", "Age", "Freq", "Survived", "prediction"))
616618
```
617619

618620
Multinomial logistic regression against three classes
@@ -807,6 +809,7 @@ df <- createDataFrame(t)
807809
dtModel <- spark.decisionTree(df, Survived ~ ., type = "classification", maxDepth = 2)
808810
summary(dtModel)
809811
predictions <- predict(dtModel, df)
812+
head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction"))
810813
```
811814

812815
#### Gradient-Boosted Trees
@@ -822,6 +825,7 @@ df <- createDataFrame(t)
822825
gbtModel <- spark.gbt(df, Survived ~ ., type = "classification", maxDepth = 2, maxIter = 2)
823826
summary(gbtModel)
824827
predictions <- predict(gbtModel, df)
828+
head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction"))
825829
```
826830

827831
#### Random Forest
@@ -837,6 +841,7 @@ df <- createDataFrame(t)
837841
rfModel <- spark.randomForest(df, Survived ~ ., type = "classification", maxDepth = 2, numTrees = 2)
838842
summary(rfModel)
839843
predictions <- predict(rfModel, df)
844+
head(select(predictions, "Class", "Sex", "Age", "Freq", "Survived", "prediction"))
840845
```
841846

842847
#### Bisecting k-Means

core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,7 @@ private[spark] class SparkSubmit extends Logging {
385385
val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES)
386386

387387
def shouldDownload(scheme: String): Boolean = {
388-
forceDownloadSchemes.contains(scheme) ||
388+
forceDownloadSchemes.contains("*") || forceDownloadSchemes.contains(scheme) ||
389389
Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure
390390
}
391391

@@ -578,7 +578,8 @@ private[spark] class SparkSubmit extends Logging {
578578
}
579579
// Add the main application jar and any added jars to classpath in case YARN client
580580
// requires these jars.
581-
// This assumes both primaryResource and user jars are local jars, otherwise it will not be
581+
// This assumes both primaryResource and user jars are local jars, or already downloaded
582+
// to local by configuring "spark.yarn.dist.forceDownloadSchemes", otherwise it will not be
582583
// added to the classpath of YARN client.
583584
if (isYarnCluster) {
584585
if (isUserJar(args.primaryResource)) {

core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -233,30 +233,44 @@ private[spark] class RestSubmissionClient(master: String) extends Logging {
233233
private[rest] def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
234234
import scala.concurrent.ExecutionContext.Implicits.global
235235
val responseFuture = Future {
236-
val dataStream =
237-
if (connection.getResponseCode == HttpServletResponse.SC_OK) {
238-
connection.getInputStream
239-
} else {
240-
connection.getErrorStream
236+
val responseCode = connection.getResponseCode
237+
238+
if (responseCode != HttpServletResponse.SC_OK) {
239+
val errString = Some(Source.fromInputStream(connection.getErrorStream())
240+
.getLines().mkString("\n"))
241+
if (responseCode == HttpServletResponse.SC_INTERNAL_SERVER_ERROR &&
242+
!connection.getContentType().contains("application/json")) {
243+
throw new SubmitRestProtocolException(s"Server responded with exception:\n${errString}")
244+
}
245+
logError(s"Server responded with error:\n${errString}")
246+
val error = new ErrorResponse
247+
if (responseCode == RestSubmissionServer.SC_UNKNOWN_PROTOCOL_VERSION) {
248+
error.highestProtocolVersion = RestSubmissionServer.PROTOCOL_VERSION
249+
}
250+
error.message = errString.get
251+
error
252+
} else {
253+
val dataStream = connection.getInputStream
254+
255+
// If the server threw an exception while writing a response, it will not have a body
256+
if (dataStream == null) {
257+
throw new SubmitRestProtocolException("Server returned empty body")
258+
}
259+
val responseJson = Source.fromInputStream(dataStream).mkString
260+
logDebug(s"Response from the server:\n$responseJson")
261+
val response = SubmitRestProtocolMessage.fromJson(responseJson)
262+
response.validate()
263+
response match {
264+
// If the response is an error, log the message
265+
case error: ErrorResponse =>
266+
logError(s"Server responded with error:\n${error.message}")
267+
error
268+
// Otherwise, simply return the response
269+
case response: SubmitRestProtocolResponse => response
270+
case unexpected =>
271+
throw new SubmitRestProtocolException(
272+
s"Message received from server was not a response:\n${unexpected.toJson}")
241273
}
242-
// If the server threw an exception while writing a response, it will not have a body
243-
if (dataStream == null) {
244-
throw new SubmitRestProtocolException("Server returned empty body")
245-
}
246-
val responseJson = Source.fromInputStream(dataStream).mkString
247-
logDebug(s"Response from the server:\n$responseJson")
248-
val response = SubmitRestProtocolMessage.fromJson(responseJson)
249-
response.validate()
250-
response match {
251-
// If the response is an error, log the message
252-
case error: ErrorResponse =>
253-
logError(s"Server responded with error:\n${error.message}")
254-
error
255-
// Otherwise, simply return the response
256-
case response: SubmitRestProtocolResponse => response
257-
case unexpected =>
258-
throw new SubmitRestProtocolException(
259-
s"Message received from server was not a response:\n${unexpected.toJson}")
260274
}
261275
}
262276

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,10 +486,11 @@ package object config {
486486

487487
private[spark] val FORCE_DOWNLOAD_SCHEMES =
488488
ConfigBuilder("spark.yarn.dist.forceDownloadSchemes")
489-
.doc("Comma-separated list of schemes for which files will be downloaded to the " +
489+
.doc("Comma-separated list of schemes for which resources will be downloaded to the " +
490490
"local disk prior to being added to YARN's distributed cache. For use in cases " +
491491
"where the YARN service does not support schemes that are supported by Spark, like http, " +
492-
"https and ftp.")
492+
"https and ftp, or jars required to be in the local YARN client's classpath. Wildcard " +
493+
"'*' is denoted to download resources for all the schemes.")
493494
.stringConf
494495
.toSequence
495496
.createWithDefault(Nil)

core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ private[spark]
3030
class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[BlockId])
3131
extends RDD[T](sc, Nil) {
3232

33-
@transient lazy val _locations = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
33+
@transient lazy val _locations = BlockManager.blockIdsToLocations(blockIds, SparkEnv.get)
3434
@volatile private var _isValid = true
3535

3636
override def getPartitions: Array[Partition] = {

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ import org.apache.spark.network.netty.SparkTransportConf
4545
import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager}
4646
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
4747
import org.apache.spark.rpc.RpcEnv
48+
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
4849
import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
4950
import org.apache.spark.shuffle.ShuffleManager
5051
import org.apache.spark.storage.memory._
@@ -1554,7 +1555,7 @@ private[spark] class BlockManager(
15541555
private[spark] object BlockManager {
15551556
private val ID_GENERATOR = new IdGenerator
15561557

1557-
def blockIdsToHosts(
1558+
def blockIdsToLocations(
15581559
blockIds: Array[BlockId],
15591560
env: SparkEnv,
15601561
blockManagerMaster: BlockManagerMaster = null): Map[BlockId, Seq[String]] = {
@@ -1569,7 +1570,9 @@ private[spark] object BlockManager {
15691570

15701571
val blockManagers = new HashMap[BlockId, Seq[String]]
15711572
for (i <- 0 until blockIds.length) {
1572-
blockManagers(blockIds(i)) = blockLocations(i).map(_.host)
1573+
blockManagers(blockIds(i)) = blockLocations(i).map { loc =>
1574+
ExecutorCacheTaskLocation(loc.host, loc.executorId).toString
1575+
}
15731576
}
15741577
blockManagers.toMap
15751578
}

core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,8 +368,8 @@ private[spark] class ExternalSorter[K, V, C](
368368
val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
369369
type Iter = BufferedIterator[Product2[K, C]]
370370
val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
371-
// Use the reverse of comparator.compare because PriorityQueue dequeues the max
372-
override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
371+
// Use the reverse order because PriorityQueue dequeues the max
372+
override def compare(x: Iter, y: Iter): Int = comparator.compare(y.head._1, x.head._1)
373373
})
374374
heap.enqueue(bufferedIters: _*) // Will contain only the iterators with hasNext = true
375375
new Iterator[Product2[K, C]] {

core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -995,20 +995,24 @@ class SparkSubmitSuite
995995
}
996996

997997
test("download remote resource if it is not supported by yarn service") {
998-
testRemoteResources(enableHttpFs = false, blacklistHttpFs = false)
998+
testRemoteResources(enableHttpFs = false)
999999
}
10001000

10011001
test("avoid downloading remote resource if it is supported by yarn service") {
1002-
testRemoteResources(enableHttpFs = true, blacklistHttpFs = false)
1002+
testRemoteResources(enableHttpFs = true)
10031003
}
10041004

10051005
test("force download from blacklisted schemes") {
1006-
testRemoteResources(enableHttpFs = true, blacklistHttpFs = true)
1006+
testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("http"))
1007+
}
1008+
1009+
test("force download for all the schemes") {
1010+
testRemoteResources(enableHttpFs = true, blacklistSchemes = Seq("*"))
10071011
}
10081012

10091013
private def testRemoteResources(
10101014
enableHttpFs: Boolean,
1011-
blacklistHttpFs: Boolean): Unit = {
1015+
blacklistSchemes: Seq[String] = Nil): Unit = {
10121016
val hadoopConf = new Configuration()
10131017
updateConfWithFakeS3Fs(hadoopConf)
10141018
if (enableHttpFs) {
@@ -1025,8 +1029,8 @@ class SparkSubmitSuite
10251029
val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir)
10261030
val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}"
10271031

1028-
val forceDownloadArgs = if (blacklistHttpFs) {
1029-
Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http")
1032+
val forceDownloadArgs = if (blacklistSchemes.nonEmpty) {
1033+
Seq("--conf", s"spark.yarn.dist.forceDownloadSchemes=${blacklistSchemes.mkString(",")}")
10301034
} else {
10311035
Nil
10321036
}
@@ -1044,14 +1048,19 @@ class SparkSubmitSuite
10441048

10451049
val jars = conf.get("spark.yarn.dist.jars").split(",").toSet
10461050

1047-
// The URI of remote S3 resource should still be remote.
1048-
assert(jars.contains(tmpS3JarPath))
1051+
def isSchemeBlacklisted(scheme: String) = {
1052+
blacklistSchemes.contains("*") || blacklistSchemes.contains(scheme)
1053+
}
1054+
1055+
if (!isSchemeBlacklisted("s3")) {
1056+
assert(jars.contains(tmpS3JarPath))
1057+
}
10491058

1050-
if (enableHttpFs && !blacklistHttpFs) {
1059+
if (enableHttpFs && blacklistSchemes.isEmpty) {
10511060
// If Http FS is supported by yarn service, the URI of remote http resource should
10521061
// still be remote.
10531062
assert(jars.contains(tmpHttpJarPath))
1054-
} else {
1063+
} else if (!enableHttpFs || isSchemeBlacklisted("http")) {
10551064
// If Http FS is not supported by yarn service, or http scheme is configured to be force
10561065
// downloading, the URI of remote http resource should be changed to a local one.
10571066
val jarName = new File(tmpHttpJar.toURI).getName

core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1422,6 +1422,19 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
14221422
assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager)
14231423
}
14241424

1425+
test("query locations of blockIds") {
1426+
val mockBlockManagerMaster = mock(classOf[BlockManagerMaster])
1427+
val blockLocations = Seq(BlockManagerId("1", "host1", 100), BlockManagerId("2", "host2", 200))
1428+
when(mockBlockManagerMaster.getLocations(mc.any[Array[BlockId]]))
1429+
.thenReturn(Array(blockLocations))
1430+
val env = mock(classOf[SparkEnv])
1431+
1432+
val blockIds: Array[BlockId] = Array(StreamBlockId(1, 2))
1433+
val locs = BlockManager.blockIdsToLocations(blockIds, env, mockBlockManagerMaster)
1434+
val expectedLocs = Seq("executor_host1_1", "executor_host2_2")
1435+
assert(locs(blockIds(0)) == expectedLocs)
1436+
}
1437+
14251438
class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
14261439
var numCalls = 0
14271440
var tempFileManager: TempFileManager = null

dev/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ jira==1.0.3
22
PyGithub==1.26.0
33
Unidecode==0.04.19
44
pypandoc==1.3.3
5+
sphinx

0 commit comments

Comments
 (0)