Skip to content

Commit 1b3c1d2

Browse files
authored
[ML][Inference] separating definition and config object storage (#48651)
This separates out the definition object from being stored within the configuration object in the index. This allows us to gather the config object without decompressing a potentially large definition. Additionally, input is moved to the TrainedModelConfig object and out of the definition. This is so the trained input fields are accessible outside the potentially large model definition.
1 parent db01555 commit 1b3c1d2

File tree

21 files changed

+677
-297
lines changed

21 files changed

+677
-297
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public class TrainedModelConfig implements ToXContentObject {
4646
public static final ParseField DEFINITION = new ParseField("definition");
4747
public static final ParseField TAGS = new ParseField("tags");
4848
public static final ParseField METADATA = new ParseField("metadata");
49+
public static final ParseField INPUT = new ParseField("input");
4950

5051
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
5152
true,
@@ -64,6 +65,7 @@ public class TrainedModelConfig implements ToXContentObject {
6465
DEFINITION);
6566
PARSER.declareStringArray(TrainedModelConfig.Builder::setTags, TAGS);
6667
PARSER.declareObject(TrainedModelConfig.Builder::setMetadata, (p, c) -> p.map(), METADATA);
68+
PARSER.declareObject(TrainedModelConfig.Builder::setInput, (p, c) -> TrainedModelInput.fromXContent(p), INPUT);
6769
}
6870

6971
public static TrainedModelConfig.Builder fromXContent(XContentParser parser) throws IOException {
@@ -78,6 +80,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
7880
private final TrainedModelDefinition definition;
7981
private final List<String> tags;
8082
private final Map<String, Object> metadata;
83+
private final TrainedModelInput input;
8184

8285
TrainedModelConfig(String modelId,
8386
String createdBy,
@@ -86,7 +89,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
8689
Instant createTime,
8790
TrainedModelDefinition definition,
8891
List<String> tags,
89-
Map<String, Object> metadata) {
92+
Map<String, Object> metadata,
93+
TrainedModelInput input) {
9094
this.modelId = modelId;
9195
this.createdBy = createdBy;
9296
this.version = version;
@@ -95,6 +99,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser) thr
9599
this.description = description;
96100
this.tags = tags == null ? null : Collections.unmodifiableList(tags);
97101
this.metadata = metadata == null ? null : Collections.unmodifiableMap(metadata);
102+
this.input = input;
98103
}
99104

100105
public String getModelId() {
@@ -129,6 +134,10 @@ public TrainedModelDefinition getDefinition() {
129134
return definition;
130135
}
131136

137+
public TrainedModelInput getInput() {
138+
return input;
139+
}
140+
132141
public static Builder builder() {
133142
return new Builder();
134143
}
@@ -160,6 +169,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
160169
if (metadata != null) {
161170
builder.field(METADATA.getPreferredName(), metadata);
162171
}
172+
if (input != null) {
173+
builder.field(INPUT.getPreferredName(), input);
174+
}
163175
builder.endObject();
164176
return builder;
165177
}
@@ -181,6 +193,7 @@ public boolean equals(Object o) {
181193
Objects.equals(createTime, that.createTime) &&
182194
Objects.equals(definition, that.definition) &&
183195
Objects.equals(tags, that.tags) &&
196+
Objects.equals(input, that.input) &&
184197
Objects.equals(metadata, that.metadata);
185198
}
186199

@@ -193,7 +206,8 @@ public int hashCode() {
193206
definition,
194207
description,
195208
tags,
196-
metadata);
209+
metadata,
210+
input);
197211
}
198212

199213

