Skip to content

Commit 1114207

Browse files
jonsondagmengxr
authored andcommitted
[SPARK-2152][MLlib] fix bin offset in DecisionTree node aggregations (also resolves SPARK-2160)
Hi, this pull fixes (what I believe to be) a bug in DecisionTree.scala. In the extractLeftRightNodeAggregates function, the first set of rightNodeAgg values for Regression are set in line 792 as follows: rightNodeAgg(featureIndex)(2 * (numBins - 2)) = binData(shift + (2 * numBins - 1))) Then there is a loop that sets the rest of the values, as in line 809: rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = binData(shift + (2 *(numBins - 2 - splitIndex))) + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) But since splitIndex starts at 1, this ends up skipping a set of binData values. The changes here address this issue, for both the Regression and Classification cases. Author: johnnywalleye <[email protected]> Closes apache#1316 from johnnywalleye/master and squashes the following commits: 73809da [johnnywalleye] fix bin offset in DecisionTree node aggregations
1 parent ac9cdc1 commit 1114207

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -807,10 +807,10 @@ object DecisionTree extends Serializable with Logging {
807807
// calculating right node aggregate for a split as a sum of right node aggregate of a
808808
// higher split and the right bin aggregate of a bin where the split is a low split
809809
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
810-
binData(shift + (2 *(numBins - 2 - splitIndex))) +
810+
binData(shift + (2 *(numBins - 1 - splitIndex))) +
811811
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
812812
rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
813-
binData(shift + (2* (numBins - 2 - splitIndex) + 1)) +
813+
binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
814814
rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
815815

816816
splitIndex += 1
@@ -855,13 +855,13 @@ object DecisionTree extends Serializable with Logging {
855855
// calculating right node aggregate for a split as a sum of right node aggregate of a
856856
// higher split and the right bin aggregate of a bin where the split is a low split
857857
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
858-
binData(shift + (3 * (numBins - 2 - splitIndex))) +
858+
binData(shift + (3 * (numBins - 1 - splitIndex))) +
859859
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
860860
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
861-
binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) +
861+
binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
862862
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
863863
rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
864-
binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) +
864+
binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
865865
rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
866866

867867
splitIndex += 1

0 commit comments

Comments
 (0)