Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.license.License;
import org.elasticsearch.xcontent.ToXContent;
Expand Down Expand Up @@ -145,6 +146,82 @@ public void testGetTrainedModelConfig() throws Exception {
assertThat(getConfigHolder.get().getMetadata(), hasKey("hyperparameters"));
}

public void testGetTrainedModelConfigWithMultiDocDefinition() throws Exception {
String modelId = "test-get-trained-model-config";
TrainedModelConfig config = buildTrainedModelConfig(modelId);

AtomicReference<Void> dummy = new AtomicReference<>();
AtomicReference<Boolean> booleanDummy = new AtomicReference<>();
AtomicReference<Exception> exceptionHolder = new AtomicReference<>();

BytesReference definition = config.getCompressedDefinition();

blockingCall(
listener -> trainedModelProvider.storeTrainedModelDefinitionDoc(
new TrainedModelDefinitionDoc(
new BytesArray(definition.array(), 0, definition.length() - 5),
modelId,
0,
(long) definition.length(),
definition.length() - 5,
1,
false
),
listener
),
dummy::set,
e -> fail(e.getMessage())
);
blockingCall(
listener -> trainedModelProvider.storeTrainedModelDefinitionDoc(
new TrainedModelDefinitionDoc(
new BytesArray(definition.array(), definition.length() - 5, 5),
modelId,
1,
(long) definition.length(),
5,
1,
true
),
listener
),
dummy::set,
e -> fail(e.getMessage())
);
blockingCall(
listener -> trainedModelProvider.storeTrainedModelConfig(
new TrainedModelConfig.Builder(config).clearDefinition().build(),
listener
),
booleanDummy::set,
e -> fail(e.getMessage())
);
blockingCall(
listener -> trainedModelProvider.refreshInferenceIndex(listener),
new AtomicReference<RefreshResponse>(),
new AtomicReference<>()
);

AtomicReference<TrainedModelConfig> getConfigHolder = new AtomicReference<>();
blockingCall(
listener -> trainedModelProvider.getTrainedModel(modelId, GetTrainedModelsAction.Includes.forModelDefinition(), listener),
getConfigHolder,
exceptionHolder
);
if (exceptionHolder.get() != null) {
throw exceptionHolder.get();
}
getConfigHolder.get().ensureParsedDefinition(xContentRegistry());
assertThat(getConfigHolder.get(), is(not(nullValue())));
assertThat(getConfigHolder.get(), equalTo(config));
assertThat(getConfigHolder.get().getModelDefinition(), is(not(nullValue())));

try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
// Should not throw
getConfigHolder.get().toXContent(builder, ToXContent.EMPTY_PARAMS);
}
}