@@ -207,6 +221,7 @@ public static class Builder {
207221
private Map<String, Object> metadata;
208222
private List<String> tags;
209223
private TrainedModelDefinition definition;
224+
private TrainedModelInput input;
210225

211226
public Builder setModelId(String modelId) {
212227
this.modelId = modelId;
@@ -257,6 +272,11 @@ public Builder setDefinition(TrainedModelDefinition definition) {
257272
return this;
258273
}
259274

275+
public Builder setInput(TrainedModelInput input) {
276+
this.input = input;
277+
return this;
278+
}
279+
260280
public TrainedModelConfig build() {
261281
return new TrainedModelConfig(
262282
modelId,
@@ -266,7 +286,9 @@ public TrainedModelConfig build() {
266286
createTime,
267287
definition,
268288
tags,
269-
metadata);
289+
metadata,
290+
input);
270291
}
271292
}
293+
272294
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelDefinition.java

Lines changed: 4 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import org.elasticsearch.client.ml.inference.trainedmodel.TrainedModel;
2323
import org.elasticsearch.common.ParseField;
2424
import org.elasticsearch.common.Strings;
25-
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
2625
import org.elasticsearch.common.xcontent.ObjectParser;
2726
import org.elasticsearch.common.xcontent.ToXContentObject;
2827
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -39,7 +38,6 @@ public class TrainedModelDefinition implements ToXContentObject {
3938

4039
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
4140
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
42-
public static final ParseField INPUT = new ParseField("input");
4341

4442
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
4543
true,
@@ -53,7 +51,6 @@ public class TrainedModelDefinition implements ToXContentObject {
5351
(p, c, n) -> p.namedObject(PreProcessor.class, n, null),
5452
(trainedModelDefBuilder) -> {/* Does not matter client side*/ },
5553
PREPROCESSORS);
56-
PARSER.declareObject(TrainedModelDefinition.Builder::setInput, (p, c) -> Input.fromXContent(p), INPUT);
5754
}
5855

5956
public static TrainedModelDefinition.Builder fromXContent(XContentParser parser) throws IOException {
@@ -62,12 +59,10 @@ public static TrainedModelDefinition.Builder fromXContent(XContentParser parser)
6259

6360
private final TrainedModel trainedModel;
6461
private final List<PreProcessor> preProcessors;
65-
private final Input input;
6662

67-
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors, Input input) {
63+
TrainedModelDefinition(TrainedModel trainedModel, List<PreProcessor> preProcessors) {
6864
this.trainedModel = trainedModel;
6965
this.preProcessors = preProcessors == null ? Collections.emptyList() : Collections.unmodifiableList(preProcessors);
70-
this.input = input;
7166
}
7267

7368
@Override
@@ -83,9 +78,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
8378
true,
8479
PREPROCESSORS.getPreferredName(),
8580
preProcessors);
86-
if (input != null) {
87-
builder.field(INPUT.getPreferredName(), input);
88-
}
8981
builder.endObject();
9082
return builder;
9183
}
@@ -98,10 +90,6 @@ public List<PreProcessor> getPreProcessors() {
9890
return preProcessors;
9991
}
10092

