Skip to content

Commit eefe768

Browse files
authored
[7.x][ML] ML Model Inference Ingest Processor (#49052) (#49257)
* [ML] ML Model Inference Ingest Processor (#49052) * [ML][Inference] adds lazy model loader and inference (#47410) This adds a couple of things: - A model loader service that is accessible via transport calls. This service will load in models and cache them. They will stay loaded until a processor no longer references them - A Model class and its first sub-class LocalModel. Used to cache model information and run inference. - Transport action and handler for requests to infer against a local model Related Feature PRs: * [ML][Inference] Adjust inference configuration option API (#47812) * [ML][Inference] adds logistic_regression output aggregator (#48075) * [ML][Inference] Adding read/del trained models (#47882) * [ML][Inference] Adding inference ingest processor (#47859) * [ML][Inference] fixing classification inference for ensemble (#48463) * [ML][Inference] Adding model memory estimations (#48323) * [ML][Inference] adding more options to inference processor (#48545) * [ML][Inference] handle string values better in feature extraction (#48584) * [ML][Inference] Adding _stats endpoint for inference (#48492) * [ML][Inference] add inference processors and trained models to usage (#47869) * [ML][Inference] add new flag for optionally including model definition (#48718) * [ML][Inference] adding license checks (#49056) * [ML][Inference] Adding memory and compute estimates to inference (#48955) * fixing version of indexed docs for model inference
1 parent 48f53ef commit eefe768

File tree

97 files changed

+7855
-362
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+7855
-362
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.client.common.TimeUtil;
2323
import org.elasticsearch.common.ParseField;
2424
import org.elasticsearch.common.Strings;
25+
import org.elasticsearch.common.unit.ByteSizeValue;
2526
import org.elasticsearch.common.xcontent.ObjectParser;
2627
import org.elasticsearch.common.xcontent.ToXContentObject;
2728
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject {
4748
public static final ParseField TAGS = new ParseField("tags");
4849
public static final ParseField METADATA = new ParseField("metadata");
4950
public static final ParseField INPUT = new ParseField("input");
51+
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
52+
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
5053

5154
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
5255
true,
@@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject {
6669
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
6770
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
6871
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
72+
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
73+
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
6974
}
7075

7176
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@@ -81,6 +86,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
8186
private final List<String> tags;
8287
private final Map<String, Object> metadata;
8388
private final TrainedModelInput input;
89+
private final Long estimatedHeapMemory;
90+
private final Long estimatedOperations;
8491

8592
TrainedModelConfig(String modelId,
8693
String createdBy,
@@ -90,7 +97,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
9097
TrainedModelDefinition definition,
9198
List<String> tags,
9299
Map<String, Object> metadata,
93-
TrainedModelInput input) {
100+
TrainedModelInput input,
101+
Long estimatedHeapMemory,
102+
Long estimatedOperations) {
94103
this.modelId = modelId;
95104
this.createdBy = createdBy;
96105
this.version = version;
@@ -100,6 +109,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
100109
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
101110
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
102111
this.input = input;
112+
this.estimatedHeapMemory = estimatedHeapMemory;
113+
this.estimatedOperations = estimatedOperations;
103114
}
104115

105116
public String getModelId() {
@@ -138,6 +149,18 @@ public TrainedModelInput getInput() {
138149
return input;
139150
}
140151

152+
public ByteSizeValue getEstimatedHeapMemory() {
153+
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory);
154+
}
155+
156+
public Long getEstimatedHeapMemoryBytes() {
157+
return estimatedHeapMemory;
158+
}
159+
160+
public Long getEstimatedOperations() {
161+
return estimatedOperations;
162+
}
163+
141164
public static Builder builder() {
142165
return new Builder();
143166
}
@@ -172,6 +195,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
172195
if (input != null) {
173196
builder.field(INPUT.getPreferredName(), input);
174197
}
198+
if (estimatedHeapMemory != null) {
199+
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory);
200+
}
201+
if (estimatedOperations != null) {
202+
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
203+
}
175204
builder.endObject();
176205
return builder;
177206
}
@@ -194,6 +223,8 @@ public boolean equals(Object o) {
194223
Objects.equals(definition, that.definition) &&
195224
Objects.equals(tags, that.tags) &&
196225
Objects.equals(input, that.input) &&
226+
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
227+
Objects.equals(estimatedOperations, that.estimatedOperations) &&
197228
Objects.equals(metadata, that.metadata);
198229
}
199230

@@ -206,6 +237,8 @@ public int hashCode() {
206237
definition,
207238
description,
208239
tags,
240+
estimatedHeapMemory,
241+
estimatedOperations,
209242
metadata,
210243
input);
211244
}
@@ -222,6 +255,8 @@ public static class Builder {
222255
private List<String> tags;
223256
private TrainedModelDefinition definition;
224257
private TrainedModelInput input;
258+
private Long estimatedHeapMemory;
259+
private Long estimatedOperations;
225260

226261
public Builder setModelId(String modelId) {
227262
this.modelId = modelId;
@@ -277,6 +312,16 @@ public Builder setInput(TrainedModelInput input) {
277312
return this;
278313
}
279314

315+
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
316+
this.estimatedHeapMemory = estimatedHeapMemory;
317+
return this;
318+
}
319+
320+
public Builder setEstimatedOperations(Long estimatedOperations) {
321+
this.estimatedOperations = estimatedOperations;
322+
return this;
323+
}
324+
280325
public TrainedModelConfig build() {
281326
return new TrainedModelConfig(
282327
modelId,
@@ -287,7 +332,9 @@ public TrainedModelConfig build() {
287332
definition,
288333
tags,
289334
metadata,
290-
input);
335+
input,
336+
estimatedHeapMemory,
337+
estimatedOperations);
291338
}
292339
}
293340

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ protected TrainedModelConfig createTestInstance() {
6464
randomBoolean() ? null :
6565
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
6666
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
67-
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
67+
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
68+
randomBoolean() ? null : randomNonNegativeLong(),
69+
randomBoolean() ? null : randomNonNegativeLong());
70+
6871
}
6972

