From 76952ac35b24e6dcb9b7c6e7df5b1ad0a6547d6d Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 14 Dec 2021 12:19:31 +0000 Subject: [PATCH 1/3] Make warnings errors --- .../deployment/DeploymentManager.java | 23 ++----------- .../ml/inference/nlp/FillMaskProcessor.java | 8 +++-- .../xpack/ml/inference/nlp/NerProcessor.java | 5 +-- .../nlp/TextClassificationProcessor.java | 13 +++----- .../nlp/ZeroShotClassificationProcessor.java | 16 +++++---- .../inference/nlp/FillMaskProcessorTests.java | 3 +- .../ml/inference/nlp/NerProcessorTests.java | 10 +++--- .../nlp/TextClassificationProcessorTests.java | 33 +++++++------------ 8 files changed, 45 insertions(+), 66 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 1fa51fce2052c..299d1282b526d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.lucene.util.SetOnce; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; @@ -34,7 +33,6 @@ import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; @@ -374,7 +372,7 @@ protected void doRun() throws Exception { processContext, request.tokenization, processor.getResultProcessor((NlpConfig) config), - ActionListener.wrap(this::onSuccess, f -> handleFailure(f, this)) + ActionListener.wrap(this::onSuccess, this::onFailure) ), this::onFailure ) @@ -382,24 +380,9 @@ protected void doRun() throws Exception { processContext.process.get().writeInferenceRequest(request.processInput); } catch (IOException e) { logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e); - handleFailure(ExceptionsHelper.serverError("error writing to process", e), this); + onFailure(ExceptionsHelper.serverError("error writing to process", e)); } catch (Exception e) { - handleFailure(e, this); - } - } - - private static void handleFailure(Exception e, ActionListener listener) { - Throwable unwrapped = org.elasticsearch.ExceptionsHelper.unwrapCause(e); - if (unwrapped instanceof ElasticsearchException ex) { - if (ex.status() == RestStatus.BAD_REQUEST) { - listener.onResponse(new WarningInferenceResults(ex.getMessage())); - } else { - listener.onFailure(ex); - } - } else if (unwrapped instanceof IllegalArgumentException) { - listener.onResponse(new WarningInferenceResults(e.getMessage())); - } else { - listener.onFailure(e); + onFailure(e); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java index 588eff96608fa..6a9ce0ea344fd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java @@ -7,10 +7,11 @@ package org.elasticsearch.xpack.ml.inference.nlp; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -86,7 +87,7 @@ static InferenceResults processResult( String resultsField ) { if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) { - return new WarningInferenceResults("No valid tokens for inference"); + throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR); } if (tokenizer.getMaskTokenId().isEmpty()) { @@ -105,8 +106,9 @@ static InferenceResults processResult( } } if (maskTokenIndex == -1) { - return new WarningInferenceResults( + throw new ElasticsearchStatusException( "mask token id [{}] not found in the tokenization {}", + RestStatus.INTERNAL_SERVER_ERROR, maskTokenId, List.of(tokenization.getTokenizations().get(0).getTokenIds()) ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java index 7fc86b437a388..3c9baac52bf77 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java @@ -7,12 +7,13 @@ package org.elasticsearch.xpack.ml.inference.nlp; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken; @@ -195,7 +196,7 @@ static class NerResultProcessor implements NlpTask.ResultProcessor { @Override public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) { if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) { - return new WarningInferenceResults("no valid tokens to build result"); + throw new ElasticsearchStatusException("no valid tokenization to build result", RestStatus.INTERNAL_SERVER_ERROR); } // TODO - process all results in the batch diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java index c80adb8cea02e..f7929d96490aa 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java @@ -7,10 +7,11 @@ package org.elasticsearch.xpack.ml.inference.nlp; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer; @@ -39,11 +40,6 @@ public class TextClassificationProcessor implements NlpTask.Processor { // negative values are a special case of asking for ALL classes. Since we require the output size to equal the classLabel size // This is a nice way of setting the value this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses(); - validate(); - } - - private void validate() { - // validation occurs in TextClassificationConfig } @Override @@ -87,14 +83,15 @@ static InferenceResults processResult( String resultsField ) { if (pyTorchResult.getInferenceResult().length < 1) { - return new WarningInferenceResults("Text classification result has no data"); + throw new ElasticsearchStatusException("Text classification result has no data", RestStatus.INTERNAL_SERVER_ERROR); } // TODO only the first entry in the batch result is verified and // checked. Implement for all in batch if (pyTorchResult.getInferenceResult()[0][0].length != labels.size()) { - return new WarningInferenceResults( + throw new ElasticsearchStatusException( "Expected exactly [{}] values in text classification result; got [{}]", + RestStatus.INTERNAL_SERVER_ERROR, labels.size(), pyTorchResult.getInferenceResult()[0][0].length ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java index 3330743029583..b504114a1b582 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java @@ -7,11 +7,12 @@ package org.elasticsearch.xpack.ml.inference.nlp; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.logging.LoggerMessageFormat; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig; @@ -114,7 +115,7 @@ static class RequestBuilder implements NlpTask.RequestBuilder { @Override public NlpTask.Request buildRequest(List inputs, String requestId, Tokenization.Truncate truncate) throws IOException { if (inputs.size() > 1) { - throw new IllegalArgumentException("Unable to do zero-shot classification on more than one text input at a time"); + throw ExceptionsHelper.badRequestException("Unable to do zero-shot classification on more than one text input at a time"); } List tokenizations = new ArrayList<>(labels.length); for (String label : labels) { @@ -148,13 +149,14 @@ static class ResultProcessor implements NlpTask.ResultProcessor { @Override public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) { if (pyTorchResult.getInferenceResult().length < 1) { - return new WarningInferenceResults("Zero shot classification result has no data"); + throw new ElasticsearchStatusException("Zero shot classification result has no data", RestStatus.INTERNAL_SERVER_ERROR); } // TODO only the first entry in the batch result is verified and // checked. Implement for all in batch if (pyTorchResult.getInferenceResult()[0].length != labels.length) { - return new WarningInferenceResults( + throw new ElasticsearchStatusException( "Expected exactly [{}] values in zero shot classification result; got [{}]", + RestStatus.INTERNAL_SERVER_ERROR, labels.length, pyTorchResult.getInferenceResult().length ); @@ -165,8 +167,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn int v = 0; for (double[] vals : pyTorchResult.getInferenceResult()[0]) { if (vals.length != 3) { - return new WarningInferenceResults( + throw new ElasticsearchStatusException( "Expected exactly [{}] values in inner zero shot classification result; got [{}]", + RestStatus.INTERNAL_SERVER_ERROR, 3, vals.length ); @@ -181,8 +184,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn int v = 0; for (double[] vals : pyTorchResult.getInferenceResult()[0]) { if (vals.length != 3) { - return new WarningInferenceResults( + throw new ElasticsearchStatusException( "Expected exactly [{}] values in inner zero shot classification result; got [{}]", + RestStatus.INTERNAL_SERVER_ERROR, 3, vals.length ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index df03613167bc7..a1227402f03b2 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig; import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer; @@ -92,7 +91,7 @@ public void testProcessResults_GivenMissingTokens() { PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null); assertThat( FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)), - instanceOf(WarningInferenceResults.class) + instanceOf(ElasticsearchStatusException.class) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index fb4991ec0966d..a29fcac6d7fef 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -7,10 +7,10 @@ package org.elasticsearch.xpack.ml.inference.nlp; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.ml.inference.results.NerResults; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; @@ -91,10 +91,12 @@ public void testValidate_NotAEntityLabel() { public void testProcessResults_GivenNoTokens() { NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, false); TokenizationResult tokenization = tokenize(List.of(BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN), ""); - assertThat( - processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null)), - instanceOf(WarningInferenceResults.class) + + var e = expectThrows( + ElasticsearchStatusException.class, + () -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null)) ); + assertThat(e, instanceOf(ElasticsearchStatusException.class)); } public void testProcessResults() { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java index 2b7deb359229e..0f1b03e4bea56 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java @@ -7,11 +7,10 @@ package org.elasticsearch.xpack.ml.inference.nlp; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults; -import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization; @@ -25,6 +24,7 @@ import java.util.Map; import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.instanceOf; @@ -33,30 +33,21 @@ public class TextClassificationProcessorTests extends ESTestCase { public void testInvalidResult() { { PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, null); - InferenceResults inferenceResults = TextClassificationProcessor.processResult( - null, - torchResult, - randomInt(), - List.of("a", "b"), - randomAlphaOfLength(10) + var e = expectThrows( + ElasticsearchStatusException.class, + () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) ); - assertThat(inferenceResults, instanceOf(WarningInferenceResults.class)); - assertEquals("Text classification result has no data", ((WarningInferenceResults) inferenceResults).getWarning()); + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), containsString("Text classification result has no data")); } { PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, null); - InferenceResults inferenceResults = TextClassificationProcessor.processResult( - null, - torchResult, - randomInt(), - List.of("a", "b"), - randomAlphaOfLength(10) - ); - assertThat(inferenceResults, instanceOf(WarningInferenceResults.class)); - assertEquals( - "Expected exactly [2] values in text classification result; got [1]", - ((WarningInferenceResults) inferenceResults).getWarning() + var e = expectThrows( + ElasticsearchStatusException.class, + () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) ); + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), containsString("Expected exactly [2] values in text classification result; got [1]")); } } From ec16126530091bd766a2fa7c75120000a2347235 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 14 Dec 2021 20:02:36 +0000 Subject: [PATCH 2/3] fix tests --- .../xpack/ml/inference/deployment/DeploymentManager.java | 2 +- .../xpack/ml/inference/nlp/FillMaskProcessorTests.java | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java index 299d1282b526d..beed432b91350 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java @@ -372,7 +372,7 @@ protected void doRun() throws Exception { processContext, request.tokenization, processor.getResultProcessor((NlpConfig) config), - ActionListener.wrap(this::onSuccess, this::onFailure) + this ), this::onFailure ) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index a1227402f03b2..76bb98b31aeed 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -27,7 +27,6 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -89,9 +88,9 @@ public void testProcessResults_GivenMissingTokens() { tokenization.addTokenization("", false, Collections.emptyList(), new int[] {}, new int[] {}); PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null); - assertThat( - FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)), - instanceOf(ElasticsearchStatusException.class) + expectThrows( + ElasticsearchStatusException.class, + () -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)) ); } From 39299d5821d11e5caa72dd9dd6f88fde05d7ac69 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 14 Dec 2021 20:57:01 +0000 Subject: [PATCH 3/3] fix another test --- .../elasticsearch/xpack/ml/integration/PyTorchModelIT.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 773d9a9bfd002..f4725e78e1e0d 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -527,8 +527,9 @@ public void testTruncation() throws IOException { startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString()); String input = "once twice thrice"; + var e = expectThrows(ResponseException.class, () -> EntityUtils.toString(infer("once twice thrice", modelId).getEntity())); assertThat( - EntityUtils.toString(infer("once twice thrice", modelId).getEntity()), + e.getMessage(), containsString("Input too large. The tokenized input length [3] exceeds the maximum sequence length [2]") ); @@ -637,7 +638,8 @@ public void testPipelineWithBadProcessor() throws IOException { """; response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity()); - assertThat(response, containsString("warning")); + assertThat(response, containsString("no value could be found for input field [input]")); + assertThat(response, containsString("status_exception")); } public void testDeleteModelWithDeploymentUsedByIngestProcessor() throws IOException {