@@ -73,8 +73,9 @@ private[tree] object TreePoint {
7373 val arr = new Array [Int ](numFeatures)
7474 var featureIndex = 0
7575 while (featureIndex < numFeatures) {
76- arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex),
77- metadata.isUnordered(featureIndex), bins, metadata.featureArity)
76+ val featureArity = metadata.featureArity.getOrElse(featureIndex, 0 )
77+ arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity,
78+ metadata.isUnordered(featureIndex), bins)
7879 featureIndex += 1
7980 }
8081
@@ -84,17 +85,16 @@ private[tree] object TreePoint {
8485 /**
8586 * Find bin for one (labeledPoint, feature).
8687 *
88+ * @param featureArity 0 for continuous features; number of categories for categorical features.
8789 * @param isUnorderedFeature (only applies if feature is categorical)
8890 * @param bins Bins for features, of size (numFeatures, numBins).
89- * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
9091 */
9192 private def findBin (
9293 featureIndex : Int ,
9394 labeledPoint : LabeledPoint ,
94- isFeatureContinuous : Boolean ,
95+ featureArity : Int ,
9596 isUnorderedFeature : Boolean ,
96- bins : Array [Array [Bin ]],
97- categoricalFeaturesInfo : Map [Int , Int ]): Int = {
97+ bins : Array [Array [Bin ]]): Int = {
9898
9999 /**
100100 * Binary search helper method for continuous feature.
@@ -120,7 +120,7 @@ private[tree] object TreePoint {
120120 - 1
121121 }
122122
123- if (isFeatureContinuous ) {
123+ if (featureArity == 0 ) {
124124 // Perform binary search for finding bin for continuous features.
125125 val binIndex = binarySearchForBins()
126126 if (binIndex == - 1 ) {
@@ -131,13 +131,12 @@ private[tree] object TreePoint {
131131 binIndex
132132 } else {
133133 // Categorical feature bins are indexed by feature values.
134- val featureCategories = categoricalFeaturesInfo(featureIndex)
135134 val featureValue = labeledPoint.features(featureIndex)
136- if (featureValue < 0 || featureValue >= featureCategories ) {
135+ if (featureValue < 0 || featureValue >= featureArity ) {
137136 throw new IllegalArgumentException (
138137 s " DecisionTree given invalid data: " +
139138 s " Feature $featureIndex is categorical with values in " +
140- s " {0,..., ${featureCategories - 1 }, " +
139+ s " {0,..., ${featureArity - 1 }, " +
141140 s " but a data point gives it value $featureValue. \n " +
142141 " Bad data point: " + labeledPoint.toString)
143142 }
0 commit comments