Skip to content

Commit 2dec141

Browse files
authored
[ML] fail inference processor more consistently on certain error types (#81475)
This updates the following scenarios and causes NER/native inference to fail and not write a warning: - missing vocabulary values - missing model/deployment - native process failed - native process stopping - request timed out - misconfigured inference task update type
1 parent 3acf0e8 commit 2dec141

File tree

6 files changed

+129
-39
lines changed

6 files changed

+129
-39
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
import static org.elasticsearch.xpack.ml.integration.InferenceIngestIT.putPipeline;
4242
import static org.elasticsearch.xpack.ml.integration.InferenceIngestIT.simulateRequest;
43+
import static org.hamcrest.Matchers.allOf;
4344
import static org.hamcrest.Matchers.containsString;
4445
import static org.hamcrest.Matchers.equalTo;
4546
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -465,7 +466,11 @@ public void testInferencePipelineAgainstUnallocatedModel() throws IOException {
465466
String response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
466467
assertThat(
467468
response,
468-
containsString("model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API.")
469+
allOf(
470+
containsString("model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API."),
471+
containsString("error"),
472+
not(containsString("warning"))
473+
)
469474
);
470475

471476
client().performRequest(
@@ -520,9 +525,8 @@ public void testTruncation() throws IOException {
520525
startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
521526

522527
String input = "once twice thrice";
523-
ResponseException ex = expectThrows(ResponseException.class, () -> infer("once twice thrice", modelId));
524528
assertThat(
525-
ex.getMessage(),
529+
EntityUtils.toString(infer("once twice thrice", modelId).getEntity()),
526530
containsString("Input too large. The tokenized input length [3] exceeds the maximum sequence length [2]")
527531
);
528532

@@ -578,6 +582,81 @@ public void testStopUsedDeploymentByIngestProcessor() throws IOException {
578582
stopDeployment(modelId, true);
579583
}
580584

585+
public void testPipelineWithBadProcessor() throws IOException {
586+
String model = "deployed";
587+
createTrainedModel(model);
588+
putVocabulary(List.of("once", "twice"), model);
589+
putModelDefinition(model);
590+
startDeployment(model);
591+
String source = """
592+
{
593+
"pipeline": {
594+
"processors": [
595+
{
596+
"inference": {
597+
"model_id": "deployed",
598+
"inference_config": {
599+
"ner": {}
600+
}
601+
}
602+
}
603+
]
604+
},
605+
"docs": [
606+
{"_source": {"input": "my words"}}]
607+
}
608+
""";
609+
610+
String response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
611+
assertThat(
612+
response,
613+
allOf(
614+
containsString("inference not possible. Task is configured with [pass_through] but received update of type [ner]"),
615+
containsString("error"),
616+
not(containsString("warning"))
617+
)
618+
);
619+
620+
source = """
621+
{
622+
"pipeline": {
623+
"processors": [
624+
{
625+
"inference": {
626+
"model_id": "deployed"
627+
}
628+
}
629+
]
630+
},
631+
"docs": [
632+
{"_source": {"input": "my words"}}]
633+
}
634+
""";
635+
636+
response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
637+
assertThat(response, allOf(containsString("error"), not(containsString("warning"))));
638+
639+
// Missing input field is a warning
640+
source = """
641+
{
642+
"pipeline": {
643+
"processors": [
644+
{
645+
"inference": {
646+
"model_id": "deployed"
647+
}
648+
}
649+
]
650+
},
651+
"docs": [
652+
{"_source": {"something": "my words"}}]
653+
}
654+
""";
655+
656+
response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
657+
assertThat(response, containsString("warning"));
658+
}
659+
581660
public void testDeleteModelWithDeploymentUsedByIngestProcessor() throws IOException {
582661
String modelId = "test_delete_model_with_used_deployment";
583662
createTrainedModel(modelId);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
*/
77
package org.elasticsearch.xpack.ml.action;
88

9-
import org.elasticsearch.ElasticsearchStatusException;
10-
import org.elasticsearch.ExceptionsHelper;
119
import org.elasticsearch.action.ActionListener;
1210
import org.elasticsearch.action.support.ActionFilters;
1311
import org.elasticsearch.action.support.HandledTransportAction;
@@ -18,7 +16,6 @@
1816
import org.elasticsearch.license.License;
1917
import org.elasticsearch.license.LicenseUtils;
2018
import org.elasticsearch.license.XPackLicenseState;
21-
import org.elasticsearch.rest.RestStatus;
2219
import org.elasticsearch.tasks.Task;
2320
import org.elasticsearch.tasks.TaskId;
2421
import org.elasticsearch.threadpool.ThreadPool;
@@ -31,7 +28,6 @@
3128
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
3229
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
3330
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
34-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3531
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
3632
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
3733
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
@@ -195,19 +191,7 @@ private void inferSingleDocAgainstAllocatedModel(
195191
ML_ORIGIN,
196192
InferTrainedModelDeploymentAction.INSTANCE,
197193
request,
198-
ActionListener.wrap(r -> listener.onResponse(r.getResults()), e -> {
199-
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
200-
if (unwrapped instanceof ElasticsearchStatusException) {
201-
ElasticsearchStatusException ex = (ElasticsearchStatusException) unwrapped;
202-
if (ex.status().equals(RestStatus.TOO_MANY_REQUESTS)) {
203-
listener.onFailure(ex);
204-
} else {
205-
listener.onResponse(new WarningInferenceResults(ex.getMessage()));
206-
}
207-
} else {
208-
listener.onResponse(new WarningInferenceResults(e.getMessage()));
209-
}
210-
})
194+
ActionListener.wrap(r -> listener.onResponse(r.getResults()), listener::onFailure)
211195
);
212196
}
213197
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.logging.log4j.message.ParameterizedMessage;
1313
import org.apache.lucene.util.SetOnce;
14+
import org.elasticsearch.ElasticsearchException;
1415
import org.elasticsearch.ElasticsearchStatusException;
1516
import org.elasticsearch.ResourceNotFoundException;
1617
import org.elasticsearch.action.ActionListener;
@@ -33,6 +34,7 @@
3334
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3435
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
3536
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
37+
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3638
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
3739
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
3840
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
@@ -227,16 +229,18 @@ public void infer(
227229
) {
228230
if (task.isStopped()) {
229231
listener.onFailure(
230-
new IllegalStateException(
231-
"[" + task.getModelId() + "] is stopping or stopped due to [" + task.stoppedReason().orElse("") + "]"
232+
ExceptionsHelper.conflictStatusException(
233+
"[{}] is stopping or stopped due to [{}]",
234+
task.getModelId(),
235+
task.stoppedReason().orElse("")
232236
)
233237
);
234238
return;
235239
}
236240

237241
ProcessContext processContext = processContextByAllocation.get(task.getId());
238242
if (processContext == null) {
239-
listener.onFailure(new IllegalStateException("[" + task.getModelId() + "] process context missing"));
243+
listener.onFailure(ExceptionsHelper.conflictStatusException("[{}] process context missing", task.getModelId()));
240244
return;
241245
}
242246

@@ -258,7 +262,7 @@ public void infer(
258262
}
259263
}
260264

261-
static class InferenceAction extends AbstractRunnable {
265+
static class InferenceAction extends AbstractRunnable implements ActionListener<InferenceResults> {
262266
private final String modelId;
263267
private final long requestId;
264268
private final TimeValue timeout;
@@ -304,6 +308,11 @@ void onTimeout() {
304308
logger.debug("[{}] request [{}] received timeout after [{}] but listener already alerted", modelId, requestId, timeout);
305309
}
306310

311+
@Override
312+
public void onResponse(InferenceResults inferenceResults) {
313+
onSuccess(inferenceResults);
314+
}
315+
307316
void onSuccess(InferenceResults inferenceResults) {
308317
timeoutHandler.cancel();
309318
if (notified.compareAndSet(false, true)) {
@@ -362,17 +371,32 @@ protected void doRun() throws Exception {
362371
processContext,
363372
request.tokenization,
364373
processor.getResultProcessor((NlpConfig) config),
365-
ActionListener.wrap(this::onSuccess, this::onFailure)
374+
ActionListener.wrap(this::onSuccess, f -> handleFailure(f, this))
366375
),
367376
this::onFailure
368377
)
369378
);
370379
processContext.process.get().writeInferenceRequest(request.processInput);
371380
} catch (IOException e) {
372381
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
373-
onFailure(ExceptionsHelper.serverError("error writing to process", e));
382+
handleFailure(ExceptionsHelper.serverError("error writing to process", e), this);
374383
} catch (Exception e) {
375-
onFailure(e);
384+
handleFailure(e, this);
385+
}
386+
}
387+
388+
private static void handleFailure(Exception e, ActionListener<InferenceResults> listener) {
389+
Throwable unwrapped = org.elasticsearch.ExceptionsHelper.unwrapCause(e);
390+
if (unwrapped instanceof ElasticsearchException ex) {
391+
if (ex.status() == RestStatus.BAD_REQUEST) {
392+
listener.onResponse(new WarningInferenceResults(ex.getMessage()));
393+
} else {
394+
listener.onFailure(ex);
395+
}
396+
} else if (unwrapped instanceof IllegalArgumentException) {
397+
listener.onResponse(new WarningInferenceResults(e.getMessage()));
398+
} else {
399+
listener.onFailure(e);
376400
}
377401
}
378402

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.SetOnce;
13+
import org.elasticsearch.ElasticsearchStatusException;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.core.TimeValue;
1516
import org.elasticsearch.license.LicensedFeature;
1617
import org.elasticsearch.license.XPackLicenseState;
18+
import org.elasticsearch.rest.RestStatus;
1719
import org.elasticsearch.tasks.CancellableTask;
1820
import org.elasticsearch.tasks.TaskId;
1921
import org.elasticsearch.xpack.core.ml.MlTasks;
@@ -110,14 +112,15 @@ protected void onCancelled() {
110112
public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {
111113
if (inferenceConfigHolder.get() == null) {
112114
listener.onFailure(
113-
ExceptionsHelper.badRequestException("[{}] inference not possible against uninitialized model", params.getModelId())
115+
ExceptionsHelper.conflictStatusException("[{}] inference not possible against uninitialized model", params.getModelId())
114116
);
115117
return;
116118
}
117119
if (update.isSupported(inferenceConfigHolder.get()) == false) {
118120
listener.onFailure(
119-
ExceptionsHelper.badRequestException(
121+
new ElasticsearchStatusException(
120122
"[{}] inference not possible. Task is configured with [{}] but received update of type [{}]",
123+
RestStatus.FORBIDDEN,
121124
params.getModelId(),
122125
inferenceConfigHolder.get().getName(),
123126
update.getName()

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1414
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
1515
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
16+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1617
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
1718
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
1819
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
1920

2021
import java.util.ArrayList;
21-
import java.util.Arrays;
2222
import java.util.List;
2323
import java.util.Optional;
2424

@@ -35,19 +35,19 @@ public class FillMaskProcessor implements NlpTask.Processor {
3535
@Override
3636
public void validateInputs(List<String> inputs) {
3737
if (inputs.isEmpty()) {
38-
throw new IllegalArgumentException("input request is empty");
38+
throw ExceptionsHelper.badRequestException("input request is empty");
3939
}
4040

4141
final String mask = tokenizer.getMaskToken();
4242
for (String input : inputs) {
4343
int maskIndex = input.indexOf(mask);
4444
if (maskIndex < 0) {
45-
throw new IllegalArgumentException("no " + mask + " token could be found");
45+
throw ExceptionsHelper.badRequestException("no {} token could be found", mask);
4646
}
4747

4848
maskIndex = input.indexOf(mask, maskIndex + mask.length());
4949
if (maskIndex > 0) {
50-
throw new IllegalArgumentException("only one " + mask + " token should exist in the input");
50+
throw ExceptionsHelper.badRequestException("only one {} token should exist in the input", mask);
5151
}
5252
}
5353
}
@@ -59,8 +59,7 @@ public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
5959

6060
@Override
6161
public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
62-
if (config instanceof FillMaskConfig) {
63-
FillMaskConfig fillMaskConfig = (FillMaskConfig) config;
62+
if (config instanceof FillMaskConfig fillMaskConfig) {
6463
return (tokenization, result) -> processResult(
6564
tokenization,
6665
result,
@@ -91,7 +90,7 @@ static InferenceResults processResult(
9190
}
9291

9392
if (tokenizer.getMaskTokenId().isEmpty()) {
94-
return new WarningInferenceResults(
93+
throw ExceptionsHelper.conflictStatusException(
9594
"The token id for the mask token {} is not known in the tokenizer. Check the vocabulary contains the mask token",
9695
tokenizer.getMaskToken()
9796
);
@@ -109,7 +108,7 @@ static InferenceResults processResult(
109108
return new WarningInferenceResults(
110109
"mask token id [{}] not found in the tokenization {}",
111110
maskTokenId,
112-
Arrays.asList(tokenization.getTokenizations().get(0).getTokenIds())
111+
List.of(tokenization.getTokenizations().get(0).getTokenIds())
113112
);
114113
}
115114

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.test.ESTestCase;
1112
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
1213
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
@@ -103,7 +104,7 @@ public void testValidate_GivenMissingMaskToken() {
103104
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
104105
FillMaskProcessor processor = new FillMaskProcessor(tokenizer, config);
105106

106-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input));
107+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> processor.validateInputs(input));
107108
assertThat(e.getMessage(), containsString("no [MASK] token could be found"));
108109
}
109110

@@ -116,7 +117,7 @@ public void testProcessResults_GivenMultipleMaskTokens() {
116117
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
117118
FillMaskProcessor processor = new FillMaskProcessor(tokenizer, config);
118119

119-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input));
120+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> processor.validateInputs(input));
120121
assertThat(e.getMessage(), containsString("only one [MASK] token should exist in the input"));
121122
}
122123
}

0 commit comments

Comments
 (0)