Skip to content

Commit 4e1f029

Browse files
authored
[ML][Inference] adds new default_field_map field to trained models (#53294)
Adds a new `default_field_map` field to trained model config objects. This allows the model creator to supply field map if it knows that there should be some map for inference to work directly against the training data. The use case internally is having analytics jobs supply a field mapping for multi-field fields. This allows us to use the model "out of the box" on data where we trained on `foo.keyword` but the `_source` only references `foo`.
1 parent ab66529 commit 4e1f029

File tree

19 files changed

+226
-41
lines changed

19 files changed

+226
-41
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/TrainedModelConfig.java

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ public class TrainedModelConfig implements ToXContentObject {
5353
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
5454
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
5555
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
56+
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
5657

5758
public static final ObjectParser<Builder, Void> PARSER = new ObjectParser<>(NAME,
5859
true,
@@ -76,6 +77,7 @@ public class TrainedModelConfig implements ToXContentObject {
7677
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedHeapMemory, ESTIMATED_HEAP_MEMORY_USAGE_BYTES);
7778
PARSER.declareLong(TrainedModelConfig.Builder::setEstimatedOperations, ESTIMATED_OPERATIONS);
7879
PARSER.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
80+
PARSER.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
7981
}
8082

8183
public static TrainedModelConfig fromXContent(XContentParser parser) throws IOException {
@@ -95,6 +97,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
9597
private final Long estimatedHeapMemory;
9698
private final Long estimatedOperations;
9799
private final String licenseLevel;
100+
private final Map<String, String> defaultFieldMap;
98101

99102
TrainedModelConfig(String modelId,
100103
String createdBy,
@@ -108,7 +111,8 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
108111
TrainedModelInput input,
109112
Long estimatedHeapMemory,
110113
Long estimatedOperations,
111-
String licenseLevel) {
114+
String licenseLevel,
115+
Map<String, String> defaultFieldMap) {
112116
this.modelId = modelId;
113117
this.createdBy = createdBy;
114118
this.version = version;
@@ -122,6 +126,7 @@ public static TrainedModelConfig fromXContent(XContentParser parser) throws IOEx
122126
this.estimatedHeapMemory = estimatedHeapMemory;
123127
this.estimatedOperations = estimatedOperations;
124128
this.licenseLevel = licenseLevel;
129+
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
125130
}
126131

127132
public String getModelId() {
@@ -180,6 +185,10 @@ public String getLicenseLevel() {
180185
return licenseLevel;
181186
}
182187

188+
public Map<String, String> getDefaultFieldMap() {
189+
return defaultFieldMap;
190+
}
191+
183192
public static Builder builder() {
184193
return new Builder();
185194
}
@@ -226,6 +235,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
226235
if (licenseLevel != null) {
227236
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel);
228237
}
238+
if (defaultFieldMap != null) {
239+
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
240+
}
229241
builder.endObject();
230242
return builder;
231243
}
@@ -252,6 +264,7 @@ public boolean equals(Object o) {
252264
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
253265
Objects.equals(estimatedOperations, that.estimatedOperations) &&
254266
Objects.equals(licenseLevel, that.licenseLevel) &&
267+
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
255268
Objects.equals(metadata, that.metadata);
256269
}
257270

@@ -269,7 +282,8 @@ public int hashCode() {
269282
estimatedOperations,
270283
metadata,
271284
licenseLevel,
272-
input);
285+
input,
286+
defaultFieldMap);
273287
}
274288

275289

