diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 273aa6b021325..384bfe53e4bb5 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -22,6 +22,7 @@ import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -47,6 +48,8 @@ public class TrainedModelConfig implements ToXContentObject { public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField INPUT = new ParseField("input"); + public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); + public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -66,6 +69,8 @@ public class TrainedModelConfig implements ToXContentObject { PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT); + PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); + PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); } public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { @@ -81,6 +86,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr private final List tags; private final Map metadata; private final TrainedModelInput input; + private final Long estimatedHeapMemory; + private final Long estimatedOperations; TrainedModelConfig(String modelId, String createdBy, @@ -90,7 +97,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr TrainedModelDefinition definition, List tags, Map metadata, - TrainedModelInput input) { + TrainedModelInput input, + Long estimatedHeapMemory, + Long estimatedOperations) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; @@ -100,6 +109,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr this.tags = tags == null ? null : Collections.unmodifiableList(tags); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.input = input; + this.estimatedHeapMemory = estimatedHeapMemory; + this.estimatedOperations = estimatedOperations; } public String getModelId() { @@ -138,6 +149,18 @@ public TrainedModelInput getInput() { return input; } + public ByteSizeValue getEstimatedHeapMemory() { + return estimatedHeapMemory == null ? null : new ByteSizeValue(estimatedHeapMemory); + } + + public Long getEstimatedHeapMemoryBytes() { + return estimatedHeapMemory; + } + + public Long getEstimatedOperations() { + return estimatedOperations; + } + public static Builder builder() { return new Builder(); } @@ -172,6 +195,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (input != null) { builder.field(INPUT.getPreferredName(), input); } + if (estimatedHeapMemory != null) { + builder.field(ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), estimatedHeapMemory); + } + if (estimatedOperations != null) { + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); + } builder.endObject(); return builder; } @@ -194,6 +223,8 @@ public boolean equals(Object o) { Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && Objects.equals(input, that.input) && + Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && + Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(metadata, that.metadata); } @@ -206,6 +237,8 @@ public int hashCode() { definition, description, tags, + estimatedHeapMemory, + estimatedOperations, metadata, input); } @@ -222,6 +255,8 @@ public static class Builder { private List tags; private TrainedModelDefinition definition; private TrainedModelInput input; + private Long estimatedHeapMemory; + private Long estimatedOperations; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -277,6 +312,16 @@ public Builder setInput(TrainedModelInput input) { return this; } + public Builder setEstimatedHeapMemory(Long estimatedHeapMemory) { + this.estimatedHeapMemory = estimatedHeapMemory; + return this; + } + + public Builder setEstimatedOperations(Long estimatedOperations) { + this.estimatedOperations = estimatedOperations; + return this; + } + public TrainedModelConfig build() { return new TrainedModelConfig( modelId, @@ -287,7 +332,9 @@ public TrainedModelConfig build() { definition, tags, metadata, - input); + input, + estimatedHeapMemory, + estimatedOperations); } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index 6d1e04f066cb7..7afba62861362 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -64,7 +64,10 @@ protected TrainedModelConfig createTestInstance() { randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomBoolean() ? null : TrainedModelInputTests.createRandomInput()); + randomBoolean() ? null : TrainedModelInputTests.createRandomInput(), + randomBoolean() ? null : randomNonNegativeLong(), + randomBoolean() ? null : randomNonNegativeLong()); + } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 7078322b5d331..5361760e5ca26 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -34,6 +35,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final String NAME = "trained_model_config"; + private static final String ESTIMATED_HEAP_MEMORY_USAGE_HUMAN = "estimated_heap_memory_usage"; + public static final ParseField MODEL_ID = new ParseField("model_id"); public static final ParseField CREATED_BY = new ParseField("created_by"); public static final ParseField VERSION = new ParseField("version"); @@ -43,6 +46,8 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); public static final ParseField INPUT = new ParseField("input"); + public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes"); + public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -66,6 +71,8 @@ private static ObjectParser createParser(boole parser.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields), INPUT); + parser.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES); + parser.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS); return parser; } @@ -81,6 +88,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo private final List tags; private final Map metadata; private final TrainedModelInput input; + private final long estimatedHeapMemory; + private final long estimatedOperations; private final TrainedModelDefinition definition; @@ -92,7 +101,9 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo TrainedModelDefinition definition, List tags, Map metadata, - TrainedModelInput input) { + TrainedModelInput input, + Long estimatedHeapMemory, + Long estimatedOperations) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); @@ -102,6 +113,15 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS)); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); this.input = ExceptionsHelper.requireNonNull(input, INPUT); + if (ExceptionsHelper.requireNonNull(estimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES) < 0) { + throw new IllegalArgumentException( + "[" + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName() + "] must be greater than or equal to 0"); + } + this.estimatedHeapMemory = estimatedHeapMemory; + if (ExceptionsHelper.requireNonNull(estimatedOperations, ESTIMATED_OPERATIONS) < 0) { + throw new IllegalArgumentException("[" + ESTIMATED_OPERATIONS.getPreferredName() + "] must be greater than or equal to 0"); + } + this.estimatedOperations = estimatedOperations; } public TrainedModelConfig(StreamInput in) throws IOException { @@ -114,6 +134,8 @@ public TrainedModelConfig(StreamInput in) throws IOException { tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); metadata = in.readMap(); input = new TrainedModelInput(in); + estimatedHeapMemory = in.readVLong(); + estimatedOperations = in.readVLong(); } public String getModelId() { @@ -157,6 +179,14 @@ public static Builder builder() { return new Builder(); } + public long getEstimatedHeapMemory() { + return estimatedHeapMemory; + } + + public long getEstimatedOperations() { + return estimatedOperations; + } + @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); @@ -168,6 +198,8 @@ public void writeTo(StreamOutput out) throws IOException { out.writeCollection(tags, StreamOutput::writeString); out.writeMap(metadata); input.writeTo(out); + out.writeVLong(estimatedHeapMemory); + out.writeVLong(estimatedOperations); } @Override @@ -192,6 +224,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); } builder.field(INPUT.getPreferredName(), input); + builder.humanReadableField( + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName(), + ESTIMATED_HEAP_MEMORY_USAGE_HUMAN, + new ByteSizeValue(estimatedHeapMemory)); + builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations); builder.endObject(); return builder; } @@ -214,6 +251,8 @@ public boolean equals(Object o) { Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && Objects.equals(input, that.input) && + Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) && + Objects.equals(estimatedOperations, that.estimatedOperations) && Objects.equals(metadata, that.metadata); } @@ -227,6 +266,8 @@ public int hashCode() { description, tags, metadata, + estimatedHeapMemory, + estimatedOperations, input); } @@ -241,6 +282,8 @@ public static class Builder { private Map metadata; private TrainedModelInput input; private TrainedModelDefinition definition; + private Long estimatedHeapMemory; + private Long estimatedOperations; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -296,6 +339,16 @@ public Builder setInput(TrainedModelInput input) { return this; } + public Builder setEstimatedHeapMemory(long estimatedHeapMemory) { + this.estimatedHeapMemory = estimatedHeapMemory; + return this; + } + + public Builder setEstimatedOperations(long estimatedOperations) { + this.estimatedOperations = estimatedOperations; + return this; + } + // TODO move to REST level instead of here in the builder public void validate() { // We require a definition to be available here even though it will be stored in a different doc @@ -326,6 +379,16 @@ public void validate() { throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", CREATE_TIME.getPreferredName()); } + + if (estimatedHeapMemory != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName()); + } + + if (estimatedOperations != null) { + throw ExceptionsHelper.badRequestException("illegal to set [{}] at inference model creation", + ESTIMATED_OPERATIONS.getPreferredName()); + } } public TrainedModelConfig build() { @@ -338,7 +401,9 @@ public TrainedModelConfig build() { definition, tags, metadata, - input); + input, + estimatedHeapMemory, + estimatedOperations); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java index a9028efdffa94..e206a70918096 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java @@ -50,4 +50,9 @@ public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accou * @throws org.elasticsearch.ElasticsearchException if validations fail */ void validate(); + + /** + * @return The estimated number of operations required at inference time + */ + long estimatedNumOperations(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java index 79883f4db4b4e..a59f1a1c245d9 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.OptionalDouble; import java.util.stream.Collectors; import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceHelpers.classificationLabel; @@ -251,6 +252,14 @@ public void validate() { this.models.forEach(TrainedModel::validate); } + @Override + public long estimatedNumOperations() { + OptionalDouble avg = models.stream().mapToLong(TrainedModel::estimatedNumOperations).average(); + assert avg.isPresent() : "unexpected null when calculating number of operations"; + // Average operations for each model and the operations required for processing and aggregating with the outputAggregator + return (long)Math.ceil(avg.getAsDouble()) + 2 * (models.size() - 1); + } + public static Builder builder() { return new Builder(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java index 14f2b1b64b523..2dba96916390c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/LogisticRegression.java @@ -157,6 +157,7 @@ public int hashCode() { @Override public long ramBytesUsed() { - return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java index 29b311794be06..73689d16b1cf8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java @@ -174,6 +174,7 @@ public int hashCode() { @Override public long ramBytesUsed() { - return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java index b9b34508b88ba..ed1c13cf10203 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedSum.java @@ -150,6 +150,7 @@ public boolean compatibleWith(TargetType targetType) { @Override public long ramBytesUsed() { - return SHALLOW_SIZE + RamUsageEstimator.sizeOf(weights); + long weightSize = weights == null ? 0L : RamUsageEstimator.sizeOf(weights); + return SHALLOW_SIZE + weightSize; } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java index 60192f46234b9..1408b17a0691a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/Tree.java @@ -87,7 +87,10 @@ public static Tree fromXContentLenient(XContentParser parser) { Tree(List featureNames, List nodes, TargetType targetType, List classificationLabels) { this.featureNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(featureNames, FEATURE_NAMES)); - this.nodes = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE)); + if(ExceptionsHelper.requireNonNull(nodes, TREE_STRUCTURE).size() == 0) { + throw new IllegalArgumentException("[tree_structure] must not be empty"); + } + this.nodes = Collections.unmodifiableList(nodes); this.targetType = ExceptionsHelper.requireNonNull(targetType, TARGET_TYPE); this.classificationLabels = classificationLabels == null ? null : Collections.unmodifiableList(classificationLabels); this.highestOrderCategory = new CachedSupplier<>(() -> this.maxLeafValue()); @@ -257,6 +260,12 @@ public void validate() { detectCycle(); } + @Override + public long estimatedNumOperations() { + // Grabbing the features from the doc + the depth of the tree + return (long)Math.ceil(Math.log(nodes.size())) + featureNames.size(); + } + private void checkTargetType() { if ((this.targetType == TargetType.CLASSIFICATION) != (this.classificationLabels != null)) { throw ExceptionsHelper.badRequestException( @@ -265,9 +274,6 @@ private void checkTargetType() { } private void detectCycle() { - if (nodes.isEmpty()) { - return; - } Set visited = new HashSet<>(nodes.size()); Queue toVisit = new ArrayDeque<>(nodes.size()); toVisit.add(0); @@ -288,10 +294,6 @@ private void detectCycle() { } private void detectMissingNodes() { - if (nodes.isEmpty()) { - return; - } - List missingNodes = new ArrayList<>(); for (int i = 0; i < nodes.size(); i++) { TreeNode currentNode = nodes.get(i); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 5583081df2732..3f3f3cb9a3ad8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -74,7 +74,9 @@ protected TrainedModelConfig createTestInstance() { null, // is not parsed so should not be provided tags, randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - TrainedModelInputTests.createRandomInput()); + TrainedModelInputTests.createRandomInput(), + randomNonNegativeLong(), + randomNonNegativeLong()); } @Override @@ -117,7 +119,9 @@ public void testToXContentWithParams() throws IOException { TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(), Collections.emptyList(), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - TrainedModelInputTests.createRandomInput()); + TrainedModelInputTests.createRandomInput(), + randomNonNegativeLong(), + randomNonNegativeLong()); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); assertThat(reference.utf8ToString(), containsString("definition")); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 5fdadac712d0d..69ff018db9b44 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -23,6 +23,7 @@ import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests; import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.TreeTests; @@ -78,7 +79,7 @@ public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) TargetMeanEncodingTests.createRandom())) .limit(numberOfProcessors) .collect(Collectors.toList())) - .setTrainedModel(randomFrom(TreeTests.createRandom())); + .setTrainedModel(randomFrom(TreeTests.createRandom(), EnsembleTests.createRandom())); } private static final String ENSEMBLE_MODEL = "" + diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java index 753a9d3dd3cad..c38591ab6cfc5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/EnsembleTests.java @@ -445,6 +445,18 @@ public void testRegressionInference() { closeTo(((SingleValueInferenceResults)ensemble.infer(featureMap, new RegressionConfig())).value(), 0.00001)); } + public void testOperationsEstimations() { + Tree tree1 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar"), 2); + Tree tree2 = TreeTests.buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); + Tree tree3 = TreeTests.buildRandomTree(Arrays.asList("foo", "baz"), 3); + Ensemble ensemble = Ensemble.builder().setTrainedModels(Arrays.asList(tree1, tree2, tree3)) + .setTargetType(TargetType.CLASSIFICATION) + .setFeatureNames(Arrays.asList("foo", "bar", "baz")) + .setOutputAggregator(new LogisticRegression(new double[]{0.1, 0.4, 1.0})) + .build(); + assertThat(ensemble.estimatedNumOperations(), equalTo(9L)); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java index 11bf44fd165e4..7f5158706941f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/tree/TreeTests.java @@ -300,6 +300,11 @@ public void testTreeWithTargetTypeAndLabelsMismatch() { assertThat(ex.getMessage(), equalTo(msg)); } + public void testOperationsEstimations() { + Tree tree = buildRandomTree(Arrays.asList("foo", "bar", "baz"), 5); + assertThat(tree.estimatedNumOperations(), equalTo(7L)); + } + private static Map zipObjMap(List keys, List values) { return IntStream.range(0, keys.size()).boxed().collect(Collectors.toMap(keys::get, values::get)); } diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java index 790d623fbec9c..17b9ed512c82b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java @@ -366,12 +366,12 @@ private Map generateSourceDoc() { private static final String REGRESSION_CONFIG = "{" + " \"model_id\": \"test_regression\",\n" + - " \"model_version\": 0,\n" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for regression\",\n" + " \"version\": \"8.0.0\",\n" + " \"created_by\": \"ml_test\",\n" + - " \"model_type\": \"local\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + " \"created_time\": 0" + "}"; @@ -499,12 +499,12 @@ private Map generateSourceDoc() { private static final String CLASSIFICATION_CONFIG = "" + "{\n" + " \"model_id\": \"test_classification\",\n" + - " \"model_version\": 0,\n" + " \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," + " \"description\": \"test model for classification\",\n" + " \"version\": \"8.0.0\",\n" + " \"created_by\": \"benwtrent\",\n" + - " \"model_type\": \"local\",\n" + + " \"estimated_heap_memory_usage_bytes\": 0," + + " \"estimated_operations\": 0," + " \"created_time\": 0\n" + "}"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java index 1982cec7eca0c..153b169ea8f32 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelIT.java @@ -192,6 +192,8 @@ private static String buildRegressionModel(String modelId) throws IOException { .setCreatedBy("ml_test") .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) .build() .toXContent(builder, new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); return XContentHelper.convertToJson(BytesReference.bytes(builder), false, XContentType.JSON); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index eb47c17137a3b..3abc3b5e43cce 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -150,6 +150,8 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) .setDefinition(definition) + .setEstimatedHeapMemory(definition.ramBytesUsed()) + .setEstimatedOperations(definition.getTrainedModel().estimatedNumOperations()) .setInput(new TrainedModelInput(fieldNames)) .build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java index 19d5d33abe4cd..aa80807aae85e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java @@ -24,6 +24,7 @@ import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.DYNAMIC; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.ENABLED; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.KEYWORD; +import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.LONG; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.PROPERTIES; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TEXT; import static org.elasticsearch.xpack.core.ml.job.persistence.ElasticsearchMappings.TYPE; @@ -103,6 +104,12 @@ private static void addInferenceDocFields(XContentBuilder builder) throws IOExce .endObject() .startObject(TrainedModelConfig.METADATA.getPreferredName()) .field(ENABLED, false) + .endObject() + .startObject(TrainedModelConfig.ESTIMATED_OPERATIONS.getPreferredName()) + .field(TYPE, LONG) + .endObject() + .startObject(TrainedModelConfig.ESTIMATED_HEAP_MEMORY_USAGE_BYTES.getPreferredName()) + .field(TYPE, LONG) .endObject(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index cb90b39772a0c..bdccdf8c6722f 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -145,6 +145,8 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() { assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build())); assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames)); + assertThat(storedModel.getEstimatedHeapMemory(), equalTo(inferenceModel.build().ramBytesUsed())); + assertThat(storedModel.getEstimatedOperations(), equalTo(inferenceModel.build().getTrainedModel().estimatedNumOperations())); Map metadata = storedModel.getMetadata(); assertThat(metadata.size(), equalTo(1)); assertThat(metadata, hasKey("analytics_config")); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java index 099baf949b684..0b0f7514caf2e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/ModelInferenceActionIT.java @@ -69,6 +69,8 @@ public void testInferModels() throws Exception { .setModelId(modelId1)) .setVersion(Version.CURRENT) .setCreateTime(Instant.now()) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) .build(); TrainedModelConfig config2 = buildTrainedModelConfigBuilder(modelId1) .setInput(new TrainedModelInput(Arrays.asList("field1", "field2"))) @@ -77,6 +79,8 @@ public void testInferModels() throws Exception { .setTrainedModel(buildRegression()) .setModelId(modelId2)) .setVersion(Version.CURRENT) + .setEstimatedOperations(0) + .setEstimatedHeapMemory(0) .setCreateTime(Instant.now()) .build(); AtomicReference putConfigHolder = new AtomicReference<>(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 75e83acb0b22e..10644cb6da547 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -144,6 +144,8 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) + .setEstimatedHeapMemory(0) + .setEstimatedOperations(0) .setInput(TrainedModelInputTests.createRandomInput()); }