Skip to content

Commit a5cdb63

Browse files
committed
Merge branch 'master' into fix_metric
2 parents 7c879e0 + ed1980f commit a5cdb63

File tree

60 files changed

+772
-248
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+772
-248
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD
4949
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
5050
import org.apache.spark.scheduler.local.LocalBackend
5151
import org.apache.spark.storage._
52-
import org.apache.spark.SPARK_VERSION
5352
import org.apache.spark.ui.SparkUI
5453
import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
5554

core/src/main/scala/org/apache/spark/deploy/master/Master.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,13 +487,25 @@ private[spark] class Master(
487487
if (state != RecoveryState.ALIVE) { return }
488488

489489
// First schedule drivers, they take strict precedence over applications
490-
val shuffledWorkers = Random.shuffle(workers) // Randomization helps balance drivers
491-
for (worker <- shuffledWorkers if worker.state == WorkerState.ALIVE) {
492-
for (driver <- List(waitingDrivers: _*)) { // iterate over a copy of waitingDrivers
490+
// Randomization helps balance drivers
491+
val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
492+
val aliveWorkerNum = shuffledAliveWorkers.size
493+
var curPos = 0
494+
for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers
495+
// We assign workers to each waiting driver in a round-robin fashion. For each driver, we
496+
// start from the last worker that was assigned a driver, and continue onwards until we have
497+
// explored all alive workers.
498+
curPos = (curPos + 1) % aliveWorkerNum
499+
val startPos = curPos
500+
var launched = false
501+
while (curPos != startPos && !launched) {
502+
val worker = shuffledAliveWorkers(curPos)
493503
if (worker.memoryFree >= driver.desc.mem && worker.coresFree >= driver.desc.cores) {
494504
launchDriver(worker, driver)
495505
waitingDrivers -= driver
506+
launched = true
496507
}
508+
curPos = (curPos + 1) % aliveWorkerNum
497509
}
498510
}
499511

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ package org.apache.spark.util
2020
import java.io._
2121
import java.net._
2222
import java.nio.ByteBuffer
23-
import java.util.{Locale, Random, UUID}
23+
import java.util.{Properties, Locale, Random, UUID}
2424
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
2525

26+
import org.apache.log4j.PropertyConfigurator
27+
2628
import scala.collection.JavaConversions._
2729
import scala.collection.Map
2830
import scala.collection.mutable.ArrayBuffer
@@ -834,6 +836,7 @@ private[spark] object Utils extends Logging {
834836
val exitCode = process.waitFor()
835837
stdoutThread.join() // Wait for it to finish reading output
836838
if (exitCode != 0) {
839+
logError(s"Process $command exited with code $exitCode: ${output}")
837840
throw new SparkException("Process " + command + " exited with code " + exitCode)
838841
}
839842
output.toString
@@ -1444,6 +1447,20 @@ private[spark] object Utils extends Logging {
14441447
}
14451448
}
14461449

1450+
/**
1451+
* config a log4j properties used for testsuite
1452+
*/
1453+
def configTestLog4j(level: String): Unit = {
1454+
val pro = new Properties()
1455+
pro.put("log4j.rootLogger", s"$level, console")
1456+
pro.put("log4j.appender.console", "org.apache.log4j.ConsoleAppender")
1457+
pro.put("log4j.appender.console.target", "System.err")
1458+
pro.put("log4j.appender.console.layout", "org.apache.log4j.PatternLayout")
1459+
pro.put("log4j.appender.console.layout.ConversionPattern",
1460+
"%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n")
1461+
PropertyConfigurator.configure(pro)
1462+
}
1463+
14471464
}
14481465

14491466
/**

core/src/test/scala/org/apache/spark/DriverSuite.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ package org.apache.spark
1919

2020
import java.io.File
2121

22-
import org.apache.log4j.Logger
23-
import org.apache.log4j.Level
24-
2522
import org.scalatest.FunSuite
2623
import org.scalatest.concurrent.Timeouts
2724
import org.scalatest.prop.TableDrivenPropertyChecks._
@@ -54,7 +51,7 @@ class DriverSuite extends FunSuite with Timeouts {
5451
*/
5552
object DriverWithoutCleanup {
5653
def main(args: Array[String]) {
57-
Logger.getRootLogger().setLevel(Level.WARN)
54+
Utils.configTestLog4j("INFO")
5855
val sc = new SparkContext(args(0), "DriverWithoutCleanup")
5956
sc.parallelize(1 to 100, 4).count()
6057
}

core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
317317

318318
object JarCreationTest {
319319
def main(args: Array[String]) {
320+
Utils.configTestLog4j("INFO")
320321
val conf = new SparkConf()
321322
val sc = new SparkContext(conf)
322323
val result = sc.makeRDD(1 to 100, 10).mapPartitions { x =>
@@ -338,6 +339,7 @@ object JarCreationTest {
338339

339340
object SimpleApplicationTest {
340341
def main(args: Array[String]) {
342+
Utils.configTestLog4j("INFO")
341343
val conf = new SparkConf()
342344
val sc = new SparkContext(conf)
343345
val configs = Seq("spark.master", "spark.app.name")

docs/running-on-yarn.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
125125
the environment of the executor launcher.
126126
</td>
127127
</tr>
128+
<tr>
129+
<td><code>spark.yarn.containerLauncherMaxThreads</code></td>
130+
<td>25</td>
131+
<td>
132+
The maximum number of threads to use in the application master for launching executor containers.
133+
</td>
134+
</tr>
128135
</table>
129136

130137
# Launching Spark on YARN

examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.examples.graphx
2020
import org.apache.spark.SparkContext._
2121
import org.apache.spark._
2222
import org.apache.spark.graphx._
23-
import org.apache.spark.examples.graphx.Analytics
23+
2424

2525
/**
2626
* Uses GraphX to run PageRank on a LiveJournal social network graph. Download the dataset from

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
130130

131131
// Find best split for all nodes at a level.
132132
timer.start("findBestSplits")
133-
val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
133+
val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
134134
DecisionTree.findBestSplits(treeInput, parentImpurities,
135135
metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
136136
timer.stop("findBestSplits")
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
143143
timer.start("extractNodeInfo")
144144
val split = nodeSplitStats._1
145145
val stats = nodeSplitStats._2
146+
val predict = nodeSplitStats._3.predict
146147
val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
147-
val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
148+
val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
148149
logDebug("Node = " + node)
149150
nodes(nodeIndex) = node
150151
timer.stop("extractNodeInfo")
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
425426
splits: Array[Array[Split]],
426427
bins: Array[Array[Bin]],
427428
maxLevelForSingleGroup: Int,
428-
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
429+
timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
429430
// split into groups to avoid memory overflow during aggregation
430431
if (level > maxLevelForSingleGroup) {
431432
// When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
434435
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
435436
val numGroups = 1 << level - maxLevelForSingleGroup
436437
logDebug("numGroups = " + numGroups)
437-
var bestSplits = new Array[(Split, InformationGainStats)](0)
438+
var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
438439
// Iterate over each group of nodes at a level.
439440
var groupIndex = 0
440441
while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
605606
bins: Array[Array[Bin]],
606607
timer: TimeTracker,
607608
numGroups: Int = 1,
608-
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
609+
groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
609610

610611
/*
611612
* The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
705706

706707
// Calculate best splits for all nodes at a given level
707708
timer.start("chooseSplits")
708-
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
709+
val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
709710
// Iterating over all nodes at this level
710711
var nodeIndex = 0
711712
while (nodeIndex < numNodes) {
@@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
734735
topImpurity: Double,
735736
level: Int,
736737
metadata: DecisionTreeMetadata): InformationGainStats = {
737-
738738
val leftCount = leftImpurityCalculator.count
739739
val rightCount = rightImpurityCalculator.count
740740

741-
val totalCount = leftCount + rightCount
742-
if (totalCount == 0) {
743-
// Return arbitrary prediction.
744-
return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
741+
// If left child or right child doesn't satisfy minimum instances per node,
742+
// then this split is invalid, return invalid information gain stats.
743+
if ((leftCount < metadata.minInstancesPerNode) ||
744+
(rightCount < metadata.minInstancesPerNode)) {
745+
return InformationGainStats.invalidInformationGainStats
745746
}
746747

747-
val parentNodeAgg = leftImpurityCalculator.copy
748-
parentNodeAgg.add(rightImpurityCalculator)
748+
val totalCount = leftCount + rightCount
749+
749750
// impurity of parent node
750751
val impurity = if (level > 0) {
751752
topImpurity
752753
} else {
754+
val parentNodeAgg = leftImpurityCalculator.copy
755+
parentNodeAgg.add(rightImpurityCalculator)
753756
parentNodeAgg.calculate()
754757
}
755758

756-
val predict = parentNodeAgg.predict
757-
val prob = parentNodeAgg.prob(predict)
758-
759759
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
760760
val rightImpurity = rightImpurityCalculator.calculate()
761761

@@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {
764764

765765
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
766766

767-
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
767+
// if information gain doesn't satisfy minimum information gain,
768+
// then this split is invalid, return invalid information gain stats.
769+
if (gain < metadata.minInfoGain) {
770+
return InformationGainStats.invalidInformationGainStats
771+
}
772+
773+
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
774+
}
775+
776+
/**
777+
* Calculate predict value for current node, given stats of any split.
778+
* Note that this function is called only once for each node.
779+
* @param leftImpurityCalculator left node aggregates for a split
780+
* @param rightImpurityCalculator right node aggregates for a node
781+
* @return predict value for current node
782+
*/
783+
private def calculatePredict(
784+
leftImpurityCalculator: ImpurityCalculator,
785+
rightImpurityCalculator: ImpurityCalculator): Predict = {
786+
val parentNodeAgg = leftImpurityCalculator.copy
787+
parentNodeAgg.add(rightImpurityCalculator)
788+
val predict = parentNodeAgg.predict
789+
val prob = parentNodeAgg.prob(predict)
790+
791+
new Predict(predict, prob)
768792
}
769793

770794
/**
@@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
780804
nodeImpurity: Double,
781805
level: Int,
782806
metadata: DecisionTreeMetadata,
783-
splits: Array[Array[Split]]): (Split, InformationGainStats) = {
807+
splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
784808

785809
logDebug("node impurity = " + nodeImpurity)
786810

811+
// calculate predict only once
812+
var predict: Option[Predict] = None
813+
787814
// For each (feature, split), calculate the gain, and select the best (feature, split).
788-
Range(0, metadata.numFeatures).map { featureIndex =>
815+
val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
789816
val numSplits = metadata.numSplits(featureIndex)
790817
if (metadata.isContinuous(featureIndex)) {
791818
// Cumulative sum (scanLeft) of bin statistics.
@@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
803830
val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
804831
val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
805832
rightChildStats.subtract(leftChildStats)
833+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
806834
val gainStats =
807835
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
808836
(splitIdx, gainStats)
@@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
816844
Range(0, numSplits).map { splitIndex =>
817845
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
818846
val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
847+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
819848
val gainStats =
820849
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
821850
(splitIndex, gainStats)
@@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
887916
val rightChildStats =
888917
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
889918
rightChildStats.subtract(leftChildStats)
919+
predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
890920
val gainStats =
891921
calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
892922
(splitIndex, gainStats)
@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
898928
(bestFeatureSplit, bestFeatureGainStats)
899929
}
900930
}.maxBy(_._2.gain)
931+
932+
require(predict.isDefined, "must calculate predict for each node")
933+
934+
(bestSplit, bestSplitStats, predict.get)
901935
}
902936

903937
/**

mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
4949
* k) implies the feature n is categorical with k categories 0,
5050
* 1, 2, ... , k-1. It's important to note that features are
5151
* zero-indexed.
52+
* @param minInstancesPerNode Minimum number of instances each child must have after split.
53+
* Default value is 1. If a split cause left or right child
54+
* to have less than minInstancesPerNode,
55+
* this split will not be considered as a valid split.
56+
* @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
57+
* If a split has less information gain than minInfoGain,
58+
* this split will not be considered as a valid split.
5259
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
5360
* 256 MB.
5461
*/
@@ -61,6 +68,8 @@ class Strategy (
6168
val maxBins: Int = 32,
6269
val quantileCalculationStrategy: QuantileStrategy = Sort,
6370
val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
71+
val minInstancesPerNode: Int = 1,
72+
val minInfoGain: Double = 0.0,
6473
val maxMemoryInMB: Int = 256) extends Serializable {
6574

6675
if (algo == Classification) {

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
4545
val unorderedFeatures: Set[Int],
4646
val numBins: Array[Int],
4747
val impurity: Impurity,
48-
val quantileStrategy: QuantileStrategy) extends Serializable {
48+
val quantileStrategy: QuantileStrategy,
49+
val minInstancesPerNode: Int,
50+
val minInfoGain: Double) extends Serializable {
4951

5052
def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
5153

@@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {
127129

128130
new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
129131
strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
130-
strategy.impurity, strategy.quantileCalculationStrategy)
132+
strategy.impurity, strategy.quantileCalculationStrategy,
133+
strategy.minInstancesPerNode, strategy.minInfoGain)
131134
}
132135

133136
/**

0 commit comments

Comments
 (0)