7073
@Override

server/src/main/java/org/elasticsearch/ingest/IngestStats.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import java.util.HashMap;
3434
import java.util.List;
3535
import java.util.Map;
36+
import java.util.Objects;
3637
import java.util.concurrent.TimeUnit;
3738

3839
public class IngestStats implements Writeable, ToXContentFragment {
@@ -150,6 +151,21 @@ public Map<String, List<ProcessorStat>> getProcessorStats() {
150151
return processorStats;
151152
}
152153

154+
@Override
155+
public boolean equals(Object o) {
156+
if (this == o) return true;
157+
if (o == null || getClass() != o.getClass()) return false;
158+
IngestStats that = (IngestStats) o;
159+
return Objects.equals(totalStats, that.totalStats)
160+
&& Objects.equals(pipelineStats, that.pipelineStats)
161+
&& Objects.equals(processorStats, that.processorStats);
162+
}
163+
164+
@Override
165+
public int hashCode() {
166+
return Objects.hash(totalStats, pipelineStats, processorStats);
167+
}
168+
153169
public static class Stats implements Writeable, ToXContentFragment {
154170

155171
private final long ingestCount;
@@ -218,6 +234,22 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
218234
builder.field("failed", ingestFailedCount);
219235
return builder;
220236
}
237+
238+
@Override
239+
public boolean equals(Object o) {
240+
if (this == o) return true;
241+
if (o == null || getClass() != o.getClass()) return false;
242+
IngestStats.Stats that = (IngestStats.Stats) o;
243+
return Objects.equals(ingestCount, that.ingestCount)
244+
&& Objects.equals(ingestTimeInMillis, that.ingestTimeInMillis)
245+
&& Objects.equals(ingestFailedCount, that.ingestFailedCount)
246+
&& Objects.equals(ingestCurrent, that.ingestCurrent);
247+
}
248+
249+
@Override
250+
public int hashCode() {
251+
return Objects.hash(ingestCount, ingestTimeInMillis, ingestFailedCount, ingestCurrent);
252+
}
221253
}
222254

