Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public class TrainedModelConfig implements ToXContentObject {
public static final ParseField DEFINITION = new ParseField("definition");
public static final ParseField TAGS = new ParseField("tags");
public static final ParseField METADATA = new ParseField("metadata");
public static final ParseField INPUT = new ParseField("input");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
Expand All @@ -64,6 +65,7 @@ public class TrainedModelConfig implements ToXContentObject {
DEFINITION);
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
}

public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
Expand All @@ -78,6 +80,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
private final TrainedModelDefinition definition;
private final List<String> tags;
private final Map<String, Object> metadata;
private final TrainedModelInput input;

TrainedModelConfig(String modelId,
String createdBy,
Expand All @@ -86,7 +89,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
Instant createTime,
TrainedModelDefinition definition,
List<String> tags,
Map<String, Object> metadata) {
Map<String, Object> metadata,
TrainedModelInput input) {
this.modelId = modelId;
this.createdBy = createdBy;
this.version = version;
Expand All @@ -95,6 +99,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
this.description = description;
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
this.input = input;
}

public String getModelId() {
Expand Down Expand Up @@ -129,6 +134,10 @@ public TrainedModelDefinition getDefinition() {
return definition;
}

public TrainedModelInput getInput() {
return input;
}

public static Builder builder() {
return new Builder();
}
Expand Down Expand Up @@ -160,6 +169,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (metadata != null) {
builder.field(METADATA.getPreferredName(), metadata);
}
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
builder.endObject();
return builder;
}
Expand All @@ -181,6 +193,7 @@ public boolean equals(Object o) {
Objects.equals(createTime, that.createTime) &&
Objects.equals(definition, that.definition) &&
Objects.equals(tags, that.tags) &&
Objects.equals(input, that.input) &&
Objects.equals(metadata, that.metadata);
}

Expand All @@ -193,7 +206,8 @@ public int hashCode() {
definition,
description,
tags,
metadata);
metadata,
input);
}


Expand All @@ -207,6 +221,7 @@ public static class Builder {
private Map<String, Object> metadata;
private List<String> tags;
private TrainedModelDefinition definition;
private TrainedModelInput input;

public Builder setModelId(String modelId) {
this.modelId = modelId;
Expand Down Expand Up @@ -257,6 +272,11 @@ public Builder setDefinition(TrainedModelDefinition definition) {
return this;
}

public Builder setInput(TrainedModelInput input) {
this.input = input;
return this;
}

public TrainedModelConfig build() {
return new TrainedModelConfig(
modelId,
Expand All @@ -266,7 +286,9 @@ public TrainedModelConfig build() {
createTime,
definition,
tags,
metadata);
metadata,
input);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
Expand All @@ -39,7 +38,6 @@ public class TrainedModelDefinition implements ToXContentObject {

public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
public static final ParseField INPUT = new ParseField("input");

public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
true,
Expand All @@ -53,7 +51,6 @@ public class TrainedModelDefinition implements ToXContentObject {
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
PREPROCESSORS);
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
}

public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
Expand All @@ -62,12 +59,10 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser)

private final TrainedModel trainedModel;
private final List<PreProcessor> preProcessors;
private final Input input;

TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
this.trainedModel = trainedModel;
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
this.input = input;
}

@Override
Expand All @@ -83,9 +78,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
true,
PREPROCESSORS.getPreferredName(),
preProcessors);
if (input != null) {
builder.field(INPUT.getPreferredName(), input);
}
builder.endObject();
return builder;
}
Expand All @@ -98,10 +90,6 @@ public List<PreProcessor> getPreProcessors() {
return preProcessors;
}

public Input getInput() {
return input;
}

@Override
public String toString() {
return Strings.toString(this);
Expand All @@ -113,20 +101,18 @@ public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition that = (TrainedModelDefinition) o;
return Objects.equals(trainedModel, that.trainedModel) &&
Objects.equals(preProcessors, that.preProcessors) &&
Objects.equals(input, that.input);
Objects.equals(preProcessors, that.preProcessors);
}

@Override
public int hashCode() {
return Objects.hash(trainedModel, preProcessors, input);
return Objects.hash(trainedModel, preProcessors);
}

public static class Builder {

private List<PreProcessor> preProcessors;
private TrainedModel trainedModel;
private Input input;

public Builder setPreProcessors(List<PreProcessor> preProcessors) {
this.preProcessors = preProcessors;
Expand All @@ -138,71 +124,14 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
return this;
}

public Builder setInput(Input input) {
this.input = input;
return this;
}

private Builder setTrainedModel(List<TrainedModel> trainedModel) {
assert trainedModel.size() == 1;
return setTrainedModel(trainedModel.get(0));
}

public TrainedModelDefinition build() {
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
}
}

public static class Input implements ToXContentObject {

public static final String NAME = "trained_mode_definition_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new Input((List<String>)a[0]));
static {
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
}

public static Input fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
}

private final List<String> fieldNames;

public Input(List<String> fieldNames) {
this.fieldNames = fieldNames;
}

public List<String> getFieldNames() {
return fieldNames;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldNames != null) {
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
return Objects.equals(fieldNames, that.fieldNames);
}

@Override
public int hashCode() {
return Objects.hash(fieldNames);
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.client.ml.inference;

import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public class TrainedModelInput implements ToXContentObject {

public static final String NAME = "trained_model_config_input";
public static final ParseField FIELD_NAMES = new ParseField("field_names");

@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<TrainedModelInput, Void> PARSER = new ConstructingObjectParser<>(NAME,
true,
a -> new TrainedModelInput((List<String>) a[0]));

static {
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
}

private final List<String> fieldNames;

public TrainedModelInput(List<String> fieldNames) {
this.fieldNames = fieldNames;
}

public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

public List<String> getFieldNames() {
return fieldNames;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (fieldNames != null) {
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
}
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
TrainedModelInput that = (TrainedModelInput) o;
return Objects.equals(fieldNames, that.fieldNames);
}

@Override
public int hashCode() {
return Objects.hash(fieldNames);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ protected TrainedModelConfig createTestInstance() {
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
randomBoolean() ? null :
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ public static TrainedModelDefinition.Builder createRandomBuilder() {
TargetMeanEncodingTests.createRandom()))
.limit(numberOfProcessors)
.collect(Collectors.toList()))
.setTrainedModel(randomFrom(TreeTests.createRandom()))
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
.limit(randomLongBetween(1, 10))
.collect(Collectors.toList())));
.setTrainedModel(randomFrom(TreeTests.createRandom()));
}

@Override
Expand Down
Loading