101-
public Input getInput() {
102-
return input;
103-
}
104-
10593
@Override
10694
public String toString() {
10795
return Strings.toString(this);
@@ -113,20 +101,18 @@ public boolean equals(Object o) {
113101
if (o == null || getClass() != o.getClass()) return false;
114102
TrainedModelDefinition that = (TrainedModelDefinition) o;
115103
return Objects.equals(trainedModel, that.trainedModel) &&
116-
Objects.equals(preProcessors, that.preProcessors) &&
117-
Objects.equals(input, that.input);
104+
Objects.equals(preProcessors, that.preProcessors);
118105
}
119106

120107
@Override
121108
public int hashCode() {
122-
return Objects.hash(trainedModel, preProcessors, input);
109+
return Objects.hash(trainedModel, preProcessors);
123110
}
124111

125112
public static class Builder {
126113

127114
private List<PreProcessor> preProcessors;
128115
private TrainedModel trainedModel;
129-
private Input input;
130116

131117
public Builder setPreProcessors(List<PreProcessor> preProcessors) {
132118
this.preProcessors = preProcessors;
@@ -138,71 +124,14 @@ public Builder setTrainedModel(TrainedModel trainedModel) {
138124
return this;
139125
}
140126

141-
public Builder setInput(Input input) {
142-
this.input = input;
143-
return this;
144-
}
145-
146127
private Builder setTrainedModel(List<TrainedModel> trainedModel) {
147128
assert trainedModel.size() == 1;
148129
return setTrainedModel(trainedModel.get(0));
149130
}
150131

151132
public TrainedModelDefinition build() {
152-
return new TrainedModelDefinition(this.trainedModel, this.preProcessors, this.input);
153-
}
154-
}
155-
156-
public static class Input implements ToXContentObject {
157-
158-
public static final String NAME = "trained_mode_definition_input";
159-
public static final ParseField FIELD_NAMES = new ParseField("field_names");
160-
161-
@SuppressWarnings("unchecked")
162-
public static final ConstructingObjectParser<Input, Void> PARSER = new ConstructingObjectParser<>(NAME,
163-
true,
164-
a -> new Input((List<String>)a[0]));
165-
static {
166-
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
167-
}
168-
169-
public static Input fromXContent(XContentParser parser) throws IOException {
170-
return PARSER.parse(parser, null);
133+
return new TrainedModelDefinition(this.trainedModel, this.preProcessors);
171134
}
172-
173-
private final List<String> fieldNames;
174-
175-
public Input(List<String> fieldNames) {
176-
this.fieldNames = fieldNames;
177-
}
178-
179-
public List<String> getFieldNames() {
180-
return fieldNames;
181-
}
182-
183-
@Override
184-
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
185-
builder.startObject();
186-
if (fieldNames != null) {
187-
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
188-
}
189-
builder.endObject();
190-
return builder;
191-
}
192-
193-
@Override
194-
public boolean equals(Object o) {
195-
if (this == o) return true;
196-
if (o == null || getClass() != o.getClass()) return false;
197-
TrainedModelDefinition.Input that = (TrainedModelDefinition.Input) o;
198-
return Objects.equals(fieldNames, that.fieldNames);
199-
}
200-
201-
@Override
202-
public int hashCode() {
203-
return Objects.hash(fieldNames);
204-
}
205-
206135
}
207136

208137
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
package org.elasticsearch.client.ml.inference;
20+
21+
import org.elasticsearch.common.ParseField;
22+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
23+
import org.elasticsearch.common.xcontent.ToXContentObject;
24+
import org.elasticsearch.common.xcontent.XContentBuilder;
25+
import org.elasticsearch.common.xcontent.XContentParser;
26+
27+
import java.io.IOException;
28+
import java.util.List;
29+
import java.util.Objects;
30+
31+
public class TrainedModelInput implements ToXContentObject {
32+
33+
public static final String NAME = "trained_model_config_input";
34+
public static final ParseField FIELD_NAMES = new ParseField("field_names");
35+
36+
@SuppressWarnings("unchecked")
37+
public static final ConstructingObjectParser<TrainedModelInput, Void> PARSER = new ConstructingObjectParser<>(NAME,
38+
true,
39+
a -> new TrainedModelInput((List<String>) a[0]));
40+
41+
static {
42+
PARSER.declareStringArray(ConstructingObjectParser.constructorArg(), FIELD_NAMES);
43+
}
44+
45+
private final List<String> fieldNames;
46+
47+
public TrainedModelInput(List<String> fieldNames) {
48+
this.fieldNames = fieldNames;
49+
}
50+
51+
public static TrainedModelInput fromXContent(XContentParser parser) throws IOException {
52+
return PARSER.parse(parser, null);
53+
}
54+
55+
public List<String> getFieldNames() {
56+
return fieldNames;
57+
}
58+
59+
@Override
60+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
61+
builder.startObject();
62+
if (fieldNames != null) {
63+
builder.field(FIELD_NAMES.getPreferredName(), fieldNames);
64+
}
65+
builder.endObject();
66+
return builder;
67+
}
68+
69+
@Override
70+
public boolean equals(Object o) {
71+
if (this == o) return true;
72+
if (o == null || getClass() != o.getClass()) return false;
73+
TrainedModelInput that = (TrainedModelInput) o;
74+
return Objects.equals(fieldNames, that.fieldNames);
75+
}
76+
77+
@Override
78+
public int hashCode() {
79+
return Objects.hash(fieldNames);
80+
}
81+
82+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ protected TrainedModelConfig createTestInstance() {
6363
randomBoolean() ? null : TrainedModelDefinitionTests.createRandomBuilder().build(),
6464
randomBoolean() ? null :
6565
Stream.generate(() -> randomAlphaOfLength(10)).limit(randomIntBetween(0, 5)).collect(Collectors.toList()),
66-
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)));
66+
randomBoolean() ? null : Collections.singletonMap(randomAlphaOfLength(10), randomAlphaOfLength(10)),
67+
randomBoolean() ? null : TrainedModelInputTests.createRandomInput());
6768
}
6869

6970
@Override

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelDefinitionTests.java

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,7 @@ public static TrainedModelDefinition.Builder createRandomBuilder() {
6464
TargetMeanEncodingTests.createRandom()))
6565
.limit(numberOfProcessors)
6666
.collect(Collectors.toList()))
67-
.setTrainedModel(randomFrom(TreeTests.createRandom()))
68-
.setInput(new TrainedModelDefinition.Input(Stream.generate(() -> randomAlphaOfLength(10))
69-
.limit(randomLongBetween(1, 10))
70-
.collect(Collectors.toList())));
67+
.setTrainedModel(randomFrom(TreeTests.createRandom()));
7168
}
7269

7370
@Override

0 commit comments

Comments
 (0)