Skip to content

Commit be53e31

Browse files
authored
[ML][Inference] Adding memory and compute estimates to inference (#48955)
* [ML][Inference] Adding memory and compute estimates to inference * Make nodes non-empty * fixing tests
1 parent cfd5c64 commit be53e31

File tree

20 files changed

+198
-23
lines changed

20 files changed

+198
-23
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.client.common.TimeUtil;
2323
import org.elasticsearch.common.ParseField;
2424
import org.elasticsearch.common.Strings;
25+
import org.elasticsearch.common.unit.ByteSizeValue;
2526
import org.elasticsearch.common.xcontent.ObjectParser;
2627
import org.elasticsearch.common.xcontent.ToXContentObject;
2728
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject {
4748
public static final ParseField TAGS = new ParseField("tags");
4849
public static final ParseField METADATA = new ParseField("metadata");
4950
public static final ParseField INPUT = new ParseField("input");
51+
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
52+
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
5053

5154
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
5255
true,
@@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject {
6669
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
6770
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
6871
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
72+
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
73+
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
6974
}
7075

7176
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@@ -81,6 +86,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
8186
private final List<String> tags;
8287
private final Map<String, Object> metadata;
8388
private final TrainedModelInput input;
89+
private final Long estimatedHeapMemory;
90+
private final Long estimatedOperations;
8491

8592
TrainedModelConfig(String modelId,
8693
String createdBy,
@@ -90,7 +97,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
9097
TrainedModelDefinition definition,
9198
List<String> tags,
9299
Map<String, Object> metadata,
93-
TrainedModelInput input) {
100+
TrainedModelInput input,
101+
Long estimatedHeapMemory,
102+
Long estimatedOperations) {
94103
this.modelId = modelId;
95104
this.createdBy = createdBy;
96105
this.version = version;
@@ -100,6 +109,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
100109
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
101110
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
102111
this.input = input;
112+
this.estimatedHeapMemory = estimatedHeapMemory;
113+
this.estimatedOperations = estimatedOperations;
103114
}
104115

105116
public String getModelId() {
@@ -138,6 +149,18 @@ public TrainedModelInput getInput() {
138149
return input;
139150
}
140151

152+
public ByteSizeValue getEstimatedHeapMemory() {
153+
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory);
154+
}
155+
156+
public Long getEstimatedHeapMemoryBytes() {
157+
return estimatedHeapMemory;
158+
}
159+
160+
public Long getEstimatedOperations() {
161+
return estimatedOperations;
162+
}
163+
141164
public static Builder builder() {
142165
return new Builder();
143166
}
@@ -172,6 +195,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
172195
if (input != null) {
173196
builder.field(INPUT.getPreferredName(), input);
174197
}
198+
if (estimatedHeapMemory != null) {
199+
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory);
200+
}
201+
if (estimatedOperations != null) {
202+
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
203+
}
175204
builder.endObject();
176205
return builder;
177206
}
@@ -194,6 +223,8 @@ public boolean equals(Object o) {
194223
Objects.equals(definition, that.definition) &&
195224
Objects.equals(tags, that.tags) &&
196225
Objects.equals(input, that.input) &&
226+
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
227+
Objects.equals(estimatedOperations, that.estimatedOperations) &&
197228
Objects.equals(metadata, that.metadata);
198229
}
199230

@@ -206,6 +237,8 @@ public int hashCode() {
206237
definition,
207238
description,
208239
tags,
240+
estimatedHeapMemory,
241+
estimatedOperations,
209242
metadata,
210243
input);
211244
}
@@ -222,6 +255,8 @@ public static class Builder {
222255
private List<String> tags;
223256
private TrainedModelDefinition definition;
224257
private TrainedModelInput input;
258+
private Long estimatedHeapMemory;
259+
private Long estimatedOperations;
225260

226261
public Builder setModelId(String modelId) {
227262
this.modelId = modelId;
@@ -277,6 +312,16 @@ public Builder setInput(TrainedModelInput input) {
277312
return this;
278313
}
279314

315+
public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
316+
this.estimatedHeapMemory = estimatedHeapMemory;
317+
return this;
318+
}
319+
320+
public Builder setEstimatedOperations(Long estimatedOperations) {
321+
this.estimatedOperations = estimatedOperations;
322+
return this;
323+
}
324+
280325
public TrainedModelConfig build() {
281326
return new TrainedModelConfig(
282327
modelId,
@@ -287,7 +332,9 @@ public TrainedModelConfig build() {
287332
definition,
288333
tags,
289334
metadata,
290-
input);
335+
input,
336+
estimatedHeapMemory,
337+
estimatedOperations);
291338
}
292339
}
293340

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,10 @@ protected TrainedModelConfig createTestInstance() {
6464
randomBoolean() ? null :
6565
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
6666
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
67-
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
67+
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
68+
randomBoolean() ? null : randomNonNegativeLong(),
69+
randomBoolean() ? null : randomNonNegativeLong());
70+
6871
}
6972

