Skip to content

Commit c5c7ee9

Browse files
authored
[7.x] [ML] Start gathering and storing inference stats (#53429) (#54738)
* [ML] Start gathering and storing inference stats (#53429) This PR enables stats on inference to be gathered and stored in the `.ml-stats-*` indices. Each node + model_id will have its own running stats document and these will later be summed together when returning _stats to the user. `.ml-stats-*` is ILM managed (when possible). So, at any point the underlying index could change. This means that a stats document that is read in and then later updated will actually be a new doc in a new index. This complicates matters as this means that having a running knowledge of seq_no and primary_term is complicated and almost impossible. This is because we don't know the latest index name. We should also strive for throughput, as this code sits in the middle of an ingest pipeline (or even a query).
1 parent 7a8a66d commit c5c7ee9

File tree

17 files changed

+957
-54
lines changed

17 files changed

+957
-54
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsAction.java

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.elasticsearch.xpack.core.ml.action;
77

88
import org.elasticsearch.ElasticsearchException;
9+
import org.elasticsearch.Version;
910
import org.elasticsearch.action.ActionRequestBuilder;
1011
import org.elasticsearch.action.ActionType;
1112
import org.elasticsearch.client.ElasticsearchClient;
@@ -20,6 +21,7 @@
2021
import org.elasticsearch.xpack.core.action.AbstractGetResourcesResponse;
2122
import org.elasticsearch.xpack.core.action.util.QueryPage;
2223
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
24+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats;
2325

2426
import java.io.IOException;
2527
import java.util.ArrayList;
@@ -37,6 +39,7 @@ public class GetTrainedModelsStatsAction extends ActionType<GetTrainedModelsStat
3739

3840
public static final ParseField MODEL_ID = new ParseField("model_id");
3941
public static final ParseField PIPELINE_COUNT = new ParseField("pipeline_count");
42+
public static final ParseField INFERENCE_STATS = new ParseField("inference_stats");
4043

4144
private GetTrainedModelsStatsAction() {
4245
super(NAME, GetTrainedModelsStatsAction.Response::new);
@@ -78,25 +81,32 @@ public static class Response extends AbstractGetResourcesResponse<Response.Train
7881
public static class TrainedModelStats implements ToXContentObject, Writeable {
7982
private final String modelId;
8083
private final IngestStats ingestStats;
84+
private final InferenceStats inferenceStats;
8185
private final int pipelineCount;
8286

8387
private static final IngestStats EMPTY_INGEST_STATS = new IngestStats(new IngestStats.Stats(0, 0, 0, 0),
8488
Collections.emptyList(),
8589
Collections.emptyMap());
8690

87-
public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount) {
91+
public TrainedModelStats(String modelId, IngestStats ingestStats, int pipelineCount, InferenceStats inferenceStats) {
8892
this.modelId = Objects.requireNonNull(modelId);
8993
this.ingestStats = ingestStats == null ? EMPTY_INGEST_STATS : ingestStats;
9094
if (pipelineCount < 0) {
9195
throw new ElasticsearchException("[{}] must be a greater than or equal to 0", PIPELINE_COUNT.getPreferredName());
9296
}
9397
this.pipelineCount = pipelineCount;
98+
this.inferenceStats = inferenceStats;
9499
}
95100

96101
public TrainedModelStats(StreamInput in) throws IOException {
97102
modelId = in.readString();
98103
ingestStats = new IngestStats(in);
99104
pipelineCount = in.readVInt();
105+
if (in.getVersion().onOrAfter(Version.V_7_8_0)) {
106+
this.inferenceStats = in.readOptionalWriteable(InferenceStats::new);
107+
} else {
108+
this.inferenceStats = null;
109+
}
100110
}
101111

102112
public String getModelId() {
@@ -120,6 +130,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
120130
// Ingest stats is a fragment
121131
ingestStats.toXContent(builder, params);
122132
}
133+
if (this.inferenceStats != null) {
134+
builder.field(INFERENCE_STATS.getPreferredName(), this.inferenceStats);
135+
}
123136
builder.endObject();
124137
return builder;
125138
}
@@ -129,11 +142,14 @@ public void writeTo(StreamOutput out) throws IOException {
129142
out.writeString(modelId);
130143
ingestStats.writeTo(out);
131144
out.writeVInt(pipelineCount);
145+
if (out.getVersion().onOrAfter(Version.V_7_8_0)) {
146+
out.writeOptionalWriteable(this.inferenceStats);
147+
}
132148
}
133149

134150
@Override
135151
public int hashCode() {
136-
return Objects.hash(modelId, ingestStats, pipelineCount);
152+
return Objects.hash(modelId, ingestStats, pipelineCount, inferenceStats);
137153
}
138154

139155
@Override
@@ -147,7 +163,8 @@ public boolean equals(Object obj) {
147163
TrainedModelStats other = (TrainedModelStats) obj;
148164
return Objects.equals(this.modelId, other.modelId)
149165
&& Objects.equals(this.ingestStats, other.ingestStats)
150-
&& Objects.equals(this.pipelineCount, other.pipelineCount);
166+
&& Objects.equals(this.pipelineCount, other.pipelineCount)
167+
&& Objects.equals(this.inferenceStats, other.inferenceStats);
151168
}
152169
}
153170

@@ -171,6 +188,7 @@ public static class Builder {
171188
private long totalModelCount;
172189
private Set<String> expandedIds;
173190
private Map<String, IngestStats> ingestStatsMap;
191+
private Map<String, InferenceStats> inferenceStatsMap;
174192

175193
public Builder setTotalModelCount(long totalModelCount) {
176194
this.totalModelCount = totalModelCount;
@@ -191,13 +209,23 @@ public Builder setIngestStatsByModelId(Map<String, IngestStats> ingestStatsByMod
191209
return this;
192210
}
193211

212+
public Builder setInferenceStatsByModelId(Map<String, InferenceStats> infereceStatsByModelId) {
213+
this.inferenceStatsMap = infereceStatsByModelId;
214+
return this;
215+
}
216+
194217
public Response build() {
195218
List<TrainedModelStats> trainedModelStats = new ArrayList<>(expandedIds.size());
196219
expandedIds.forEach(id -> {
197220
IngestStats ingestStats = ingestStatsMap.get(id);
198-
trainedModelStats.add(new TrainedModelStats(id, ingestStats, ingestStats == null ?
199-
0 :
200-
ingestStats.getPipelineStats().size()));
221+
InferenceStats inferenceStats = inferenceStatsMap.get(id);
222+
trainedModelStats.add(new TrainedModelStats(
223+
id,
224+
ingestStats,
225+
ingestStats == null ?
226+
0 :
227+
ingestStats.getPipelineStats().size(),
228+
inferenceStats));
201229
});
202230
trainedModelStats.sort(Comparator.comparing(TrainedModelStats::getModelId));
203231
return new Response(new QueryPage<>(trainedModelStats, totalModelCount, RESULTS_FIELD));
Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
7+
8+
import org.elasticsearch.common.Nullable;
9+
import org.elasticsearch.common.ParseField;
10+
import org.elasticsearch.common.io.stream.StreamInput;
11+
import org.elasticsearch.common.io.stream.StreamOutput;
12+
import org.elasticsearch.common.io.stream.Writeable;
13+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
14+
import org.elasticsearch.common.xcontent.ObjectParser;
15+
import org.elasticsearch.common.xcontent.ToXContentObject;
16+
import org.elasticsearch.common.xcontent.XContentBuilder;
17+
import org.elasticsearch.xpack.core.common.time.TimeUtils;
18+
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
19+
20+
import java.io.IOException;
21+
import java.time.Instant;
22+
import java.util.Objects;
23+
import java.util.concurrent.atomic.LongAdder;
24+
25+
public class InferenceStats implements ToXContentObject, Writeable {
26+
27+
public static final String NAME = "inference_stats";
28+
public static final ParseField MISSING_ALL_FIELDS_COUNT = new ParseField("missing_all_fields_count");
29+
public static final ParseField INFERENCE_COUNT = new ParseField("inference_count");
30+
public static final ParseField MODEL_ID = new ParseField("model_id");
31+
public static final ParseField NODE_ID = new ParseField("node_id");
32+
public static final ParseField FAILURE_COUNT = new ParseField("failure_count");
33+
public static final ParseField TYPE = new ParseField("type");
34+
public static final ParseField TIMESTAMP = new ParseField("time_stamp");
35+
36+
public static final ConstructingObjectParser<InferenceStats, Void> PARSER = new ConstructingObjectParser<>(
37+
NAME,
38+
true,
39+
a -> new InferenceStats((Long)a[0], (Long)a[1], (Long)a[2], (String)a[3], (String)a[4], (Instant)a[5])
40+
);
41+
static {
42+
PARSER.declareLong(ConstructingObjectParser.constructorArg(), MISSING_ALL_FIELDS_COUNT);
43+
PARSER.declareLong(ConstructingObjectParser.constructorArg(), INFERENCE_COUNT);
44+
PARSER.declareLong(ConstructingObjectParser.constructorArg(), FAILURE_COUNT);
45+
PARSER.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID);
46+
PARSER.declareString(ConstructingObjectParser.constructorArg(), NODE_ID);
47+
PARSER.declareField(ConstructingObjectParser.constructorArg(),
48+
p -> TimeUtils.parseTimeFieldToInstant(p, TIMESTAMP.getPreferredName()),
49+
TIMESTAMP,
50+
ObjectParser.ValueType.VALUE);
51+
}
52+
public static InferenceStats emptyStats(String modelId, String nodeId) {
53+
return new InferenceStats(0L, 0L, 0L, modelId, nodeId, Instant.now());
54+
}
55+
56+
public static String docId(String modelId, String nodeId) {
57+
return NAME + "-" + modelId + "-" + nodeId;
58+
}
59+
60+
private final long missingAllFieldsCount;
61+
private final long inferenceCount;
62+
private final long failureCount;
63+
private final String modelId;
64+
private final String nodeId;
65+
private final Instant timeStamp;
66+
67+
private InferenceStats(Long missingAllFieldsCount,
68+
Long inferenceCount,
69+
Long failureCount,
70+
String modelId,
71+
String nodeId,
72+
Instant instant) {
73+
this(unbox(missingAllFieldsCount),
74+
unbox(inferenceCount),
75+
unbox(failureCount),
76+
modelId,
77+
nodeId,
78+
instant);
79+
}
80+
81+
public InferenceStats(long missingAllFieldsCount,
82+
long inferenceCount,
83+
long failureCount,
84+
String modelId,
85+
String nodeId,
86+
Instant timeStamp) {
87+
this.missingAllFieldsCount = missingAllFieldsCount;
88+
this.inferenceCount = inferenceCount;
89+
this.failureCount = failureCount;
90+
this.modelId = modelId;
91+
this.nodeId = nodeId;
92+
this.timeStamp = timeStamp == null ?
93+
Instant.ofEpochMilli(Instant.now().toEpochMilli()) :
94+
Instant.ofEpochMilli(timeStamp.toEpochMilli());
95+
}
96+
97+
public InferenceStats(StreamInput in) throws IOException {
98+
this.missingAllFieldsCount = in.readVLong();
99+
this.inferenceCount = in.readVLong();
100+
this.failureCount = in.readVLong();
101+
this.modelId = in.readOptionalString();
102+
this.nodeId = in.readOptionalString();
103+
this.timeStamp = in.readInstant();
104+
}
105+
106+
public long getMissingAllFieldsCount() {
107+
return missingAllFieldsCount;
108+
}
109+
110+
public long getInferenceCount() {
111+
return inferenceCount;
112+
}
113+
114+
public long getFailureCount() {
115+
return failureCount;
116+
}
117+
118+
public String getModelId() {
119+
return modelId;
120+
}
121+
122+
public String getNodeId() {
123+
return nodeId;
124+
}
125+
126+
public Instant getTimeStamp() {
127+
return timeStamp;
128+
}
129+
130+
@Override
131+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
132+
builder.startObject();
133+
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) {
134+
assert modelId != null : "model_id cannot be null when storing inference stats";
135+
assert nodeId != null : "node_id cannot be null when storing inference stats";
136+
builder.field(TYPE.getPreferredName(), NAME);
137+
builder.field(MODEL_ID.getPreferredName(), modelId);
138+
builder.field(NODE_ID.getPreferredName(), nodeId);
139+
}
140+
builder.field(FAILURE_COUNT.getPreferredName(), failureCount);
141+
builder.field(INFERENCE_COUNT.getPreferredName(), inferenceCount);
142+
builder.field(MISSING_ALL_FIELDS_COUNT.getPreferredName(), missingAllFieldsCount);
143+
builder.timeField(TIMESTAMP.getPreferredName(), TIMESTAMP.getPreferredName() + "_string", timeStamp.toEpochMilli());
144+
builder.endObject();
145+
return builder;
146+
}
147+
148+
@Override
149+
public boolean equals(Object o) {
150+
if (this == o) return true;
151+
if (o == null || getClass() != o.getClass()) return false;
152+
InferenceStats that = (InferenceStats) o;
153+
return missingAllFieldsCount == that.missingAllFieldsCount
154+
&& inferenceCount == that.inferenceCount
155+
&& failureCount == that.failureCount
156+
&& Objects.equals(modelId, that.modelId)
157+
&& Objects.equals(nodeId, that.nodeId)
158+
&& Objects.equals(timeStamp, that.timeStamp);
159+
}
160+
161+
@Override
162+
public int hashCode() {
163+
return Objects.hash(missingAllFieldsCount, inferenceCount, failureCount, modelId, nodeId, timeStamp);
164+
}
165+
166+
@Override
167+
public String toString() {
168+
return "InferenceStats{" +
169+
"missingAllFieldsCount=" + missingAllFieldsCount +
170+
", inferenceCount=" + inferenceCount +
171+
", failureCount=" + failureCount +
172+
", modelId='" + modelId + '\'' +
173+
", nodeId='" + nodeId + '\'' +
174+
", timeStamp=" + timeStamp +
175+
'}';
176+
}
177+
178+
private static long unbox(@Nullable Long value) {
179+
return value == null ? 0L : value;
180+
}
181+
182+
public static Accumulator accumulator(InferenceStats stats) {
183+
return new Accumulator(stats);
184+
}
185+
186+
@Override
187+
public void writeTo(StreamOutput out) throws IOException {
188+
out.writeVLong(this.missingAllFieldsCount);
189+
out.writeVLong(this.inferenceCount);
190+
out.writeVLong(this.failureCount);
191+
out.writeOptionalString(this.modelId);
192+
out.writeOptionalString(this.nodeId);
193+
out.writeInstant(timeStamp);
194+
}
195+
196+
public static class Accumulator {
197+
198+
private final LongAdder missingFieldsAccumulator = new LongAdder();
199+
private final LongAdder inferenceAccumulator = new LongAdder();
200+
private final LongAdder failureCountAccumulator = new LongAdder();
201+
private final String modelId;
202+
private final String nodeId;
203+
204+
public Accumulator(String modelId, String nodeId) {
205+
this.modelId = modelId;
206+
this.nodeId = nodeId;
207+
}
208+
209+
public Accumulator(InferenceStats previousStats) {
210+
this.modelId = previousStats.modelId;
211+
this.nodeId = previousStats.nodeId;
212+
this.missingFieldsAccumulator.add(previousStats.missingAllFieldsCount);
213+
this.inferenceAccumulator.add(previousStats.inferenceCount);
214+
this.failureCountAccumulator.add(previousStats.failureCount);
215+
}
216+
217+
public Accumulator merge(InferenceStats otherStats) {
218+
this.missingFieldsAccumulator.add(otherStats.missingAllFieldsCount);
219+
this.inferenceAccumulator.add(otherStats.inferenceCount);
220+
this.failureCountAccumulator.add(otherStats.failureCount);
221+
return this;
222+
}
223+
224+
public void incMissingFields() {
225+
this.missingFieldsAccumulator.increment();
226+
}
227+
228+
public void incInference() {
229+
this.inferenceAccumulator.increment();
230+
}
231+
232+
public void incFailure() {
233+
this.failureCountAccumulator.increment();
234+
}
235+
236+
public InferenceStats currentStats() {
237+
return currentStats(Instant.now());
238+
}
239+
240+
public InferenceStats currentStats(Instant timeStamp) {
241+
return new InferenceStats(missingFieldsAccumulator.longValue(),
242+
inferenceAccumulator.longValue(),
243+
failureCountAccumulator.longValue(),
244+
modelId,
245+
nodeId,
246+
timeStamp);
247+
}
248+
}
249+
}

0 commit comments

Comments
 (0)