From 95b0f395035bd846f1dc05918f5fe4f641924e86 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 29 Oct 2019 15:14:12 -0400 Subject: [PATCH 1/4] [ML][Inference] separating definition and config object storage --- .../ml/inference/TrainedModelConfig.java | 80 ++++++- .../ml/inference/TrainedModelDefinition.java | 79 +------ .../ml/inference/TrainedModelConfigTests.java | 5 +- .../TrainedModelDefinitionTests.java | 5 +- .../core/ml/inference/TrainedModelConfig.java | 105 ++++++++- .../ml/inference/TrainedModelDefinition.java | 110 +++------ .../xpack/core/ml/job/messages/Messages.java | 3 +- .../ml/inference/TrainedModelConfigTests.java | 53 ++++- .../TrainedModelDefinitionTests.java | 24 +- .../process/AnalyticsProcessManager.java | 2 +- .../process/AnalyticsResultProcessor.java | 23 +- .../process/results/AnalyticsResult.java | 14 +- .../persistence/InferenceInternalIndex.java | 6 +- .../persistence/TrainedModelProvider.java | 211 +++++++++++++----- .../AnalyticsResultProcessorTests.java | 16 +- .../process/results/AnalyticsResultTests.java | 7 +- .../integration/TrainedModelProviderIT.java | 56 ++++- 17 files changed, 503 insertions(+), 296 deletions(-) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index f50c9b69eef53..ece9765960f06 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -22,6 +22,7 @@ import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -46,6 +47,7 @@ public class TrainedModelConfig implements ToXContentObject { public static final ParseField DEFINITION = new ParseField("definition"); public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); + public static final ParseField INPUT = new ParseField("input"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -64,6 +66,7 @@ public class TrainedModelConfig implements ToXContentObject { DEFINITION); PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); + PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT); } public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { @@ -78,6 +81,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr private final TrainedModelDefinition definition; private final List tags; private final Map metadata; + private final Input input; TrainedModelConfig(String modelId, String createdBy, @@ -86,7 +90,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr Instant createTime, TrainedModelDefinition definition, List tags, - Map metadata) { + Map metadata, + Input input) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; @@ -95,6 +100,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr this.description = description; this.tags = tags == null ? null : Collections.unmodifiableList(tags); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); + this.input = input; } public String getModelId() { @@ -129,6 +135,10 @@ public TrainedModelDefinition getDefinition() { return definition; } + public Input getInput() { + return input; + } + public static Builder builder() { return new Builder(); } @@ -160,6 +170,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (metadata != null) { builder.field(METADATA.getPreferredName(), metadata); } + if (input != null) { + builder.field(INPUT.getPreferredName(), input); + } builder.endObject(); return builder; } @@ -181,6 +194,7 @@ public boolean equals(Object o) { Objects.equals(createTime, that.createTime) && Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && + Objects.equals(input, that.input) && Objects.equals(metadata, that.metadata); } @@ -193,7 +207,8 @@ public int hashCode() { definition, description, tags, - metadata); + metadata, + input); } @@ -207,6 +222,7 @@ public static class Builder { private Map metadata; private List tags; private TrainedModelDefinition definition; + private Input input; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -257,6 +273,11 @@ public Builder setDefinition(TrainedModelDefinition definition) { return this; } + public Builder setInput(Input input) { + this.input = input; + return this; + } + public TrainedModelConfig build() { return new TrainedModelConfig( modelId, @@ -266,7 +287,60 @@ public TrainedModelConfig build() { createTime, definition, tags, - metadata); + metadata, + input); + } + } + + public static class Input implements ToXContentObject { + + public static final String NAME = "trained_model_config_input"; + public static final ParseField FIELD_NAMES = new ParseField("field_names"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new Input((List)a[0])); + static { + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); + } + + public static Input fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private final List fieldNames; + + public Input(List fieldNames) { + this.fieldNames = fieldNames; } + + public List getFieldNames() { + return fieldNames; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (fieldNames != null) { + builder.field(FIELD_NAMES.getPreferredName(), fieldNames); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Input that = (Input) o; + return Objects.equals(fieldNames, that.fieldNames); + } + + @Override + public int hashCode() { + return Objects.hash(fieldNames); + } + } } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java index dec834fa328f1..e01d08d019fdc 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java @@ -22,7 +22,6 @@ import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -39,7 +38,6 @@ public class TrainedModelDefinition implements ToXContentObject { public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); - public static final ParseField INPUT = new ParseField("input"); public static final ObjectParser PARSER = new ObjectParser<>(NAME, true, @@ -53,7 +51,6 @@ public class TrainedModelDefinition implements ToXContentObject { (p, c, n) -> p.namedObject(PreProcessor.class, n, null), (trainedModelDefBuilder) -> {/* Does not matter client side*/ }, PREPROCESSORS); - PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT); } public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException { @@ -62,12 +59,10 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) private final TrainedModel trainedModel; private final List preProcessors; - private final Input input; - TrainedModelDefinition(TrainedModel trainedModel, List preProcessors, Input input) { + TrainedModelDefinition(TrainedModel trainedModel, List preProcessors) { this.trainedModel = trainedModel; this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); - this.input = input; } @Override @@ -83,9 +78,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws true, PREPROCESSORS.getPreferredName(), preProcessors); - if (input != null) { - builder.field(INPUT.getPreferredName(), input); - } builder.endObject(); return builder; } @@ -98,10 +90,6 @@ public List getPreProcessors() { return preProcessors; } - public Input getInput() { - return input; - } - @Override public String toString() { return Strings.toString(this); @@ -113,20 +101,18 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; TrainedModelDefinition that = (TrainedModelDefinition) o; return Objects.equals(trainedModel, that.trainedModel) && - Objects.equals(preProcessors, that.preProcessors) && - Objects.equals(input, that.input); + Objects.equals(preProcessors, that.preProcessors); } @Override public int hashCode() { - return Objects.hash(trainedModel, preProcessors, input); + return Objects.hash(trainedModel, preProcessors); } public static class Builder { private List preProcessors; private TrainedModel trainedModel; - private Input input; public Builder setPreProcessors(List preProcessors) { this.preProcessors = preProcessors; @@ -138,71 +124,14 @@ public Builder setTrainedModel(TrainedModel trainedModel) { return this; } - public Builder setInput(Input input) { - this.input = input; - return this; - } - private Builder setTrainedModel(List trainedModel) { assert trainedModel.size() == 1; return setTrainedModel(trainedModel.get(0)); } public TrainedModelDefinition build() { - return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input); - } - } - - public static class Input implements ToXContentObject { - - public static final String NAME = "trained_mode_definition_input"; - public static final ParseField FIELD_NAMES = new ParseField("field_names"); - - @SuppressWarnings("unchecked") - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, - true, - a -> new Input((List)a[0])); - static { - PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); - } - - public static Input fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); + return new TrainedModelDefinition(this.trainedModel, this.preProcessors); } - - private final List fieldNames; - - public Input(List fieldNames) { - this.fieldNames = fieldNames; - } - - public List getFieldNames() { - return fieldNames; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (fieldNames != null) { - builder.field(FIELD_NAMES.getPreferredName(), fieldNames); - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o; - return Objects.equals(fieldNames, that.fieldNames); - } - - @Override - public int hashCode() { - return Objects.hash(fieldNames); - } - } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index f136bedcc8b44..b5ce306f25472 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -63,7 +63,10 @@ protected TrainedModelConfig createTestInstance() { randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), - randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); + randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), + randomBoolean() ? null : new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomLongBetween(1, 10)) + .collect(Collectors.toList()))); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java index ff53c0d8fc082..8eeec2ce2fcb9 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java @@ -64,10 +64,7 @@ public static TrainedModelDefinition.Builder createRandomBuilder() { TargetMeanEncodingTests.createRandom())) .limit(numberOfProcessors) .collect(Collectors.toList())) - .setTrainedModel(randomFrom(TreeTests.createRandom())) - .setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomLongBetween(1, 10)) - .collect(Collectors.toList()))); + .setTrainedModel(randomFrom(TreeTests.createRandom())); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index e1c24eee02b88..48976725a46f3 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -42,6 +43,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable { public static final ParseField DEFINITION = new ParseField("definition"); public static final ParseField TAGS = new ParseField("tags"); public static final ParseField METADATA = new ParseField("metadata"); + public static final ParseField INPUT = new ParseField("input"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -61,10 +63,10 @@ private static ObjectParser createParser(boole ObjectParser.ValueType.VALUE); parser.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); - parser.declareObject(TrainedModelConfig.Builder::setDefinition, - (p, c) -> TrainedModelDefinition.fromXContent(p, ignoreUnknownFields), - DEFINITION); parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE); + parser.declareObject(TrainedModelConfig.Builder::setInput, + (p, c) -> TrainedModelConfig.Input.fromXContent(p, ignoreUnknownFields), + INPUT); return parser; } @@ -79,10 +81,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo private final Instant createTime; private final List tags; private final Map metadata; + private final Input input; - // TODO how to reference and store large models that will not be executed in Java??? - // Potentially allow this to be null and have an {index: indexName, doc: model_doc_id} or something - // TODO Should this be lazily parsed when loading via the index??? private final TrainedModelDefinition definition; TrainedModelConfig(String modelId, @@ -92,7 +92,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo Instant createTime, TrainedModelDefinition definition, List tags, - Map metadata) { + Map metadata, + Input input) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); @@ -101,6 +102,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo this.description = description; this.tags = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(tags, TAGS)); this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata); + this.input = ExceptionsHelper.requireNonNull(input, INPUT); } public TrainedModelConfig(StreamInput in) throws IOException { @@ -112,6 +114,7 @@ public TrainedModelConfig(StreamInput in) throws IOException { definition = in.readOptionalWriteable(TrainedModelDefinition::new); tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); metadata = in.readMap(); + input = new Input(in); } public String getModelId() { @@ -147,6 +150,10 @@ public TrainedModelDefinition getDefinition() { return definition; } + public Input getInput() { + return input; + } + public static Builder builder() { return new Builder(); } @@ -161,6 +168,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(definition); out.writeCollection(tags, StreamOutput::writeString); out.writeMap(metadata); + input.writeTo(out); } @Override @@ -173,7 +181,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(DESCRIPTION.getPreferredName(), description); } builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli()); - if (definition != null) { + + // We don't store the definition in the same document as the configuration + if ((params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) && definition != null) { builder.field(DEFINITION.getPreferredName(), definition); } builder.field(TAGS.getPreferredName(), tags); @@ -183,6 +193,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); } + builder.field(INPUT.getPreferredName(), input); builder.endObject(); return builder; } @@ -204,6 +215,7 @@ public boolean equals(Object o) { Objects.equals(createTime, that.createTime) && Objects.equals(definition, that.definition) && Objects.equals(tags, that.tags) && + Objects.equals(input, that.input) && Objects.equals(metadata, that.metadata); } @@ -216,7 +228,8 @@ public int hashCode() { definition, description, tags, - metadata); + metadata, + input); } public static class Builder { @@ -228,6 +241,7 @@ public static class Builder { private Instant createTime; private List tags = Collections.emptyList(); private Map metadata; + private Input input; private TrainedModelDefinition definition; public Builder setModelId(String modelId) { @@ -279,9 +293,14 @@ public Builder setDefinition(TrainedModelDefinition definition) { return this; } + public Builder setInput(Input input) { + this.input = input; + return this; + } + // TODO move to REST level instead of here in the builder public void validate() { - // We require a definition to be available until we support other means of supplying the definition + // We require a definition to be available here even though it will be stored in a different doc ExceptionsHelper.requireNonNull(definition, DEFINITION); ExceptionsHelper.requireNonNull(modelId, MODEL_ID); @@ -320,7 +339,71 @@ public TrainedModelConfig build() { createTime == null ? Instant.now() : createTime, definition, tags, - metadata); + metadata, + input); + } + } + + public static class Input implements ToXContentObject, Writeable { + + public static final String NAME = "trained_model_config_input"; + public static final ParseField FIELD_NAMES = new ParseField("field_names"); + + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new Input((List)a[0])); + parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); + return parser; + } + + public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + private final List fieldNames; + + public Input(List fieldNames) { + this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES)); + } + + public Input(StreamInput in) throws IOException { + this.fieldNames = Collections.unmodifiableList(in.readStringList()); + } + + public List getFieldNames() { + return fieldNames; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(fieldNames); } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_NAMES.getPreferredName(), fieldNames); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + TrainedModelConfig.Input that = (TrainedModelConfig.Input) o; + return Objects.equals(fieldNames, that.fieldNames); + } + + @Override + public int hashCode() { + return Objects.hash(fieldNames); + } + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index f85c184646e1f..38a05048694a4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -5,16 +5,17 @@ */ package org.elasticsearch.xpack.core.ml.inference; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.inference.preprocessing.LenientlyParsedPreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.PreProcessor; import org.elasticsearch.xpack.core.ml.inference.preprocessing.StrictlyParsedPreProcessor; @@ -23,6 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import java.io.IOException; import java.util.Collections; @@ -31,11 +33,10 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { - public static final String NAME = "trained_mode_definition"; + public static final String NAME = "trained_model_definition"; public static final ParseField TRAINED_MODEL = new ParseField("trained_model"); public static final ParseField PREPROCESSORS = new ParseField("preprocessors"); - public static final ParseField INPUT = new ParseField("input"); // These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly public static final ObjectParser LENIENT_PARSER = createParser(true); @@ -57,7 +58,7 @@ private static ObjectParser createParser(b p.namedObject(StrictlyParsedPreProcessor.class, n, null), (trainedModelDefBuilder) -> trainedModelDefBuilder.setProcessorsInOrder(true), PREPROCESSORS); - parser.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT); + parser.declareString(TrainedModelDefinition.Builder::setModelId, TrainedModelConfig.MODEL_ID); return parser; } @@ -65,27 +66,31 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser, return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); } + public static String docId(String modelId) { + return NAME + "-" + modelId; + } + private final TrainedModel trainedModel; private final List preProcessors; - private final Input input; + private final String modelId; - TrainedModelDefinition(TrainedModel trainedModel, List preProcessors, Input input) { + TrainedModelDefinition(TrainedModel trainedModel, List preProcessors, @Nullable String modelId) { this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL); this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); - this.input = ExceptionsHelper.requireNonNull(input, INPUT); + this.modelId = modelId; } public TrainedModelDefinition(StreamInput in) throws IOException { this.trainedModel = in.readNamedWriteable(TrainedModel.class); this.preProcessors = in.readNamedWriteableList(PreProcessor.class); - this.input = new Input(in); + this.modelId = in.readOptionalString(); } @Override public void writeTo(StreamOutput out) throws IOException { out.writeNamedWriteable(trainedModel); out.writeNamedWriteableList(preProcessors); - input.writeTo(out); + out.writeOptionalString(modelId); } @Override @@ -101,7 +106,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws true, PREPROCESSORS.getPreferredName(), preProcessors); - builder.field(INPUT.getPreferredName(), input); + if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false)) { + builder.field(InferenceIndexConstants.DOC_TYPE.getPreferredName(), NAME); + assert modelId != null; + builder.field(TrainedModelConfig.MODEL_ID.getPreferredName(), modelId); + } builder.endObject(); return builder; } @@ -114,10 +123,6 @@ public List getPreProcessors() { return preProcessors; } - public Input getInput() { - return input; - } - @Override public String toString() { return Strings.toString(this); @@ -129,21 +134,21 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; TrainedModelDefinition that = (TrainedModelDefinition) o; return Objects.equals(trainedModel, that.trainedModel) && - Objects.equals(input, that.input) && - Objects.equals(preProcessors, that.preProcessors); + Objects.equals(preProcessors, that.preProcessors) && + Objects.equals(modelId, that.modelId); } @Override public int hashCode() { - return Objects.hash(trainedModel, input, preProcessors); + return Objects.hash(trainedModel, preProcessors, modelId); } public static class Builder { private List preProcessors; private TrainedModel trainedModel; + private String modelId; private boolean processorsInOrder; - private Input input; private static Builder builderForParser() { return new Builder(false); @@ -167,8 +172,8 @@ public Builder setTrainedModel(TrainedModel trainedModel) { return this; } - public Builder setInput(Input input) { - this.input = input; + public Builder setModelId(String modelId) { + this.modelId = modelId; return this; } @@ -188,71 +193,8 @@ public TrainedModelDefinition build() { if (preProcessors != null && preProcessors.size() > 1 && processorsInOrder == false) { throw new IllegalArgumentException("preprocessors must be an array of preprocessor objects"); } - return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input); + return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.modelId); } } - public static class Input implements ToXContentObject, Writeable { - - public static final String NAME = "trained_mode_definition_input"; - public static final ParseField FIELD_NAMES = new ParseField("field_names"); - - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - - @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { - ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, - ignoreUnknownFields, - a -> new Input((List)a[0])); - parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); - return parser; - } - - public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException { - return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); - } - - private final List fieldNames; - - public Input(List fieldNames) { - this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES)); - } - - public Input(StreamInput in) throws IOException { - this.fieldNames = Collections.unmodifiableList(in.readStringList()); - } - - public List getFieldNames() { - return fieldNames; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeStringCollection(fieldNames); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(FIELD_NAMES.getPreferredName(), fieldNames); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o; - return Objects.equals(fieldNames, that.fieldNames); - } - - @Override - public int hashCode() { - return Objects.hash(fieldNames); - } - - } - } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java index 75cc468160d17..5bbe46ecd8b11 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java @@ -82,9 +82,8 @@ public final class Messages { public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists"; public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]"; - public static final String INFERENCE_FAILED_TO_SERIALIZE_MODEL = - "Failed to serialize the trained model [{0}] with version [{1}] for storage"; public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]"; + public static final String MODEL_DEFINITION_NOT_FOUND = "Could not find trained model definition [{0}]"; public static final String JOB_AUDIT_DATAFEED_DATA_SEEN_AGAIN = "Datafeed has started retrieving data again"; public static final String JOB_AUDIT_CREATED = "Job created"; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 678bb8a2982b4..139899e2fd2b2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -7,15 +7,24 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.MlStrings; +import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.junit.Before; import java.io.IOException; @@ -27,8 +36,11 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; public class TrainedModelConfigTests extends AbstractSerializingTestCase { @@ -63,9 +75,12 @@ protected TrainedModelConfig createTestInstance() { Version.CURRENT, randomBoolean() ? null : randomAlphaOfLength(100), Instant.ofEpochMilli(randomNonNegativeLong()), - randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(), + null, // is not parsed so should not be provided tags, - randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10))); + randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), + new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toList()))); } @Override @@ -88,6 +103,30 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { return new NamedWriteableRegistry(entries); } + public void testToXContentWithParams() throws IOException { + TrainedModelConfig config = new TrainedModelConfig( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + Version.CURRENT, + randomBoolean() ? null : randomAlphaOfLength(100), + Instant.ofEpochMilli(randomNonNegativeLong()), + TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(), + Collections.emptyList(), + randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), + new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toList()))); + + BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); + assertThat(reference.utf8ToString(), containsString("definition")); + + reference = XContentHelper.toXContent(config, + XContentType.JSON, + new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), + false); + assertThat(reference.utf8ToString(), not(containsString("definition"))); + } + public void testValidateWithNullDefinition() { IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> TrainedModelConfig.builder().validate()); assertThat(ex.getMessage(), equalTo("[definition] must not be null.")); @@ -97,7 +136,7 @@ public void testValidateWithInvalidID() { String modelId = "InvalidID-"; ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId)) .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.INVALID_ID, "model_id", modelId))); } @@ -106,7 +145,7 @@ public void testValidateWithLongID() { String modelId = IntStream.range(0, 100).mapToObj(x -> "a").collect(Collectors.joining()); ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId)) .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo(Messages.getMessage(Messages.ID_TOO_LONG, "model_id", modelId, MlStrings.ID_LENGTH_LIMIT))); } @@ -115,21 +154,21 @@ public void testValidateWithIllegallyUserProvidedFields() { String modelId = "simplemodel"; ElasticsearchException ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId)) .setCreateTime(Instant.now()) .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo("illegal to set [create_time] at inference model creation")); ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId)) .setVersion(Version.CURRENT) .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo("illegal to set [version] at inference model creation")); ex = expectThrows(ElasticsearchException.class, () -> TrainedModelConfig.builder() - .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId)) .setCreatedBy("ml_user") .setModelId(modelId).validate()); assertThat(ex.getMessage(), equalTo("illegal to set [created_by] at inference model creation")); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java index 5339d93bf9100..41ff51be8212d 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java @@ -58,9 +58,10 @@ protected Predicate getRandomFieldsExcludeFilter() { return field -> !field.isEmpty(); } - public static TrainedModelDefinition.Builder createRandomBuilder() { + public static TrainedModelDefinition.Builder createRandomBuilder(String modelId) { int numberOfProcessors = randomIntBetween(1, 10); return new TrainedModelDefinition.Builder() + .setModelId(modelId) .setPreProcessors( randomBoolean() ? null : Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(), @@ -68,22 +69,11 @@ public static TrainedModelDefinition.Builder createRandomBuilder() { TargetMeanEncodingTests.createRandom())) .limit(numberOfProcessors) .collect(Collectors.toList())) - .setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomLongBetween(1, 10)) - .collect(Collectors.toList()))) .setTrainedModel(randomFrom(TreeTests.createRandom())); } private static final String ENSEMBLE_MODEL = "" + "{\n" + - " \"input\": {\n" + - " \"field_names\": [\n" + - " \"col1\",\n" + - " \"col2\",\n" + - " \"col3\",\n" + - " \"col4\"\n" + - " ]\n" + - " },\n" + " \"preprocessors\": [\n" + " {\n" + " \"one_hot_encoding\": {\n" + @@ -203,14 +193,6 @@ public static TrainedModelDefinition.Builder createRandomBuilder() { "}"; private static final String TREE_MODEL = "" + "{\n" + - " \"input\": {\n" + - " \"field_names\": [\n" + - " \"col1\",\n" + - " \"col2\",\n" + - " \"col3\",\n" + - " \"col4\"\n" + - " ]\n" + - " },\n" + " \"preprocessors\": [\n" + " {\n" + " \"one_hot_encoding\": {\n" + @@ -293,7 +275,7 @@ public void testTreeSchemaDeserialization() throws IOException { @Override protected TrainedModelDefinition createTestInstance() { - return createRandomBuilder().build(); + return createRandomBuilder(null).build(); } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 85b40bd64934f..2fe5004aabcd6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -366,7 +366,7 @@ private synchronized boolean startProcess(DataFrameDataExtractorFactory dataExtr DataFrameRowsJoiner dataFrameRowsJoiner = new DataFrameRowsJoiner(config.getId(), client, dataExtractorFactory.newExtractor(true)); resultProcessor = new AnalyticsResultProcessor(config, dataFrameRowsJoiner, this::isProcessKilled, task.getProgressTracker(), - trainedModelProvider, auditor); + trainedModelProvider, auditor, dataExtractor.getFieldNames()); return true; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 72337f77e9f84..a3c2ba5aae37e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -26,6 +26,7 @@ import java.time.Instant; import java.util.Collections; import java.util.Iterator; +import java.util.List; import java.util.Objects; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -41,18 +42,21 @@ public class AnalyticsResultProcessor { private final ProgressTracker progressTracker; private final TrainedModelProvider trainedModelProvider; private final DataFrameAnalyticsAuditor auditor; + private final List fieldNames; private final CountDownLatch completionLatch = new CountDownLatch(1); private volatile String failure; public AnalyticsResultProcessor(DataFrameAnalyticsConfig analytics, DataFrameRowsJoiner dataFrameRowsJoiner, Supplier isProcessKilled, ProgressTracker progressTracker, - TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor) { + TrainedModelProvider trainedModelProvider, DataFrameAnalyticsAuditor auditor, + List fieldNames) { this.analytics = Objects.requireNonNull(analytics); this.dataFrameRowsJoiner = Objects.requireNonNull(dataFrameRowsJoiner); this.isProcessKilled = Objects.requireNonNull(isProcessKilled); this.progressTracker = Objects.requireNonNull(progressTracker); this.trainedModelProvider = Objects.requireNonNull(trainedModelProvider); this.auditor = Objects.requireNonNull(auditor); + this.fieldNames = Collections.unmodifiableList(Objects.requireNonNull(fieldNames)); } @Nullable @@ -111,13 +115,13 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo if (progressPercent != null) { progressTracker.analyzingPercent.set(progressPercent); } - TrainedModelDefinition inferenceModel = result.getInferenceModel(); - if (inferenceModel != null) { - createAndIndexInferenceModel(inferenceModel); + TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder(); + if (inferenceModelBuilder != null) { + createAndIndexInferenceModel(inferenceModelBuilder); } } - private void createAndIndexInferenceModel(TrainedModelDefinition inferenceModel) { + private void createAndIndexInferenceModel(TrainedModelDefinition.Builder inferenceModel) { TrainedModelConfig trainedModelConfig = createTrainedModelConfig(inferenceModel); CountDownLatch latch = storeTrainedModel(trainedModelConfig); @@ -131,10 +135,12 @@ private void createAndIndexInferenceModel(TrainedModelDefinition inferenceModel) } } - private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition inferenceModel) { + private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Builder inferenceModel) { Instant createTime = Instant.now(); + String modelId = analytics.getId() + "-" + createTime.toEpochMilli(); + TrainedModelDefinition definition = inferenceModel.setModelId(modelId).build(); return TrainedModelConfig.builder() - .setModelId(analytics.getId() + "-" + createTime.toEpochMilli()) + .setModelId(modelId) .setCreatedBy("data-frame-analytics") .setVersion(Version.CURRENT) .setCreateTime(createTime) @@ -142,7 +148,8 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition infer .setDescription(analytics.getDescription()) .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) - .setDefinition(inferenceModel) + .setDefinition(definition) + .setInput(new TrainedModelConfig.Input(fieldNames)) .build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index c383fd195767a..8b301a44b83d3 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -24,23 +24,25 @@ public class AnalyticsResult implements ToXContentObject { public static final ParseField INFERENCE_MODEL = new ParseField("inference_model"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(), - a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition) a[2])); + a -> new AnalyticsResult((RowResults) a[0], (Integer) a[1], (TrainedModelDefinition.Builder) a[2])); static { PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); - PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.STRICT_PARSER.apply(p, null).build(), + PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.LENIENT_PARSER.apply(p, null), INFERENCE_MODEL); } private final RowResults rowResults; private final Integer progressPercent; + private final TrainedModelDefinition.Builder inferenceModelBuilder; private final TrainedModelDefinition inferenceModel; - public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition inferenceModel) { + public AnalyticsResult(RowResults rowResults, Integer progressPercent, TrainedModelDefinition.Builder inferenceModelBuilder) { this.rowResults = rowResults; this.progressPercent = progressPercent; - this.inferenceModel = inferenceModel; + this.inferenceModelBuilder = inferenceModelBuilder; + this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build(); } public RowResults getRowResults() { @@ -51,8 +53,8 @@ public Integer getProgressPercent() { return progressPercent; } - public TrainedModelDefinition getInferenceModel() { - return inferenceModel; + public TrainedModelDefinition.Builder getInferenceModelBuilder() { + return inferenceModelBuilder; } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java index 33a4180b25f5c..19d5d33abe4cd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/InferenceInternalIndex.java @@ -86,6 +86,9 @@ private static void addInferenceDocFields(XContentBuilder builder) throws IOExce .startObject(TrainedModelConfig.CREATED_BY.getPreferredName()) .field(TYPE, KEYWORD) .endObject() + .startObject(TrainedModelConfig.INPUT.getPreferredName()) + .field(ENABLED, false) + .endObject() .startObject(TrainedModelConfig.VERSION.getPreferredName()) .field(TYPE, KEYWORD) .endObject() @@ -95,9 +98,6 @@ private static void addInferenceDocFields(XContentBuilder builder) throws IOExce .startObject(TrainedModelConfig.CREATE_TIME.getPreferredName()) .field(TYPE, DATE) .endObject() - .startObject(TrainedModelConfig.DEFINITION.getPreferredName()) - .field(ENABLED, false) - .endObject() .startObject(TrainedModelConfig.TAGS.getPreferredName()) .field(TYPE, KEYWORD) .endObject() diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 6f1e543896c9d..1266e4b11f9ee 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -8,7 +8,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.ResourceNotFoundException; @@ -16,14 +15,18 @@ import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.search.SearchAction; -import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.index.IndexResponse; +import org.elasticsearch.action.search.MultiSearchAction; +import org.elasticsearch.action.search.MultiSearchRequestBuilder; +import org.elasticsearch.action.search.MultiSearchResponse; import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.client.Client; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentParser; @@ -34,6 +37,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -51,6 +55,8 @@ public class TrainedModelProvider { private static final Logger logger = LogManager.getLogger(TrainedModelProvider.class); private final Client client; private final NamedXContentRegistry xContentRegistry; + private static final ToXContent.Params FOR_INTERNAL_STORAGE_PARAMS = + new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")); public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistry) { this.client = client; @@ -58,76 +64,169 @@ public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistr } public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener listener) { - try (XContentBuilder builder = XContentFactory.jsonBuilder()) { - XContentBuilder source = trainedModelConfig.toXContent(builder, - new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true"))); - IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME) - .opType(DocWriteRequest.OpType.CREATE) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .id(trainedModelConfig.getModelId()) - .source(source); - executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, - ActionListener.wrap( - r -> listener.onResponse(true), - e -> { - logger.error(new ParameterizedMessage( - "[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e); - if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { - listener.onFailure(new ResourceAlreadyExistsException( - Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); - } else { - listener.onFailure( - new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, - RestStatus.INTERNAL_SERVER_ERROR, - e, - trainedModelConfig.getModelId())); - } - })); - } catch (IOException e) { - // not expected to happen but for the sake of completeness - listener.onFailure(new ElasticsearchParseException( - Messages.getMessage(Messages.INFERENCE_FAILED_TO_SERIALIZE_MODEL, trainedModelConfig.getModelId()), - e)); - } + ActionListener putDefinitionListener = ActionListener.wrap( + r -> listener.onResponse(true), + e -> { + logger.error(new ParameterizedMessage( + "[{}] failed to store trained model definition for inference", trainedModelConfig.getModelId()), e); + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelConfig.getModelId())); + } + } + ); + + ActionListener putConfigListener = ActionListener.wrap( + r -> { + if (trainedModelConfig.getDefinition() != null) { + indexObject(TrainedModelDefinition.docId(trainedModelConfig.getModelId()), + trainedModelConfig.getDefinition(), + putDefinitionListener); + } else { + listener.onResponse(true); + } + }, + e -> { + logger.error(new ParameterizedMessage( + "[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e); + if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { + listener.onFailure(new ResourceAlreadyExistsException( + Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); + } else { + listener.onFailure( + new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, + RestStatus.INTERNAL_SERVER_ERROR, + e, + trainedModelConfig.getModelId())); + } + } + ); + + indexObject(trainedModelConfig.getModelId(), trainedModelConfig, putConfigListener); } - public void getTrainedModel(String modelId, ActionListener listener) { + public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener listener) { + QueryBuilder queryBuilder = QueryBuilders.constantScoreQuery(QueryBuilders .idsQuery() .addIds(modelId)); - SearchRequest searchRequest = client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) - .setQuery(queryBuilder) - // use sort to get the last - .addSort("_index", SortOrder.DESC) - .setSize(1) - .request(); - - executeAsyncWithOrigin(client, ML_ORIGIN, SearchAction.INSTANCE, searchRequest, - ActionListener.wrap( - searchResponse -> { - if (searchResponse.getHits().getHits().length == 0) { + MultiSearchRequestBuilder multiSearchRequestBuilder = client.prepareMultiSearch() + .add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setQuery(queryBuilder) + // use sort to get the last + .addSort("_index", SortOrder.DESC) + .setSize(1) + .request()); + + if (includeDefinition) { + multiSearchRequestBuilder.add(client.prepareSearch(InferenceIndexConstants.INDEX_PATTERN) + .setQuery(QueryBuilders.constantScoreQuery(QueryBuilders + .idsQuery() + .addIds(TrainedModelDefinition.docId(modelId)))) + // use sort to get the last + .addSort("_index", SortOrder.DESC) + .setSize(1) + .request()); + } + + ActionListener multiSearchResponseActionListener = ActionListener.wrap( + multiSearchResponse -> { + TrainedModelConfig.Builder builder; + TrainedModelDefinition definition; + try { + builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource); + } catch(ResourceNotFoundException ex) { + listener.onFailure(new ResourceNotFoundException( + Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + return; + } catch (Exception ex) { + listener.onFailure(ex); + return; + } + + if (includeDefinition) { + try { + definition = handleSearchItem(multiSearchResponse.getResponses()[1], + modelId, + this::parseModelDefinitionDocLenientlyFromSource); + builder.setDefinition(definition); + } catch(ResourceNotFoundException ex) { listener.onFailure(new ResourceNotFoundException( - Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); + Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + return; + } catch (Exception ex) { + listener.onFailure(ex); return; } - BytesReference source = searchResponse.getHits().getHits()[0].getSourceRef(); - parseInferenceDocLenientlyFromSource(source, modelId, listener); - }, - listener::onFailure)); + } + listener.onResponse(builder.build()); + }, + listener::onFailure + ); + + executeAsyncWithOrigin(client, + ML_ORIGIN, + MultiSearchAction.INSTANCE, + multiSearchRequestBuilder.request(), + multiSearchResponseActionListener); } - private void parseInferenceDocLenientlyFromSource(BytesReference source, - String modelId, - ActionListener modelListener) { + private static T handleSearchItem(MultiSearchResponse.Item item, + String resourceId, + CheckedBiFunction parseLeniently) throws Exception { + if (item.isFailure()) { + throw item.getFailure(); + } + if (item.getResponse().getHits().getHits().length == 0) { + throw new ResourceNotFoundException(resourceId); + } + return parseLeniently.apply(item.getResponse().getHits().getHits()[0].getSourceRef(), resourceId); + } + + private TrainedModelConfig.Builder parseInferenceDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { try (InputStream stream = source.streamInput(); XContentParser parser = XContentFactory.xContent(XContentType.JSON) .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { - modelListener.onResponse(TrainedModelConfig.fromXContent(parser, true).build()); + return TrainedModelConfig.fromXContent(parser, true); } catch (Exception e) { logger.error(new ParameterizedMessage("[{}] failed to parse model", modelId), e); - modelListener.onFailure(e); + throw e; + } + } + + private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesReference source, String modelId) throws Exception { + try (InputStream stream = source.streamInput(); + XContentParser parser = XContentFactory.xContent(XContentType.JSON) + .createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, stream)) { + return TrainedModelDefinition.fromXContent(parser, true).build(); + } catch (Exception e) { + logger.error(new ParameterizedMessage("[{}] failed to parse model definition", modelId), e); + throw e; + } + } + + private void indexObject(String docId, ToXContentObject body, ActionListener indexListener) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS); + + IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME) + .opType(DocWriteRequest.OpType.CREATE) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .id(docId) + .source(source); + + executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexListener); + } catch (IOException ex) { + // not expected to happen but for the sake of completeness + indexListener.onFailure(ex); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index 8612f263a0c7a..cb90b39772a0c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -126,9 +126,10 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() { return null; }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); + List expectedFieldNames = Arrays.asList("foo", "bar", "baz"); + TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(JOB_ID); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel))); - AnalyticsResultProcessor resultProcessor = createResultProcessor(); + AnalyticsResultProcessor resultProcessor = createResultProcessor(expectedFieldNames); resultProcessor.process(process); resultProcessor.awaitForCompletion(); @@ -142,7 +143,8 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() { assertThat(storedModel.getCreatedBy(), equalTo("data-frame-analytics")); assertThat(storedModel.getTags(), contains(JOB_ID)); assertThat(storedModel.getDescription(), equalTo(JOB_DESCRIPTION)); - assertThat(storedModel.getDefinition(), equalTo(inferenceModel)); + assertThat(storedModel.getDefinition(), equalTo(inferenceModel.build())); + assertThat(storedModel.getInput().getFieldNames(), equalTo(expectedFieldNames)); Map metadata = storedModel.getMetadata(); assertThat(metadata.size(), equalTo(1)); assertThat(metadata, hasKey("analytics_config")); @@ -166,7 +168,7 @@ public void testProcess_GivenInferenceModelFailedToStore() { return null; }).when(trainedModelProvider).storeTrainedModel(any(TrainedModelConfig.class), any(ActionListener.class)); - TrainedModelDefinition inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); + TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder("failed_model"); givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel))); AnalyticsResultProcessor resultProcessor = createResultProcessor(); @@ -192,7 +194,11 @@ private void givenDataFrameRows(int rows) { } private AnalyticsResultProcessor createResultProcessor() { + return createResultProcessor(Collections.emptyList()); + } + + private AnalyticsResultProcessor createResultProcessor(List fieldNames) { return new AnalyticsResultProcessor(analyticsConfig, dataFrameRowsJoiner, () -> false, progressTracker, trainedModelProvider, - auditor); + auditor, fieldNames); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 2ad4da9cc3483..4b2b30358582d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; -import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -33,7 +32,7 @@ protected NamedXContentRegistry xContentRegistry() { protected AnalyticsResult createTestInstance() { RowResults rowResults = null; Integer progressPercent = null; - TrainedModelDefinition inferenceModel = null; + TrainedModelDefinition.Builder inferenceModel = null; if (randomBoolean()) { rowResults = RowResultsTests.createRandom(); } @@ -41,13 +40,13 @@ protected AnalyticsResult createTestInstance() { progressPercent = randomIntBetween(0, 100); } if (randomBoolean()) { - inferenceModel = TrainedModelDefinitionTests.createRandomBuilder().build(); + inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(null); } return new AnalyticsResult(rowResults, progressPercent, inferenceModel); } @Override - protected AnalyticsResult doParseInstance(XContentParser parser) throws IOException { + protected AnalyticsResult doParseInstance(XContentParser parser) { return AnalyticsResult.PARSER.apply(parser, null); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 6dfd444d0589b..302dd6084bf3c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.search.SearchModule; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; @@ -21,6 +22,8 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; @@ -75,29 +78,72 @@ public void testGetTrainedModelConfig() throws Exception { assertThat(exceptionHolder.get(), is(nullValue())); AtomicReference getConfigHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); assertThat(getConfigHolder.get(), is(not(nullValue()))); assertThat(getConfigHolder.get(), equalTo(config)); + assertThat(getConfigHolder.get().getDefinition(), is(not(nullValue()))); + } + + public void testGetTrainedModelConfigWithoutDefinition() throws Exception { + String modelId = "test-get-trained-model-config-no-definition"; + TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId); + TrainedModelConfig config = configBuilder.build(); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + AtomicReference getConfigHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, false, listener), getConfigHolder, exceptionHolder); + assertThat(getConfigHolder.get(), is(not(nullValue()))); + assertThat(getConfigHolder.get(), + equalTo(configBuilder.setCreateTime(config.getCreateTime()).setDefinition((TrainedModelDefinition) null).build())); + assertThat(getConfigHolder.get().getDefinition(), is(nullValue())); } public void testGetMissingTrainingModelConfig() throws Exception { String modelId = "test-get-missing-trained-model-config"; AtomicReference getConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); - blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, listener), getConfigHolder, exceptionHolder); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); } - private static TrainedModelConfig buildTrainedModelConfig(String modelId) { + public void testGetMissingTrainingModelConfigDefinition() throws Exception { + String modelId = "test-get-missing-trained-model-config-definition"; + TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).setDefinition((TrainedModelDefinition) null).build(); + AtomicReference putConfigHolder = new AtomicReference<>(); + AtomicReference exceptionHolder = new AtomicReference<>(); + + blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder); + assertThat(putConfigHolder.get(), is(true)); + assertThat(exceptionHolder.get(), is(nullValue())); + + AtomicReference getConfigHolder = new AtomicReference<>(); + blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); + assertThat(exceptionHolder.get(), is(not(nullValue()))); + assertThat(exceptionHolder.get().getMessage(), + equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); + } + + private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) { return TrainedModelConfig.builder() .setCreatedBy("ml_test") - .setDefinition(TrainedModelDefinitionTests.createRandomBuilder()) + .setDefinition(TrainedModelDefinitionTests.createRandomBuilder(modelId)) .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) - .build(); + .setInput(new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomIntBetween(0, 10)) + .collect(Collectors.toList()))); + } + + private static TrainedModelConfig buildTrainedModelConfig(String modelId) { + return buildTrainedModelConfigBuilder(modelId).build(); } @Override From acf3be9140d361bf0eedd20497a34179a21e296a Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 30 Oct 2019 08:24:20 -0400 Subject: [PATCH 2/4] addressing PR comments --- .../client/ml/inference/Input.java | 82 ++++++++++++++++++ .../ml/inference/TrainedModelConfig.java | 52 ------------ .../ml/inference/TrainedModelConfigTests.java | 2 +- .../xpack/core/ml/inference/Input.java | 84 ++++++++++++++++++ .../core/ml/inference/TrainedModelConfig.java | 65 +------------- .../ml/inference/TrainedModelDefinition.java | 4 +- .../ml/inference/TrainedModelConfigTests.java | 8 +- .../process/AnalyticsResultProcessor.java | 3 +- .../process/results/AnalyticsResult.java | 4 +- .../persistence/TrainedModelProvider.java | 85 +++++++++++-------- .../integration/TrainedModelProviderIT.java | 13 ++- 11 files changed, 235 insertions(+), 167 deletions(-) create mode 100644 client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/Input.java create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/Input.java diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/Input.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/Input.java new file mode 100644 index 0000000000000..62f195a2d5bed --- /dev/null +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/Input.java @@ -0,0 +1,82 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class Input implements ToXContentObject { + + public static final String NAME = "trained_model_config_input"; + public static final ParseField FIELD_NAMES = new ParseField("field_names"); + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + true, + a -> new Input((List) a[0])); + + static { + PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); + } + + private final List fieldNames; + + public Input(List fieldNames) { + this.fieldNames = fieldNames; + } + + public static Input fromXContent(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + public List getFieldNames() { + return fieldNames; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (fieldNames != null) { + builder.field(FIELD_NAMES.getPreferredName(), fieldNames); + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Input that = (Input) o; + return Objects.equals(fieldNames, that.fieldNames); + } + + @Override + public int hashCode() { + return Objects.hash(fieldNames); + } + +} diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index ece9765960f06..7d2e347605b8e 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -22,7 +22,6 @@ import org.elasticsearch.client.common.TimeUtil; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -292,55 +291,4 @@ public TrainedModelConfig build() { } } - public static class Input implements ToXContentObject { - - public static final String NAME = "trained_model_config_input"; - public static final ParseField FIELD_NAMES = new ParseField("field_names"); - - @SuppressWarnings("unchecked") - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, - true, - a -> new Input((List)a[0])); - static { - PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); - } - - public static Input fromXContent(XContentParser parser) throws IOException { - return PARSER.parse(parser, null); - } - - private final List fieldNames; - - public Input(List fieldNames) { - this.fieldNames = fieldNames; - } - - public List getFieldNames() { - return fieldNames; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (fieldNames != null) { - builder.field(FIELD_NAMES.getPreferredName(), fieldNames); - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - Input that = (Input) o; - return Objects.equals(fieldNames, that.fieldNames); - } - - @Override - public int hashCode() { - return Objects.hash(fieldNames); - } - - } } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index b5ce306f25472..cac3d0cdef154 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -64,7 +64,7 @@ protected TrainedModelConfig createTestInstance() { randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomBoolean() ? null : new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + randomBoolean() ? null : new Input(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomLongBetween(1, 10)) .collect(Collectors.toList()))); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/Input.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/Input.java new file mode 100644 index 0000000000000..c052a1322969d --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/Input.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ConstructingObjectParser; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + + +public class Input implements ToXContentObject, Writeable { + + public static final String NAME = "trained_model_config_input"; + public static final ParseField FIELD_NAMES = new ParseField("field_names"); + + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + private final List fieldNames; + + public Input(List fieldNames) { + this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES)); + } + + public Input(StreamInput in) throws IOException { + this.fieldNames = Collections.unmodifiableList(in.readStringList()); + } + + @SuppressWarnings("unchecked") + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + ignoreUnknownFields, + a -> new Input((List) a[0])); + parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); + return parser; + } + + public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException { + return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); + } + + public List getFieldNames() { + return fieldNames; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(fieldNames); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(FIELD_NAMES.getPreferredName(), fieldNames); + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Input that = (Input) o; + return Objects.equals(fieldNames, that.fieldNames); + } + + @Override + public int hashCode() { + return Objects.hash(fieldNames); + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 48976725a46f3..10514a3c4e7bc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -12,7 +12,6 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.ObjectParser; import org.elasticsearch.common.xcontent.ToXContentObject; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -65,7 +64,7 @@ private static ObjectParser createParser(boole parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE); parser.declareObject(TrainedModelConfig.Builder::setInput, - (p, c) -> TrainedModelConfig.Input.fromXContent(p, ignoreUnknownFields), + (p, c) -> Input.fromXContent(p, ignoreUnknownFields), INPUT); return parser; } @@ -344,66 +343,4 @@ public TrainedModelConfig build() { } } - public static class Input implements ToXContentObject, Writeable { - - public static final String NAME = "trained_model_config_input"; - public static final ParseField FIELD_NAMES = new ParseField("field_names"); - - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); - - @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { - ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, - ignoreUnknownFields, - a -> new Input((List)a[0])); - parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); - return parser; - } - - public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException { - return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); - } - - private final List fieldNames; - - public Input(List fieldNames) { - this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES)); - } - - public Input(StreamInput in) throws IOException { - this.fieldNames = Collections.unmodifiableList(in.readStringList()); - } - - public List getFieldNames() { - return fieldNames; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeStringCollection(fieldNames); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(FIELD_NAMES.getPreferredName(), fieldNames); - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - TrainedModelConfig.Input that = (TrainedModelConfig.Input) o; - return Objects.equals(fieldNames, that.fieldNames); - } - - @Override - public int hashCode() { - return Objects.hash(fieldNames); - } - - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java index 38a05048694a4..63a5b1fd1d698 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java @@ -45,7 +45,7 @@ public class TrainedModelDefinition implements ToXContentObject, Writeable { private static ObjectParser createParser(boolean ignoreUnknownFields) { ObjectParser parser = new ObjectParser<>(NAME, ignoreUnknownFields, - TrainedModelDefinition.Builder::new); + TrainedModelDefinition.Builder::builderForParser); parser.declareNamedObjects(TrainedModelDefinition.Builder::setTrainedModel, (p, c, n) -> ignoreUnknownFields ? p.namedObject(LenientlyParsedTrainedModel.class, n, null) : @@ -74,7 +74,7 @@ public static String docId(String modelId) { private final List preProcessors; private final String modelId; - TrainedModelDefinition(TrainedModel trainedModel, List preProcessors, @Nullable String modelId) { + private TrainedModelDefinition(TrainedModel trainedModel, List preProcessors, @Nullable String modelId) { this.trainedModel = ExceptionsHelper.requireNonNull(trainedModel, TRAINED_MODEL); this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors); this.modelId = modelId; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 139899e2fd2b2..03353d717c73c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -11,14 +11,10 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.xcontent.DeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.ToXContent; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentParser; -import org.elasticsearch.common.xcontent.XContentParserUtils; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractSerializingTestCase; @@ -78,7 +74,7 @@ protected TrainedModelConfig createTestInstance() { null, // is not parsed so should not be provided tags, randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + new Input(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomInt(10)) .collect(Collectors.toList()))); } @@ -113,7 +109,7 @@ public void testToXContentWithParams() throws IOException { TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(), Collections.emptyList(), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + new Input(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomInt(10)) .collect(Collectors.toList()))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index a3c2ba5aae37e..5a61cbec88c4b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; +import org.elasticsearch.xpack.core.ml.inference.Input; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker; @@ -149,7 +150,7 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) .setDefinition(definition) - .setInput(new TrainedModelConfig.Input(fieldNames)) + .setInput(new Input(fieldNames)) .build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index 8b301a44b83d3..0127ff26f3c86 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -29,8 +29,8 @@ public class AnalyticsResult implements ToXContentObject { static { PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE); PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT); - PARSER.declareObject(optionalConstructorArg(), (p, c) -> TrainedModelDefinition.LENIENT_PARSER.apply(p, null), - INFERENCE_MODEL); + // TODO change back to STRICT_PARSER once native side is aligned + PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL); } private final RowResults rowResults; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index 1266e4b11f9ee..bfca78ac18aea 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -13,9 +13,10 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest; -import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.search.MultiSearchAction; import org.elasticsearch.action.search.MultiSearchRequestBuilder; import org.elasticsearch.action.search.MultiSearchResponse; @@ -65,11 +66,22 @@ public TrainedModelProvider(Client client, NamedXContentRegistry xContentRegistr public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListener listener) { - ActionListener putDefinitionListener = ActionListener.wrap( - r -> listener.onResponse(true), + if (trainedModelConfig.getDefinition() == null) { + listener.onFailure(ExceptionsHelper.badRequestException("Unable to store [{}]. [{}] is required", + trainedModelConfig.getModelId(), + TrainedModelConfig.DEFINITION.getPreferredName())); + return; + } + + BulkRequest bulkRequest = client.prepareBulk(InferenceIndexConstants.LATEST_INDEX_NAME) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .add(createRequest(trainedModelConfig.getModelId(), trainedModelConfig)) + .add(createRequest(TrainedModelDefinition.docId(trainedModelConfig.getModelId()), trainedModelConfig.getDefinition())) + .request(); + + ActionListener wrappedListener = ActionListener.wrap( + listener::onResponse, e -> { - logger.error(new ParameterizedMessage( - "[{}] failed to store trained model definition for inference", trainedModelConfig.getModelId()), e); if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { listener.onFailure(new ResourceAlreadyExistsException( Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); @@ -83,33 +95,31 @@ public void storeTrainedModel(TrainedModelConfig trainedModelConfig, ActionListe } ); - ActionListener putConfigListener = ActionListener.wrap( + ActionListener bulkResponseActionListener = ActionListener.wrap( r -> { - if (trainedModelConfig.getDefinition() != null) { - indexObject(TrainedModelDefinition.docId(trainedModelConfig.getModelId()), - trainedModelConfig.getDefinition(), - putDefinitionListener); - } else { - listener.onResponse(true); + assert r.getItems().length == 2; + if (r.getItems()[0].isFailed()) { + logger.error(new ParameterizedMessage( + "[{}] failed to store trained model config for inference", + trainedModelConfig.getModelId()), + r.getItems()[0].getFailure().getCause()); + wrappedListener.onFailure(r.getItems()[0].getFailure().getCause()); + return; } - }, - e -> { - logger.error(new ParameterizedMessage( - "[{}] failed to store trained model for inference", trainedModelConfig.getModelId()), e); - if (ExceptionsHelper.unwrapCause(e) instanceof VersionConflictEngineException) { - listener.onFailure(new ResourceAlreadyExistsException( - Messages.getMessage(Messages.INFERENCE_TRAINED_MODEL_EXISTS, trainedModelConfig.getModelId()))); - } else { - listener.onFailure( - new ElasticsearchStatusException(Messages.INFERENCE_FAILED_TO_STORE_MODEL, - RestStatus.INTERNAL_SERVER_ERROR, - e, - trainedModelConfig.getModelId())); + if (r.getItems()[1].isFailed()) { + logger.error(new ParameterizedMessage( + "[{}] failed to store trained model definition for inference", + trainedModelConfig.getModelId()), + r.getItems()[1].getFailure().getCause()); + wrappedListener.onFailure(r.getItems()[1].getFailure().getCause()); + return; } - } + wrappedListener.onResponse(true); + }, + wrappedListener::onFailure ); - indexObject(trainedModelConfig.getModelId(), trainedModelConfig, putConfigListener); + executeAsyncWithOrigin(client, ML_ORIGIN, BulkAction.INSTANCE, bulkRequest, bulkResponseActionListener); } public void getTrainedModel(final String modelId, final boolean includeDefinition, final ActionListener listener) { @@ -142,7 +152,7 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio TrainedModelDefinition definition; try { builder = handleSearchItem(multiSearchResponse.getResponses()[0], modelId, this::parseInferenceDocLenientlyFromSource); - } catch(ResourceNotFoundException ex) { + } catch (ResourceNotFoundException ex) { listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.INFERENCE_NOT_FOUND, modelId))); return; @@ -157,7 +167,7 @@ public void getTrainedModel(final String modelId, final boolean includeDefinitio modelId, this::parseModelDefinitionDocLenientlyFromSource); builder.setDefinition(definition); - } catch(ResourceNotFoundException ex) { + } catch (ResourceNotFoundException ex) { listener.onFailure(new ResourceNotFoundException( Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId))); return; @@ -213,20 +223,21 @@ private TrainedModelDefinition parseModelDefinitionDocLenientlyFromSource(BytesR } } - private void indexObject(String docId, ToXContentObject body, ActionListener indexListener) { + private IndexRequest createRequest(String docId, ToXContentObject body) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) { XContentBuilder source = body.toXContent(builder, FOR_INTERNAL_STORAGE_PARAMS); - IndexRequest indexRequest = new IndexRequest(InferenceIndexConstants.LATEST_INDEX_NAME) + return new IndexRequest() .opType(DocWriteRequest.OpType.CREATE) - .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .id(docId) .source(source); - - executeAsyncWithOrigin(client, ML_ORIGIN, IndexAction.INSTANCE, indexRequest, indexListener); } catch (IOException ex) { - // not expected to happen but for the sake of completeness - indexListener.onFailure(ex); + // This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again + // that is not the users fault. We did something wrong and should throw. + throw new ElasticsearchStatusException("Unexpected serialization exception for [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + ex, + docId); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 302dd6084bf3c..ff8fa29f1419a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -6,13 +6,17 @@ package org.elasticsearch.xpack.ml.integration; import org.elasticsearch.Version; +import org.elasticsearch.action.delete.DeleteRequest; +import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.xpack.core.ml.inference.Input; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider; @@ -115,7 +119,7 @@ public void testGetMissingTrainingModelConfig() throws Exception { public void testGetMissingTrainingModelConfigDefinition() throws Exception { String modelId = "test-get-missing-trained-model-config-definition"; - TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).setDefinition((TrainedModelDefinition) null).build(); + TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).build(); AtomicReference putConfigHolder = new AtomicReference<>(); AtomicReference exceptionHolder = new AtomicReference<>(); @@ -123,6 +127,11 @@ public void testGetMissingTrainingModelConfigDefinition() throws Exception { assertThat(putConfigHolder.get(), is(true)); assertThat(exceptionHolder.get(), is(nullValue())); + client().delete(new DeleteRequest(InferenceIndexConstants.LATEST_INDEX_NAME) + .id(TrainedModelDefinition.docId(config.getModelId())) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)) + .actionGet(); + AtomicReference getConfigHolder = new AtomicReference<>(); blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder); assertThat(exceptionHolder.get(), is(not(nullValue()))); @@ -137,7 +146,7 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) - .setInput(new TrainedModelConfig.Input(Stream.generate(() -> randomAlphaOfLength(10)) + .setInput(new Input(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomIntBetween(0, 10)) .collect(Collectors.toList()))); } From 5b3e499a232131ea62fd397f88c6d42572550024 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 30 Oct 2019 10:06:34 -0400 Subject: [PATCH 3/4] renaming input and using exceptions helper --- .../ml/inference/TrainedModelConfig.java | 12 +++++------ .../{Input.java => TrainedModelInput.java} | 12 +++++------ .../ml/inference/TrainedModelConfigTests.java | 2 +- .../core/ml/inference/TrainedModelConfig.java | 14 ++++++------- .../{Input.java => TrainedModelInput.java} | 20 +++++++++---------- .../ml/inference/TrainedModelConfigTests.java | 4 ++-- .../process/AnalyticsResultProcessor.java | 4 ++-- .../persistence/TrainedModelProvider.java | 7 +++---- .../integration/TrainedModelProviderIT.java | 4 ++-- 9 files changed, 39 insertions(+), 40 deletions(-) rename client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/{Input.java => TrainedModelInput.java} (84%) rename x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/{Input.java => TrainedModelInput.java} (73%) diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java index 7d2e347605b8e..273aa6b021325 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java @@ -65,7 +65,7 @@ public class TrainedModelConfig implements ToXContentObject { DEFINITION); PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS); PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); - PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT); + PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT); } public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException { @@ -80,7 +80,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr private final TrainedModelDefinition definition; private final List tags; private final Map metadata; - private final Input input; + private final TrainedModelInput input; TrainedModelConfig(String modelId, String createdBy, @@ -90,7 +90,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr TrainedModelDefinition definition, List tags, Map metadata, - Input input) { + TrainedModelInput input) { this.modelId = modelId; this.createdBy = createdBy; this.version = version; @@ -134,7 +134,7 @@ public TrainedModelDefinition getDefinition() { return definition; } - public Input getInput() { + public TrainedModelInput getInput() { return input; } @@ -221,7 +221,7 @@ public static class Builder { private Map metadata; private List tags; private TrainedModelDefinition definition; - private Input input; + private TrainedModelInput input; public Builder setModelId(String modelId) { this.modelId = modelId; @@ -272,7 +272,7 @@ public Builder setDefinition(TrainedModelDefinition definition) { return this; } - public Builder setInput(Input input) { + public Builder setInput(TrainedModelInput input) { this.input = input; return this; } diff --git a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/Input.java b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java similarity index 84% rename from client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/Input.java rename to client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java index 62f195a2d5bed..10f849cac481a 100644 --- a/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/Input.java +++ b/client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelInput.java @@ -28,15 +28,15 @@ import java.util.List; import java.util.Objects; -public class Input implements ToXContentObject { +public class TrainedModelInput implements ToXContentObject { public static final String NAME = "trained_model_config_input"; public static final ParseField FIELD_NAMES = new ParseField("field_names"); @SuppressWarnings("unchecked") - public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>(NAME, true, - a -> new Input((List) a[0])); + a -> new TrainedModelInput((List) a[0])); static { PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); @@ -44,11 +44,11 @@ public class Input implements ToXContentObject { private final List fieldNames; - public Input(List fieldNames) { + public TrainedModelInput(List fieldNames) { this.fieldNames = fieldNames; } - public static Input fromXContent(XContentParser parser) throws IOException { + public static TrainedModelInput fromXContent(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } @@ -70,7 +70,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - Input that = (Input) o; + TrainedModelInput that = (TrainedModelInput) o; return Objects.equals(fieldNames, that.fieldNames); } diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index cac3d0cdef154..634d5d9edf09f 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -64,7 +64,7 @@ protected TrainedModelConfig createTestInstance() { randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomBoolean() ? null : new Input(Stream.generate(() -> randomAlphaOfLength(10)) + randomBoolean() ? null : new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomLongBetween(1, 10)) .collect(Collectors.toList()))); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java index 10514a3c4e7bc..04eece32b5cd8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java @@ -64,7 +64,7 @@ private static ObjectParser createParser(boole parser.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA); parser.declareString((trainedModelConfig, s) -> {}, InferenceIndexConstants.DOC_TYPE); parser.declareObject(TrainedModelConfig.Builder::setInput, - (p, c) -> Input.fromXContent(p, ignoreUnknownFields), + (p, c) -> TrainedModelInput.fromXContent(p, ignoreUnknownFields), INPUT); return parser; } @@ -80,7 +80,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo private final Instant createTime; private final List tags; private final Map metadata; - private final Input input; + private final TrainedModelInput input; private final TrainedModelDefinition definition; @@ -92,7 +92,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo TrainedModelDefinition definition, List tags, Map metadata, - Input input) { + TrainedModelInput input) { this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY); this.version = ExceptionsHelper.requireNonNull(version, VERSION); @@ -113,7 +113,7 @@ public TrainedModelConfig(StreamInput in) throws IOException { definition = in.readOptionalWriteable(TrainedModelDefinition::new); tags = Collections.unmodifiableList(in.readList(StreamInput::readString)); metadata = in.readMap(); - input = new Input(in); + input = new TrainedModelInput(in); } public String getModelId() { @@ -149,7 +149,7 @@ public TrainedModelDefinition getDefinition() { return definition; } - public Input getInput() { + public TrainedModelInput getInput() { return input; } @@ -240,7 +240,7 @@ public static class Builder { private Instant createTime; private List tags = Collections.emptyList(); private Map metadata; - private Input input; + private TrainedModelInput input; private TrainedModelDefinition definition; public Builder setModelId(String modelId) { @@ -292,7 +292,7 @@ public Builder setDefinition(TrainedModelDefinition definition) { return this; } - public Builder setInput(Input input) { + public Builder setInput(TrainedModelInput input) { this.input = input; return this; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/Input.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelInput.java similarity index 73% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/Input.java rename to x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelInput.java index c052a1322969d..d55dcfc04d324 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/Input.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelInput.java @@ -21,33 +21,33 @@ import java.util.Objects; -public class Input implements ToXContentObject, Writeable { +public class TrainedModelInput implements ToXContentObject, Writeable { public static final String NAME = "trained_model_config_input"; public static final ParseField FIELD_NAMES = new ParseField("field_names"); - public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); - public static final ConstructingObjectParser STRICT_PARSER = createParser(false); + public static final ConstructingObjectParser LENIENT_PARSER = createParser(true); + public static final ConstructingObjectParser STRICT_PARSER = createParser(false); private final List fieldNames; - public Input(List fieldNames) { + public TrainedModelInput(List fieldNames) { this.fieldNames = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(fieldNames, FIELD_NAMES)); } - public Input(StreamInput in) throws IOException { + public TrainedModelInput(StreamInput in) throws IOException { this.fieldNames = Collections.unmodifiableList(in.readStringList()); } @SuppressWarnings("unchecked") - private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { - ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, + private static ConstructingObjectParser createParser(boolean ignoreUnknownFields) { + ConstructingObjectParser parser = new ConstructingObjectParser<>(NAME, ignoreUnknownFields, - a -> new Input((List) a[0])); + a -> new TrainedModelInput((List) a[0])); parser.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES); return parser; } - public static Input fromXContent(XContentParser parser, boolean lenient) throws IOException { + public static TrainedModelInput fromXContent(XContentParser parser, boolean lenient) throws IOException { return lenient ? LENIENT_PARSER.parse(parser, null) : STRICT_PARSER.parse(parser, null); } @@ -72,7 +72,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - Input that = (Input) o; + TrainedModelInput that = (TrainedModelInput) o; return Objects.equals(fieldNames, that.fieldNames); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 03353d717c73c..0af4ce037426f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -74,7 +74,7 @@ protected TrainedModelConfig createTestInstance() { null, // is not parsed so should not be provided tags, randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - new Input(Stream.generate(() -> randomAlphaOfLength(10)) + new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomInt(10)) .collect(Collectors.toList()))); } @@ -109,7 +109,7 @@ public void testToXContentWithParams() throws IOException { TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(), Collections.emptyList(), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - new Input(Stream.generate(() -> randomAlphaOfLength(10)) + new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomInt(10)) .collect(Collectors.toList()))); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 5a61cbec88c4b..eb47c17137a3b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -15,9 +15,9 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig; -import org.elasticsearch.xpack.core.ml.inference.Input; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsTask.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; @@ -150,7 +150,7 @@ private TrainedModelConfig createTrainedModelConfig(TrainedModelDefinition.Build .setMetadata(Collections.singletonMap("analytics_config", XContentHelper.convertToMap(JsonXContent.jsonXContent, analytics.toString(), true))) .setDefinition(definition) - .setInput(new Input(fieldNames)) + .setInput(new TrainedModelInput(fieldNames)) .build(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java index bfca78ac18aea..59f87120b20cd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java @@ -234,10 +234,9 @@ private IndexRequest createRequest(String docId, ToXContentObject body) { } catch (IOException ex) { // This should never happen. If we were able to deserialize the object (from Native or REST) and then fail to serialize it again // that is not the users fault. We did something wrong and should throw. - throw new ElasticsearchStatusException("Unexpected serialization exception for [{}]", - RestStatus.INTERNAL_SERVER_ERROR, - ex, - docId); + throw ExceptionsHelper.serverError( + new ParameterizedMessage("Unexpected serialization exception for [{}]", docId).getFormattedMessage(), + ex); } } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index ff8fa29f1419a..559ee5bdc2014 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -11,11 +11,11 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.search.SearchModule; -import org.elasticsearch.xpack.core.ml.inference.Input; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; @@ -146,7 +146,7 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) - .setInput(new Input(Stream.generate(() -> randomAlphaOfLength(10)) + .setInput(new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) .limit(randomIntBetween(0, 10)) .collect(Collectors.toList()))); } From 126eabf98e3a17f5b912a4932c89f3937e6d8172 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 30 Oct 2019 10:56:25 -0400 Subject: [PATCH 4/4] adding tests --- .../ml/inference/TrainedModelConfigTests.java | 4 +- .../ml/inference/TrainedModelInputTests.java | 58 ++++++++++++++++++ .../ml/inference/TrainedModelConfigTests.java | 9 +-- .../ml/inference/TrainedModelInputTests.java | 59 +++++++++++++++++++ .../integration/TrainedModelProviderIT.java | 8 +-- 5 files changed, 122 insertions(+), 16 deletions(-) create mode 100644 client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelInputTests.java diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java index 634d5d9edf09f..6d1e04f066cb7 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java @@ -64,9 +64,7 @@ protected TrainedModelConfig createTestInstance() { randomBoolean() ? null : Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - randomBoolean() ? null : new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomLongBetween(1, 10)) - .collect(Collectors.toList()))); + randomBoolean() ? null : TrainedModelInputTests.createRandomInput()); } @Override diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java new file mode 100644 index 0000000000000..30b6c46402df4 --- /dev/null +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelInputTests.java @@ -0,0 +1,58 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.client.ml.inference; + +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractXContentTestCase; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TrainedModelInputTests extends AbstractXContentTestCase { + + @Override + protected TrainedModelInput doParseInstance(XContentParser parser) throws IOException { + return TrainedModelInput.fromXContent(parser); + } + + @Override + protected boolean supportsUnknownFields() { + return true; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + public static TrainedModelInput createRandomInput() { + return new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomLongBetween(1, 10)) + .collect(Collectors.toList())); + } + + @Override + protected TrainedModelInput createTestInstance() { + return createRandomInput(); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java index 0af4ce037426f..17c4d01d797f8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java @@ -32,7 +32,6 @@ import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.IntStream; -import java.util.stream.Stream; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -74,9 +73,7 @@ protected TrainedModelConfig createTestInstance() { null, // is not parsed so should not be provided tags, randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomInt(10)) - .collect(Collectors.toList()))); + TrainedModelInputTests.createRandomInput()); } @Override @@ -109,9 +106,7 @@ public void testToXContentWithParams() throws IOException { TrainedModelDefinitionTests.createRandomBuilder(randomAlphaOfLength(10)).build(), Collections.emptyList(), randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)), - new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomInt(10)) - .collect(Collectors.toList()))); + TrainedModelInputTests.createRandomInput()); BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false); assertThat(reference.utf8ToString(), containsString("definition")); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelInputTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelInputTests.java new file mode 100644 index 0000000000000..1949e48dbc052 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelInputTests.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.ml.inference; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.test.AbstractSerializingTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import java.util.stream.Stream; + + +public class TrainedModelInputTests extends AbstractSerializingTestCase { + + private boolean lenient; + + @Before + public void chooseStrictOrLenient() { + lenient = randomBoolean(); + } + + @Override + protected TrainedModelInput doParseInstance(XContentParser parser) throws IOException { + return TrainedModelInput.fromXContent(parser, lenient); + } + + @Override + protected boolean supportsUnknownFields() { + return lenient; + } + + @Override + protected Predicate getRandomFieldsExcludeFilter() { + return field -> !field.isEmpty(); + } + + public static TrainedModelInput createRandomInput() { + return new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) + .limit(randomInt(10)) + .collect(Collectors.toList())); + } + + @Override + protected TrainedModelInput createTestInstance() { + return createRandomInput(); + } + + @Override + protected Writeable.Reader instanceReader() { + return TrainedModelInput::new; + } + +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java index 559ee5bdc2014..75e83acb0b22e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java @@ -15,7 +15,7 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition; import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests; -import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; +import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests; import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.ml.MlSingleNodeTestCase; @@ -26,8 +26,6 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; -import java.util.stream.Stream; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; @@ -146,9 +144,7 @@ private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String .setDescription("trained model config for test") .setModelId(modelId) .setVersion(Version.CURRENT) - .setInput(new TrainedModelInput(Stream.generate(() -> randomAlphaOfLength(10)) - .limit(randomIntBetween(0, 10)) - .collect(Collectors.toList()))); + .setInput(TrainedModelInputTests.createRandomInput()); } private static TrainedModelConfig buildTrainedModelConfig(String modelId) {