7073
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.common.io.stream.StreamInput;
1313
import org.elasticsearch.common.io.stream.StreamOutput;
1414
import org.elasticsearch.common.io.stream.Writeable;
15+
import org.elasticsearch.common.unit.ByteSizeValue;
1516
import org.elasticsearch.common.xcontent.ObjectParser;
1617
import org.elasticsearch.common.xcontent.ToXContentObject;
1718
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -34,6 +35,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
3435

3536
public static final String NAME = "trained_model_config";
3637

38+
private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";
39+
3740
public static final ParseField MODEL_ID = new ParseField("model_id");
3841
public static final ParseField CREATED_BY = new ParseField("created_by");
3942
public static final ParseField VERSION = new ParseField("version");
@@ -43,6 +46,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
4346
public static final ParseField TAGS = new ParseField("tags");
4447
public static final ParseField METADATA = new ParseField("metadata");
4548
public static final ParseField INPUT = new ParseField("input");
49+
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
50+
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
4651

4752
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
4853
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
@@ -66,6 +71,8 @@ private static ObjectParser<TrainedModelConfig.Builder, Void> createParser(boole
6671
parser.declareObject(TrainedModelConfig.Builder::setInput,
6772
(p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields),
6873
INPUT);
74+
parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
75+
parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
6976
return parser;
7077
}
7178

@@ -81,6 +88,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
8188
private final List<String> tags;
8289
private final Map<String, Object> metadata;
8390
private final TrainedModelInput input;
91+
private final long estimatedHeapMemory;
92+
private final long estimatedOperations;
8493

8594
private final TrainedModelDefinition definition;
8695

@@ -92,7 +101,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
92101
TrainedModelDefinition definition,
93102
List<String> tags,
94103
Map<String, Object> metadata,
95-
TrainedModelInput input) {
104+
TrainedModelInput input,
105+
Long estimatedHeapMemory,
106+
Long estimatedOperations) {
96107
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
97108
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
98109
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
@@ -102,6 +113,15 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
102113
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
103114
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
104115
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
116+
if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) {
117+
throw new IllegalArgumentException(
118+
"[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0");
119+
}
120+
this.estimatedHeapMemory = estimatedHeapMemory;
121+
if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) {
122+
throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0");
123+
}
124+
this.estimatedOperations = estimatedOperations;
105125
}
106126

107127
public TrainedModelConfig(StreamInput in) throws IOException {
@@ -114,6 +134,8 @@ public TrainedModelConfig(StreamInput in) throws IOException {
114134
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
115135
metadata = in.readMap();
116136
input = new TrainedModelInput(in);
137+
estimatedHeapMemory = in.readVLong();
138+
estimatedOperations = in.readVLong();
117139
}
118140

119141
public String getModelId() {
@@ -157,6 +179,14 @@ public static Builder builder() {
157179
return new Builder();
158180
}
159181

182+
public long getEstimatedHeapMemory() {
183+
return estimatedHeapMemory;
184+
}
185+
186+
public long getEstimatedOperations() {
187+
return estimatedOperations;
188+
}
189+
160190
@Override
161191
public void writeTo(StreamOutput out) throws IOException {
162192
out.writeString(modelId);
@@ -168,6 +198,8 @@ public void writeTo(StreamOutput out) throws IOException {
168198
out.writeCollection(tags, StreamOutput::writeString);
169199
out.writeMap(metadata);
170200
input.writeTo(out);
201+
out.writeVLong(estimatedHeapMemory);
202+
out.writeVLong(estimatedOperations);
171203
}
172204

173205
@Override
@@ -192,6 +224,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
192224
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
193225
}
194226
builder.field(INPUT.getPreferredName(), input);
227+
builder.humanReadableField(
228+
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
229+
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
230+
new ByteSizeValue(estimatedHeapMemory));
231+
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
195232
builder.endObject();
196233
return builder;
197234
}
@@ -214,6 +251,8 @@ public boolean equals(Object o) {
214251
Objects.equals(definition, that.definition) &&
215252
Objects.equals(tags, that.tags) &&
216253
Objects.equals(input, that.input) &&
254+
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
255+
Objects.equals(estimatedOperations, that.estimatedOperations) &&
217256
Objects.equals(metadata, that.metadata);
218257
}
219258

@@ -227,6 +266,8 @@ public int hashCode() {
227266
description,
228267
tags,
229268
metadata,
269+
estimatedHeapMemory,
270+
estimatedOperations,
230271
input);
231272
}
232273

