diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java index 5c453c3985fb3..f91e30dd9ff15 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/persistence/InferenceIndexConstants.java @@ -16,8 +16,10 @@ public final class InferenceIndexConstants { * version: 7.8.0: * - adds inference_config definition to trained model config * + * version: 7.10.0: 000003 + * - adds trained_model_metadata object */ - public static final String INDEX_VERSION = "000002"; + public static final String INDEX_VERSION = "000003"; public static final String INDEX_NAME_PREFIX = ".ml-inference-"; public static final String INDEX_PATTERN = INDEX_NAME_PREFIX + "*"; public static final String LATEST_INDEX_NAME = INDEX_NAME_PREFIX + INDEX_VERSION; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java new file mode 100644 index 0000000000000..4fe3464f8e9cc --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportance.java @@ -0,0 +1,242 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class TotalFeatureImportance implements ToXContentObject, Writeable { + + private static final String NAME = "total_feature_importance"; + public static final ParseField FEATURE_NAME = new ParseField("feature_name"); + public static final ParseField IMPORTANCE = new ParseField("importance"); + public static final ParseField CLASSES = new ParseField("classes"); + public static final ParseField MEAN_MAGNITUDE = new ParseField("mean_magnitude"); + public static final ParseField MIN = new ParseField("min"); + public static final ParseField MAX = new ParseField("max"); + + // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new TotalFeatureImportance((String)a[0], (Importance)a[1], (List)a[2])); + parser.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME); + parser.declareObject(ConstructingObjectParser.optionalConstructorArg(), + ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER, + IMPORTANCE); + parser.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), + ignoreUnknownFields ? ClassImportance.LENIENT_PARSER : ClassImportance.STRICT_PARSER, + CLASSES); + return parser; + } + + public static TotalFeatureImportance fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + public final String featureName; + public final Importance importance; + public final List classImportances; + + public TotalFeatureImportance(StreamInput in) throws IOException { + this.featureName = in.readString(); + this.importance = in.readOptionalWriteable(Importance::new); + this.classImportances = in.readList(ClassImportance::new); + } + + TotalFeatureImportance(String featureName, @Nullable Importance importance, @Nullable List classImportances) { + this.featureName = featureName; + this.importance = importance; + this.classImportances = classImportances == null ? Collections.emptyList() : classImportances; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(featureName); + out.writeOptionalWriteable(importance); + out.writeList(classImportances); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FEATURE_NAME.getPreferredName(), featureName); + if (importance != null) { + builder.field(IMPORTANCE.getPreferredName(), importance); + } + if (classImportances.isEmpty() == false) { + builder.field(CLASSES.getPreferredName(), classImportances); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TotalFeatureImportance that = (TotalFeatureImportance) o; + return Objects.equals(that.importance, importance) + && Objects.equals(featureName, that.featureName) + && Objects.equals(classImportances, that.classImportances); + } + + @Override + public int hashCode() { + return Objects.hash(featureName, importance, classImportances); + } + + public static class Importance implements ToXContentObject, Writeable { + private static final String NAME = "importance"; + + // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new Importance((double)a[0], (double)a[1], (double)a[2])); + parser.declareDouble(ConstructingObjectParser.constructorArg(), MEAN_MAGNITUDE); + parser.declareDouble(ConstructingObjectParser.constructorArg(), MIN); + parser.declareDouble(ConstructingObjectParser.constructorArg(), MAX); + return parser; + } + + private final double meanMagnitude; + private final double min; + private final double max; + + public Importance(double meanMagnitude, double min, double max) { + this.meanMagnitude = meanMagnitude; + this.min = min; + this.max = max; + } + + public Importance(StreamInput in) throws IOException { + this.meanMagnitude = in.readDouble(); + this.min = in.readDouble(); + this.max = in.readDouble(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Importance that = (Importance) o; + return Double.compare(that.meanMagnitude, meanMagnitude) == 0 && + Double.compare(that.min, min) == 0 && + Double.compare(that.max, max) == 0; + } + + @Override + public int hashCode() { + return Objects.hash(meanMagnitude, min, max); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeDouble(meanMagnitude); + out.writeDouble(min); + out.writeDouble(max); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(MEAN_MAGNITUDE.getPreferredName(), meanMagnitude); + builder.field(MIN.getPreferredName(), min); + builder.field(MAX.getPreferredName(), max); + builder.endObject(); + return builder; + } + } + + public static class ClassImportance implements ToXContentObject, Writeable { + private static final String NAME = "total_class_importance"; + + public static final ParseField CLASS_NAME = new ParseField("class_name"); + public static final ParseField IMPORTANCE = new ParseField("importance"); + + // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new ClassImportance((String)a[0], (Importance)a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), CLASS_NAME); + parser.declareObject(ConstructingObjectParser.constructorArg(), + ignoreUnknownFields ? Importance.LENIENT_PARSER : Importance.STRICT_PARSER, + IMPORTANCE); + return parser; + } + + public static ClassImportance fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + public final String className; + public final Importance importance; + + public ClassImportance(StreamInput in) throws IOException { + this.className = in.readString(); + this.importance = new Importance(in); + } + + ClassImportance(String className, Importance importance) { + this.className = className; + this.importance = importance; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(className); + importance.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CLASS_NAME.getPreferredName(), className); + builder.field(IMPORTANCE.getPreferredName(), importance); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ClassImportance that = (ClassImportance) o; + return Objects.equals(that.importance, importance) && Objects.equals(className, that.className); + } + + @Override + public int hashCode() { + return Objects.hash(className, importance); + } + + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java new file mode 100644 index 0000000000000..dc3e8fc54d998 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadata.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +public class TrainedModelMetadata implements ToXContentObject, Writeable { + + public static final String NAME = "trained_model_metadata"; + public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance"); + public static final ParseField MODEL_ID = new ParseField("model_id"); + + // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new TrainedModelMetadata((String)a[0], (List)a[1])); + parser.declareString(ConstructingObjectParser.constructorArg(), MODEL_ID); + parser.declareObjectArray(ConstructingObjectParser.constructorArg(), + ignoreUnknownFields ? TotalFeatureImportance.LENIENT_PARSER : TotalFeatureImportance.STRICT_PARSER, + TOTAL_FEATURE_IMPORTANCE); + return parser; + } + + public static TrainedModelMetadata fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + public static String docId(String modelId) { + return NAME + "-" + modelId; + } + + private final List totalFeatureImportances; + private final String modelId; + + public TrainedModelMetadata(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.totalFeatureImportances = in.readList(TotalFeatureImportance::new); + } + + public TrainedModelMetadata(String modelId, List totalFeatureImportances) { + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); + this.totalFeatureImportances = Collections.unmodifiableList(totalFeatureImportances); + } + + public String getModelId() { + return modelId; + } + + public String getDocId() { + return docId(modelId); + } + + public List getTotalFeatureImportances() { + return totalFeatureImportances; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelMetadata that = (TrainedModelMetadata) o; + return Objects.equals(totalFeatureImportances, that.totalFeatureImportances) && + Objects.equals(modelId, that.modelId); + } + + @Override + public int hashCode() { + return Objects.hash(totalFeatureImportances, modelId); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeList(totalFeatureImportances); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); + } + builder.field(MODEL_ID.getPreferredName(), modelId); + builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), totalFeatureImportances); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 8ac621cec1294..a382f269bd17c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -92,12 +92,15 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists"; + public static final String INFERENCE_TRAINED_MODEL_METADATA_EXISTS = "Trained machine learning model metadata [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; + public static final String INFERENCE_FAILED_TO_STORE_MODEL_METADATA = "Failed to store trained machine learning model metadata [{0}]"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}"; public static final String INFERENCE_CONFIG_NOT_SUPPORTED_ON_VERSION = "Configuration [{0}] requires minimum node version [{1}] (current minimum node version [{2}]"; public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; + public static final String MODEL_METADATA_NOT_FOUND = "Could not find trained model metadata [{0}]"; public static final String INFERENCE_CANNOT_DELETE_MODEL = "Unable to delete model [{0}]"; public static final String MODEL_DEFINITION_TRUNCATED = diff --git a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json index a77a0119e953b..00f5eb2a90fe2 100644 --- a/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json +++ b/x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json @@ -2,7 +2,7 @@ "order" : 0, "version" : ${xpack.ml.version.id}, "index_patterns" : [ - ".ml-inference-000002" + ".ml-inference-000003" ], "settings" : { "index" : { @@ -70,6 +70,50 @@ }, "inference_config": { "enabled": false + }, + "total_feature_importance": { + "type": "nested", + "dynamic": "false", + "properties": { + "importance": { + "properties": { + "min": { + "type": "double" + }, + "max": { + "type": "double" + }, + "mean_magnitude": { + "type": "double" + } + } + }, + "feature_name": { + "type": "keyword" + }, + "classes": { + "type": "nested", + "dynamic": "false", + "properties": { + "importance": { + "properties": { + "min": { + "type": "double" + }, + "max": { + "type": "double" + }, + "mean_magnitude": { + "type": "double" + } + } + }, + "class_name": { + "type": "keyword" + } + } + } + } } } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java new file mode 100644 index 0000000000000..fcf4978f5258e --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TotalFeatureImportanceTests.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TotalFeatureImportanceTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + public static TotalFeatureImportance randomInstance() { + return new TotalFeatureImportance( + randomAlphaOfLength(10), + randomBoolean() ? null : randomImportance(), + randomBoolean() ? + null : + Stream.generate(() -> new TotalFeatureImportance.ClassImportance(randomAlphaOfLength(10), randomImportance())) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList()) + ); + } + + private static TotalFeatureImportance.Importance randomImportance() { + return new TotalFeatureImportance.Importance(randomDouble(), randomDouble(), randomDouble()); + } + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected TotalFeatureImportance createTestInstance() { + return randomInstance(); + } + + @Override + protected Writeable.Reader instanceReader() { + return TotalFeatureImportance::new; + } + + @Override + protected TotalFeatureImportance doParseInstance(XContentParser parser) throws IOException { + return TotalFeatureImportance.fromXContent(parser, lenient); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected TotalFeatureImportance mutateInstanceForVersion(TotalFeatureImportance instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java new file mode 100644 index 0000000000000..6567a729ee2b6 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/metadata/TrainedModelMetadataTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.AbstractBWCSerializationTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TrainedModelMetadataTests extends AbstractBWCSerializationTestCase { + + private boolean lenient; + + public static TrainedModelMetadata randomInstance() { + return new TrainedModelMetadata( + randomAlphaOfLength(10), + Stream.generate(TotalFeatureImportanceTests::randomInstance).limit(randomIntBetween(1, 10)).collect(Collectors.toList())); + } + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected TrainedModelMetadata createTestInstance() { + return randomInstance(); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelMetadata::new; + } + + @Override + protected TrainedModelMetadata doParseInstance(XContentParser parser) throws IOException { + return TrainedModelMetadata.fromXContent(parser, lenient); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected TrainedModelMetadata mutateInstanceForVersion(TrainedModelMetadata instance, Version version) { + return instance; + } +} diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java index 5331beacf4a69..3c7a9cfd8d14b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.ml.integration; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.DocWriteRequest; @@ -75,7 +74,6 @@ import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.startsWith; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1456") public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String BOOLEAN_FIELD = "boolean-field"; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java index 035d72ee5e433..47b53010f636c 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java @@ -5,7 +5,6 @@ */ package org.elasticsearch.xpack.ml.integration; -import org.apache.lucene.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -58,7 +57,6 @@ import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1456") public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase { private static final String NUMERICAL_FEATURE_FIELD = "feature"; diff --git a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java index 624aee9e41be3..b9549842333d7 100644 --- a/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java +++ b/x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/ChunkedTrainedModelPersisterIT.java @@ -22,8 +22,11 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister; +import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; @@ -40,8 +43,11 @@ import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.startsWith; public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase { @@ -76,10 +82,14 @@ public void testStoreModelViaChunkedPersister() throws IOException { //Accuracy for size is not tested here ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); - persister.createAndIndexInferenceModelMetadata(modelSizeInfo); + persister.createAndIndexInferenceModelConfig(modelSizeInfo); for (int i = 0; i < chunks.size(); i++) { persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1))); } + ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); + persister.createAndIndexInferenceModelMetadata(modelMetadata); PlainActionFuture>> getIdsFuture = new PlainActionFuture<>(); trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture); @@ -93,6 +103,13 @@ public void testStoreModelViaChunkedPersister() throws IOException { assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition)); assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations())); assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed())); + + PlainActionFuture getTrainedMetadataFuture = new PlainActionFuture<>(); + trainedModelProvider.getTrainedModelMetadata(ids.v2().iterator().next(), getTrainedMetadataFuture); + + TrainedModelMetadata storedMetadata = getTrainedMetadataFuture.actionGet(); + assertThat(storedMetadata.getModelId(), startsWith(modelId)); + assertThat(storedMetadata.getTotalFeatureImportances(), equalTo(modelMetadata.getFeatureImportances())); } private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 41dcea15577e8..69ade87dfd588 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; +import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; @@ -141,12 +142,16 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } ModelSizeInfo modelSize = result.getModelSizeInfo(); if (modelSize != null) { - latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize); + latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelConfig(modelSize); } TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk(); if (trainedModelDefinitionChunk != null && isCancelled == false) { chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk); } + ModelMetadata modelMetadata = result.getModelMetadata(); + if (modelMetadata != null) { + chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelMetadata); + } MemoryUsage memoryUsage = result.getMemoryUsage(); if (memoryUsage != null) { processMemoryUsage(memoryUsage); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java index 213fa1d369ffb..df2cb4cf32e8f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersister.java @@ -23,9 +23,11 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.security.user.XPackUser; +import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.extractor.ExtractedField; import org.elasticsearch.xpack.ml.extractor.ExtractedFields; @@ -75,7 +77,7 @@ public ChunkedTrainedModelPersister(TrainedModelProvider provider, } public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedModelDefinitionChunk) { - if (Strings.isNullOrEmpty(this.currentModelId.get())) { + if (readyToStoreNewModel.get()) { failureHandler.accept(ExceptionsHelper.serverError( "chunked inference model definition is attempting to be stored before trained model configuration" )); @@ -98,7 +100,7 @@ public void createAndIndexInferenceModelDoc(TrainedModelDefinitionChunk trainedM } } - public String createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelSize) { + public String createAndIndexInferenceModelConfig(ModelSizeInfo inferenceModelSize) { if (readyToStoreNewModel.compareAndSet(true, false) == false) { failureHandler.accept(ExceptionsHelper.serverError( "new inference model is attempting to be stored before completion previous model storage" @@ -106,19 +108,41 @@ public String createAndIndexInferenceModelMetadata(ModelSizeInfo inferenceModelS return null; } TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModelSize); - CountDownLatch latch = storeTrainedModelMetadata(trainedModelConfig); + CountDownLatch latch = storeTrainedModelConfig(trainedModelConfig); try { if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) { - LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId()); + LOGGER.error("[{}] Timed out (30s) waiting for inference model config to be stored", analytics.getId()); } } catch (InterruptedException e) { Thread.currentThread().interrupt(); this.readyToStoreNewModel.set(true); - failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored")); + failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model config to be stored")); } return trainedModelConfig.getModelId(); } + public void createAndIndexInferenceModelMetadata(ModelMetadata modelMetadata) { + if (Strings.isNullOrEmpty(this.currentModelId.get())) { + failureHandler.accept(ExceptionsHelper.serverError( + "inference model metadata is attempting to be stored before trained model configuration" + )); + return; + } + TrainedModelMetadata trainedModelMetadata = new TrainedModelMetadata(this.currentModelId.get(), + modelMetadata.getFeatureImportances()); + + + CountDownLatch latch = storeTrainedModelMetadata(trainedModelMetadata); + try { + if (latch.await(STORE_TIMEOUT_SEC, TimeUnit.SECONDS) == false) { + LOGGER.error("[{}] Timed out (30s) waiting for inference model metadata to be stored", analytics.getId()); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + failureHandler.accept(ExceptionsHelper.serverError("interrupted waiting for inference model metadata to be stored")); + } + } + private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedModelDefinitionDoc) { CountDownLatch latch = new CountDownLatch(1); @@ -154,7 +178,6 @@ private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedMod analytics.getId(), this.currentModelId.get()); auditor.info(analytics.getId(), "Stored trained model with id [" + this.currentModelId.get() + "]"); - this.currentModelId.set(""); readyToStoreNewModel.set(true); provider.refreshInferenceIndex(refreshListener); }, @@ -171,26 +194,68 @@ private CountDownLatch storeTrainedModelDoc(TrainedModelDefinitionDoc trainedMod provider.storeTrainedModelDefinitionDoc(trainedModelDefinitionDoc, storeListener); return latch; } - private CountDownLatch storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig) { + + private CountDownLatch storeTrainedModelMetadata(TrainedModelMetadata trainedModelMetadata) { + CountDownLatch latch = new CountDownLatch(1); + + // Latch is attached to this action as it is the last one to execute. + ActionListener refreshListener = new LatchedActionListener<>(ActionListener.wrap( + refreshed -> { + if (refreshed != null) { + LOGGER.debug(() -> new ParameterizedMessage( + "[{}] refreshed inference index after model metadata store", + analytics.getId() + )); + } + }, + e -> LOGGER.warn( + new ParameterizedMessage("[{}] failed to refresh inference index after model metadata store", analytics.getId()), + e) + ), latch); + + // First, store the model and refresh is necessary + ActionListener storeListener = ActionListener.wrap( + r -> { + LOGGER.debug( + "[{}] stored trained model metadata with id [{}]", + analytics.getId(), + this.currentModelId.get()); + readyToStoreNewModel.set(true); + provider.refreshInferenceIndex(refreshListener); + }, + e -> { + this.readyToStoreNewModel.set(true); + failureHandler.accept(ExceptionsHelper.serverError( + "error storing trained model metadata with id [{}]", + e, + trainedModelMetadata.getModelId())); + refreshListener.onResponse(null); + } + ); + provider.storeTrainedModelMetadata(trainedModelMetadata, storeListener); + return latch; + } + + private CountDownLatch storeTrainedModelConfig(TrainedModelConfig trainedModelConfig) { CountDownLatch latch = new CountDownLatch(1); ActionListener storeListener = ActionListener.wrap( aBoolean -> { if (aBoolean == false) { - LOGGER.error("[{}] Storing trained model metadata responded false", analytics.getId()); + LOGGER.error("[{}] Storing trained model config responded false", analytics.getId()); readyToStoreNewModel.set(true); - failureHandler.accept(ExceptionsHelper.serverError("storing trained model responded false")); + failureHandler.accept(ExceptionsHelper.serverError("storing trained model config false")); } else { - LOGGER.debug("[{}] Stored trained model metadata with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); + LOGGER.debug("[{}] Stored trained model config with id [{}]", analytics.getId(), trainedModelConfig.getModelId()); } }, e -> { readyToStoreNewModel.set(true); - failureHandler.accept(ExceptionsHelper.serverError("error storing trained model metadata with id [{}]", + failureHandler.accept(ExceptionsHelper.serverError("error storing trained model config with id [{}]", e, trainedModelConfig.getModelId())); } ); - provider.storeTrainedModelMetadata(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); + provider.storeTrainedModelConfig(trainedModelConfig, new LatchedActionListener<>(storeListener, latch)); return latch; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 0020c2df8bc88..ac2298f01c751 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -33,6 +33,7 @@ public class AnalyticsResult implements ToXContentObject { private static final ParseField OUTLIER_DETECTION_STATS = new ParseField("outlier_detection_stats"); private static final ParseField CLASSIFICATION_STATS = new ParseField("classification_stats"); private static final ParseField REGRESSION_STATS = new ParseField("regression_stats"); + private static final ParseField MODEL_METADATA = new ParseField("model_metadata"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), a -> new AnalyticsResult( @@ -43,7 +44,8 @@ public class AnalyticsResult implements ToXContentObject { (ClassificationStats) a[4], (RegressionStats) a[5], (ModelSizeInfo) a[6], - (TrainedModelDefinitionChunk) a[7] + (TrainedModelDefinitionChunk) a[7], + (ModelMetadata) a[8] )); static { @@ -55,6 +57,7 @@ public class AnalyticsResult implements ToXContentObject { PARSER.declareObject(optionalConstructorArg(), RegressionStats.STRICT_PARSER, REGRESSION_STATS); PARSER.declareObject(optionalConstructorArg(), ModelSizeInfo.PARSER, MODEL_SIZE_INFO); PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinitionChunk.PARSER, COMPRESSED_INFERENCE_MODEL); + PARSER.declareObject(optionalConstructorArg(), ModelMetadata.PARSER, MODEL_METADATA); } private final RowResults rowResults; @@ -65,6 +68,7 @@ public class AnalyticsResult implements ToXContentObject { private final RegressionStats regressionStats; private final ModelSizeInfo modelSizeInfo; private final TrainedModelDefinitionChunk trainedModelDefinitionChunk; + private final ModelMetadata modelMetadata; private AnalyticsResult(@Nullable RowResults rowResults, @Nullable PhaseProgress phaseProgress, @@ -73,7 +77,8 @@ private AnalyticsResult(@Nullable RowResults rowResults, @Nullable ClassificationStats classificationStats, @Nullable RegressionStats regressionStats, @Nullable ModelSizeInfo modelSizeInfo, - @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) { + @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk, + @Nullable ModelMetadata modelMetadata) { this.rowResults = rowResults; this.phaseProgress = phaseProgress; this.memoryUsage = memoryUsage; @@ -82,6 +87,7 @@ private AnalyticsResult(@Nullable RowResults rowResults, this.regressionStats = regressionStats; this.modelSizeInfo = modelSizeInfo; this.trainedModelDefinitionChunk = trainedModelDefinitionChunk; + this.modelMetadata = modelMetadata; } public RowResults getRowResults() { @@ -116,6 +122,10 @@ public TrainedModelDefinitionChunk getTrainedModelDefinitionChunk() { return trainedModelDefinitionChunk; } + public ModelMetadata getModelMetadata() { + return modelMetadata; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -143,6 +153,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (trainedModelDefinitionChunk != null) { builder.field(COMPRESSED_INFERENCE_MODEL.getPreferredName(), trainedModelDefinitionChunk); } + if (modelMetadata != null) { + builder.field(MODEL_METADATA.getPreferredName(), modelMetadata); + } builder.endObject(); return builder; } @@ -164,13 +177,14 @@ public boolean equals(Object other) { && Objects.equals(classificationStats, that.classificationStats) && Objects.equals(modelSizeInfo, that.modelSizeInfo) && Objects.equals(trainedModelDefinitionChunk, that.trainedModelDefinitionChunk) + && Objects.equals(modelMetadata, that.modelMetadata) && Objects.equals(regressionStats, that.regressionStats); } @Override public int hashCode() { return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats, - regressionStats, modelSizeInfo, trainedModelDefinitionChunk); + regressionStats, modelSizeInfo, trainedModelDefinitionChunk, modelMetadata); } public static Builder builder() { @@ -187,6 +201,7 @@ public static class Builder { private RegressionStats regressionStats; private ModelSizeInfo modelSizeInfo; private TrainedModelDefinitionChunk trainedModelDefinitionChunk; + private ModelMetadata modelMetadata; private Builder() {} @@ -230,6 +245,11 @@ public Builder setTrainedModelDefinitionChunk(TrainedModelDefinitionChunk traine return this; } + public Builder setModelMetadata(ModelMetadata modelMetadata) { + this.modelMetadata = modelMetadata; + return this; + } + public AnalyticsResult build() { return new AnalyticsResult( rowResults, @@ -239,7 +259,8 @@ public AnalyticsResult build() { classificationStats, regressionStats, modelSizeInfo, - trainedModelDefinitionChunk + trainedModelDefinitionChunk, + modelMetadata ); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/ModelMetadata.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/ModelMetadata.java new file mode 100644 index 0000000000000..fb05989652b29 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/ModelMetadata.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.ml.dataframe.process.results; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportance; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; + +public class ModelMetadata implements ToXContentObject { + + public static final ParseField TOTAL_FEATURE_IMPORTANCE = new ParseField("total_feature_importance"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "trained_model_metadata", + a -> new ModelMetadata((List) a[0])); + + static { + PARSER.declareObjectArray(constructorArg(), TotalFeatureImportance.STRICT_PARSER, TOTAL_FEATURE_IMPORTANCE); + } + + private final List featureImportances; + + public ModelMetadata(List featureImportances) { + this.featureImportances = featureImportances; + } + + public List getFeatureImportances() { + return featureImportances; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ModelMetadata that = (ModelMetadata) o; + return Objects.equals(featureImportances, that.featureImportances); + } + + @Override + public int hashCode() { + return Objects.hash(featureImportances); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TOTAL_FEATURE_IMPORTANCE.getPreferredName(), featureImportances); + builder.endObject(); + return builder; + } + +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index e460be8b59094..f119774a96880 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -74,6 +74,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceStats; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.inference.InferenceDefinition; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.langident.LangIdentNeuralNetwork; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; @@ -146,8 +147,7 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig, storeTrainedModelAndDefinition(trainedModelConfig, listener); } - public void storeTrainedModelMetadata(TrainedModelConfig trainedModelConfig, - ActionListener listener) { + public void storeTrainedModelConfig(TrainedModelConfig trainedModelConfig, ActionListener listener) { if (MODELS_STORED_AS_RESOURCE.contains(trainedModelConfig.getModelId())) { listener.onFailure(new ResourceAlreadyExistsException( Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); @@ -206,6 +206,68 @@ public void storeTrainedModelDefinitionDoc(TrainedModelDefinitionDoc trainedMode )); } + public void storeTrainedModelMetadata(TrainedModelMetadata trainedModelMetadata, ActionListener listener) { + if (MODELS_STORED_AS_RESOURCE.contains(trainedModelMetadata.getModelId())) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelMetadata.getModelId()))); + return; + } + executeAsyncWithOrigin(client, + ML_ORIGIN, + IndexAction.INSTANCE, + createRequest(trainedModelMetadata.getDocId(), InferenceIndexConstants.LATEST_INDEX_NAME, trainedModelMetadata), + ActionListener.wrap( + indexResponse -> listener.onResponse(null), + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_METADATA_EXISTS, + trainedModelMetadata.getModelId()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL_METADATA, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelMetadata.getModelId())); + } + } + )); + } + + public void getTrainedModelMetadata(String modelId, ActionListener listener) { + SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders + .boolQuery() + .filter(QueryBuilders.termQuery(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId)) + .filter(QueryBuilders.termQuery(InferenceIndexConstants.DOC_TYPE.getPreferredName(), + TrainedModelMetadata.NAME)))) + .setSize(1) + // First find the latest index + .addSort("_index", SortOrder.DESC) + .request(); + executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, ActionListener.wrap( + searchResponse -> { + if (searchResponse.getHits().getHits().length == 0) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); + return; + } + List metadataList = handleHits(searchResponse.getHits().getHits(), + modelId, + this::parseMetadataLenientlyFromSource); + listener.onResponse(metadataList.get(0)); + }, + e -> { + if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.MODEL_METADATA_NOT_FOUND, modelId))); + return; + } + listener.onFailure(e); + } + )); + } + public void refreshInferenceIndex(ActionListener listener) { executeAsyncWithOrigin(client, ML_ORIGIN, @@ -927,6 +989,17 @@ private TrainedModelDefinitionDoc parseModelDefinitionDocLenientlyFromSource(Byt } } + private TrainedModelMetadata parseMetadataLenientlyFromSource(BytesReference source, String modelId) throws IOException { + try (InputStream stream = source.streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { + return TrainedModelMetadata.fromXContent(parser, true); + } catch (IOException e) { + logger.error(new ParameterizedMessage("[{}] failed to parse model metadata", modelId), e); + throw e; + } + } + private IndexRequest createRequest(String docId, String index, ToXContentObject body) { return createRequest(new IndexRequest(index), docId, body); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java index 5c450df29b360..b7abf9cdbeaa8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/ChunkedTrainedModelPersisterTests.java @@ -18,7 +18,10 @@ import org.elasticsearch.xpack.core.ml.dataframe.analyses.Classification; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TrainedModelMetadata; import org.elasticsearch.xpack.core.security.user.XPackUser; +import org.elasticsearch.xpack.ml.dataframe.process.results.ModelMetadata; import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.extractor.DocValueField; import org.elasticsearch.xpack.ml.extractor.ExtractedField; @@ -35,6 +38,8 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.contains; import static org.hamcrest.Matchers.containsString; @@ -78,7 +83,7 @@ public void testPersistAllDocs() { ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; storeListener.onResponse(true); return null; - }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelConfig.class), any(ActionListener.class)); + }).when(trainedModelProvider).storeTrainedModelConfig(any(TrainedModelConfig.class), any(ActionListener.class)); doAnswer(invocationOnMock -> { ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -86,22 +91,36 @@ public void testPersistAllDocs() { return null; }).when(trainedModelProvider).storeTrainedModelDefinitionDoc(any(TrainedModelDefinitionDoc.class), any(ActionListener.class)); + doAnswer(invocationOnMock -> { + ActionListener storeListener = (ActionListener) invocationOnMock.getArguments()[1]; + storeListener.onResponse(null); + return null; + }).when(trainedModelProvider).storeTrainedModelMetadata(any(TrainedModelMetadata.class), any(ActionListener.class)); + ChunkedTrainedModelPersister resultProcessor = createChunkedTrainedModelPersister(extractedFieldList, analyticsConfig); ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom(); TrainedModelDefinitionChunk chunk1 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 0, false); TrainedModelDefinitionChunk chunk2 = new TrainedModelDefinitionChunk(randomAlphaOfLength(10), 1, true); + ModelMetadata modelMetadata = new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList())); - resultProcessor.createAndIndexInferenceModelMetadata(modelSizeInfo); + resultProcessor.createAndIndexInferenceModelConfig(modelSizeInfo); resultProcessor.createAndIndexInferenceModelDoc(chunk1); resultProcessor.createAndIndexInferenceModelDoc(chunk2); + resultProcessor.createAndIndexInferenceModelMetadata(modelMetadata); ArgumentCaptor storedModelCaptor = ArgumentCaptor.forClass(TrainedModelConfig.class); - verify(trainedModelProvider).storeTrainedModelMetadata(storedModelCaptor.capture(), any(ActionListener.class)); + verify(trainedModelProvider).storeTrainedModelConfig(storedModelCaptor.capture(), any(ActionListener.class)); ArgumentCaptor storedDocCapture = ArgumentCaptor.forClass(TrainedModelDefinitionDoc.class); verify(trainedModelProvider, times(2)) .storeTrainedModelDefinitionDoc(storedDocCapture.capture(), any(ActionListener.class)); + ArgumentCaptor storedMetadataCaptor = ArgumentCaptor.forClass(TrainedModelMetadata.class); + verify(trainedModelProvider, times(1)) + .storeTrainedModelMetadata(storedMetadataCaptor.capture(), any(ActionListener.class)); + TrainedModelConfig storedModel = storedModelCaptor.getValue(); assertThat(storedModel.getLicenseLevel(), equalTo(License.OperationMode.PLATINUM)); assertThat(storedModel.getModelId(), containsString(JOB_ID)); @@ -132,6 +151,9 @@ public void testPersistAllDocs() { assertThat(storedModel.getModelId(), equalTo(storedDoc1.getModelId())); assertThat(storedModel.getModelId(), equalTo(storedDoc2.getModelId())); + TrainedModelMetadata storedMetadata = storedMetadataCaptor.getValue(); + assertThat(storedMetadata.getModelId(), equalTo(storedModel.getModelId())); + ArgumentCaptor auditCaptor = ArgumentCaptor.forClass(String.class); verify(auditor).info(eq(JOB_ID), auditCaptor.capture()); assertThat(auditCaptor.getValue(), containsString("Stored trained model with id [" + JOB_ID)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 35e562bca2459..44a58704445f4 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.metadata.TotalFeatureImportanceTests; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; @@ -24,6 +25,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class AnalyticsResultTests extends AbstractXContentTestCase { @@ -38,7 +41,6 @@ protected NamedXContentRegistry xContentRegistry() { protected AnalyticsResult createTestInstance() { AnalyticsResult.Builder builder = AnalyticsResult.builder(); - if (randomBoolean()) { builder.setRowResults(RowResultsTests.createRandom()); } @@ -64,6 +66,11 @@ protected AnalyticsResult createTestInstance() { String def = randomAlphaOfLengthBetween(100, 1000); builder.setTrainedModelDefinitionChunk(new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean())); } + if (randomBoolean()) { + builder.setModelMetadata(new ModelMetadata(Stream.generate(TotalFeatureImportanceTests::randomInstance) + .limit(randomIntBetween(1, 10)) + .collect(Collectors.toList()))); + } return builder.build(); } diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml index 1b5aa8bbc4929..e2da5db5b4495 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/ml/inference_stats_crud.yml @@ -5,12 +5,12 @@ setup: - allowed_warnings - do: allowed_warnings: - - "index [.ml-inference-000002] matches multiple legacy templates [.ml-inference-000002, global], composable templates will only match a single template" + - "index [.ml-inference-000003] matches multiple legacy templates [.ml-inference-000003, global], composable templates will only match a single template" headers: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-unused-regression-model1-0 - index: .ml-inference-000002 + index: .ml-inference-000003 body: > { "model_id": "a-unused-regression-model1", @@ -27,7 +27,7 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-unused-regression-model-0 - index: .ml-inference-000002 + index: .ml-inference-000003 body: > { "model_id": "a-unused-regression-model", @@ -43,7 +43,7 @@ setup: Authorization: "Basic eF9wYWNrX3Jlc3RfdXNlcjp4LXBhY2stdGVzdC1wYXNzd29yZA==" # run as x_pack_rest_user, i.e. the test setup superuser index: id: trained_model_config-a-used-regression-model-0 - index: .ml-inference-000002 + index: .ml-inference-000003 body: > { "model_id": "a-used-regression-model",