Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,9 @@ public void testTruncation() throws IOException {
startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());

String input = "once twice thrice";
var e = expectThrows(ResponseException.class, () -> EntityUtils.toString(infer("once twice thrice", modelId).getEntity()));
assertThat(
EntityUtils.toString(infer("once twice thrice", modelId).getEntity()),
e.getMessage(),
containsString("Input too large. The tokenized input length [3] exceeds the maximum sequence length [2]")
);

Expand Down Expand Up @@ -637,7 +638,8 @@ public void testPipelineWithBadProcessor() throws IOException {
""";

response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
assertThat(response, containsString("warning"));
assertThat(response, containsString("no value could be found for input field [input]"));
assertThat(response, containsString("status_exception"));
}

public void testDeleteModelWithDeploymentUsedByIngestProcessor() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
Expand All @@ -34,7 +33,6 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
Expand Down Expand Up @@ -374,32 +372,17 @@ protected void doRun() throws Exception {
processContext,
request.tokenization,
processor.getResultProcessor((NlpConfig) config),
ActionListener.wrap(this::onSuccess, f -> handleFailure(f, this))
this
),
this::onFailure
)
);
processContext.process.get().writeInferenceRequest(request.processInput);
} catch (IOException e) {
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
handleFailure(ExceptionsHelper.serverError("error writing to process", e), this);
onFailure(ExceptionsHelper.serverError("error writing to process", e));
} catch (Exception e) {
handleFailure(e, this);
}
}

private static void handleFailure(Exception e, ActionListener<InferenceResults> listener) {
Throwable unwrapped = org.elasticsearch.ExceptionsHelper.unwrapCause(e);
if (unwrapped instanceof ElasticsearchException ex) {
if (ex.status() == RestStatus.BAD_REQUEST) {
listener.onResponse(new WarningInferenceResults(ex.getMessage()));
} else {
listener.onFailure(ex);
}
} else if (unwrapped instanceof IllegalArgumentException) {
listener.onResponse(new WarningInferenceResults(e.getMessage()));
} else {
listener.onFailure(e);
onFailure(e);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -86,7 +87,7 @@ static InferenceResults processResult(
String resultsField
) {
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
return new WarningInferenceResults("No valid tokens for inference");
throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR);
}

if (tokenizer.getMaskTokenId().isEmpty()) {
Expand All @@ -105,8 +106,9 @@ static InferenceResults processResult(
}
}
if (maskTokenIndex == -1) {
return new WarningInferenceResults(
throw new ElasticsearchStatusException(
"mask token id [{}] not found in the tokenization {}",
RestStatus.INTERNAL_SERVER_ERROR,
maskTokenId,
List.of(tokenization.getTokenizations().get(0).getTokenIds())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
Expand Down Expand Up @@ -195,7 +196,7 @@ static class NerResultProcessor implements NlpTask.ResultProcessor {
@Override
public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
return new WarningInferenceResults("no valid tokens to build result");
throw new ElasticsearchStatusException("no valid tokenization to build result", RestStatus.INTERNAL_SERVER_ERROR);
}
// TODO - process all results in the batch

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
Expand Down Expand Up @@ -39,11 +40,6 @@ public class TextClassificationProcessor implements NlpTask.Processor {
// negative values are a special case of asking for ALL classes. Since we require the output size to equal the classLabel size
// This is a nice way of setting the value
this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses();
validate();
}

private void validate() {
// validation occurs in TextClassificationConfig
}

@Override
Expand Down Expand Up @@ -87,14 +83,15 @@ static InferenceResults processResult(
String resultsField
) {
if (pyTorchResult.getInferenceResult().length < 1) {
return new WarningInferenceResults("Text classification result has no data");
throw new ElasticsearchStatusException("Text classification result has no data", RestStatus.INTERNAL_SERVER_ERROR);
}

// TODO only the first entry in the batch result is verified and
// checked. Implement for all in batch
if (pyTorchResult.getInferenceResult()[0][0].length != labels.size()) {
return new WarningInferenceResults(
throw new ElasticsearchStatusException(
"Expected exactly [{}] values in text classification result; got [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
labels.size(),
pyTorchResult.getInferenceResult()[0][0].length
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
Expand Down Expand Up @@ -114,7 +115,7 @@ static class RequestBuilder implements NlpTask.RequestBuilder {
@Override
public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
if (inputs.size() > 1) {
throw new IllegalArgumentException("Unable to do zero-shot classification on more than one text input at a time");
throw ExceptionsHelper.badRequestException("Unable to do zero-shot classification on more than one text input at a time");
}
List<TokenizationResult.Tokenization> tokenizations = new ArrayList<>(labels.length);
for (String label : labels) {
Expand Down Expand Up @@ -148,13 +149,14 @@ static class ResultProcessor implements NlpTask.ResultProcessor {
@Override
public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
if (pyTorchResult.getInferenceResult().length < 1) {
return new WarningInferenceResults("Zero shot classification result has no data");
throw new ElasticsearchStatusException("Zero shot classification result has no data", RestStatus.INTERNAL_SERVER_ERROR);
}
// TODO only the first entry in the batch result is verified and
// checked. Implement for all in batch
if (pyTorchResult.getInferenceResult()[0].length != labels.length) {
return new WarningInferenceResults(
throw new ElasticsearchStatusException(
"Expected exactly [{}] values in zero shot classification result; got [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
labels.length,
pyTorchResult.getInferenceResult().length
);
Expand All @@ -165,8 +167,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn
int v = 0;
for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
if (vals.length != 3) {
return new WarningInferenceResults(
throw new ElasticsearchStatusException(
"Expected exactly [{}] values in inner zero shot classification result; got [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
3,
vals.length
);
Expand All @@ -181,8 +184,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn
int v = 0;
for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
if (vals.length != 3) {
return new WarningInferenceResults(
throw new ElasticsearchStatusException(
"Expected exactly [{}] values in inner zero shot classification result; got [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
3,
vals.length
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer;
Expand All @@ -28,7 +27,6 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -90,9 +88,9 @@ public void testProcessResults_GivenMissingTokens() {
tokenization.addTokenization("", false, Collections.emptyList(), new int[] {}, new int[] {});

PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null);
assertThat(
FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)),
instanceOf(WarningInferenceResults.class)
expectThrows(
ElasticsearchStatusException.class,
() -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10))
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
Expand Down Expand Up @@ -91,10 +91,12 @@ public void testValidate_NotAEntityLabel() {
public void testProcessResults_GivenNoTokens() {
NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, false);
TokenizationResult tokenization = tokenize(List.of(BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN), "");
assertThat(
processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null)),
instanceOf(WarningInferenceResults.class)

var e = expectThrows(
ElasticsearchStatusException.class,
() -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null))
);
assertThat(e, instanceOf(ElasticsearchStatusException.class));
}

public void testProcessResults() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@

package org.elasticsearch.xpack.ml.inference.nlp;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
Expand All @@ -25,6 +24,7 @@
import java.util.Map;

import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;

Expand All @@ -33,30 +33,21 @@ public class TextClassificationProcessorTests extends ESTestCase {
public void testInvalidResult() {
{
PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, null);
InferenceResults inferenceResults = TextClassificationProcessor.processResult(
null,
torchResult,
randomInt(),
List.of("a", "b"),
randomAlphaOfLength(10)
var e = expectThrows(
ElasticsearchStatusException.class,
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
);
assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
assertEquals("Text classification result has no data", ((WarningInferenceResults) inferenceResults).getWarning());
assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat(e.getMessage(), containsString("Text classification result has no data"));
}
{
PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
InferenceResults inferenceResults = TextClassificationProcessor.processResult(
null,
torchResult,
randomInt(),
List.of("a", "b"),
randomAlphaOfLength(10)
);
assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
assertEquals(
"Expected exactly [2] values in text classification result; got [1]",
((WarningInferenceResults) inferenceResults).getWarning()
var e = expectThrows(
ElasticsearchStatusException.class,
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
);
assertThat(e, instanceOf(ElasticsearchStatusException.class));
assertThat(e.getMessage(), containsString("Expected exactly [2] values in text classification result; got [1]"));
}
}

Expand Down