Skip to content

Commit 8b33d88

Browse files
authored
[ML] binary classification per-class feature importance for model inference (#61597) (#61746)
This commit addresses two issues: - per class feature importance is now written out for binary classification (logistic regression) - The `class_name` in per class feature importance now matches what is written in the `top_classes` array. backport of #61597
1 parent 2858e1e commit 8b33d88

File tree

6 files changed

+107
-43
lines changed

6 files changed

+107
-43
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FeatureImportance.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ public static FeatureImportance forRegression(String featureName, double importa
3939
return new FeatureImportance(featureName, importance, null);
4040
}
4141

42+
public static FeatureImportance forBinaryClassification(String featureName, double importance, List<ClassImportance> classImportance) {
43+
return new FeatureImportance(featureName,
44+
importance,
45+
classImportance);
46+
}
47+
4248
public static FeatureImportance forClassification(String featureName, List<ClassImportance> classImportance) {
4349
return new FeatureImportance(featureName,
4450
classImportance.stream().mapToDouble(ClassImportance::getImportance).map(Math::abs).sum(),
@@ -170,27 +176,27 @@ private static List<ClassImportance> fromMap(Map<String, Double> classImportance
170176
}
171177

172178
private static Map<String, Double> toMap(List<ClassImportance> importances) {
173-
return importances.stream().collect(Collectors.toMap(i -> i.className, i -> i.importance));
179+
return importances.stream().collect(Collectors.toMap(i -> i.className.toString(), i -> i.importance));
174180
}
175181

176182
public static ClassImportance fromXContent(XContentParser parser) {
177183
return PARSER.apply(parser, null);
178184
}
179185

180-
private final String className;
186+
private final Object className;
181187
private final double importance;
182188

183-
public ClassImportance(String className, double importance) {
189+
public ClassImportance(Object className, double importance) {
184190
this.className = className;
185191
this.importance = importance;
186192
}
187193

188194
public ClassImportance(StreamInput in) throws IOException {
189-
this.className = in.readString();
195+
this.className = in.readGenericValue();
190196
this.importance = in.readDouble();
191197
}
192198

193-
public String getClassName() {
199+
public Object getClassName() {
194200
return className;
195201
}
196202

@@ -207,7 +213,7 @@ public Map<String, Object> toMap() {
207213

208214
@Override
209215
public void writeTo(StreamOutput out) throws IOException {
210-
out.writeString(className);
216+
out.writeGenericValue(className);
211217
out.writeDouble(importance);
212218
}
213219

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1313

1414
import java.util.ArrayList;
15+
import java.util.Arrays;
1516
import java.util.Collections;
1617
import java.util.Comparator;
1718
import java.util.HashMap;
@@ -129,21 +130,46 @@ public static Map<String, double[]> decodeFeatureImportances(Map<String, String>
129130
return originalFeatureImportance;
130131
}
131132

132-
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
133-
@Nullable List<String> classificationLabels) {
133+
public static List<FeatureImportance> transformFeatureImportanceRegression(Map<String, double[]> featureImportance) {
134134
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
135+
featureImportance.forEach((k, v) -> importances.add(FeatureImportance.forRegression(k, v[0])));
136+
return importances;
137+
}
138+
139+
public static List<FeatureImportance> transformFeatureImportanceClassification(Map<String, double[]> featureImportance,
140+
final int predictedValue,
141+
@Nullable List<String> classificationLabels,
142+
@Nullable PredictionFieldType predictionFieldType) {
143+
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
144+
final PredictionFieldType fieldType = predictionFieldType == null ? PredictionFieldType.STRING : predictionFieldType;
135145
featureImportance.forEach((k, v) -> {
136-
// This indicates regression, or logistic regression
146+
// This indicates logistic regression (binary classification)
137147
// If the length > 1, we assume multi-class classification.
138148
if (v.length == 1) {
139-
importances.add(FeatureImportance.forRegression(k, v[0]));
149+
assert predictedValue == 1 || predictedValue == 0;
150+
// If predicted value is `1`, then the other class is `0`
151+
// If predicted value is `0`, then the other class is `1`
152+
final int otherClass = 1 - predictedValue;
153+
String predictedLabel = classificationLabels == null ? null : classificationLabels.get(predictedValue);
154+
String otherLabel = classificationLabels == null ? null : classificationLabels.get(otherClass);
155+
importances.add(FeatureImportance.forBinaryClassification(k,
156+
v[0],
157+
Arrays.asList(
158+
new FeatureImportance.ClassImportance(
159+
fieldType.transformPredictedValue((double)predictedValue, predictedLabel),
160+
v[0]),
161+
new FeatureImportance.ClassImportance(
162+
fieldType.transformPredictedValue((double)otherClass, otherLabel),
163+
-v[0])
164+
)));
140165
} else {
141166
List<FeatureImportance.ClassImportance> classImportance = new ArrayList<>(v.length);
142167
// If the classificationLabels exist, their length must match leaf_value length
143168
assert classificationLabels == null || classificationLabels.size() == v.length;
144169
for (int i = 0; i < v.length; i++) {
170+
String label = classificationLabels == null ? null : classificationLabels.get(i);
145171
classImportance.add(new FeatureImportance.ClassImportance(
146-
classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i),
172+
fieldType.transformPredictedValue((double)i, label),
147173
v[i]));
148174
}
149175
importances.add(FeatureImportance.forClassification(k, classImportance));

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/EnsembleInferenceModel.java

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
4444
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.decodeFeatureImportances;
4545
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.sumDoubleArrays;
46-
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportance;
46+
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceClassification;
47+
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.transformFeatureImportanceRegression;
4748
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.AGGREGATE_OUTPUT;
4849
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_LABELS;
4950
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble.CLASSIFICATION_WEIGHTS;
@@ -154,14 +155,7 @@ private InferenceResults innerInfer(double[] features, InferenceConfig config, M
154155
RawInferenceResults inferenceResult = (RawInferenceResults) result;
155156
inferenceResults[i++] = inferenceResult.getValue();
156157
if (config.requestingImportance()) {
157-
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
158-
assert modelFeatureImportance.length == featureInfluence.length;
159-
for (int j = 0; j < modelFeatureImportance.length; j++) {
160-
if (featureInfluence[j] == null) {
161-
featureInfluence[j] = new double[modelFeatureImportance[j].length];
162-
}
163-
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
164-
}
158+
addFeatureImportance(featureInfluence, inferenceResult);
165159
}
166160
}
167161
double[] processed = outputAggregator.processValues(inferenceResults);
@@ -176,18 +170,22 @@ private InferenceResults innerInfer(double[] features, InferenceConfig config, M
176170
InferenceResults result = model.infer(features, subModelInferenceConfig);
177171
assert result instanceof RawInferenceResults;
178172
RawInferenceResults inferenceResult = (RawInferenceResults) result;
179-
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
180-
assert modelFeatureImportance.length == featureInfluence.length;
181-
for (int j = 0; j < modelFeatureImportance.length; j++) {
182-
if (featureInfluence[j] == null) {
183-
featureInfluence[j] = new double[modelFeatureImportance[j].length];
184-
}
185-
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
186-
}
173+
addFeatureImportance(featureInfluence, inferenceResult);
187174
}
188175
return featureInfluence;
189176
}
190177

178+
private void addFeatureImportance(double[][] featureInfluence, RawInferenceResults inferenceResult) {
179+
double[][] modelFeatureImportance = inferenceResult.getFeatureImportance();
180+
assert modelFeatureImportance.length == featureInfluence.length;
181+
for (int j = 0; j < modelFeatureImportance.length; j++) {
182+
if (featureInfluence[j] == null) {
183+
featureInfluence[j] = new double[modelFeatureImportance[j].length];
184+
}
185+
featureInfluence[j] = sumDoubleArrays(featureInfluence[j], modelFeatureImportance[j]);
186+
}
187+
}
188+
191189
private InferenceResults buildResults(double[] processedInferences,
192190
double[][] featureImportance,
193191
Map<String, String> featureDecoderMap,
@@ -208,7 +206,7 @@ private InferenceResults buildResults(double[] processedInferences,
208206
case REGRESSION:
209207
return new RegressionInferenceResults(outputAggregator.aggregate(processedInferences),
210208
config,
211-
transformFeatureImportance(decodedFeatureImportance, null));
209+
transformFeatureImportanceRegression(decodedFeatureImportance));
212210
case CLASSIFICATION:
213211
ClassificationConfig classificationConfig = (ClassificationConfig) config;
214212
assert classificationWeights == null || processedInferences.length == classificationWeights.length;
@@ -220,10 +218,13 @@ private InferenceResults buildResults(double[] processedInferences,
220218
classificationConfig.getNumTopClasses(),
221219
classificationConfig.getPredictionFieldType());
222220
final InferenceHelpers.TopClassificationValue value = topClasses.v1();
223-
return new ClassificationInferenceResults((double)value.getValue(),
221+
return new ClassificationInferenceResults(value.getValue(),
224222
classificationLabel(topClasses.v1().getValue(), classificationLabels),
225223
topClasses.v2(),
226-
transformFeatureImportance(decodedFeatureImportance, classificationLabels),
224+
transformFeatureImportanceClassification(decodedFeatureImportance,
225+
value.getValue(),
226+
classificationLabels,
227+
classificationConfig.getPredictionFieldType()),
227228
config,
228229
value.getProbability(),
229230
value.getScore());

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/TreeInferenceModel.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,17 @@ private InferenceResults buildResult(double[] value,
188188
return new ClassificationInferenceResults(classificationValue.getValue(),
189189
classificationLabel(classificationValue.getValue(), classificationLabels),
190190
topClasses.v2(),
191-
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, classificationLabels),
191+
InferenceHelpers.transformFeatureImportanceClassification(decodedFeatureImportance,
192+
classificationValue.getValue(),
193+
classificationLabels,
194+
classificationConfig.getPredictionFieldType()),
192195
config,
193196
classificationValue.getProbability(),
194197
classificationValue.getScore());
195198
case REGRESSION:
196199
return new RegressionInferenceResults(value[0],
197200
config,
198-
InferenceHelpers.transformFeatureImportance(decodedFeatureImportance, null));
201+
InferenceHelpers.transformFeatureImportanceRegression(decodedFeatureImportance));
199202
default:
200203
throw new UnsupportedOperationException("unsupported target_type [" + targetType + "] for inference on tree model");
201204
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
import org.elasticsearch.common.io.stream.StreamOutput;
1313
import org.elasticsearch.common.io.stream.Writeable;
1414
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
15+
import org.elasticsearch.common.xcontent.ObjectParser;
1516
import org.elasticsearch.common.xcontent.ToXContentObject;
1617
import org.elasticsearch.common.xcontent.XContentBuilder;
18+
import org.elasticsearch.common.xcontent.XContentParseException;
1719
import org.elasticsearch.common.xcontent.XContentParser;
1820

1921
import java.io.IOException;
@@ -185,8 +187,17 @@ public static class ClassImportance implements ToXContentObject, Writeable {
185187
private static ConstructingObjectParser<ClassImportance, Void> createParser(boolean ignoreUnknownFields) {
186188
ConstructingObjectParser<ClassImportance, Void> parser = new ConstructingObjectParser<>(NAME,
187189
ignoreUnknownFields,
188-
a -> new ClassImportance((String)a[0], (Importance)a[1]));
189-
parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME);
190+
a -> new ClassImportance(a[0], (Importance)a[1]));
191+
parser.declareField(ConstructingObjectParser.constructorArg(), (p, c) -> {
192+
if (p.currentToken() == XContentParser.Token.VALUE_STRING) {
193+
return p.text();
194+
} else if (p.currentToken() == XContentParser.Token.VALUE_NUMBER) {
195+
return p.numberValue();
196+
} else if (p.currentToken() == XContentParser.Token.VALUE_BOOLEAN) {
197+
return p.booleanValue();
198+
}
199+
throw new XContentParseException("Unsupported token [" + p.currentToken() + "]");
200+
}, CLASS_NAME, ObjectParser.ValueType.VALUE);
190201
parser.declareObject(ConstructingObjectParser.constructorArg(),
191202
ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER,
192203
IMPORTANCE);
@@ -197,22 +208,22 @@ public static ClassImportance fromXContent(XContentParser parser, boolean lenien
197208
return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null);
198209
}
199210

200-
public final String className;
211+
public final Object className;
201212
public final Importance importance;
202213

203214
public ClassImportance(StreamInput in) throws IOException {
204-
this.className = in.readString();
215+
this.className = in.readGenericValue();
205216
this.importance = new Importance(in);
206217
}
207218

208-
ClassImportance(String className, Importance importance) {
219+
ClassImportance(Object className, Importance importance) {
209220
this.className = className;
210221
this.importance = importance;
211222
}
212223

213224
@Override
214225
public void writeTo(StreamOutput out) throws IOException {
215-
out.writeString(className);
226+
out.writeGenericValue(className);
216227
importance.writeTo(out);
217228
}
218229

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/inference/InferenceDefinitionTests.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.xpack.core.ml.inference.InferenceToXContentCompressor;
1818
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
1919
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
20+
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
2021
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
2122

2223
import java.io.IOException;
@@ -154,10 +155,26 @@ public void testComplexInferenceDefinitionInferWithCustomPreProcessor() throws I
154155

155156
ClassificationInferenceResults results = (ClassificationInferenceResults) inferenceDefinition.infer(featureMap, config);
156157
assertThat(results.valueAsString(), equalTo("second"));
157-
assertThat(results.getFeatureImportance().get(0).getFeatureName(), equalTo("col2"));
158-
assertThat(results.getFeatureImportance().get(0).getImportance(), closeTo(0.944, 0.001));
159-
assertThat(results.getFeatureImportance().get(1).getFeatureName(), equalTo("col1_male"));
160-
assertThat(results.getFeatureImportance().get(1).getImportance(), closeTo(0.199, 0.001));
158+
FeatureImportance featureImportance1 = results.getFeatureImportance().get(0);
159+
assertThat(featureImportance1.getFeatureName(), equalTo("col2"));
160+
assertThat(featureImportance1.getImportance(), closeTo(0.944, 0.001));
161+
for (FeatureImportance.ClassImportance classImportance : featureImportance1.getClassImportance()) {
162+
if (classImportance.getClassName().equals("second")) {
163+
assertThat(classImportance.getImportance(), closeTo(0.944, 0.001));
164+
} else {
165+
assertThat(classImportance.getImportance(), closeTo(-0.944, 0.001));
166+
}
167+
}
168+
FeatureImportance featureImportance2 = results.getFeatureImportance().get(1);
169+
assertThat(featureImportance2.getFeatureName(), equalTo("col1_male"));
170+
assertThat(featureImportance2.getImportance(), closeTo(0.199, 0.001));
171+
for (FeatureImportance.ClassImportance classImportance : featureImportance2.getClassImportance()) {
172+
if (classImportance.getClassName().equals("second")) {
173+
assertThat(classImportance.getImportance(), closeTo(0.199, 0.001));
174+
} else {
175+
assertThat(classImportance.getImportance(), closeTo(-0.199, 0.001));
176+
}
177+
}
161178
}
162179

163180
public static String getClassificationDefinition(boolean customPreprocessor) {

0 commit comments

Comments
 (0)