@@ -288,6 +302,7 @@ public static class Builder {
288302
private Long estimatedHeapMemory;
289303
private Long estimatedOperations;
290304
private String licenseLevel;
305+
private Map<String, String> defaultFieldMap;
291306

292307
public Builder setModelId(String modelId) {
293308
this.modelId = modelId;
@@ -367,6 +382,11 @@ private Builder setLicenseLevel(String licenseLevel) {
367382
return this;
368383
}
369384

385+
public Builder setDefaultFieldMap(Map<String, String> defaultFieldMap) {
386+
this.defaultFieldMap = defaultFieldMap;
387+
return this;
388+
}
389+
370390
public TrainedModelConfig build() {
371391
return new TrainedModelConfig(
372392
modelId,
@@ -381,7 +401,8 @@ public TrainedModelConfig build() {
381401
input,
382402
estimatedHeapMemory,
383403
estimatedOperations,
384-
licenseLevel);
404+
licenseLevel,
405+
defaultFieldMap);
385406
}
386407
}
387408

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/TrainedModelConfigTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.ArrayList;
3131
import java.util.Collections;
3232
import java.util.List;
33+
import java.util.function.Function;
3334
import java.util.function.Predicate;
3435
import java.util.stream.Collectors;
3536
import java.util.stream.Stream;
@@ -52,7 +53,11 @@ public static TrainedModelConfig createTestTrainedModelConfig() {
5253
randomBoolean() ? null : TrainedModelInputTests.createRandomInput(),
5354
randomBoolean() ? null : randomNonNegativeLong(),
5455
randomBoolean() ? null : randomNonNegativeLong(),
55-
randomBoolean() ? null : randomFrom("platinum", "basic"));
56+
randomBoolean() ? null : randomFrom("platinum", "basic"),
57+
randomBoolean() ? null :
58+
Stream.generate(() -> randomAlphaOfLength(10))
59+
.limit(randomIntBetween(1, 10))
60+
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
5661
}
5762

5863
@Override

docs/reference/ingest/processors/inference.asciidoc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ ingested in the pipeline.
1414
| Name | Required | Default | Description
1515
| `model_id` | yes | - | (String) The ID of the model to load and infer against.
1616
| `target_field` | no | `ml.inference.<processor_tag>` | (String) Field added to incoming documents to contain results objects.
17-
| `field_mappings` | yes | - | (Object) Maps the document field names to the known field names of the model.
17+
| `field_mappings` | yes | - | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration.
1818
| `inference_config` | yes | - | (Object) Contains the inference type and its options. There are two types: <<inference-processor-regression-opt,`regression`>> and <<inference-processor-classification-opt,`classification`>>.
1919
include::common-options.asciidoc[]
2020
|======

docs/reference/ml/ml-shared.asciidoc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,17 @@ The estimated number of operations to use the trained model.
15061506
`license_level`:::
15071507
(string)
15081508
The license level of the trained model.
1509+
1510+
`default_field_map` :::
1511+
(object)
1512+
A string to string object that contains the default field map to use
1513+
when inferring against the model. For example, data frame analytics
1514+
may train the model on a specific multi-field `foo.keyword`.
1515+
The analytics job would then supply a default field map entry for
1516+
`"foo" : "foo.keyword"`.
1517+
1518+
Any field map described in the inference configuration takes precedence.
1519+
15091520
end::trained-model-configs[]
15101521

15111522
tag::training-percent[]

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.io.IOException;
3232
import java.time.Instant;
3333
import java.util.Collections;
34+
import java.util.HashMap;
3435
import java.util.List;
3536
import java.util.Map;
3637
import java.util.Objects;
@@ -60,6 +61,7 @@ public class TrainedModelConfig implements ToXContentObject, Writeable {
6061
public static final ParseField ESTIMATED_HEAP_MEMORY_USAGE_BYTES = new ParseField("estimated_heap_memory_usage_bytes");
6162
public static final ParseField ESTIMATED_OPERATIONS = new ParseField("estimated_operations");
6263
public static final ParseField LICENSE_LEVEL = new ParseField("license_level");
64+
public static final ParseField DEFAULT_FIELD_MAP = new ParseField("default_field_map");
6365

6466
// These parsers follow the pattern that metadata is parsed leniently (to allow for enhancements), whilst config is parsed strictly
6567
public static final ObjectParser<TrainedModelConfig.Builder, Void> LENIENT_PARSER = createParser(true);
@@ -90,6 +92,7 @@ private static ObjectParser<TrainedModelConfig.Builder, Void> createParser(boole
9092
DEFINITION);
9193
parser.declareString(TrainedModelConfig.Builder::setLazyDefinition, COMPRESSED_DEFINITION);
9294
parser.declareString(TrainedModelConfig.Builder::setLicenseLevel, LICENSE_LEVEL);
95+
parser.declareObject(TrainedModelConfig.Builder::setDefaultFieldMap, (p, c) -> p.mapStrings(), DEFAULT_FIELD_MAP);
9396
return parser;
9497
}
9598

