Skip to content

Commit 7ed67e6

Browse files
authored
[ML] fixing bug when returning multi-doc compressed model definitions (#80377)
When deserializing multi-doc compressed model definitions, I periodically receive weird errors regarding bytes references. Turns out, when we parse the individual bytes references we parse them into `ByteArrays` which satisfy the `bytes()` method. But, `CompositeBytesReference`, when there are multiple BytesReferences, does not allow `bytes()` to be called.
1 parent b230ede commit 7ed67e6

File tree

5 files changed

+238
-11
lines changed

5 files changed

+238
-11
lines changed

x-pack/plugin/ml/src/internalClusterTest/java/org/elasticsearch/xpack/ml/integration/TrainedModelProviderIT.java

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.action.index.IndexRequestBuilder;
1515
import org.elasticsearch.action.index.IndexResponse;
1616
import org.elasticsearch.action.support.WriteRequest;
17+
import org.elasticsearch.common.bytes.BytesArray;
1718
import org.elasticsearch.common.bytes.BytesReference;
1819
import org.elasticsearch.license.License;
1920
import org.elasticsearch.xcontent.ToXContent;
@@ -145,6 +146,82 @@ public void testGetTrainedModelConfig() throws Exception {
145146
assertThat(getConfigHolder.get().getMetadata(), hasKey("hyperparameters"));
146147
}
147148

149+
public void testGetTrainedModelConfigWithMultiDocDefinition() throws Exception {
150+
String modelId = "test-get-trained-model-config";
151+
TrainedModelConfig config = buildTrainedModelConfig(modelId);
152+
153+
AtomicReference<Void> dummy = new AtomicReference<>();
154+
AtomicReference<Boolean> booleanDummy = new AtomicReference<>();
155+
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();
156+
157+
BytesReference definition = config.getCompressedDefinition();
158+
159+
blockingCall(
160+
listener -> trainedModelProvider.storeTrainedModelDefinitionDoc(
161+
new TrainedModelDefinitionDoc(
162+
new BytesArray(definition.array(), 0, definition.length() - 5),
163+
modelId,
164+
0,
165+
(long) definition.length(),
166+
definition.length() - 5,
167+
1,
168+
false
169+
),
170+
listener
171+
),
172+
dummy::set,
173+
e -> fail(e.getMessage())
174+
);
175+
blockingCall(
176+
listener -> trainedModelProvider.storeTrainedModelDefinitionDoc(
177+
new TrainedModelDefinitionDoc(
178+
new BytesArray(definition.array(), definition.length() - 5, 5),
179+
modelId,
180+
1,
181+
(long) definition.length(),
182+
5,
183+
1,
184+
true
185+
),
186+
listener
187+
),
188+
dummy::set,
189+
e -> fail(e.getMessage())
190+
);
191+
blockingCall(
192+
listener -> trainedModelProvider.storeTrainedModelConfig(
193+
new TrainedModelConfig.Builder(config).clearDefinition().build(),
194+
listener
195+
),
196+
booleanDummy::set,
197+
e -> fail(e.getMessage())
198+
);
199+
blockingCall(
200+
listener -> trainedModelProvider.refreshInferenceIndex(listener),
201+
new AtomicReference<RefreshResponse>(),
202+
new AtomicReference<>()
203+
);
204+
205+
AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
206+
blockingCall(
207+
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
208+
getConfigHolder,
209+
exceptionHolder
210+
);
211+
if (exceptionHolder.get() != null) {
212+
throw exceptionHolder.get();
213+
}
214+
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
215+
assertThat(getConfigHolder.get(), is(not(nullValue())));
216+
assertThat(getConfigHolder.get(), equalTo(config));
217+
assertThat(getConfigHolder.get().getModelDefinition(), is(not(nullValue())));
218+
219+
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
220+
// Should not throw
221+
getConfigHolder.get().toXContent(builder, ToXContent.EMPTY_PARAMS);
222+
}
223+
}
224+
148225
public void testGetTrainedModelConfigWithoutDefinition() throws Exception {
149226
String modelId = "test-get-trained-model-config-no-definition";
150227
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).build();

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelDefinitionDoc.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public static String docId(String modelId, int docNum) {
8686
private final int compressionVersion;
8787
private final boolean eos;
8888

89-
private TrainedModelDefinitionDoc(
89+
public TrainedModelDefinitionDoc(
9090
BytesReference binaryData,
9191
String modelId,
9292
int docNum,

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProvider.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.elasticsearch.common.CheckedBiFunction;
3636
import org.elasticsearch.common.Numbers;
3737
import org.elasticsearch.common.Strings;
38+
import org.elasticsearch.common.bytes.BytesArray;
3839
import org.elasticsearch.common.bytes.BytesReference;
3940
import org.elasticsearch.common.bytes.CompositeBytesReference;
4041
import org.elasticsearch.common.regex.Regex;
@@ -1219,14 +1220,16 @@ private static <T> List<T> handleHits(
12191220
return results;
12201221
}
12211222

1222-
private static BytesReference getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId)
1223-
throws ElasticsearchException {
1223+
static BytesReference getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId) throws ElasticsearchException {
12241224

1225-
BytesReference[] bb = new BytesReference[docs.size()];
1226-
for (int i = 0; i < docs.size(); i++) {
1227-
bb[i] = docs.get(i).getBinaryData();
1228-
}
1229-
BytesReference bytes = CompositeBytesReference.of(bb);
1225+
// If the user requested the compressed data string, we need access to the underlying bytes.
1226+
// BytesArray gives us that access.
1227+
BytesReference bytes = docs.size() == 1
1228+
? docs.get(0).getBinaryData()
1229+
: new BytesArray(
1230+
CompositeBytesReference.of(docs.stream().map(TrainedModelDefinitionDoc::getBinaryData).toArray(BytesReference[]::new))
1231+
.toBytesRef()
1232+
);
12301233

12311234
if (docs.get(0).getTotalDefinitionLength() != null) {
12321235
if (bytes.length() != docs.get(0).getTotalDefinitionLength()) {
@@ -1264,7 +1267,6 @@ private TrainedModelConfig.Builder parseModelConfigLenientlyFromSource(BytesRefe
12641267
// lang ident model were the only models supported. Models created after
12651268
// VERSION_3RD_PARTY_CONFIG_ADDED must have modelType set, if not set modelType
12661269
// is a tree ensemble
1267-
assert builder.getVersion().before(TrainedModelConfig.VERSION_3RD_PARTY_CONFIG_ADDED);
12681270
builder.setModelType(TrainedModelType.TREE_ENSEMBLE);
12691271
}
12701272
return builder;

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlSingleNodeTestCase.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,18 @@ protected void waitForMlTemplates() throws Exception {
119119
}
120120

121121
protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)
122+
throws InterruptedException {
123+
blockingCall(function, response::set, error::set);
124+
}
125+
126+
protected <T> void blockingCall(Consumer<ActionListener<T>> function, Consumer<T> response, Consumer<Exception> error)
122127
throws InterruptedException {
123128
CountDownLatch latch = new CountDownLatch(1);
124129
ActionListener<T> listener = ActionListener.wrap(r -> {
125-
response.set(r);
130+
response.accept(r);
126131
latch.countDown();
127132
}, e -> {
128-
error.set(e);
133+
error.accept(e);
129134
latch.countDown();
130135
});
131136

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/persistence/TrainedModelProviderTests.java

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,17 @@
2424
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
2525
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
2626

27+
import java.nio.ByteBuffer;
28+
import java.nio.charset.StandardCharsets;
2729
import java.util.Arrays;
30+
import java.util.Base64;
2831
import java.util.Collections;
2932
import java.util.HashSet;
3033
import java.util.List;
3134
import java.util.TreeSet;
3235

36+
import static org.hamcrest.Matchers.containsString;
37+
import static org.hamcrest.Matchers.emptyString;
3338
import static org.hamcrest.Matchers.equalTo;
3439
import static org.hamcrest.Matchers.hasSize;
3540
import static org.hamcrest.Matchers.instanceOf;
@@ -183,6 +188,144 @@ public void testChunkDefinitionWithSize() {
183188
}
184189
}
185190

191+
public void testGetDefinitionFromDocsTruncated() {
192+
String modelId = randomAlphaOfLength(10);
193+
Exception ex = expectThrows(
194+
Exception.class,
195+
() -> TrainedModelProvider.getDefinitionFromDocs(
196+
List.of(
197+
new TrainedModelDefinitionDoc(
198+
new BytesArray(randomByteArrayOfLength(10)),
199+
modelId,
200+
0,
201+
randomLongBetween(10, 100),
202+
10,
203+
1,
204+
randomBoolean()
205+
)
206+
),
207+
modelId
208+
)
209+
);
210+
assertThat(
211+
ex.getMessage(),
212+
containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
213+
);
214+
215+
ex = expectThrows(
216+
Exception.class,
217+
() -> TrainedModelProvider.getDefinitionFromDocs(
218+
List.of(
219+
new TrainedModelDefinitionDoc(
220+
new BytesArray(randomByteArrayOfLength(10)),
221+
modelId,
222+
0,
223+
randomLongBetween(21, 100),
224+
10,
225+
1,
226+
randomBoolean()
227+
),
228+
new TrainedModelDefinitionDoc(
229+
new BytesArray(randomByteArrayOfLength(10)),
230+
modelId,
231+
0,
232+
randomLongBetween(21, 100),
233+
10,
234+
1,
235+
randomBoolean()
236+
)
237+
),
238+
modelId
239+
)
240+
);
241+
assertThat(
242+
ex.getMessage(),
243+
containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
244+
);
245+
246+
ex = expectThrows(
247+
Exception.class,
248+
() -> TrainedModelProvider.getDefinitionFromDocs(
249+
List.of(
250+
new TrainedModelDefinitionDoc(
251+
new BytesArray(randomByteArrayOfLength(10)),
252+
modelId,
253+
0,
254+
randomFrom((Long) null, 20L),
255+
10,
256+
1,
257+
randomBoolean()
258+
),
259+
new TrainedModelDefinitionDoc(
260+
new BytesArray(randomByteArrayOfLength(10)),
261+
modelId,
262+
1,
263+
randomFrom((Long) null, 20L),
264+
10,
265+
1,
266+
false
267+
)
268+
),
269+
modelId
270+
)
271+
);
272+
assertThat(
273+
ex.getMessage(),
274+
containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
275+
);
276+
}
277+
278+
public void testGetDefinitionFromDocs() {
279+
String modelId = randomAlphaOfLength(10);
280+
281+
int byteArrayLength = randomIntBetween(1, 1000);
282+
BytesReference bytesReference = TrainedModelProvider.getDefinitionFromDocs(
283+
List.of(
284+
new TrainedModelDefinitionDoc(
285+
new BytesArray(randomByteArrayOfLength(byteArrayLength)),
286+
modelId,
287+
0,
288+
randomFrom((Long) null, (long) byteArrayLength),
289+
byteArrayLength,
290+
1,
291+
true
292+
)
293+
),
294+
modelId
295+
);
296+
// None of the following should throw
297+
ByteBuffer bb = Base64.getEncoder()
298+
.encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
299+
assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
300+
301+
bytesReference = TrainedModelProvider.getDefinitionFromDocs(
302+
List.of(
303+
new TrainedModelDefinitionDoc(
304+
new BytesArray(randomByteArrayOfLength(byteArrayLength)),
305+
modelId,
306+
0,
307+
randomFrom((Long) null, (long) byteArrayLength * 2),
308+
byteArrayLength,
309+
1,
310+
false
311+
),
312+
new TrainedModelDefinitionDoc(
313+
new BytesArray(randomByteArrayOfLength(byteArrayLength)),
314+
modelId,
315+
1,
316+
randomFrom((Long) null, (long) byteArrayLength * 2),
317+
byteArrayLength,
318+
1,
319+
true
320+
)
321+
),
322+
modelId
323+
);
324+
325+
bb = Base64.getEncoder().encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
326+
assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
327+
}
328+
186329
@Override
187330
protected NamedXContentRegistry xContentRegistry() {
188331
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());

0 commit comments

Comments
 (0)