Skip to content

Commit 72c2709

Browse files
authored
[ML][Inference] Adding classification_weights to ensemble models (#50874) (#50994)
* [ML][Inference] Adding classification_weights to ensemble models classification_weights are a way to allow models to prefer specific classification results over others this might be advantageous if classification value probabilities are a known quantity and can improve model error rates.
1 parent de5713f commit 72c2709

File tree

12 files changed

+190
-156
lines changed

12 files changed

+190
-156
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/Ensemble.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.common.xcontent.XContentParser;
3030

3131
import java.io.IOException;
32+
import java.util.Arrays;
3233
import java.util.Collections;
3334
import java.util.List;
3435
import java.util.Objects;
@@ -41,6 +42,7 @@ public class Ensemble implements TrainedModel {
4142
public static final ParseField AGGREGATE_OUTPUT = new ParseField("aggregate_output");
4243
public static final ParseField TARGET_TYPE = new ParseField("target_type");
4344
public static final ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels");
45+
public static final ParseField CLASSIFICATION_WEIGHTS = new ParseField("classification_weights");
4446

4547
private static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(
4648
NAME,
@@ -60,6 +62,7 @@ public class Ensemble implements TrainedModel {
6062
AGGREGATE_OUTPUT);
6163
PARSER.declareString(Ensemble.Builder::setTargetType, TARGET_TYPE);
6264
PARSER.declareStringArray(Ensemble.Builder::setClassificationLabels, CLASSIFICATION_LABELS);
65+
PARSER.declareDoubleArray(Ensemble.Builder::setClassificationWeights, CLASSIFICATION_WEIGHTS);
6366
}
6467

6568
public static Ensemble fromXContent(XContentParser parser) {
@@ -71,17 +74,20 @@ public static Ensemble fromXContent(XContentParser parser) {
7174
private final OutputAggregator outputAggregator;
7275
private final TargetType targetType;
7376
private final List<String> classificationLabels;
77+
private final double[] classificationWeights;
7478

7579
Ensemble(List<String> featureNames,
7680
List<TrainedModel> models,
7781
@Nullable OutputAggregator outputAggregator,
7882
TargetType targetType,
79-
@Nullable List<String> classificationLabels) {
83+
@Nullable List<String> classificationLabels,
84+
@Nullable double[] classificationWeights) {
8085
this.featureNames = featureNames;
8186
this.models = models;
8287
this.outputAggregator = outputAggregator;
8388
this.targetType = targetType;
8489
this.classificationLabels = classificationLabels;
90+
this.classificationWeights = classificationWeights;
8591
}
8692

8793
@Override
@@ -116,6 +122,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
116122
if (classificationLabels != null) {
117123
builder.field(CLASSIFICATION_LABELS.getPreferredName(), classificationLabels);
118124
}
125+
if (classificationWeights != null) {
126+
builder.field(CLASSIFICATION_WEIGHTS.getPreferredName(), classificationWeights);
127+
}
119128
builder.endObject();
120129
return builder;
121130
}
@@ -129,12 +138,18 @@ public boolean equals(Object o) {
129138
&& Objects.equals(models, that.models)
130139
&& Objects.equals(targetType, that.targetType)
131140
&& Objects.equals(classificationLabels, that.classificationLabels)
141+
&& Arrays.equals(classificationWeights, that.classificationWeights)
132142
&& Objects.equals(outputAggregator, that.outputAggregator);
133143
}
134144

135145
@Override
136146
public int hashCode() {
137-
return Objects.hash(featureNames, models, outputAggregator, classificationLabels, targetType);
147+
return Objects.hash(featureNames,
148+
models,
149+
outputAggregator,
150+
classificationLabels,
151+
targetType,
152+
Arrays.hashCode(classificationWeights));
138153
}
139154

140155
public static Builder builder() {
@@ -147,6 +162,7 @@ public static class Builder {
147162
private OutputAggregator outputAggregator;
148163
private TargetType targetType;
149164
private List<String> classificationLabels;
165+
private double[] classificationWeights;
150166

151167
public Builder setFeatureNames(List<String> featureNames) {
152168
this.featureNames = featureNames;
@@ -173,6 +189,11 @@ public Builder setClassificationLabels(List<String> classificationLabels) {
173189
return this;
174190
}
175191

192+
public Builder setClassificationWeights(List<Double> classificationWeights) {
193+
this.classificationWeights = classificationWeights.stream().mapToDouble(Double::doubleValue).toArray();
194+
return this;
195+
}
196+
176197
private void setOutputAggregatorFromParser(List<OutputAggregator> outputAggregators) {
177198
this.setOutputAggregator(outputAggregators.get(0));
178199
}
@@ -182,7 +203,7 @@ private void setTargetType(String targetType) {
182203
}
183204

184205
public Ensemble build() {
185-
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels);
206+
return new Ensemble(featureNames, trainedModels, outputAggregator, targetType, classificationLabels, classificationWeights);
186207
}
187208
}
188209
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,19 @@ public static Ensemble createRandom(TargetType targetType) {
8080
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
8181
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
8282
}
83+
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
84+
Stream.generate(ESTestCase::randomDouble)
85+
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())
86+
.mapToDouble(Double::valueOf)
87+
.toArray() :
88+
null;
89+
8390
return new Ensemble(featureNames,
8491
models,
8592
outputAggregator,
8693
targetType,
87-
categoryLabels);
94+
categoryLabels,
95+
thresholds);
8896
}
8997

9098
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,26 @@ public static class TopClassEntry implements Writeable {
112112

113113
public final ParseField CLASS_NAME = new ParseField("class_name");
114114
public final ParseField CLASS_PROBABILITY = new ParseField("class_probability");
115+
public final ParseField CLASS_SCORE = new ParseField("class_score");
115116

116117
private final String classification;
117118
private final double probability;
119+
private final double score;
118120

119-
public TopClassEntry(String classification, Double probability) {
121+
public TopClassEntry(String classification, double probability) {
122+
this(classification, probability, probability);
123+
}
124+
125+
public TopClassEntry(String classification, double probability, double score) {
120126
this.classification = ExceptionsHelper.requireNonNull(classification, CLASS_NAME);
121-
this.probability = ExceptionsHelper.requireNonNull(probability, CLASS_PROBABILITY);
127+
this.probability = probability;
128+
this.score = score;
122129
}
123130

124131
public TopClassEntry(StreamInput in) throws IOException {
125132
this.classification = in.readString();
126133
this.probability = in.readDouble();
134+
this.score = in.readDouble();
127135
}
128136

129137
public String getClassification() {
@@ -134,31 +142,36 @@ public double getProbability() {
134142
return probability;
135143
}
136144

145+
public double getScore() {
146+
return score;
147+
}
148+
137149
public Map<String, Object> asValueMap() {
138-
Map<String, Object> map = new HashMap<>(2);
150+
Map<String, Object> map = new HashMap<>(3, 1.0f);
139151
map.put(CLASS_NAME.getPreferredName(), classification);
140152
map.put(CLASS_PROBABILITY.getPreferredName(), probability);
153+
map.put(CLASS_SCORE.getPreferredName(), score);
141154
return map;
142155
}
143156

144157
@Override
145158
public void writeTo(StreamOutput out) throws IOException {
146159
out.writeString(classification);
147160
out.writeDouble(probability);
161+
out.writeDouble(score);
148162
}
149163

150164
@Override
151165
public boolean equals(Object object) {
152166
if (object == this) { return true; }
153167
if (object == null || getClass() != object.getClass()) { return false; }
154168
TopClassEntry that = (TopClassEntry) object;
155-
return Objects.equals(classification, that.classification) &&
156-
Objects.equals(probability, that.probability);
169+
return Objects.equals(classification, that.classification) && probability == that.probability && score == that.score;
157170
}
158171

159172
@Override
160173
public int hashCode() {
161-
return Objects.hash(classification, probability);
174+
return Objects.hash(classification, probability, score);
162175
}
163176
}
164177
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceHelpers.java

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
77

88
import org.elasticsearch.common.Nullable;
9+
import org.elasticsearch.common.collect.Tuple;
910
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
1011
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1112

@@ -20,25 +21,38 @@ public final class InferenceHelpers {
2021

2122
private InferenceHelpers() { }
2223

23-
public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List<Double> probabilities,
24-
List<String> classificationLabels,
25-
int numToInclude) {
26-
if (numToInclude == 0) {
27-
return Collections.emptyList();
28-
}
29-
int[] sortedIndices = IntStream.range(0, probabilities.size())
30-
.boxed()
31-
.sorted(Comparator.comparing(probabilities::get).reversed())
32-
.mapToInt(i -> i)
33-
.toArray();
24+
/**
25+
* @return Tuple of the highest scored index and the top classes
26+
*/
27+
public static Tuple<Integer, List<ClassificationInferenceResults.TopClassEntry>> topClasses(List<Double> probabilities,
28+
List<String> classificationLabels,
29+
@Nullable double[] classificationWeights,
30+
int numToInclude) {
3431

3532
if (classificationLabels != null && probabilities.size() != classificationLabels.size()) {
3633
throw ExceptionsHelper
3734
.serverError(
3835
"model returned classification probabilities of size [{}] which is not equal to classification labels size [{}]",
3936
null,
4037
probabilities.size(),
41-
classificationLabels);
38+
classificationLabels.size());
39+
}
40+
41+
List<Double> scores = classificationWeights == null ?
42+
probabilities :
43+
IntStream.range(0, probabilities.size())
44+
.mapToDouble(i -> probabilities.get(i) * classificationWeights[i])
45+
.boxed()
46+
.collect(Collectors.toList());
47+
48+
int[] sortedIndices = IntStream.range(0, probabilities.size())
49+
.boxed()
50+
.sorted(Comparator.comparing(scores::get).reversed())
51+
.mapToInt(i -> i)
52+
.toArray();
53+
54+
if (numToInclude == 0) {
55+
return Tuple.tuple(sortedIndices[0], Collections.emptyList());
4256
}
4357

4458
List<String> labels = classificationLabels == null ?
@@ -50,26 +64,24 @@ public static List<ClassificationInferenceResults.TopClassEntry> topClasses(List
5064
List<ClassificationInferenceResults.TopClassEntry> topClassEntries = new ArrayList<>(count);
5165
for(int i = 0; i < count; i++) {
5266
int idx = sortedIndices[i];
53-
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx)));
67+
topClassEntries.add(new ClassificationInferenceResults.TopClassEntry(labels.get(idx), probabilities.get(idx), scores.get(idx)));
5468
}
5569

56-
return topClassEntries;
70+
return Tuple.tuple(sortedIndices[0], topClassEntries);
5771
}
5872

59-
public static String classificationLabel(double inferenceValue, @Nullable List<String> classificationLabels) {
60-
assert inferenceValue == Math.rint(inferenceValue);
73+
public static String classificationLabel(Integer inferenceValue, @Nullable List<String> classificationLabels) {
6174
if (classificationLabels == null) {
6275
return String.valueOf(inferenceValue);
6376
}
64-
int label = Double.valueOf(inferenceValue).intValue();
65-
if (label < 0 || label >= classificationLabels.size()) {
77+
if (inferenceValue < 0 || inferenceValue >= classificationLabels.size()) {
6678
throw ExceptionsHelper.serverError(
6779
"model returned classification value of [{}] which is not a valid index in classification labels [{}]",
6880
null,
69-
label,
81+
inferenceValue,
7082
classificationLabels);
7183
}
72-
return classificationLabels.get(label);
84+
return classificationLabels.get(inferenceValue);
7385
}
7486

7587
public static Double toDouble(Object value) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,14 @@
66
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
77

88
import org.apache.lucene.util.Accountable;
9-
import org.elasticsearch.common.Nullable;
109
import org.elasticsearch.common.io.stream.NamedWriteable;
1110
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1211
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
1312

14-
import java.util.List;
1513
import java.util.Map;
1614

1715
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
1816

19-
/**
20-
* @return List of featureNames expected by the model. In the order that they are expected
21-
*/
22-
List<String> getFeatureNames();
23-
2417
/**
2518
* Infer against the provided fields
2619
*
@@ -36,12 +29,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
3629
*/
3730
TargetType targetType();
3831

39-
/**
40-
* @return Ordinal encoded list of classification labels.
41-
*/
42-
@Nullable
43-
List<String> classificationLabels();
44-
4532
/**
4633
* Runs validations against the model.
4734
*

0 commit comments

Comments
 (0)