Skip to content

Commit 756a297

Browse files
authored
[ML] adds multi-class feature importance support (#53803)
Adds multi-class feature importance calculation. Feature importance objects are now mapped as follows (logistic) Regression: ``` { "feature_name": "feature_0", "importance": -1.3 } ``` Multi-class [class names are `foo`, `bar`, `baz`] ``` { “feature_name”: “feature_0”, “importance”: 2.0, // sum(abs()) of class importances “foo”: 1.0, “bar”: 0.5, “baz”: -0.5 }, ``` For users to get the full benefit of aggregating and searching for feature importance, they should update their index mapping as follows (before turning this option on in their pipelines) ``` "ml.inference.feature_importance": { "type": "nested", "dynamic": true, "properties": { "feature_name": { "type": "keyword" }, "importance": { "type": "double" } } } ``` The mapping field name is as follows `ml.<inference.target_field>.<inference.tag>.feature_importance` if `inference.tag` is not provided in the processor definition, it is not part of the field path. `inference.target_field` is defaulted to `ml.inference`. //cc @lcawl ^ Where should we document this? If this makes it in for 7.7, there shouldn't be any feature_importance at inference BWC worries as 7.7 is the first version to have it.
1 parent ecdbd37 commit 756a297

File tree

18 files changed

+410
-138
lines changed

18 files changed

+410
-138
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,21 @@ public ClassificationInferenceResults(double value,
3535
String classificationLabel,
3636
List<TopClassEntry> topClasses,
3737
InferenceConfig config) {
38-
this(value, classificationLabel, topClasses, Collections.emptyMap(), (ClassificationConfig)config);
38+
this(value, classificationLabel, topClasses, Collections.emptyList(), (ClassificationConfig)config);
3939
}
4040

4141
public ClassificationInferenceResults(double value,
4242
String classificationLabel,
4343
List<TopClassEntry> topClasses,
44-
Map<String, Double> featureImportance,
44+
List<FeatureImportance> featureImportance,
4545
InferenceConfig config) {
4646
this(value, classificationLabel, topClasses, featureImportance, (ClassificationConfig)config);
4747
}
4848

4949
private ClassificationInferenceResults(double value,
5050
String classificationLabel,
5151
List<TopClassEntry> topClasses,
52-
Map<String, Double> featureImportance,
52+
List<FeatureImportance> featureImportance,
5353
ClassificationConfig classificationConfig) {
5454
super(value,
5555
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
@@ -118,7 +118,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
118118
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
119119
}
120120
if (getFeatureImportance().size() > 0) {
121-
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
121+
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
122+
.stream()
123+
.map(FeatureImportance::toMap)
124+
.collect(Collectors.toList()));
122125
}
123126
}
124127

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.inference.results;
7+
8+
import org.elasticsearch.common.io.stream.StreamInput;
9+
import org.elasticsearch.common.io.stream.StreamOutput;
10+
import org.elasticsearch.common.io.stream.Writeable;
11+
12+
import java.io.IOException;
13+
import java.util.Collections;
14+
import java.util.LinkedHashMap;
15+
import java.util.Map;
16+
import java.util.Objects;
17+
18+
public class FeatureImportance implements Writeable {
19+
20+
private final Map<String, Double> classImportance;
21+
private final double importance;
22+
private final String featureName;
23+
private static final String IMPORTANCE = "importance";
24+
private static final String FEATURE_NAME = "feature_name";
25+
26+
public static FeatureImportance forRegression(String featureName, double importance) {
27+
return new FeatureImportance(featureName, importance, null);
28+
}
29+
30+
public static FeatureImportance forClassification(String featureName, Map<String, Double> classImportance) {
31+
return new FeatureImportance(featureName, classImportance.values().stream().mapToDouble(Math::abs).sum(), classImportance);
32+
}
33+
34+
private FeatureImportance(String featureName, double importance, Map<String, Double> classImportance) {
35+
this.featureName = Objects.requireNonNull(featureName);
36+
this.importance = importance;
37+
this.classImportance = classImportance == null ? null : Collections.unmodifiableMap(classImportance);
38+
}
39+
40+
public FeatureImportance(StreamInput in) throws IOException {
41+
this.featureName = in.readString();
42+
this.importance = in.readDouble();
43+
if (in.readBoolean()) {
44+
this.classImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
45+
} else {
46+
this.classImportance = null;
47+
}
48+
}
49+
50+
public Map<String, Double> getClassImportance() {
51+
return classImportance;
52+
}
53+
54+
public double getImportance() {
55+
return importance;
56+
}
57+
58+
public String getFeatureName() {
59+
return featureName;
60+
}
61+
62+
@Override
63+
public void writeTo(StreamOutput out) throws IOException {
64+
out.writeString(this.featureName);
65+
out.writeDouble(this.importance);
66+
out.writeBoolean(this.classImportance != null);
67+
if (this.classImportance != null) {
68+
out.writeMap(this.classImportance, StreamOutput::writeString, StreamOutput::writeDouble);
69+
}
70+
}
71+
72+
public Map<String, Object> toMap() {
73+
Map<String, Object> map = new LinkedHashMap<>();
74+
map.put(FEATURE_NAME, featureName);
75+
map.put(IMPORTANCE, importance);
76+
if (classImportance != null) {
77+
classImportance.forEach(map::put);
78+
}
79+
return map;
80+
}
81+
82+
@Override
83+
public boolean equals(Object object) {
84+
if (object == this) { return true; }
85+
if (object == null || getClass() != object.getClass()) { return false; }
86+
FeatureImportance that = (FeatureImportance) object;
87+
return Objects.equals(featureName, that.featureName)
88+
&& Objects.equals(importance, that.importance)
89+
&& Objects.equals(classImportance, that.classImportance);
90+
}
91+
92+
@Override
93+
public int hashCode() {
94+
return Objects.hash(featureName, importance, classImportance);
95+
}
96+
97+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ public class RawInferenceResults implements InferenceResults {
1818
public static final String NAME = "raw";
1919

2020
private final double[] value;
21-
private final Map<String, Double> featureImportance;
21+
private final Map<String, double[]> featureImportance;
2222

23-
public RawInferenceResults(double[] value, Map<String, Double> featureImportance) {
23+
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
2424
this.value = value;
2525
this.featureImportance = featureImportance;
2626
}
@@ -29,7 +29,7 @@ public double[] getValue() {
2929
return value;
3030
}
3131

32-
public Map<String, Double> getFeatureImportance() {
32+
public Map<String, double[]> getFeatureImportance() {
3333
return featureImportance;
3434
}
3535

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
import java.io.IOException;
1616
import java.util.Collections;
17-
import java.util.Map;
17+
import java.util.List;
1818
import java.util.Objects;
19+
import java.util.stream.Collectors;
1920

2021
public class RegressionInferenceResults extends SingleValueInferenceResults {
2122

@@ -24,14 +25,14 @@ public class RegressionInferenceResults extends SingleValueInferenceResults {
2425
private final String resultsField;
2526

2627
public RegressionInferenceResults(double value, InferenceConfig config) {
27-
this(value, (RegressionConfig) config, Collections.emptyMap());
28+
this(value, (RegressionConfig) config, Collections.emptyList());
2829
}
2930

30-
public RegressionInferenceResults(double value, InferenceConfig config, Map<String, Double> featureImportance) {
31+
public RegressionInferenceResults(double value, InferenceConfig config, List<FeatureImportance> featureImportance) {
3132
this(value, (RegressionConfig)config, featureImportance);
3233
}
3334

34-
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, Map<String, Double> featureImportance) {
35+
private RegressionInferenceResults(double value, RegressionConfig regressionConfig, List<FeatureImportance> featureImportance) {
3536
super(value,
3637
SingleValueInferenceResults.takeTopFeatureImportances(featureImportance,
3738
regressionConfig.getNumTopFeatureImportanceValues()));
@@ -70,7 +71,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
7071
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
7172
document.setFieldValue(parentResultField + "." + this.resultsField, value());
7273
if (getFeatureImportance().size() > 0) {
73-
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance());
74+
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
75+
.stream()
76+
.map(FeatureImportance::toMap)
77+
.collect(Collectors.toList()));
7478
}
7579
}
7680

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/SingleValueInferenceResults.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,45 +8,46 @@
88
import org.elasticsearch.Version;
99
import org.elasticsearch.common.io.stream.StreamInput;
1010
import org.elasticsearch.common.io.stream.StreamOutput;
11-
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1211

1312
import java.io.IOException;
1413
import java.util.Collections;
15-
import java.util.LinkedHashMap;
16-
import java.util.Map;
14+
import java.util.List;
15+
import java.util.stream.Collectors;
1716

1817
public abstract class SingleValueInferenceResults implements InferenceResults {
1918

2019
private final double value;
21-
private final Map<String, Double> featureImportance;
20+
private final List<FeatureImportance> featureImportance;
2221

23-
static Map<String, Double> takeTopFeatureImportances(Map<String, Double> unsortedFeatureImportances, int numTopFeatures) {
24-
return unsortedFeatureImportances.entrySet()
25-
.stream()
26-
.sorted((l, r)-> Double.compare(Math.abs(r.getValue()), Math.abs(l.getValue())))
22+
static List<FeatureImportance> takeTopFeatureImportances(List<FeatureImportance> unsortedFeatureImportances, int numTopFeatures) {
23+
if (unsortedFeatureImportances == null || unsortedFeatureImportances.isEmpty()) {
24+
return unsortedFeatureImportances;
25+
}
26+
return unsortedFeatureImportances.stream()
27+
.sorted((l, r)-> Double.compare(Math.abs(r.getImportance()), Math.abs(l.getImportance())))
2728
.limit(numTopFeatures)
28-
.collect(LinkedHashMap::new, (h, e) -> h.put(e.getKey(), e.getValue()) , LinkedHashMap::putAll);
29+
.collect(Collectors.toList());
2930
}
3031

3132
SingleValueInferenceResults(StreamInput in) throws IOException {
3233
value = in.readDouble();
3334
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
34-
this.featureImportance = in.readMap(StreamInput::readString, StreamInput::readDouble);
35+
this.featureImportance = in.readList(FeatureImportance::new);
3536
} else {
36-
this.featureImportance = Collections.emptyMap();
37+
this.featureImportance = Collections.emptyList();
3738
}
3839
}
3940

40-
SingleValueInferenceResults(double value, Map<String, Double> featureImportance) {
41+
SingleValueInferenceResults(double value, List<FeatureImportance> featureImportance) {
4142
this.value = value;
42-
this.featureImportance = ExceptionsHelper.requireNonNull(featureImportance, "featureImportance");
43+
this.featureImportance = featureImportance == null ? Collections.emptyList() : featureImportance;
4344
}
4445

4546
public Double value() {
4647
return value;
4748
}
4849

49-
public Map<String, Double> getFeatureImportance() {
50+
public List<FeatureImportance> getFeatureImportance() {
5051
return featureImportance;
5152
}
5253

@@ -58,7 +59,7 @@ public String valueAsString() {
5859
public void writeTo(StreamOutput out) throws IOException {
5960
out.writeDouble(value);
6061
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
61-
out.writeMap(this.featureImportance, StreamOutput::writeString, StreamOutput::writeDouble);
62+
out.writeList(this.featureImportance);
6263
}
6364
}
6465

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import org.elasticsearch.common.Nullable;
99
import org.elasticsearch.common.collect.Tuple;
1010
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
11+
import org.elasticsearch.xpack.core.ml.inference.results.FeatureImportance;
1112
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1213

1314
import java.util.ArrayList;
1415
import java.util.Collections;
1516
import java.util.Comparator;
1617
import java.util.HashMap;
18+
import java.util.LinkedHashMap;
1719
import java.util.List;
1820
import java.util.Map;
1921
import java.util.stream.Collectors;
@@ -100,18 +102,46 @@ public static Double toDouble(Object value) {
100102
return null;
101103
}
102104

103-
public static Map<String, Double> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
104-
Map<String, Double> featureImportances) {
105+
public static Map<String, double[]> decodeFeatureImportances(Map<String, String> processedFeatureToOriginalFeatureMap,
106+
Map<String, double[]> featureImportances) {
105107
if (processedFeatureToOriginalFeatureMap == null || processedFeatureToOriginalFeatureMap.isEmpty()) {
106108
return featureImportances;
107109
}
108110

109-
Map<String, Double> originalFeatureImportance = new HashMap<>();
111+
Map<String, double[]> originalFeatureImportance = new HashMap<>();
110112
featureImportances.forEach((feature, importance) -> {
111113
String featureName = processedFeatureToOriginalFeatureMap.getOrDefault(feature, feature);
112-
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : v1 + importance);
114+
originalFeatureImportance.compute(featureName, (f, v1) -> v1 == null ? importance : sumDoubleArrays(importance, v1));
113115
});
114-
115116
return originalFeatureImportance;
116117
}
118+
119+
public static List<FeatureImportance> transformFeatureImportance(Map<String, double[]> featureImportance,
120+
@Nullable List<String> classificationLabels) {
121+
List<FeatureImportance> importances = new ArrayList<>(featureImportance.size());
122+
featureImportance.forEach((k, v) -> {
123+
// This indicates regression, or logistic regression
124+
// If the length > 1, we assume multi-class classification.
125+
if (v.length == 1) {
126+
importances.add(FeatureImportance.forRegression(k, v[0]));
127+
} else {
128+
Map<String, Double> classImportance = new LinkedHashMap<>(v.length, 1.0f);
129+
// If the classificationLabels exist, their length must match leaf_value length
130+
assert classificationLabels == null || classificationLabels.size() == v.length;
131+
for (int i = 0; i < v.length; i++) {
132+
classImportance.put(classificationLabels == null ? String.valueOf(i) : classificationLabels.get(i), v[i]);
133+
}
134+
importances.add(FeatureImportance.forClassification(k, classImportance));
135+
}
136+
});
137+
return importances;
138+
}
139+
140+
public static double[] sumDoubleArrays(double[] sumTo, double[] inc) {
141+
assert sumTo != null && inc != null && sumTo.length == inc.length;
142+
for (int i = 0; i < inc.length; i++) {
143+
sumTo[i] += inc[i];
144+
}
145+
return sumTo;
146+
}
117147
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
6060
* NOTE: Must be thread safe
6161
* @param fields The fields inferring against
6262
* @param featureDecoder A Map translating processed feature names to their original feature names
63-
* @return A {@code Map<String, Double>} mapping each featureName to its importance
63+
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
6464
*/
65-
Map<String, Double> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
65+
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
6666

6767
default Version getMinimalCompatibilityVersion() {
6868
return Version.V_7_6_0;

0 commit comments

Comments
 (0)