@@ -108,6 +111,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
108111
private final long estimatedHeapMemory;
109112
private final long estimatedOperations;
110113
private final License.OperationMode licenseLevel;
114+
private final Map<String, String> defaultFieldMap;
111115

112116
private final LazyModelDefinition definition;
113117

@@ -122,7 +126,8 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
122126
TrainedModelInput input,
123127
Long estimatedHeapMemory,
124128
Long estimatedOperations,
125-
String licenseLevel) {
129+
String licenseLevel,
130+
Map<String, String> defaultFieldMap) {
126131
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
127132
this.createdBy = ExceptionsHelper.requireNonNull(createdBy, CREATED_BY);
128133
this.version = ExceptionsHelper.requireNonNull(version, VERSION);
@@ -142,6 +147,7 @@ public static TrainedModelConfig.Builder fromXContent(XContentParser parser, boo
142147
}
143148
this.estimatedOperations = estimatedOperations;
144149
this.licenseLevel = License.OperationMode.parse(ExceptionsHelper.requireNonNull(licenseLevel, LICENSE_LEVEL));
150+
this.defaultFieldMap = defaultFieldMap == null ? null : Collections.unmodifiableMap(defaultFieldMap);
145151
}
146152

147153
public TrainedModelConfig(StreamInput in) throws IOException {
@@ -157,6 +163,13 @@ public TrainedModelConfig(StreamInput in) throws IOException {
157163
estimatedHeapMemory = in.readVLong();
158164
estimatedOperations = in.readVLong();
159165
licenseLevel = License.OperationMode.parse(in.readString());
166+
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
167+
this.defaultFieldMap = in.readBoolean() ?
168+
Collections.unmodifiableMap(in.readMap(StreamInput::readString, StreamInput::readString)) :
169+
null;
170+
} else {
171+
this.defaultFieldMap = null;
172+
}
160173
}
161174

162175
public String getModelId() {
@@ -187,6 +200,10 @@ public Map<String, Object> getMetadata() {
187200
return metadata;
188201
}
189202

203+
public Map<String, String> getDefaultFieldMap() {
204+
return defaultFieldMap;
205+
}
206+
190207
@Nullable
191208
public String getCompressedDefinition() throws IOException {
192209
if (definition == null) {
@@ -249,6 +266,14 @@ public void writeTo(StreamOutput out) throws IOException {
249266
out.writeVLong(estimatedHeapMemory);
250267
out.writeVLong(estimatedOperations);
251268
out.writeString(licenseLevel.description());
269+
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
270+
if (defaultFieldMap != null) {
271+
out.writeBoolean(true);
272+
out.writeMap(defaultFieldMap, StreamOutput::writeString, StreamOutput::writeString);
273+
} else {
274+
out.writeBoolean(false);
275+
}
276+
}
252277
}
253278

254279
@Override
@@ -283,6 +308,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
283308
new ByteSizeValue(estimatedHeapMemory));
284309
builder.field(ESTIMATED_OPERATIONS.getPreferredName(), estimatedOperations);
285310
builder.field(LICENSE_LEVEL.getPreferredName(), licenseLevel.description());
311+
if (defaultFieldMap != null && defaultFieldMap.isEmpty() == false) {
312+
builder.field(DEFAULT_FIELD_MAP.getPreferredName(), defaultFieldMap);
313+
}
286314
builder.endObject();
287315
return builder;
288316
}
@@ -308,6 +336,7 @@ public boolean equals(Object o) {
308336
Objects.equals(estimatedHeapMemory, that.estimatedHeapMemory) &&
309337
Objects.equals(estimatedOperations, that.estimatedOperations) &&
310338
Objects.equals(licenseLevel, that.licenseLevel) &&
339+
Objects.equals(defaultFieldMap, that.defaultFieldMap) &&
311340
Objects.equals(metadata, that.metadata);
312341
}
313342

