Skip to content

Commit a03b0c7

Browse files
authored
[8.0] [ML] fail inference processor more consistently on certain error types (#81475) (#81546)
* [ML] fail inference processor more consistently on certain error types (#81475) This updates the following scenarios and causes NER/native inference to fail and not write a warning: - missing vocabulary values - missing model/deployment - native process failed - native process stopping - request timed out - misconfigured inference task update type * fixing for backport * fixing backport * fixing backport
1 parent c355885 commit a03b0c7

File tree

6 files changed

+126
-34
lines changed

6 files changed

+126
-34
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
import static org.elasticsearch.xpack.ml.integration.InferenceIngestIT.putPipeline;
4242
import static org.elasticsearch.xpack.ml.integration.InferenceIngestIT.simulateRequest;
43+
import static org.hamcrest.Matchers.allOf;
4344
import static org.hamcrest.Matchers.containsString;
4445
import static org.hamcrest.Matchers.equalTo;
4546
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -465,7 +466,11 @@ public void testInferencePipelineAgainstUnallocatedModel() throws IOException {
465466
String response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
466467
assertThat(
467468
response,
468-
containsString("model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API.")
469+
allOf(
470+
containsString("model [not-deployed] must be deployed to use. Please deploy with the start trained model deployment API."),
471+
containsString("error"),
472+
not(containsString("warning"))
473+
)
469474
);
470475

471476
client().performRequest(
@@ -528,6 +533,81 @@ public void testStopUsedDeploymentByIngestProcessor() throws IOException {
528533
stopDeployment(modelId, true);
529534
}
530535

536+
public void testPipelineWithBadProcessor() throws IOException {
537+
String model = "deployed";
538+
createTrainedModel(model);
539+
putVocabulary(List.of("once", "twice"), model);
540+
putModelDefinition(model);
541+
startDeployment(model);
542+
String source = """
543+
{
544+
"pipeline": {
545+
"processors": [
546+
{
547+
"inference": {
548+
"model_id": "deployed",
549+
"inference_config": {
550+
"ner": {}
551+
}
552+
}
553+
}
554+
]
555+
},
556+
"docs": [
557+
{"_source": {"input": "my words"}}]
558+
}
559+
""";
560+
561+
String response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
562+
assertThat(
563+
response,
564+
allOf(
565+
containsString("inference not possible. Task is configured with [pass_through] but received update of type [ner]"),
566+
containsString("error"),
567+
not(containsString("warning"))
568+
)
569+
);
570+
571+
source = """
572+
{
573+
"pipeline": {
574+
"processors": [
575+
{
576+
"inference": {
577+
"model_id": "deployed"
578+
}
579+
}
580+
]
581+
},
582+
"docs": [
583+
{"_source": {"input": "my words"}}]
584+
}
585+
""";
586+
587+
response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
588+
assertThat(response, allOf(containsString("error"), not(containsString("warning"))));
589+
590+
// Missing input field is a warning
591+
source = """
592+
{
593+
"pipeline": {
594+
"processors": [
595+
{
596+
"inference": {
597+
"model_id": "deployed"
598+
}
599+
}
600+
]
601+
},
602+
"docs": [
603+
{"_source": {"something": "my words"}}]
604+
}
605+
""";
606+
607+
response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
608+
assertThat(response, containsString("warning"));
609+
}
610+
531611
private int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {
532612
int inferenceCount = 0;
533613
for (var node : nodes) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportInternalInferModelAction.java

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
*/
77
package org.elasticsearch.xpack.ml.action;
88

9-
import org.elasticsearch.ElasticsearchStatusException;
10-
import org.elasticsearch.ExceptionsHelper;
119
import org.elasticsearch.action.ActionListener;
1210
import org.elasticsearch.action.support.ActionFilters;
1311
import org.elasticsearch.action.support.HandledTransportAction;
@@ -17,7 +15,6 @@
1715
import org.elasticsearch.core.TimeValue;
1816
import org.elasticsearch.license.LicenseUtils;
1917
import org.elasticsearch.license.XPackLicenseState;
20-
import org.elasticsearch.rest.RestStatus;
2118
import org.elasticsearch.tasks.Task;
2219
import org.elasticsearch.tasks.TaskId;
2320
import org.elasticsearch.threadpool.ThreadPool;
@@ -30,7 +27,6 @@
3027
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Request;
3128
import org.elasticsearch.xpack.core.ml.action.InternalInferModelAction.Response;
3229
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
33-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3430
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
3531
import org.elasticsearch.xpack.ml.inference.allocation.TrainedModelAllocationMetadata;
3632
import org.elasticsearch.xpack.ml.inference.loadingservice.LocalModel;
@@ -191,19 +187,7 @@ private void inferSingleDocAgainstAllocatedModel(
191187
ML_ORIGIN,
192188
InferTrainedModelDeploymentAction.INSTANCE,
193189
request,
194-
ActionListener.wrap(r -> listener.onResponse(r.getResults()), e -> {
195-
Throwable unwrapped = ExceptionsHelper.unwrapCause(e);
196-
if (unwrapped instanceof ElasticsearchStatusException) {
197-
ElasticsearchStatusException ex = (ElasticsearchStatusException) unwrapped;
198-
if (ex.status().equals(RestStatus.TOO_MANY_REQUESTS)) {
199-
listener.onFailure(ex);
200-
} else {
201-
listener.onResponse(new WarningInferenceResults(ex.getMessage()));
202-
}
203-
} else {
204-
listener.onResponse(new WarningInferenceResults(e.getMessage()));
205-
}
206-
})
190+
ActionListener.wrap(r -> listener.onResponse(r.getResults()), listener::onFailure)
207191
);
208192
}
209193
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.logging.log4j.message.ParameterizedMessage;
1313
import org.apache.lucene.util.SetOnce;
14+
import org.elasticsearch.ElasticsearchException;
1415
import org.elasticsearch.ElasticsearchStatusException;
1516
import org.elasticsearch.ResourceNotFoundException;
1617
import org.elasticsearch.action.ActionListener;
@@ -33,6 +34,7 @@
3334
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3435
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
3536
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
37+
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3638
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
3739
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
3840
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
@@ -227,16 +229,18 @@ public void infer(
227229
) {
228230
if (task.isStopped()) {
229231
listener.onFailure(
230-
new IllegalStateException(
231-
"[" + task.getModelId() + "] is stopping or stopped due to [" + task.stoppedReason().orElse("") + "]"
232+
ExceptionsHelper.conflictStatusException(
233+
"[{}] is stopping or stopped due to [{}]",
234+
task.getModelId(),
235+
task.stoppedReason().orElse("")
232236
)
233237
);
234238
return;
235239
}
236240

237241
ProcessContext processContext = processContextByAllocation.get(task.getId());
238242
if (processContext == null) {
239-
listener.onFailure(new IllegalStateException("[" + task.getModelId() + "] process context missing"));
243+
listener.onFailure(ExceptionsHelper.conflictStatusException("[{}] process context missing", task.getModelId()));
240244
return;
241245
}
242246

@@ -258,7 +262,7 @@ public void infer(
258262
}
259263
}
260264

261-
static class InferenceAction extends AbstractRunnable {
265+
static class InferenceAction extends AbstractRunnable implements ActionListener<InferenceResults> {
262266
private final String modelId;
263267
private final long requestId;
264268
private final TimeValue timeout;
@@ -304,6 +308,11 @@ void onTimeout() {
304308
logger.debug("[{}] request [{}] received timeout after [{}] but listener already alerted", modelId, requestId, timeout);
305309
}
306310

311+
@Override
312+
public void onResponse(InferenceResults inferenceResults) {
313+
onSuccess(inferenceResults);
314+
}
315+
307316
void onSuccess(InferenceResults inferenceResults) {
308317
timeoutHandler.cancel();
309318
if (notified.compareAndSet(false, true)) {
@@ -360,17 +369,32 @@ protected void doRun() throws Exception {
360369
processContext,
361370
request.tokenization,
362371
processor.getResultProcessor((NlpConfig) config),
363-
ActionListener.wrap(this::onSuccess, this::onFailure)
372+
ActionListener.wrap(this::onSuccess, f -> handleFailure(f, this))
364373
),
365374
this::onFailure
366375
)
367376
);
368377
processContext.process.get().writeInferenceRequest(request.processInput);
369378
} catch (IOException e) {
370379
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
371-
onFailure(ExceptionsHelper.serverError("error writing to process", e));
380+
handleFailure(ExceptionsHelper.serverError("error writing to process", e), this);
372381
} catch (Exception e) {
373-
onFailure(e);
382+
handleFailure(e, this);
383+
}
384+
}
385+
386+
private static void handleFailure(Exception e, ActionListener<InferenceResults> listener) {
387+
Throwable unwrapped = org.elasticsearch.ExceptionsHelper.unwrapCause(e);
388+
if (unwrapped instanceof ElasticsearchException ex) {
389+
if (ex.status() == RestStatus.BAD_REQUEST) {
390+
listener.onResponse(new WarningInferenceResults(ex.getMessage()));
391+
} else {
392+
listener.onFailure(ex);
393+
}
394+
} else if (unwrapped instanceof IllegalArgumentException) {
395+
listener.onResponse(new WarningInferenceResults(e.getMessage()));
396+
} else {
397+
listener.onFailure(e);
374398
}
375399
}
376400

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/TrainedModelDeploymentTask.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010
import org.apache.logging.log4j.LogManager;
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.lucene.util.SetOnce;
13+
import org.elasticsearch.ElasticsearchStatusException;
1314
import org.elasticsearch.action.ActionListener;
1415
import org.elasticsearch.core.TimeValue;
1516
import org.elasticsearch.license.LicensedFeature;
1617
import org.elasticsearch.license.XPackLicenseState;
18+
import org.elasticsearch.rest.RestStatus;
1719
import org.elasticsearch.tasks.CancellableTask;
1820
import org.elasticsearch.tasks.TaskId;
1921
import org.elasticsearch.xpack.core.ml.MlTasks;
@@ -110,14 +112,15 @@ protected void onCancelled() {
110112
public void infer(Map<String, Object> doc, InferenceConfigUpdate update, TimeValue timeout, ActionListener<InferenceResults> listener) {
111113
if (inferenceConfigHolder.get() == null) {
112114
listener.onFailure(
113-
ExceptionsHelper.badRequestException("[{}] inference not possible against uninitialized model", params.getModelId())
115+
ExceptionsHelper.conflictStatusException("[{}] inference not possible against uninitialized model", params.getModelId())
114116
);
115117
return;
116118
}
117119
if (update.isSupported(inferenceConfigHolder.get()) == false) {
118120
listener.onFailure(
119-
ExceptionsHelper.badRequestException(
121+
new ElasticsearchStatusException(
120122
"[{}] inference not possible. Task is configured with [{}] but received update of type [{}]",
123+
RestStatus.FORBIDDEN,
121124
params.getModelId(),
122125
inferenceConfigHolder.get().getName(),
123126
update.getName()

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1414
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
1515
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
16+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1617
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
1718
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizer;
1819
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -36,18 +37,18 @@ public class FillMaskProcessor implements NlpTask.Processor {
3637
@Override
3738
public void validateInputs(List<String> inputs) {
3839
if (inputs.isEmpty()) {
39-
throw new IllegalArgumentException("input request is empty");
40+
throw ExceptionsHelper.badRequestException("input request is empty");
4041
}
4142

4243
for (String input : inputs) {
4344
int maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN);
4445
if (maskIndex < 0) {
45-
throw new IllegalArgumentException("no " + BertTokenizer.MASK_TOKEN + " token could be found");
46+
throw ExceptionsHelper.badRequestException("no {} token could be found", BertTokenizer.MASK_TOKEN);
4647
}
4748

4849
maskIndex = input.indexOf(BertTokenizer.MASK_TOKEN, maskIndex + BertTokenizer.MASK_TOKEN.length());
4950
if (maskIndex > 0) {
50-
throw new IllegalArgumentException("only one " + BertTokenizer.MASK_TOKEN + " token should exist in the input");
51+
throw ExceptionsHelper.badRequestException("only one {} token should exist in the input", BertTokenizer.MASK_TOKEN);
5152
}
5253
}
5354
}
@@ -59,8 +60,7 @@ public NlpTask.RequestBuilder getRequestBuilder(NlpConfig config) {
5960

6061
@Override
6162
public NlpTask.ResultProcessor getResultProcessor(NlpConfig config) {
62-
if (config instanceof FillMaskConfig) {
63-
FillMaskConfig fillMaskConfig = (FillMaskConfig) config;
63+
if (config instanceof FillMaskConfig fillMaskConfig) {
6464
return (tokenization, result) -> processResult(
6565
tokenization,
6666
result,

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.test.ESTestCase;
1112
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
1213
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
@@ -88,7 +89,7 @@ public void testValidate_GivenMissingMaskToken() {
8889
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
8990
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
9091

91-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input));
92+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> processor.validateInputs(input));
9293
assertThat(e.getMessage(), containsString("no [MASK] token could be found"));
9394
}
9495

@@ -98,7 +99,7 @@ public void testProcessResults_GivenMultipleMaskTokens() {
9899
FillMaskConfig config = new FillMaskConfig(new VocabularyConfig("test-index"), null, null, null);
99100
FillMaskProcessor processor = new FillMaskProcessor(mock(BertTokenizer.class), config);
100101

101-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> processor.validateInputs(input));
102+
ElasticsearchStatusException e = expectThrows(ElasticsearchStatusException.class, () -> processor.validateInputs(input));
102103
assertThat(e.getMessage(), containsString("only one [MASK] token should exist in the input"));
103104
}
104105
}

0 commit comments

Comments
 (0)