Skip to content

Commit 78fcd05

Browse files
authored
[ML] inference performance optimizations and refactor (#57674)
This is a major refactor of the underlying inference logic. The main refactor is now we are separating the model configuration and the inference interfaces. This has the following benefits: - we can store extra things with the model that are not necessary for inference (i.e. treenode split information gain) - we can optimize inference separate from model serialization and storage. - The user is oblivious to the optimizations (other than seeing the benefits). A major part of this commit is removing all inference related methods from the trained model configurations (ensemble, tree, etc.) and moving them to a new class. This new class satisfies a new interface that is ONLY for inference. The optimizations applied currently are: - feature maps are flattened once - feature extraction only happens once at the highest level (improves inference + feature importance through put) - Only storing what we need for inference + feature importance on heap
1 parent 1f28bd0 commit 78fcd05

File tree

39 files changed

+2366
-1608
lines changed

39 files changed

+2366
-1608
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/InferenceToXContentCompressor.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import org.elasticsearch.common.bytes.BytesArray;
1111
import org.elasticsearch.common.bytes.BytesReference;
1212
import org.elasticsearch.common.io.stream.BytesStreamOutput;
13+
import org.elasticsearch.common.unit.ByteSizeUnit;
14+
import org.elasticsearch.common.unit.ByteSizeValue;
1315
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
1416
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
1517
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -37,7 +39,7 @@ public final class InferenceToXContentCompressor {
3739
// Either 10% of the configured JVM heap, or 1 GB, which ever is smaller
3840
private static final long MAX_INFLATED_BYTES = Math.min(
3941
(long)((0.10) * JvmInfo.jvmInfo().getMem().getHeapMax().getBytes()),
40-
1_000_000_000); // 1 gb maximum
42+
new ByteSizeValue(1, ByteSizeUnit.GB).getBytes());
4143

4244
private InferenceToXContentCompressor() {}
4345

@@ -46,9 +48,9 @@ public static <T extends ToXContentObject> String deflate(T objectToCompress) th
4648
return deflate(reference);
4749
}
4850

