Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.client.common.TimeUtil;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
Expand All @@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject {
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
}

public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
Expand All @@ -81,6 +86,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
private final List<String> tags;
private final Map<String, Object> metadata;
private final TrainedModelInput input;
private final Long estimatedHeapMemory;
private final Long estimatedOperations;

TrainedModelConfig(String modelId,
String createdBy,
Expand All @@ -90,7 +97,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata,
TrainedModelInput input) {
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
Expand All @@ -100,6 +109,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = input;
this.estimatedHeapMemory = estimatedHeapMemory;
this.estimatedOperations = estimatedOperations;
}

public String getModelId() {
Expand Down Expand Up @@ -138,6 +149,18 @@ public TrainedModelInput getInput() {
return input;
}

public ByteSizeValue getEstimatedHeapMemory() {
return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory);
}

public Long getEstimatedHeapMemoryBytes() {
return estimatedHeapMemory;
}

public Long getEstimatedOperations() {
return estimatedOperations;
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -172,6 +195,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
if (estimatedHeapMemory != null) {
builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory);
}
if (estimatedOperations != null) {
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
}
builder.endObject();
return builder;
}
Expand All @@ -194,6 +223,8 @@ public boolean equals(Object o) {
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(metadata, that.metadata);
}

Expand All @@ -206,6 +237,8 @@ public int hashCode() {
definition,
description,
tags,
estimatedHeapMemory,
estimatedOperations,
metadata,
input);
}
Expand All @@ -222,6 +255,8 @@ public static class Builder {
private List<String> tags;
private TrainedModelDefinition definition;
private TrainedModelInput input;
private Long estimatedHeapMemory;
private Long estimatedOperations;

public Builder setModelId(String modelId) {
this.modelId = modelId;
Expand Down Expand Up @@ -277,6 +312,16 @@ public Builder setInput(TrainedModelInput input) {
return this;
}

public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}

public Builder setEstimatedOperations(Long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}

public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
Expand All @@ -287,7 +332,9 @@ public TrainedModelConfig build() {
definition,
tags,
metadata,
input);
input,
estimatedHeapMemory,
estimatedOperations);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ protected TrainedModelConfig createTestInstance() {
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
randomBoolean() ? null : randomNonNegativeLong(),
randomBoolean() ? null : randomNonNegativeLong());

}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -34,6 +35,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {

public static final String NAME = "trained_model_config";

private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage";

public static final ParseField MODEL_ID = new ParseField("model_id");
public static final ParseField CREATED_BY = new ParseField("created_by");
public static final ParseField VERSION = new ParseField("version");
Expand All @@ -43,6 +46,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");

// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
Expand All @@ -66,6 +71,8 @@ private static ObjectParser<TrainedModelConfig.Builder, Void> createParser(boole
parser.declareObject(TrainedModelConfig.Builder::setInput,
(p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields),
INPUT);
parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
return parser;
}

Expand All @@ -81,6 +88,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
private final List<String> tags;
private final Map<String, Object> metadata;
private final TrainedModelInput input;
private final long estimatedHeapMemory;
private final long estimatedOperations;

private final TrainedModelDefinition definition;

Expand All @@ -92,7 +101,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata,
TrainedModelInput input) {
TrainedModelInput input,
Long estimatedHeapMemory,
Long estimatedOperations) {
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
Expand All @@ -102,6 +113,15 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS));
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = ExceptionsHelper.requireNonNull(input, INPUT);
if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) {
throw new IllegalArgumentException(
"[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0");
}
this.estimatedHeapMemory = estimatedHeapMemory;
if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) {
throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0");
}
this.estimatedOperations = estimatedOperations;
}

public TrainedModelConfig(StreamInput in) throws IOException {
Expand All @@ -114,6 +134,8 @@ public TrainedModelConfig(StreamInput in) throws IOException {
tags = Collections.unmodifiableList(in.readList(StreamInput::readString));
metadata = in.readMap();
input = new TrainedModelInput(in);
estimatedHeapMemory = in.readVLong();
estimatedOperations = in.readVLong();
}

public String getModelId() {
Expand Down Expand Up @@ -157,6 +179,14 @@ public static Builder builder() {
return new Builder();
}

public long getEstimatedHeapMemory() {
return estimatedHeapMemory;
}

public long getEstimatedOperations() {
return estimatedOperations;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(modelId);
Expand All @@ -168,6 +198,8 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeCollection(tags, StreamOutput::writeString);
out.writeMap(metadata);
input.writeTo(out);
out.writeVLong(estimatedHeapMemory);
out.writeVLong(estimatedOperations);
}

@Override
Expand All @@ -192,6 +224,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME);
}
builder.field(INPUT.getPreferredName(), input);
builder.humanReadableField(
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(),
ESTIMATED_HEAP_MEMORY_USAGE_HUMAN,
new ByteSizeValue(estimatedHeapMemory));
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
builder.endObject();
return builder;
}
Expand All @@ -214,6 +251,8 @@ public boolean equals(Object o) {
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
Objects.equals(estimatedOperations, that.estimatedOperations) &&
Objects.equals(metadata, that.metadata);
}

Expand All @@ -227,6 +266,8 @@ public int hashCode() {
description,
tags,
metadata,
estimatedHeapMemory,
estimatedOperations,
input);
}

Expand All @@ -241,6 +282,8 @@ public static class Builder {
private Map<String, Object> metadata;
private TrainedModelInput input;
private TrainedModelDefinition definition;
private Long estimatedHeapMemory;
private Long estimatedOperations;

public Builder setModelId(String modelId) {
this.modelId = modelId;
Expand Down Expand Up @@ -296,6 +339,16 @@ public Builder setInput(TrainedModelInput input) {
return this;
}

public Builder setEstimatedHeapMemory(long estimatedHeapMemory) {
this.estimatedHeapMemory = estimatedHeapMemory;
return this;
}

public Builder setEstimatedOperations(long estimatedOperations) {
this.estimatedOperations = estimatedOperations;
return this;
}

// TODO move to REST level instead of here in the builder
public void validate() {
// We require a definition to be available here even though it will be stored in a different doc
Expand Down Expand Up @@ -326,6 +379,16 @@ public void validate() {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
CREATE_TIME.getPreferredName());
}

if (estimatedHeapMemory != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName());
}

if (estimatedOperations != null) {
throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation",
ESTIMATED_OPERATIONS.getPreferredName());
}
}

public TrainedModelConfig build() {
Expand All @@ -338,7 +401,9 @@ public TrainedModelConfig build() {
definition,
tags,
metadata,
input);
input,
estimatedHeapMemory,
estimatedOperations);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou
* @throws org.elasticsearch.ElasticsearchException if validations fail
*/
void validate();

/**
* @return The estimated number of operations required at inference time
*/
long estimatedNumOperations();
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.stream.Collectors;

import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel;
Expand Down Expand Up @@ -251,6 +252,14 @@ public void validate() {
this.models.forEach(TrainedModel::validate);
}

@Override
public long estimatedNumOperations() {
OptionalDouble avg = models.stream().mapToLong(TrainedModel::estimatedNumOperations).average();
assert avg.isPresent() : "unexpected null when calculating number of operations";
// Average operations for each model and the operations required for processing and aggregating with the outputAggregator
return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1);
}

public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ public int hashCode() {

@Override
public long ramBytesUsed() {
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights);
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
return SHALLOW_SIZE + weightSize;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ public int hashCode() {

@Override
public long ramBytesUsed() {
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights);
long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights);
return SHALLOW_SIZE + weightSize;
}
}
Loading