Skip to content

Commit a2acea5

Browse files
committed
Small optimizations based on profiling
1 parent aa4e4df commit a2acea5

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ object DecisionTree extends Serializable with Logging {
572572
val label = treePoint.label
573573
val nodeOffset = agg.getNodeOffset(nodeIndex)
574574
// Iterate over all features.
575-
val numFeatures = treePoint.binnedFeatures.size
575+
val numFeatures = agg.numFeatures
576576
var featureIndex = 0
577577
while (featureIndex < numFeatures) {
578578
val binIndex = treePoint.binnedFeatures(featureIndex)

mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,9 @@ private[tree] object TreePoint {
7373
val arr = new Array[Int](numFeatures)
7474
var featureIndex = 0
7575
while (featureIndex < numFeatures) {
76-
arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
77-
metadata.isUnordered(featureIndex), bins, metadata.featureArity)
76+
val featureArity = metadata.featureArity.getOrElse(featureIndex, 0)
77+
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity,
78+
metadata.isUnordered(featureIndex), bins)
7879
featureIndex += 1
7980
}
8081

@@ -84,17 +85,16 @@ private[tree] object TreePoint {
8485
/**
8586
* Find bin for one (labeledPoint, feature).
8687
*
88+
* @param featureArity 0 for continuous features; number of categories for categorical features.
8789
* @param isUnorderedFeature (only applies if feature is categorical)
8890
* @param bins Bins for features, of size (numFeatures, numBins).
89-
* @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
9091
*/
9192
private def findBin(
9293
featureIndex: Int,
9394
labeledPoint: LabeledPoint,
94-
isFeatureContinuous: Boolean,
95+
featureArity: Int,
9596
isUnorderedFeature: Boolean,
96-
bins: Array[Array[Bin]],
97-
categoricalFeaturesInfo: Map[Int, Int]): Int = {
97+
bins: Array[Array[Bin]]): Int = {
9898

9999
/**
100100
* Binary search helper method for continuous feature.
@@ -120,7 +120,7 @@ private[tree] object TreePoint {
120120
-1
121121
}
122122

123-
if (isFeatureContinuous) {
123+
if (featureArity == 0) {
124124
// Perform binary search for finding bin for continuous features.
125125
val binIndex = binarySearchForBins()
126126
if (binIndex == -1) {
@@ -131,13 +131,12 @@ private[tree] object TreePoint {
131131
binIndex
132132
} else {
133133
// Categorical feature bins are indexed by feature values.
134-
val featureCategories = categoricalFeaturesInfo(featureIndex)
135134
val featureValue = labeledPoint.features(featureIndex)
136-
if (featureValue < 0 || featureValue >= featureCategories) {
135+
if (featureValue < 0 || featureValue >= featureArity) {
137136
throw new IllegalArgumentException(
138137
s"DecisionTree given invalid data:" +
139138
s" Feature $featureIndex is categorical with values in" +
140-
s" {0,...,${featureCategories - 1}," +
139+
s" {0,...,${featureArity - 1}," +
141140
s" but a data point gives it value $featureValue.\n" +
142141
" Bad data point: " + labeledPoint.toString)
143142
}

0 commit comments

Comments
 (0)