49-
static <T> T inflate(String compressedString,
50-
CheckedFunction<XContentParser, T, IOException> parserFunction,
51-
NamedXContentRegistry xContentRegistry) throws IOException {
51+
public static <T> T inflate(String compressedString,
52+
CheckedFunction<XContentParser, T, IOException> parserFunction,
53+
NamedXContentRegistry xContentRegistry) throws IOException {
5254
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
5355
LoggingDeprecationHandler.INSTANCE,
5456
inflate(compressedString, MAX_INFLATED_BYTES))) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@
3030
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.StrictlyParsedOutputAggregator;
3131
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
3232
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
33+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.EnsembleInferenceModel;
34+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceModel;
35+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.TreeInferenceModel;
3336
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork;
3437
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
3538
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
@@ -119,6 +122,13 @@ public List<NamedXContentRegistry.Entry> getNamedXContentParsers() {
119122
namedXContent.add(new NamedXContentRegistry.Entry(InferenceConfigUpdate.class, RegressionConfigUpdate.NAME,
120123
RegressionConfigUpdate::fromXContentStrict));
121124

125+
// Inference models
126+
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Ensemble.NAME, EnsembleInferenceModel::fromXContent));
127+
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class, Tree.NAME, TreeInferenceModel::fromXContent));
128+
namedXContent.add(new NamedXContentRegistry.Entry(InferenceModel.class,
129+
LangIdentNeuralNetwork.NAME,
130+
LangIdentNeuralNetwork::fromXContentLenient));
131+
122132
return namedXContent;
123133
}
124134

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import java.io.IOException;
3535
import java.time.Instant;
36+
import java.util.Arrays;
3637
import java.util.Collections;
3738
import java.util.HashMap;
3839
import java.util.List;
@@ -607,6 +608,14 @@ public Builder validate(boolean forCreation) {
607608
if (input != null && input.getFieldNames().isEmpty()) {
608609
validationException = addValidationError("[input.field_names] must not be empty", validationException);
609610
}
611+
if (input != null && input.getFieldNames()
612+
.stream()
613+
.filter(s -> s.contains("."))
614+
.flatMap(s -> Arrays.stream(Strings.delimitedListToStringArray(s, ".")))
615+
.anyMatch(String::isEmpty)) {
616+
validationException = addValidationError("[input.field_names] must only contain valid dot delimited field names",
617+
validationException);
618+
}
610619
if (forCreation) {
611620
validationException = checkIllegalSetting(version, VERSION.getPreferredName(), validationException);
612621
validationException = checkIllegalSetting(createdBy, CREATED_BY.getPreferredName(), validationException);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor;
2121
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
2222
import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor;
23-
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
24-
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
2523
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.LenientlyParsedTrainedModel;
2624
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.StrictlyParsedTrainedModel;
2725
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
@@ -32,9 +30,7 @@
3230
import java.util.ArrayList;
3331
import java.util.Collection;
3432
import java.util.Collections;
35-
import java.util.HashMap;
3633
import java.util.List;
37-
import java.util.Map;
3834
import java.util.Objects;
3935

4036
public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable {
@@ -73,7 +69,6 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser,
7369

7470
private final TrainedModel trainedModel;
7571
private final List<PreProcessor> preProcessors;
76-
private Map<String, String> decoderMap;
7772

7873
private TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
7974
this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL);
@@ -116,37 +111,6 @@ public List<PreProcessor> getPreProcessors() {
116111
return preProcessors;
117112
}
118113

119-
void preProcess(Map<String, Object> fields) {
120-
preProcessors.forEach(preProcessor -> preProcessor.process(fields));
121-
}
122-
123-
public InferenceResults infer(Map<String, Object> fields, InferenceConfig config) {
124-
preProcess(fields);
125-
if (config.requestingImportance() && trainedModel.supportsFeatureImportance() == false) {
126-
throw ExceptionsHelper.badRequestException(
127-
"Feature importance is not supported for the configured model of type [{}]",
128-
trainedModel.getName());
129-
}
130-
return trainedModel.infer(fields,
131-
config,
132-
config.requestingImportance() ? getDecoderMap() : Collections.emptyMap());
133-
}
134-
135-
private Map<String, String> getDecoderMap() {
136-
if (decoderMap != null) {
137-
return decoderMap;
138-
}
139-
synchronized (this) {
140-
if (decoderMap != null) {
141-
return decoderMap;
142-
}
143-
this.decoderMap = preProcessors.stream()
144-
.map(PreProcessor::reverseLookup)
145-
.collect(HashMap::new, Map::putAll, Map::putAll);
146-
return decoderMap;
147-
}
148-
}
149-
150114
@Override
151115
public String toString() {
152116
return Strings.toString(this);
@@ -218,14 +182,6 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
218182
return this;
219183
}
220184

221-
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
222-
if (trainedModel.size() != 1) {
223-
throw ExceptionsHelper.badRequestException("[{}] must have exactly one trained model defined.",
224-
TRAINED_MODEL.getPreferredName());
225-
}
226-
return setTrainedModel(trainedModel.get(0));
227-
}
228-
229185
private void setProcessorsInOrder(boolean value) {
230186
this.processorsInOrder = value;
231187
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.common.xcontent.XContentBuilder;
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17-
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1817

1918
import java.io.IOException;
2019
import java.util.Collections;
@@ -109,7 +108,7 @@ public String getName() {
109108

110109
@Override
111110
public void process(Map<String, Object> fields) {
112-
Object value = MapHelper.dig(field, fields);
111+
Object value = fields.get(field);
113112
if (value == null) {
114113
return;
115114
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.common.xcontent.XContentBuilder;
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17-
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1817

1918
import java.io.IOException;
2019
import java.util.Collections;
@@ -94,7 +93,7 @@ public String getName() {
9493

9594
@Override
9695
public void process(Map<String, Object> fields) {
97-
Object value = MapHelper.dig(field, fields);
96+
Object value = fields.get(field);
9897
if (value == null) {
9998
return;
10099
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import org.elasticsearch.common.xcontent.XContentBuilder;
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17-
import org.elasticsearch.xpack.core.ml.utils.MapHelper;
1817

1918
import java.io.IOException;
2019
import java.util.Collections;
@@ -120,7 +119,7 @@ public String getName() {
120119

121120
@Override
122121
public void process(Map<String, Object> fields) {
123-
Object value = MapHelper.dig(field, fields);
122+
Object value = fields.get(field);
124123
if (value == null) {
125124
return;
126125
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,17 +10,16 @@
1010

1111
import java.io.IOException;
1212
import java.util.Arrays;
13-
import java.util.Map;
1413
import java.util.Objects;
1514

1615
public class RawInferenceResults implements InferenceResults {
1716

1817
public static final String NAME = "raw";
1918

2019
private final double[] value;
21-
private final Map<String, double[]> featureImportance;
20+
private final double[][] featureImportance;
2221

23-
public RawInferenceResults(double[] value, Map<String, double[]> featureImportance) {
22+
public RawInferenceResults(double[] value, double[][] featureImportance) {
2423
this.value = value;
2524
this.featureImportance = featureImportance;
2625
}
@@ -29,7 +28,7 @@ public double[] getValue() {
2928
return value;
3029
}
3130

32-
public Map<String, double[]> getFeatureImportance() {
31+
public double[][] getFeatureImportance() {
3332
return featureImportance;
3433
}
3534

@@ -44,7 +43,7 @@ public boolean equals(Object object) {
4443
if (object == null || getClass() != object.getClass()) { return false; }
4544
RawInferenceResults that = (RawInferenceResults) object;
4645
return Arrays.equals(value, that.value)
47-
&& Objects.equals(featureImportance, that.featureImportance);
46+
&& Arrays.deepEquals(featureImportance, that.featureImportance);
4847
}
4948

5049
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,12 @@
77

88
import org.apache.lucene.util.Accountable;
99
import org.elasticsearch.Version;
10-
import org.elasticsearch.common.Nullable;
1110
import org.elasticsearch.common.io.stream.NamedWriteable;
12-
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1311
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
1412

15-
import java.util.Map;
1613

1714
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
1815

19-
/**
20-
* Infer against the provided fields
21-
*
22-
* NOTE: Must be thread safe
23-
*
24-
* @param fields The fields and their values to infer against
25-
* @param config The configuration options for inference
26-
* @param featureDecoderMap A map for decoding feature value names to their originating feature.
27-
* Necessary for feature influence.
28-
* @return The predicted value. For classification this will be discrete values (e.g. 0.0, or 1.0).
29-
* For regression this is continuous.
30-
*/
31-
InferenceResults infer(Map<String, Object> fields, InferenceConfig config, @Nullable Map<String, String> featureDecoderMap);
32-
3316
/**
3417
* @return {@link TargetType} for the model.
3518
*/
@@ -49,21 +32,6 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
4932
*/
5033
long estimatedNumOperations();
5134

52-
/**
53-
* @return Does the model support feature importance
54-
*/
55-
boolean supportsFeatureImportance();
56-
57-
/**
58-
* Calculates the importance of each feature reference by the model for the passed in field values
59-
*
60-
* NOTE: Must be thread safe
61-
* @param fields The fields inferring against
62-
* @param featureDecoder A Map translating processed feature names to their original feature names
63-
* @return A {@code Map<String, double[]>} mapping each featureName to its importance
64-
*/
65-
Map<String, double[]> featureImportance(Map<String, Object> fields, Map<String, String> featureDecoder);
66-
6735
default Version getMinimalCompatibilityVersion() {
6836
return Version.V_7_6_0;
6937
}

0 commit comments

Comments
 (0)