@@ -324,7 +353,8 @@ public int hashCode() {
324353
estimatedHeapMemory,
325354
estimatedOperations,
326355
input,
327-
licenseLevel);
356+
licenseLevel,
357+
defaultFieldMap);
328358
}
329359

330360
public static class Builder {
@@ -341,6 +371,7 @@ public static class Builder {
341371
private Long estimatedOperations;
342372
private LazyModelDefinition definition;
343373
private String licenseLevel;
374+
private Map<String, String> defaultFieldMap;
344375

345376
public Builder() {}
346377

@@ -357,6 +388,7 @@ public Builder(TrainedModelConfig config) {
357388
this.estimatedOperations = config.estimatedOperations;
358389
this.estimatedHeapMemory = config.estimatedHeapMemory;
359390
this.licenseLevel = config.licenseLevel.description();
391+
this.defaultFieldMap = config.defaultFieldMap == null ? null : new HashMap<>(config.defaultFieldMap);
360392
}
361393

362394
public Builder setModelId(String modelId) {
@@ -475,6 +507,11 @@ public Builder setLicenseLevel(String licenseLevel) {
475507
return this;
476508
}
477509

510+
public Builder setDefaultFieldMap(Map<String, String> defaultFieldMap) {
511+
this.defaultFieldMap = defaultFieldMap;
512+
return this;
513+
}
514+
478515
public Builder validate() {
479516
return validate(false);
480517
}
@@ -567,7 +604,8 @@ public TrainedModelConfig build() {
567604
input,
568605
estimatedHeapMemory == null ? 0 : estimatedHeapMemory,
569606
estimatedOperations == null ? 0 : estimatedOperations,
570-
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel);
607+
licenseLevel == null ? License.OperationMode.PLATINUM.description() : licenseLevel,
608+
defaultFieldMap);
571609
}
572610
}
573611

x-pack/plugin/core/src/main/resources/org/elasticsearch/xpack/core/ml/inference_index_template.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@
6464
},
6565
"total_definition_length": {
6666
"type": "long"
67+
},
68+
"default_field_map": {
69+
"enabled": false
6770
}
6871
}
6972
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfigTests.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
import java.util.Collections;
3535
import java.util.List;
3636
import java.util.Map;
37+
import java.util.function.Function;
3738
import java.util.function.Predicate;
3839
import java.util.stream.Collectors;
3940
import java.util.stream.IntStream;
41+
import java.util.stream.Stream;
4042

4143
import static org.elasticsearch.test.AbstractXContentTestCase.xContentTester;
4244
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
@@ -137,7 +139,11 @@ public void testToXContentWithParams() throws IOException {
137139
TrainedModelInputTests.createRandomInput(),
138140
randomNonNegativeLong(),
139141
randomNonNegativeLong(),
140-
"platinum");
142+
"platinum",
143+
randomBoolean() ? null :
144+
Stream.generate(() -> randomAlphaOfLength(10))
145+
.limit(randomIntBetween(1, 10))
146+
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
141147

142148
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
143149
assertThat(reference.utf8ToString(), containsString("\"compressed_definition\""));
@@ -172,7 +178,11 @@ public void testParseWithBothDefinitionAndCompressedSupplied() throws IOExceptio
172178
TrainedModelInputTests.createRandomInput(),
173179
randomNonNegativeLong(),
174180
randomNonNegativeLong(),
175-
"platinum");
181+
"platinum",
182+
randomBoolean() ? null :
183+
Stream.generate(() -> randomAlphaOfLength(10))
184+
.limit(randomIntBetween(1, 10))
185+
.collect(Collectors.toMap(Function.identity(), (k) -> randomAlphaOfLength(10))));
176186

177187
BytesReference reference = XContentHelper.toXContent(config, XContentType.JSON, ToXContent.EMPTY_PARAMS, false);
178188
Map<String, Object> objectMap = XContentHelper.convertToMap(reference, true, XContentType.JSON).v2();

0 commit comments

Comments
 (0)