@@ -30,52 +30,51 @@ class TreeSplitUtilsSuite
3030
3131 /**
3232 * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
33- * with the data from the specified training points.
33+ * with the data from the specified training points. Assumes a feature index of 0 and that
34+ * all training points have the same weights (1.0).
3435 */
3536 private def getAggregator (
3637 metadata : DecisionTreeMetadata ,
3738 values : Array [Int ],
38- from : Int ,
39- to : Int ,
4039 labels : Array [Double ],
4140 featureSplits : Array [Split ]): DTStatsAggregator = {
42-
43- val featureIndex = 0
41+ // Create stats aggregator
4442 val statsAggregator = new DTStatsAggregator (metadata, featureSubset = None )
45- val indices = values.indices.toArray
46- val instanceWeights = Array .fill[Double ](values.length)(1.0 )
4743 // Update parent impurity stats
48- AggUpdateUtils .updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
44+ val featureIndex = 0
45+ val instanceWeights = Array .fill[Double ](values.length)(1.0 )
46+ AggUpdateUtils .updateParentImpurity(statsAggregator, indices = values.indices.toArray,
47+ from = 0 , to = values.length, instanceWeights, labels)
4948 // Update current aggregator's impurity stats
50- from.until(to).foreach { idx =>
51- val rowIndex = indices(idx)
49+ values.zip(labels).foreach { case (value : Int , label : Double ) =>
5250 if (metadata.isUnordered(featureIndex)) {
53- AggUpdateUtils .updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
54- featureIndex = featureIndex, featureIndexIdx, featureSplits,
55- instanceWeight = 1.0 )
51+ AggUpdateUtils .updateUnorderedFeature(statsAggregator, value, label,
52+ featureIndex = featureIndex, featureIndexIdx = 0 , featureSplits, instanceWeight = 1.0 )
5653 } else {
57- AggUpdateUtils .updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex) ,
58- featureIndexIdx, instanceWeight = 1.0 )
54+ AggUpdateUtils .updateOrderedFeature(statsAggregator, value, label, featureIndexIdx = 0 ,
55+ instanceWeight = 1.0 )
5956 }
6057 }
61-
62- updateAggregator(statsAggregator, featureIndex = 0 , featureIndexIdx = 0 , values, indices,
63- labels, from, to, featureSplits)
6458 statsAggregator
6559 }
6660
67- /** Check that left/right impurities match what we'd expect for a split. */
61+ /**
62+ * Check that left/right impurities match what we'd expect for a split.
63+ * @param labels Labels whose impurity information should be reflected in stats
64+ * @param stats ImpurityStats object containing impurity info for the left/right sides of a split
65+ */
6866 private def validateImpurityStats (
6967 impurity : Impurity ,
7068 labels : Array [Double ],
7169 stats : ImpurityStats ,
7270 expectedLeftStats : Array [Double ],
7371 expectedRightStats : Array [Double ]): Unit = {
74- // Verify that impurity stats were computed correctly for split
72+ // Compute impurity for our data points manually
7573 val numClasses = (labels.max + 1 ).toInt
7674 val fullImpurityStatsArray
7775 = Array .tabulate[Double ](numClasses)((label : Int ) => labels.count(_ == label).toDouble)
7876 val fullImpurity = Entropy .calculate(fullImpurityStatsArray, labels.length)
77+ // Verify that impurity stats were computed correctly for split
7978 assert(stats.impurityCalculator.stats === fullImpurityStatsArray)
8079 assert(stats.impurity === fullImpurity)
8180 assert(stats.leftImpurityCalculator.stats === expectedLeftStats)
@@ -87,37 +86,37 @@ class TreeSplitUtilsSuite
8786
8887 test(" chooseSplit: choose correct type of split (continuous split)" ) {
8988 // Construct (binned) continuous data
90- val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
91- val values = Array (8 , 1 , 1 , 2 , 3 , 5 , 6 )
89+ val labels = Array (0.0 , 0.0 , 1.0 )
90+ val values = Array (1 , 2 , 3 )
9291 val featureIndex = 0
9392 // Get an array of continuous splits corresponding to values in our binned data
94- val splits = TreeTests .getContinuousSplits(1 .to(8 ).toArray, featureIndex = 0 )
93+ val splits = TreeTests .getContinuousSplits(thresholds = values.distinct.sorted,
94+ featureIndex = 0 )
9595 // Construct DTStatsAggregator, compute sufficient stats
96- val metadata = TreeTests .getMetadata(numExamples = 7 ,
97- numFeatures = 1 , numClasses = 2 , Map .empty)
98- val statsAggregator = getAggregator(metadata, values, from = 1 , to = 4 , labels, splits)
96+ val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
97+ numClasses = 2 , Map .empty)
98+ val statsAggregator = getAggregator(metadata, values, labels, splits)
9999 // Choose split, check that it's a valid ContinuousSplit
100- val (split1, stats1 ) = SplitUtils .chooseSplit(statsAggregator, featureIndex, featureIndex,
100+ val (split, stats ) = SplitUtils .chooseSplit(statsAggregator, featureIndex, featureIndex,
101101 splits)
102- assert(stats1 .valid && split1 .isInstanceOf [ContinuousSplit ])
102+ assert(stats .valid && split .isInstanceOf [ContinuousSplit ])
103103 }
104104
105105 test(" chooseSplit: choose correct type of split (categorical split)" ) {
106106 // Construct categorical data
107- val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
107+ val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 1.0 )
108108 val featureArity = 3
109- val values = Array (0 , 0 , 1 , 1 , 1 , 2 , 2 )
109+ val values = Array (0 , 0 , 1 , 2 , 2 )
110110 val featureIndex = 0
111111 // Construct DTStatsAggregator, compute sufficient stats
112- val metadata = TreeTests .getMetadata(numExamples = 7 ,
113- numFeatures = 1 , numClasses = 2 , Map (featureIndex -> featureArity))
112+ val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
113+ numClasses = 2 , Map (featureIndex -> featureArity))
114114 val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
115- val statsAggregator = getAggregator(metadata, values, from = 1 , to = 4 , labels, splits)
115+ val statsAggregator = getAggregator(metadata, values, labels, splits)
116116 // Choose split, check that it's a valid categorical split
117- val (split2, stats2) = SplitUtils .chooseSplit(statsAggregator = statsAggregator,
118- featureIndex = featureIndex, featureIndexIdx = featureIndex,
119- featureSplits = splits)
120- assert(stats2.valid && split2.isInstanceOf [CategoricalSplit ])
117+ val (split, stats) = SplitUtils .chooseSplit(statsAggregator = statsAggregator,
118+ featureIndex = featureIndex, featureIndexIdx = featureIndex, featureSplits = splits)
119+ assert(stats.valid && split.isInstanceOf [CategoricalSplit ])
121120 }
122121
123122 test(" chooseOrderedCategoricalSplit: basic case" ) {
@@ -128,15 +127,14 @@ class TreeSplitUtilsSuite
128127 expectedLeftCategories : Array [Double ],
129128 expectedLeftStats : Array [Double ],
130129 expectedRightStats : Array [Double ]): Unit = {
130+ // Set up metadata for ordered categorical feature
131131 val featureIndex = 0
132- // Construct FeatureVector to store categorical data
133132 val featureArity = values.max + 1
134133 val arityMap = Map [Int , Int ](featureIndex -> featureArity)
135- // Construct DTStatsAggregator, compute sufficient stats
136134 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
137135 numClasses = 2 , arityMap, unorderedFeatures = Some (Set .empty))
138- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length,
139- labels, featureSplits = Array .empty)
136+ // Construct DTStatsAggregator, compute sufficient stats
137+ val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array .empty)
140138 // Choose split
141139 val (split, stats) =
142140 SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex)
@@ -155,12 +153,18 @@ class TreeSplitUtilsSuite
155153 validateImpurityStats(Entropy , labels, stats, expectedLeftStats, expectedRightStats)
156154 }
157155
156+ // Test a single split: The left side of our split should contain the two points with label 0,
157+ // the left side of our split should contain the five points with label 1
158158 val values = Array (0 , 0 , 1 , 2 , 2 , 2 , 2 )
159159 val labels1 = Array (0 , 0 , 1 , 1 , 1 , 1 , 1 ).map(_.toDouble)
160- testHelper(values, labels1, Array (0.0 ), Array (2.0 , 0.0 ), Array (0.0 , 5.0 ))
160+ testHelper(values, labels1, expectedLeftCategories = Array (0.0 ),
161+ expectedLeftStats = Array (2.0 , 0.0 ), expectedRightStats = Array (0.0 , 5.0 ))
161162
163+ // Test a single split: The left side of our split should contain the three points with label 0,
164+ // the left side of our split should contain the four points with label 1
162165 val labels2 = Array (0 , 0 , 0 , 1 , 1 , 1 , 1 ).map(_.toDouble)
163- testHelper(values, labels2, Array (0.0 , 1.0 ), Array (3.0 , 0.0 ), Array (0.0 , 4.0 ))
166+ testHelper(values, labels2, expectedLeftCategories = Array (0.0 , 1.0 ),
167+ expectedLeftStats = Array (3.0 , 0.0 ), expectedRightStats = Array (0.0 , 4.0 ))
164168 }
165169
166170 test(" chooseOrderedCategoricalSplit: return bad stats if we should not split" ) {
@@ -172,8 +176,7 @@ class TreeSplitUtilsSuite
172176 // Construct DTStatsAggregator, compute sufficient stats
173177 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
174178 numClasses = 2 , Map (featureIndex -> featureArity), unorderedFeatures = Some (Set .empty))
175- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length,
176- labels, featureSplits = Array .empty)
179+ val statsAggregator = getAggregator(metadata, values, labels, featureSplits = Array .empty)
177180 // Choose split, verify that it's invalid
178181 val (_, stats) = SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex,
179182 featureIndex)
@@ -186,15 +189,15 @@ class TreeSplitUtilsSuite
186189 // label: 0 --> values: 1
187190 // label: 1 --> values: 0, 2
188191 // label: 2 --> values: 2
192+ // Expected split: feature value 1 on the left, values (0, 2) on the right
189193 val values = Array (1 , 1 , 0 , 2 , 2 )
190194 val featureArity = values.max + 1
191195 val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 2.0 )
192196 // Construct DTStatsAggregator, compute sufficient stats
193197 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
194198 numClasses = 3 , Map (featureIndex -> featureArity))
195199 val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
196- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length,
197- labels, splits)
200+ val statsAggregator = getAggregator(metadata, values, labels, splits)
198201 // Choose split
199202 val (split, stats) =
200203 SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex,
@@ -214,7 +217,7 @@ class TreeSplitUtilsSuite
214217 }
215218
216219 test(" chooseUnorderedCategoricalSplit: return bad stats if we should not split" ) {
217- // Construct data for unordered categorical feature
220+ // Construct data for unordered categorical feature; all points have label 1
218221 val featureIndex = 0
219222 val featureArity = 4
220223 val values = Array (3 , 1 , 0 , 2 , 2 )
@@ -223,8 +226,7 @@ class TreeSplitUtilsSuite
223226 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
224227 numClasses = 2 , Map (featureIndex -> featureArity))
225228 val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
226- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
227- splits)
229+ val statsAggregator = getAggregator(metadata, values, labels, splits)
228230 // Choose split, verify that it's invalid
229231 val (_, stats) = SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex,
230232 featureIndex, splits)
@@ -241,8 +243,7 @@ class TreeSplitUtilsSuite
241243 val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
242244 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
243245 numClasses = 2 , Map .empty)
244- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
245- splits)
246+ val statsAggregator = getAggregator(metadata, values, labels, splits)
246247
247248 // Choose split, verify that it has expected threshold
248249 val (split, stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
@@ -261,7 +262,7 @@ class TreeSplitUtilsSuite
261262 }
262263
263264 test(" chooseContinuousSplit: return bad stats if we should not split" ) {
264- // Construct data for continuous feature
265+ // Construct data for continuous feature; all points have label 0
265266 val featureIndex = 0
266267 val thresholds = Array (0 , 1 , 2 , 3 )
267268 val values = thresholds.indices.toArray
@@ -270,10 +271,9 @@ class TreeSplitUtilsSuite
270271 val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
271272 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
272273 numClasses = 2 , Map .empty[Int , Int ])
273- val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
274- splits)
274+ val statsAggregator = getAggregator(metadata, values, labels, splits)
275275 // Choose split, verify that it's invalid
276- val (split , stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
276+ val (_ , stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
277277 featureIndex, splits)
278278 assert(! stats.valid)
279279 }
0 commit comments