public void testGetTrainedModelConfigWithoutDefinition() throws Exception {
String modelId = "test-get-trained-model-config-no-definition";
TrainedModelConfig config = buildTrainedModelConfigBuilder(modelId).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public static String docId(String modelId, int docNum) {
private final int compressionVersion;
private final boolean eos;

private TrainedModelDefinitionDoc(
public TrainedModelDefinitionDoc(
BytesReference binaryData,
String modelId,
int docNum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.common.CheckedBiFunction;
import org.elasticsearch.common.Numbers;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.regex.Regex;
Expand Down Expand Up @@ -1219,14 +1220,16 @@ private static <T> List<T> handleHits(
return results;
}

private static BytesReference getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId)
throws ElasticsearchException {
static BytesReference getDefinitionFromDocs(List<TrainedModelDefinitionDoc> docs, String modelId) throws ElasticsearchException {

BytesReference[] bb = new BytesReference[docs.size()];
for (int i = 0; i < docs.size(); i++) {
bb[i] = docs.get(i).getBinaryData();
}
BytesReference bytes = CompositeBytesReference.of(bb);
// If the user requested the compressed data string, we need access to the underlying bytes.
// BytesArray gives us that access.
BytesReference bytes = docs.size() == 1
? docs.get(0).getBinaryData()
: new BytesArray(
CompositeBytesReference.of(docs.stream().map(TrainedModelDefinitionDoc::getBinaryData).toArray(BytesReference[]::new))
.toBytesRef()
);

if (docs.get(0).getTotalDefinitionLength() != null) {
if (bytes.length() != docs.get(0).getTotalDefinitionLength()) {
Expand Down Expand Up @@ -1264,7 +1267,6 @@ private TrainedModelConfig.Builder parseModelConfigLenientlyFromSource(BytesRefe
// lang ident model were the only models supported. Models created after
// VERSION_3RD_PARTY_CONFIG_ADDED must have modelType set, if not set modelType
// is a tree ensemble
assert builder.getVersion().before(TrainedModelConfig.VERSION_3RD_PARTY_CONFIG_ADDED);
builder.setModelType(TrainedModelType.TREE_ENSEMBLE);
}
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,18 @@ protected void waitForMlTemplates() throws Exception {
}

protected <T> void blockingCall(Consumer<ActionListener<T>> function, AtomicReference<T> response, AtomicReference<Exception> error)
throws InterruptedException {
blockingCall(function, response::set, error::set);
}

protected <T> void blockingCall(Consumer<ActionListener<T>> function, Consumer<T> response, Consumer<Exception> error)
throws InterruptedException {
CountDownLatch latch = new CountDownLatch(1);
ActionListener<T> listener = ActionListener.wrap(r -> {
response.set(r);
response.accept(r);
latch.countDown();
}, e -> {
error.set(e);
error.accept(e);
latch.countDown();
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,17 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfigTests;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.TreeSet;

import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.emptyString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
Expand Down Expand Up @@ -183,6 +188,144 @@ public void testChunkDefinitionWithSize() {
}
}

public void testGetDefinitionFromDocsTruncated() {
String modelId = randomAlphaOfLength(10);
Exception ex = expectThrows(
Exception.class,
() -> TrainedModelProvider.getDefinitionFromDocs(
List.of(
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(10)),
modelId,
0,
randomLongBetween(10, 100),
10,
1,
randomBoolean()
)
),
modelId
)
);
assertThat(
ex.getMessage(),
containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
);

ex = expectThrows(
Exception.class,
() -> TrainedModelProvider.getDefinitionFromDocs(
List.of(
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(10)),
modelId,
0,
randomLongBetween(21, 100),
10,
1,
randomBoolean()
),
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(10)),
modelId,
0,
randomLongBetween(21, 100),
10,
1,
randomBoolean()
)
),
modelId
)
);
assertThat(
ex.getMessage(),
containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
);

ex = expectThrows(
Exception.class,
() -> TrainedModelProvider.getDefinitionFromDocs(
List.of(
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(10)),
modelId,
0,
randomFrom((Long) null, 20L),
10,
1,
randomBoolean()
),
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(10)),
modelId,
1,
randomFrom((Long) null, 20L),
10,
1,
false
)
),
modelId
)
);
assertThat(
ex.getMessage(),
containsString("Model definition truncated. Unable to deserialize trained model definition [" + modelId + "]")
);
}

public void testGetDefinitionFromDocs() {
String modelId = randomAlphaOfLength(10);

int byteArrayLength = randomIntBetween(1, 1000);
BytesReference bytesReference = TrainedModelProvider.getDefinitionFromDocs(
List.of(
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(byteArrayLength)),
modelId,
0,
randomFrom((Long) null, (long) byteArrayLength),
byteArrayLength,
1,
true
)
),
modelId
);
// None of the following should throw
ByteBuffer bb = Base64.getEncoder()
.encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));

bytesReference = TrainedModelProvider.getDefinitionFromDocs(
List.of(
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(byteArrayLength)),
modelId,
0,
randomFrom((Long) null, (long) byteArrayLength * 2),
byteArrayLength,
1,
false
),
new TrainedModelDefinitionDoc(
new BytesArray(randomByteArrayOfLength(byteArrayLength)),
modelId,
1,
randomFrom((Long) null, (long) byteArrayLength * 2),
byteArrayLength,
1,
true
)
),
modelId
);

bb = Base64.getEncoder().encode(ByteBuffer.wrap(bytesReference.array(), bytesReference.arrayOffset(), bytesReference.length()));
assertThat(new String(bb.array(), StandardCharsets.UTF_8), is(not(emptyString())));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
Expand Down