@@ -28,6 +28,35 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
2828class TreeSplitUtilsSuite
2929 extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3030
31+ /**
32+ * Iterate over feature values and labels for a specific (node, feature), updating stats
33+ * aggregator for the current node.
34+ */
35+ private [impl] def updateAggregator (
36+ statsAggregator : DTStatsAggregator ,
37+ featureIndex : Int ,
38+ values : Array [Int ],
39+ indices : Array [Int ],
40+ instanceWeights : Array [Double ],
41+ labels : Array [Double ],
42+ from : Int ,
43+ to : Int ,
44+ featureIndexIdx : Int ,
45+ featureSplits : Array [Split ]): Unit = {
46+ val metadata = statsAggregator.metadata
47+ from.until(to).foreach { idx =>
48+ val rowIndex = indices(idx)
49+ if (metadata.isUnordered(featureIndex)) {
50+ AggUpdateUtils .updateUnorderedFeature(statsAggregator, values(idx), labels(rowIndex),
51+ featureIndex = featureIndex, featureIndexIdx, featureSplits,
52+ instanceWeight = instanceWeights(rowIndex))
53+ } else {
54+ AggUpdateUtils .updateOrderedFeature(statsAggregator, values(idx), labels(rowIndex),
55+ featureIndexIdx, instanceWeight = instanceWeights(rowIndex))
56+ }
57+ }
58+ }
59+
3160 /**
3261 * Get a DTStatsAggregator for sufficient stat collection/impurity calculation populated
3362 * with the data from the specified training points.
@@ -40,12 +69,13 @@ class TreeSplitUtilsSuite
4069 labels : Array [Double ],
4170 featureSplits : Array [Split ]): DTStatsAggregator = {
4271
72+ val featureIndex = 0
4373 val statsAggregator = new DTStatsAggregator (metadata, featureSubset = None )
4474 val instanceWeights = Array .fill[Double ](values.length)(1.0 )
4575 val indices = values.indices.toArray
4676 AggUpdateUtils .updateParentImpurity(statsAggregator, indices, from, to, instanceWeights, labels)
47- LocalDecisionTree . updateAggregator(statsAggregator, col , indices, instanceWeights, labels,
48- from, to, col. featureIndex, featureSplits)
77+ updateAggregator(statsAggregator, featureIndex = 0 , values , indices, instanceWeights, labels,
78+ from, to, featureIndex, featureSplits)
4979 statsAggregator
5080 }
5181
@@ -73,34 +103,34 @@ class TreeSplitUtilsSuite
73103 test(" chooseSplit: choose correct type of split (continuous split)" ) {
74104 // Construct (binned) continuous data
75105 val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
76- val col = FeatureColumn (featureIndex = 0 , values = Array (8 , 1 , 1 , 2 , 3 , 5 , 6 ))
106+ val values = Array (8 , 1 , 1 , 2 , 3 , 5 , 6 )
107+ val featureIndex = 0
77108 // Get an array of continuous splits corresponding to values in our binned data
78109 val splits = TreeTests .getContinuousSplits(1 .to(8 ).toArray, featureIndex = 0 )
79110 // Construct DTStatsAggregator, compute sufficient stats
80111 val metadata = TreeTests .getMetadata(numExamples = 7 ,
81112 numFeatures = 1 , numClasses = 2 , Map .empty)
82- val statsAggregator = getAggregator(metadata, col , from = 1 , to = 4 , labels, splits)
113+ val statsAggregator = getAggregator(metadata, values , from = 1 , to = 4 , labels, splits)
83114 // Choose split, check that it's a valid ContinuousSplit
84- val (split1, stats1) = SplitUtils .chooseSplit(statsAggregator, col. featureIndex,
85- col.featureIndex, splits)
115+ val (split1, stats1) = SplitUtils .chooseSplit(statsAggregator, featureIndex, featureIndex,
116+ splits)
86117 assert(stats1.valid && split1.isInstanceOf [ContinuousSplit ])
87118 }
88119
89120 test(" chooseSplit: choose correct type of split (categorical split)" ) {
90121 // Construct categorical data
91122 val labels = Array (0.0 , 0.0 , 0.0 , 1.0 , 1.0 , 1.0 , 1.0 )
92- val featureIndex = 0
93123 val featureArity = 3
94124 val values = Array (0 , 0 , 1 , 1 , 1 , 2 , 2 )
95- val col = FeatureColumn (featureIndex, values)
125+ val featureIndex = 0
96126 // Construct DTStatsAggregator, compute sufficient stats
97127 val metadata = TreeTests .getMetadata(numExamples = 7 ,
98128 numFeatures = 1 , numClasses = 2 , Map (featureIndex -> featureArity))
99129 val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
100- val statsAggregator = getAggregator(metadata, col , from = 1 , to = 4 , labels, splits)
130+ val statsAggregator = getAggregator(metadata, values , from = 1 , to = 4 , labels, splits)
101131 // Choose split, check that it's a valid categorical split
102132 val (split2, stats2) = SplitUtils .chooseSplit(statsAggregator = statsAggregator,
103- featureIndex = col. featureIndex, featureIndexIdx = col. featureIndex,
133+ featureIndex = featureIndex, featureIndexIdx = featureIndex,
104134 featureSplits = splits)
105135 assert(stats2.valid && split2.isInstanceOf [CategoricalSplit ])
106136 }
@@ -117,16 +147,14 @@ class TreeSplitUtilsSuite
117147 // Construct FeatureVector to store categorical data
118148 val featureArity = values.max + 1
119149 val arityMap = Map [Int , Int ](featureIndex -> featureArity)
120- val col = FeatureColumn (featureIndex = 0 , values = values)
121150 // Construct DTStatsAggregator, compute sufficient stats
122151 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
123152 numClasses = 2 , arityMap, unorderedFeatures = Some (Set .empty))
124- val statsAggregator = getAggregator(metadata, col , from = 0 , to = values.length,
153+ val statsAggregator = getAggregator(metadata, values , from = 0 , to = values.length,
125154 labels, featureSplits = Array .empty)
126155 // Choose split
127156 val (split, stats) =
128- SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, col.featureIndex,
129- col.featureIndex)
157+ SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex, featureIndex)
130158 // Verify that split has the expected left-side/right-side categories
131159 val expectedRightCategories = Range (0 , featureArity)
132160 .filter(c => ! expectedLeftCategories.contains(c)).map(_.toDouble).toArray
@@ -156,15 +184,14 @@ class TreeSplitUtilsSuite
156184 val values = Array (0 , 0 , 1 , 2 , 2 , 2 , 2 )
157185 val featureArity = values.max + 1
158186 val labels = Array (1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 , 1.0 )
159- val col = FeatureColumn (featureIndex, values)
160187 // Construct DTStatsAggregator, compute sufficient stats
161188 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
162189 numClasses = 2 , Map (featureIndex -> featureArity), unorderedFeatures = Some (Set .empty))
163- val statsAggregator = getAggregator(metadata, col , from = 0 , to = values.length,
190+ val statsAggregator = getAggregator(metadata, values , from = 0 , to = values.length,
164191 labels, featureSplits = Array .empty)
165192 // Choose split, verify that it's invalid
166- val (_, stats) = SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, col. featureIndex,
167- col. featureIndex)
193+ val (_, stats) = SplitUtils .chooseOrderedCategoricalSplit(statsAggregator, featureIndex,
194+ featureIndex)
168195 assert(! stats.valid)
169196 }
170197
@@ -177,17 +204,16 @@ class TreeSplitUtilsSuite
177204 val values = Array (1 , 1 , 0 , 2 , 2 )
178205 val featureArity = values.max + 1
179206 val labels = Array (0.0 , 0.0 , 1.0 , 1.0 , 2.0 )
180- val col = FeatureColumn (featureIndex, values)
181207 // Construct DTStatsAggregator, compute sufficient stats
182208 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
183209 numClasses = 3 , Map (featureIndex -> featureArity))
184210 val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
185- val statsAggregator = getAggregator(metadata, col , from = 0 , to = values.length,
211+ val statsAggregator = getAggregator(metadata, values , from = 0 , to = values.length,
186212 labels, splits)
187213 // Choose split
188214 val (split, stats) =
189- SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, col. featureIndex,
190- col.featureIndex, splits)
215+ SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex, featureIndex,
216+ splits)
191217 // Verify that split has the expected left-side/right-side categories
192218 split match {
193219 case s : CategoricalSplit =>
@@ -208,12 +234,12 @@ class TreeSplitUtilsSuite
208234 val featureArity = 4
209235 val values = Array (3 , 1 , 0 , 2 , 2 )
210236 val labels = Array (1.0 , 1.0 , 1.0 , 1.0 , 1.0 )
211- val col = FeatureColumn (featureIndex, values)
212237 // Construct DTStatsAggregator, compute sufficient stats
213238 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
214239 numClasses = 2 , Map (featureIndex -> featureArity))
215240 val splits = RandomForest .findUnorderedSplits(metadata, featureIndex)
216- val statsAggregator = getAggregator(metadata, col, from = 0 , to = values.length, labels, splits)
241+ val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
242+ splits)
217243 // Choose split, verify that it's invalid
218244 val (_, stats) = SplitUtils .chooseUnorderedCategoricalSplit(statsAggregator, featureIndex,
219245 featureIndex, splits)
@@ -226,13 +252,12 @@ class TreeSplitUtilsSuite
226252 val thresholds = Array (0 , 1 , 2 , 3 )
227253 val values = thresholds.indices.toArray
228254 val labels = Array (0.0 , 0.0 , 1.0 , 1.0 )
229- val col = FeatureColumn (featureIndex = featureIndex, values = values)
230-
231255 // Construct DTStatsAggregator, compute sufficient stats
232256 val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
233257 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
234258 numClasses = 2 , Map .empty)
235- val statsAggregator = getAggregator(metadata, col, from = 0 , to = values.length, labels, splits)
259+ val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
260+ splits)
236261
237262 // Choose split, verify that it has expected threshold
238263 val (split, stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
@@ -256,12 +281,12 @@ class TreeSplitUtilsSuite
256281 val thresholds = Array (0 , 1 , 2 , 3 )
257282 val values = thresholds.indices.toArray
258283 val labels = Array (0.0 , 0.0 , 0.0 , 0.0 , 0.0 )
259- val col = FeatureColumn (featureIndex = featureIndex, values = values)
260284 // Construct DTStatsAggregator, compute sufficient stats
261285 val splits = TreeTests .getContinuousSplits(thresholds, featureIndex)
262286 val metadata = TreeTests .getMetadata(numExamples = values.length, numFeatures = 1 ,
263287 numClasses = 2 , Map .empty[Int , Int ])
264- val statsAggregator = getAggregator(metadata, col, from = 0 , to = values.length, labels, splits)
288+ val statsAggregator = getAggregator(metadata, values, from = 0 , to = values.length, labels,
289+ splits)
265290 // Choose split, verify that it's invalid
266291 val (split, stats) = SplitUtils .chooseContinuousSplit(statsAggregator, featureIndex,
267292 featureIndex, splits)
0 commit comments