@@ -31,6 +31,7 @@ import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
3131import org .apache .spark .mllib .tree .model ._
3232import org .apache .spark .rdd .RDD
3333import org .apache .spark .util .random .XORShiftRandom
34+ import org .apache .spark .util .Utils .memoryStringToMb
3435import org .apache .spark .mllib .linalg .{Vector , Vectors }
3536
3637/**
@@ -79,7 +80,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
7980 // Calculate level for single group construction
8081
8182 // Max memory usage for aggregates
82- val maxMemoryUsage = scala.math.pow( 2 , 27 ).toInt // 128MB
83+ val maxMemoryUsage = strategy.maxMemory * 1024 * 1024
8384 logDebug(" max memory usage for aggregates = " + maxMemoryUsage)
8485 val numElementsPerNode = {
8586 strategy.algo match {
@@ -1158,10 +1159,13 @@ object DecisionTree extends Serializable with Logging {
11581159
11591160 val maxDepth = options.getOrElse(' maxDepth , " 1" ).toString.toInt
11601161 val maxBins = options.getOrElse(' maxBins , " 100" ).toString.toInt
1162+ val maxMemUsage = memoryStringToMb(options.getOrElse(' maxMemory , " 128m" ).toString)
11611163
1162- val strategy = new Strategy (algo, impurity, maxDepth, maxBins)
1164+ val strategy = new Strategy (algo, impurity, maxDepth, maxBins, maxMemory = maxMemUsage )
11631165 val model = DecisionTree .train(trainData, strategy)
11641166
1167+
1168+
11651169 // Load test data.
11661170 val testData = loadLabeledData(sc, options.get(' testDataDir ).get.toString)
11671171
0 commit comments