@@ -241,6 +282,8 @@ public static class Builder {
241282
private Map<String, Object> metadata;
242283
private TrainedModelInput input;
243284
private TrainedModelDefinition definition;
285+
private Long estimatedHeapMemory;
286+
private Long estimatedOperations;
244287

245288
public Builder setModelId(String modelId) {
246289
this.modelId = modelId;
@@ -296,6 +339,16 @@ public Builder setInput(TrainedModelInput input) {
296339
return this;
297340
}
298341

342+
public Builder setEstimatedHeapMemory(long estimatedHeapMemory) {
343+
this.estimatedHeapMemory = estimatedHeapMemory;
344+
return this;
345+
}
346+
347+
public Builder setEstimatedOperations(long estimatedOperations) {
348+
this.estimatedOperations = estimatedOperations;
349+
return this;
350+
}
351+
299352
// TODO move to REST level instead of here in the builder
300353
public void validate() {
301354
// We require a definition to be available here even though it will be stored in a different doc
@@ -326,6 +379,16 @@ public void validate() {
326379
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
327380
CREATE_TIME.getPreferredName());
328381
}
382+
383+
if (estimatedHeapMemory != null) {
384+
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
385+
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
386+
}
387+
388+
if (estimatedOperations != null) {
389+
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
390+
ESTIMATED_OPERATIONS.getPreferredName());
391+
}
329392
}
330393

331394
public TrainedModelConfig build() {
@@ -338,7 +401,9 @@ public TrainedModelConfig build() {
338401
definition,
339402
tags,
340403
metadata,
341-
input);
404+
input,
405+
estimatedHeapMemory,
406+
estimatedOperations);
342407
}
343408
}
344409

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
5050
* @throws org.elasticsearch.ElasticsearchException if validations fail
5151
*/
5252
void validate();
53+
54+
/**
55+
* @return The estimated number of operations required at inference time
56+
*/
57+
long estimatedNumOperations();
5358
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.List;
3939
import java.util.Map;
4040
import java.util.Objects;
41+
import java.util.OptionalDouble;
4142
import java.util.stream.Collectors;
4243

4344
import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
@@ -251,6 +252,14 @@ public void validate() {
251252
this.models.forEach(TrainedModel::validate);
252253
}
253254

255+
@Override
256+
public long estimatedNumOperations() {
257+
OptionalDouble avg = models.stream().mapToLong(TrainedModel::estimatedNumOperations).average();
258+
assert avg.isPresent() : "unexpected null when calculating number of operations";
259+
// Average operations for each model and the operations required for processing and aggregating with the outputAggregator
260+
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
261+
}
262+
254263
public static Builder builder() {
255264
return new Builder();
256265
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ public int hashCode() {
157157

158158
@Override
159159
public long ramBytesUsed() {
160-
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights);
160+
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
161+
return SHALLOW_SIZE + weightSize;
161162
}
162163
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ public int hashCode() {
174174

175175
@Override
176176
public long ramBytesUsed() {
177-
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights);
177+
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
178+
return SHALLOW_SIZE + weightSize;
178179
}
179180
}

0 commit comments

Comments
 (0)