Skip to content

Commit c64e283

Browse files
authored
[7.x] [ML] handles compressed model stream from native process (#58009) (#58836)
* [ML] handles compressed model stream from native process (#58009) This moves model storage from handling the fully parsed JSON string to handling two separate types of documents. 1. ModelSizeInfo which contains model size information 2. TrainedModelDefinitionChunk which contains a particular chunk of the compressed model definition string. `model_size_info` is assumed to be handled first. This will generate the model_id and store the initial trained model config object. Then each chunk is assumed to be in correct order for concatenating the chunks to get a compressed definition. Native side change: elastic/ml-cpp#1349
1 parent 9c77862 commit c64e283

File tree

14 files changed

+853
-286
lines changed

14 files changed

+853
-286
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ public final class Messages {
8989
" (a-z and 0-9), hyphens or underscores; must start and end with alphanumeric";
9090

9191
public static final String INFERENCE_TRAINED_MODEL_EXISTS = "Trained machine learning model [{0}] already exists";
92+
public static final String INFERENCE_TRAINED_MODEL_DOC_EXISTS = "Trained machine learning model chunked doc [{0}][{1}] already exists";
9293
public static final String INFERENCE_FAILED_TO_STORE_MODEL = "Failed to store trained machine learning model [{0}]";
9394
public static final String INFERENCE_NOT_FOUND = "Could not find trained model [{0}]";
9495
public static final String INFERENCE_NOT_FOUND_MULTIPLE = "Could not find trained models {0}";

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.elasticsearch.xpack.ml.integration;
77

88
import org.apache.logging.log4j.message.ParameterizedMessage;
9-
import org.apache.lucene.util.LuceneTestCase;
109
import org.elasticsearch.ElasticsearchException;
1110
import org.elasticsearch.ElasticsearchStatusException;
1211
import org.elasticsearch.action.ActionModule;
@@ -67,7 +66,6 @@
6766
import static org.hamcrest.Matchers.nullValue;
6867
import static org.hamcrest.Matchers.startsWith;
6968

70-
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
7169
public class ClassificationIT extends MlNativeDataFrameAnalyticsIntegTestCase {
7270

7371
private static final String BOOLEAN_FIELD = "boolean-field";

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
*/
66
package org.elasticsearch.xpack.ml.integration;
77

8-
import org.apache.lucene.util.LuceneTestCase;
98
import org.elasticsearch.ElasticsearchException;
109
import org.elasticsearch.action.ActionModule;
1110
import org.elasticsearch.action.DocWriteRequest;
@@ -45,7 +44,6 @@
4544
import static org.hamcrest.Matchers.lessThan;
4645
import static org.hamcrest.Matchers.not;
4746

48-
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/1349")
4947
public class RegressionIT extends MlNativeDataFrameAnalyticsIntegTestCase {
5048

5149
private static final String NUMERICAL_FEATURE_FIELD = "feature";
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License;
4+
* you may not use this file except in compliance with the Elastic License.
5+
*/
6+
package org.elasticsearch.xpack.ml.integration;
7+
8+
import org.elasticsearch.ElasticsearchException;
9+
import org.elasticsearch.Version;
10+
import org.elasticsearch.action.support.PlainActionFuture;
11+
import org.elasticsearch.common.collect.Tuple;
12+
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
13+
import org.elasticsearch.license.License;
14+
import org.elasticsearch.xpack.core.action.util.PageParams;
15+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
16+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsDest;
17+
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource;
18+
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
19+
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
20+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
21+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
22+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
23+
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInputTests;
24+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
25+
import org.elasticsearch.xpack.ml.MlSingleNodeTestCase;
26+
import org.elasticsearch.xpack.ml.dataframe.process.ChunkedTrainedModelPersister;
27+
import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk;
28+
import org.elasticsearch.xpack.ml.extractor.DocValueField;
29+
import org.elasticsearch.xpack.ml.extractor.ExtractedField;
30+
import org.elasticsearch.xpack.ml.extractor.ExtractedFields;
31+
import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider;
32+
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo;
33+
import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests;
34+
import org.elasticsearch.xpack.ml.inference.persistence.TrainedModelProvider;
35+
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
36+
import org.junit.Before;
37+
38+
import java.io.IOException;
39+
import java.util.ArrayList;
40+
import java.util.Collections;
41+
import java.util.List;
42+
import java.util.Set;
43+
44+
import static org.hamcrest.Matchers.equalTo;
45+
46+
public class ChunkedTrainedModelPersisterIT extends MlSingleNodeTestCase {
47+
48+
private TrainedModelProvider trainedModelProvider;
49+
50+
@Before
51+
public void createComponents() throws Exception {
52+
trainedModelProvider = new TrainedModelProvider(client(), xContentRegistry());
53+
waitForMlTemplates();
54+
}
55+
56+
public void testStoreModelViaChunkedPersister() throws IOException {
57+
String modelId = "stored-chunked-model";
58+
DataFrameAnalyticsConfig analyticsConfig = new DataFrameAnalyticsConfig.Builder()
59+
.setId(modelId)
60+
.setSource(new DataFrameAnalyticsSource(new String[] {"my_source"}, null, null))
61+
.setDest(new DataFrameAnalyticsDest("my_dest", null))
62+
.setAnalysis(new Regression("foo"))
63+
.build();
64+
List<ExtractedField> extractedFieldList = Collections.singletonList(new DocValueField("foo", Collections.emptySet()));
65+
TrainedModelConfig.Builder configBuilder = buildTrainedModelConfigBuilder(modelId);
66+
String compressedDefinition = configBuilder.build().getCompressedDefinition();
67+
int totalSize = compressedDefinition.length();
68+
List<String> chunks = chunkStringWithSize(compressedDefinition, totalSize/3);
69+
70+
ChunkedTrainedModelPersister persister = new ChunkedTrainedModelPersister(trainedModelProvider,
71+
analyticsConfig,
72+
new DataFrameAnalyticsAuditor(client(), "test-node"),
73+
(ex) -> { throw new ElasticsearchException(ex); },
74+
new ExtractedFields(extractedFieldList, Collections.emptyMap())
75+
);
76+
77+
//Accuracy for size is not tested here
78+
ModelSizeInfo modelSizeInfo = ModelSizeInfoTests.createRandom();
79+
persister.createAndIndexInferenceModelMetadata(modelSizeInfo);
80+
for (int i = 0; i < chunks.size(); i++) {
81+
persister.createAndIndexInferenceModelDoc(new TrainedModelDefinitionChunk(chunks.get(i), i, i == (chunks.size() - 1)));
82+
}
83+
84+
PlainActionFuture<Tuple<Long, Set<String>>> getIdsFuture = new PlainActionFuture<>();
85+
trainedModelProvider.expandIds(modelId + "*", false, PageParams.defaultParams(), Collections.emptySet(), getIdsFuture);
86+
Tuple<Long, Set<String>> ids = getIdsFuture.actionGet();
87+
assertThat(ids.v1(), equalTo(1L));
88+
89+
PlainActionFuture<TrainedModelConfig> getTrainedModelFuture = new PlainActionFuture<>();
90+
trainedModelProvider.getTrainedModel(ids.v2().iterator().next(), true, getTrainedModelFuture);
91+
92+
TrainedModelConfig storedConfig = getTrainedModelFuture.actionGet();
93+
assertThat(storedConfig.getCompressedDefinition(), equalTo(compressedDefinition));
94+
assertThat(storedConfig.getEstimatedOperations(), equalTo((long)modelSizeInfo.numOperations()));
95+
assertThat(storedConfig.getEstimatedHeapMemory(), equalTo(modelSizeInfo.ramBytesUsed()));
96+
}
97+
98+
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
99+
TrainedModelDefinition.Builder definitionBuilder = TrainedModelDefinitionTests.createRandomBuilder();
100+
long bytesUsed = definitionBuilder.build().ramBytesUsed();
101+
long operations = definitionBuilder.build().getTrainedModel().estimatedNumOperations();
102+
return TrainedModelConfig.builder()
103+
.setCreatedBy("ml_test")
104+
.setParsedDefinition(TrainedModelDefinitionTests.createRandomBuilder(TargetType.REGRESSION))
105+
.setDescription("trained model config for test")
106+
.setModelId(modelId)
107+
.setVersion(Version.CURRENT)
108+
.setLicenseLevel(License.OperationMode.PLATINUM.description())
109+
.setEstimatedHeapMemory(bytesUsed)
110+
.setEstimatedOperations(operations)
111+
.setInput(TrainedModelInputTests.createRandomInput());
112+
}
113+
114+
public static List<String> chunkStringWithSize(String str, int chunkSize) {
115+
List<String> subStrings = new ArrayList<>((str.length() + chunkSize - 1) / chunkSize);
116+
for (int i = 0; i < str.length(); i += chunkSize) {
117+
subStrings.add(str.substring(i, Math.min(i + chunkSize, str.length())));
118+
}
119+
return subStrings;
120+
}
121+
122+
@Override
123+
public NamedXContentRegistry xContentRegistry() {
124+
List<NamedXContentRegistry.Entry> namedXContent = new ArrayList<>();
125+
namedXContent.addAll(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
126+
namedXContent.addAll(new MlModelSizeNamedXContentProvider().getNamedXContentParsers());
127+
return new NamedXContentRegistry(namedXContent);
128+
}
129+
130+
}

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
import java.util.Collections;
3333
import java.util.List;
3434
import java.util.concurrent.atomic.AtomicReference;
35+
import java.util.stream.Collectors;
36+
import java.util.stream.IntStream;
3537

3638
import static org.elasticsearch.xpack.core.ml.utils.ToXContentParams.FOR_INTERNAL_STORAGE;
39+
import static org.elasticsearch.xpack.ml.integration.ChunkedTrainedModelPersisterIT.chunkStringWithSize;
3740
import static org.hamcrest.CoreMatchers.is;
3841
import static org.hamcrest.Matchers.equalTo;
3942
import static org.hamcrest.Matchers.not;
@@ -157,8 +160,8 @@ public void testGetMissingTrainingModelConfigDefinition() throws Exception {
157160
equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_NOT_FOUND, modelId)));
158161
}
159162

160-
public void testGetTruncatedModelDefinition() throws Exception {
161-
String modelId = "test-get-truncated-model-config";
163+
public void testGetTruncatedModelDeprecatedDefinition() throws Exception {
164+
String modelId = "test-get-truncated-legacy-model-config";
162165
TrainedModelConfig config = buildTrainedModelConfig(modelId);
163166
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
164167
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
@@ -196,6 +199,51 @@ public void testGetTruncatedModelDefinition() throws Exception {
196199
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
197200
}
198201

202+
public void testGetTruncatedModelDefinition() throws Exception {
203+
String modelId = "test-get-truncated-model-config";
204+
TrainedModelConfig config = buildTrainedModelConfig(modelId);
205+
AtomicReference<Boolean> putConfigHolder = new AtomicReference<>();
206+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
207+
208+
blockingCall(listener -> trainedModelProvider.storeTrainedModel(config, listener), putConfigHolder, exceptionHolder);
209+
assertThat(putConfigHolder.get(), is(true));
210+
assertThat(exceptionHolder.get(), is(nullValue()));
211+
212+
List<String> chunks = chunkStringWithSize(config.getCompressedDefinition(), config.getCompressedDefinition().length()/3);
213+
214+
List<TrainedModelDefinitionDoc.Builder> docBuilders = IntStream.range(0, chunks.size())
215+
.mapToObj(i -> new TrainedModelDefinitionDoc.Builder()
216+
.setDocNum(i)
217+
.setCompressedString(chunks.get(i))
218+
.setCompressionVersion(TrainedModelConfig.CURRENT_DEFINITION_COMPRESSION_VERSION)
219+
.setDefinitionLength(chunks.get(i).length())
220+
.setEos(i == chunks.size() - 1)
221+
.setModelId(modelId))
222+
.collect(Collectors.toList());
223+
boolean missingEos = randomBoolean();
224+
docBuilders.get(docBuilders.size() - 1).setEos(missingEos == false);
225+
for (int i = missingEos ? 0 : 1 ; i < docBuilders.size(); ++i) {
226+
TrainedModelDefinitionDoc doc = docBuilders.get(i).build();
227+
try(XContentBuilder xContentBuilder = doc.toXContent(XContentFactory.jsonBuilder(),
228+
new ToXContent.MapParams(Collections.singletonMap(FOR_INTERNAL_STORAGE, "true")))) {
229+
AtomicReference<IndexResponse> putDocHolder = new AtomicReference<>();
230+
blockingCall(listener -> client().prepareIndex(InferenceIndexConstants.LATEST_INDEX_NAME, MapperService.SINGLE_MAPPING_NAME)
231+
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
232+
.setSource(xContentBuilder)
233+
.setId(TrainedModelDefinitionDoc.docId(modelId, 0))
234+
.execute(listener),
235+
putDocHolder,
236+
exceptionHolder);
237+
assertThat(exceptionHolder.get(), is(nullValue()));
238+
}
239+
}
240+
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
241+
blockingCall(listener -> trainedModelProvider.getTrainedModel(modelId, true, listener), getConfigHolder, exceptionHolder);
242+
assertThat(getConfigHolder.get(), is(nullValue()));
243+
assertThat(exceptionHolder.get(), is(not(nullValue())));
244+
assertThat(exceptionHolder.get().getMessage(), equalTo(Messages.getMessage(Messages.MODEL_DEFINITION_TRUNCATED, modelId)));
245+
}
246+
199247
private static TrainedModelConfig.Builder buildTrainedModelConfigBuilder(String modelId) {
200248
return TrainedModelConfig.builder()
201249
.setCreatedBy("ml_test")

0 commit comments

Comments
 (0)