223255
/**
@@ -270,6 +302,20 @@ public String getPipelineId() {
270302
public Stats getStats() {
271303
return stats;
272304
}
305+
306+
@Override
307+
public boolean equals(Object o) {
308+
if (this == o) return true;
309+
if (o == null || getClass() != o.getClass()) return false;
310+
IngestStats.PipelineStat that = (IngestStats.PipelineStat) o;
311+
return Objects.equals(pipelineId, that.pipelineId)
312+
&& Objects.equals(stats, that.stats);
313+
}
314+
315+
@Override
316+
public int hashCode() {
317+
return Objects.hash(pipelineId, stats);
318+
}
273319
}
274320

275321
/**
@@ -297,5 +343,21 @@ public String getType() {
297343
public Stats getStats() {
298344
return stats;
299345
}
346+
347+
348+
@Override
349+
public boolean equals(Object o) {
350+
if (this == o) return true;
351+
if (o == null || getClass() != o.getClass()) return false;
352+
IngestStats.ProcessorStat that = (IngestStats.ProcessorStat) o;
353+
return Objects.equals(name, that.name)
354+
&& Objects.equals(type, that.type)
355+
&& Objects.equals(stats, that.stats);
356+
}
357+
358+
@Override
359+
public int hashCode() {
360+
return Objects.hash(name, type, stats);
361+
}
300362
}
301363
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/XPackClientPlugin.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@
8888
import org.elasticsearch.xpack.core.ml.action.DeleteForecastAction;
8989
import org.elasticsearch.xpack.core.ml.action.DeleteJobAction;
9090
import org.elasticsearch.xpack.core.ml.action.DeleteModelSnapshotAction;
91+
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
9192
import org.elasticsearch.xpack.core.ml.action.EstimateMemoryUsageAction;
9293
import org.elasticsearch.xpack.core.ml.action.EvaluateDataFrameAction;
9394
import org.elasticsearch.xpack.core.ml.action.FinalizeJobExecutionAction;
@@ -109,6 +110,9 @@
109110
import org.elasticsearch.xpack.core.ml.action.GetModelSnapshotsAction;
110111
import org.elasticsearch.xpack.core.ml.action.GetOverallBucketsAction;
111112
import org.elasticsearch.xpack.core.ml.action.GetRecordsAction;
113+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsAction;
114+
import org.elasticsearch.xpack.core.ml.action.GetTrainedModelsStatsAction;
115+
import org.elasticsearch.xpack.core.ml.action.InferModelAction;
112116
import org.elasticsearch.xpack.core.ml.action.IsolateDatafeedAction;
113117
import org.elasticsearch.xpack.core.ml.action.KillProcessAction;
114118
import org.elasticsearch.xpack.core.ml.action.MlInfoAction;
@@ -153,6 +157,19 @@
153157
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.Recall;
154158
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.ScoreByThresholdResult;
155159
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
160+
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
161+
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
162+
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
163+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
164+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
165+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig;
166+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
167+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
168+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.LogisticRegression;
169+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator;
170+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedMode;
171+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.WeightedSum;
172+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
156173
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncoding;
157174
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncoding;
158175
import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor;
@@ -371,6 +388,10 @@ public List<ActionType<? extends ActionResponse>> getClientActions() {
371388
StopDataFrameAnalyticsAction.INSTANCE,
372389
EvaluateDataFrameAction.INSTANCE,
373390
EstimateMemoryUsageAction.INSTANCE,
391+
InferModelAction.INSTANCE,
392+
GetTrainedModelsAction.INSTANCE,
393+
DeleteTrainedModelAction.INSTANCE,
394+
GetTrainedModelsStatsAction.INSTANCE,
374395
// security
375396
ClearRealmCacheAction.INSTANCE,
376397
ClearRolesCacheAction.INSTANCE,
@@ -519,6 +540,16 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
519540
new NamedWriteableRegistry.Entry(OutputAggregator.class,
520541
LogisticRegression.NAME.getPreferredName(),
521542
LogisticRegression::new),
543+
// ML - Inference Results
544+
new NamedWriteableRegistry.Entry(InferenceResults.class,
545+
ClassificationInferenceResults.NAME,
546+
ClassificationInferenceResults::new),
547+
new NamedWriteableRegistry.Entry(InferenceResults.class,
548+
RegressionInferenceResults.NAME,
549+
RegressionInferenceResults::new),
550+
// ML - Inference Configuration
551+
new NamedWriteableRegistry.Entry(InferenceConfig.class, ClassificationConfig.NAME, ClassificationConfig::new),
552+
new NamedWriteableRegistry.Entry(InferenceConfig.class, RegressionConfig.NAME, RegressionConfig::new),
522553

523554
// monitoring
524555
new NamedWriteableRegistry.Entry(XPackFeatureSet.Usage.class, XPackField.MONITORING, MonitoringFeatureSetUsage::new),

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MachineLearningFeatureSetUsage.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,26 @@ public class MachineLearningFeatureSetUsage extends XPackFeatureSet.Usage {
2929
public static final String CREATED_BY = "created_by";
3030
public static final String NODE_COUNT = "node_count";
3131
public static final String DATA_FRAME_ANALYTICS_JOBS_FIELD = "data_frame_analytics_jobs";
32+
public static final String INFERENCE_FIELD = "inference";
3233

3334
private final Map<String, Object> jobsUsage;
3435
private final Map<String, Object> datafeedsUsage;
3536
private final Map<String, Object> analyticsUsage;
37+
private final Map<String, Object> inferenceUsage;
3638
private final int nodeCount;
3739

3840
public MachineLearningFeatureSetUsage(boolean available,
3941
boolean enabled,
4042
Map<String, Object> jobsUsage,
4143
Map<String, Object> datafeedsUsage,
4244
Map<String, Object> analyticsUsage,
45+
Map<String, Object> inferenceUsage,
4346
int nodeCount) {
4447
super(XPackField.MACHINE_LEARNING, available, enabled);
4548
this.jobsUsage = Objects.requireNonNull(jobsUsage);
4649
this.datafeedsUsage = Objects.requireNonNull(datafeedsUsage);
4750
this.analyticsUsage = Objects.requireNonNull(analyticsUsage);
51+
this.inferenceUsage = Objects.requireNonNull(inferenceUsage);
4852
this.nodeCount = nodeCount;
4953
}
5054

@@ -57,12 +61,17 @@ public MachineLearningFeatureSetUsage(StreamInput in) throws IOException {
5761
} else {
5862
this.analyticsUsage = Collections.emptyMap();
5963
}
64+
if (in.getVersion().onOrAfter(Version.V_7_6_0)) {
65+
this.inferenceUsage = in.readMap();
66+
} else {
67+
this.inferenceUsage = Collections.emptyMap();
68+
}
6069
if (in.getVersion().onOrAfter(Version.V_6_5_0)) {
6170
this.nodeCount = in.readInt();
6271
} else {
6372
this.nodeCount = -1;
6473
}
65-
}
74+
}
6675

6776
@Override
6877
public void writeTo(StreamOutput out) throws IOException {
@@ -72,17 +81,21 @@ public void writeTo(StreamOutput out) throws IOException {
7281
if (out.getVersion().onOrAfter(Version.V_7_4_0)) {
7382
out.writeMap(analyticsUsage);
7483
}
84+
if (out.getVersion().onOrAfter(Version.V_7_6_0)) {
85+
out.writeMap(inferenceUsage);
86+
}
7587
if (out.getVersion().onOrAfter(Version.V_6_5_0)) {
7688
out.writeInt(nodeCount);
7789
}
78-
}
90+
}
7991

8092
@Override
8193
protected void innerXContent(XContentBuilder builder, Params params) throws IOException {
8294
super.innerXContent(builder, params);
8395
builder.field(JOBS_FIELD, jobsUsage);
8496
builder.field(DATAFEEDS_FIELD, datafeedsUsage);
8597
builder.field(DATA_FRAME_ANALYTICS_JOBS_FIELD, analyticsUsage);
98+
builder.field(INFERENCE_FIELD, inferenceUsage);
8699
if (nodeCount >= 0) {
87100
builder.field(NODE_COUNT, nodeCount);
88101
}

0 commit comments

Comments
 (0)