diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 9498b58bb5b22..40eb8a77913b0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -75,7 +75,6 @@ * torch.jit.save(traced_model, "simplemodel.pt") * ## End Python */ -@ESRestTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class PyTorchModelIT extends ESRestTestCase { private static final String BASIC_AUTH_VALUE_SUPER_USER = UsernamePasswordToken.basicAuthHeaderValue( diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index c430d2a873a6f..89bc976dbb60b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.core.ml.utils.Intervals; +import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult; @@ -105,10 +106,12 @@ public void process(PyTorchProcess process) { threadSettingsConsumer.accept(threadSettings); processThreadSettings(result); } + if (result.ackResult() != null) { + processAcknowledgement(result); + } if (result.errorResult() != null) { processErrorResult(result); } - } } catch (Exception e) { // No need to report error as we're stopping @@ -118,10 +121,13 @@ public void process(PyTorchProcess process) { pendingResults.forEach( (id, pendingResult) -> pendingResult.listener.onResponse( new PyTorchResult( + id, + null, + null, + null, null, null, new ErrorResult( - id, isStopping ? "inference canceled as process is stopping" : "inference native process died unexpectedly with failure [" + e.getMessage() + "]" @@ -133,7 +139,7 @@ public void process(PyTorchProcess process) { } finally { pendingResults.forEach( (id, pendingResult) -> pendingResult.listener.onResponse( - new PyTorchResult(null, null, new ErrorResult(id, "inference canceled as process is stopping")) + new PyTorchResult(id, false, null, null, null, null, new ErrorResult("inference canceled as process is stopping")) ) ); pendingResults.clear(); @@ -144,12 +150,17 @@ public void process(PyTorchProcess process) { void processInferenceResult(PyTorchResult result) { PyTorchInferenceResult inferenceResult = result.inferenceResult(); assert inferenceResult != null; + Long timeMs = result.timeMs(); + if (timeMs == null) { + assert false : "time_ms should be set for an inference result"; + timeMs = 0L; + } - logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, inferenceResult.getRequestId())); - processResult(inferenceResult); - PendingResult pendingResult = pendingResults.remove(inferenceResult.getRequestId()); + logger.trace(() -> format("[%s] Parsed inference result with id [%s]", deploymentId, result.requestId())); + processResult(inferenceResult, timeMs, Boolean.TRUE.equals(result.isCacheHit())); + PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { - logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, inferenceResult.getRequestId())); + logger.debug(() -> format("[%s] no pending result for inference [%s]", deploymentId, result.requestId())); } else { pendingResult.listener.onResponse(result); } @@ -159,10 +170,23 @@ void processThreadSettings(PyTorchResult result) { ThreadSettings threadSettings = result.threadSettings(); assert threadSettings != null; - logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, threadSettings.requestId())); - PendingResult pendingResult = pendingResults.remove(threadSettings.requestId()); + logger.trace(() -> format("[%s] Parsed thread settings result with id [%s]", deploymentId, result.requestId())); + PendingResult pendingResult = pendingResults.remove(result.requestId()); + if (pendingResult == null) { + logger.debug(() -> format("[%s] no pending result for thread settings [%s]", deploymentId, result.requestId())); + } else { + pendingResult.listener.onResponse(result); + } + } + + void processAcknowledgement(PyTorchResult result) { + AckResult ack = result.ackResult(); + assert ack != null; + + logger.trace(() -> format("[%s] Parsed ack result with id [%s]", deploymentId, result.requestId())); + PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { - logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, threadSettings.requestId())); + logger.debug(() -> format("[%s] no pending result for ack [%s]", deploymentId, result.requestId())); } else { pendingResult.listener.onResponse(result); } @@ -172,12 +196,15 @@ void processErrorResult(PyTorchResult result) { ErrorResult errorResult = result.errorResult(); assert errorResult != null; - errorCount++; + // Only one result is processed at any time, but we need to stop this happening part way through another thread getting stats + synchronized (this) { + errorCount++; + } - logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, errorResult.requestId())); - PendingResult pendingResult = pendingResults.remove(errorResult.requestId()); + logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, result.requestId())); + PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { - logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, errorResult.requestId())); + logger.debug(() -> format("[%s] no pending result for error [%s]", deploymentId, result.requestId())); } else { pendingResult.listener.onResponse(result); } @@ -218,8 +245,8 @@ public synchronized ResultStats getResultStats() { ); } - private synchronized void processResult(PyTorchInferenceResult result) { - timingStats.accept(result.getTimeMs()); + private synchronized void processResult(PyTorchInferenceResult result, long timeMs, boolean isCacheHit) { + timingStats.accept(timeMs); lastResultTimeMs = currentTimeMsSupplier.getAsLong(); if (lastResultTimeMs > currentPeriodEndTimeMs) { @@ -240,15 +267,15 @@ private synchronized void processResult(PyTorchInferenceResult result) { lastPeriodCacheHitCount = 0; lastPeriodSummaryStats = new LongSummaryStatistics(); - lastPeriodSummaryStats.accept(result.getTimeMs()); + lastPeriodSummaryStats.accept(timeMs); // set to the end of the current bucket currentPeriodEndTimeMs = startTime + Intervals.alignToCeil(lastResultTimeMs - startTime, REPORTING_PERIOD_MS); } else { - lastPeriodSummaryStats.accept(result.getTimeMs()); + lastPeriodSummaryStats.accept(timeMs); } - if (result.isCacheHit()) { + if (isCacheHit) { cacheHitCount++; lastPeriodCacheHitCount++; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java new file mode 100644 index 0000000000000..06f2679c56c2d --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.results; + +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; + +public record AckResult(boolean acknowledged) implements ToXContentObject { + + public static final ParseField ACKNOWLEDGED = new ParseField("acknowledged"); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "ack", + a -> new AckResult((Boolean) a[0]) + ); + + static { + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ACKNOWLEDGED); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACKNOWLEDGED.getPreferredName(), acknowledged); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java index 20e0855a50b3e..68fc5cc589231 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java @@ -14,26 +14,22 @@ import java.io.IOException; -public record ErrorResult(String requestId, String error) implements ToXContentObject { +public record ErrorResult(String error) implements ToXContentObject { public static final ParseField ERROR = new ParseField("error"); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "error", - a -> new ErrorResult((String) a[0], (String) a[1]) + a -> new ErrorResult((String) a[0]) ); static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID); PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - if (requestId != null) { - builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId); - } builder.field(ERROR.getPreferredName(), error); builder.endObject(); return builder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java index c4636f3110f4a..a1482851fc21d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Objects; /** * All results must have a request_id. @@ -28,62 +27,38 @@ public class PyTorchInferenceResult implements ToXContentObject { private static final ParseField INFERENCE = new ParseField("inference"); - private static final ParseField TIME_MS = new ParseField("time_ms"); - private static final ParseField CACHE_HIT = new ParseField("cache_hit"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "pytorch_inference_result", - a -> new PyTorchInferenceResult((String) a[0], (double[][][]) a[1], (Long) a[2], (Boolean) a[3]) + a -> new PyTorchInferenceResult((double[][][]) a[0]) ); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), PyTorchResult.REQUEST_ID); PARSER.declareField( ConstructingObjectParser.optionalConstructorArg(), (p, c) -> MlParserUtils.parse3DArrayOfDoubles(INFERENCE.getPreferredName(), p), INFERENCE, ObjectParser.ValueType.VALUE_ARRAY ); - PARSER.declareLong(ConstructingObjectParser.constructorArg(), TIME_MS); - PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), CACHE_HIT); } public static PyTorchInferenceResult fromXContent(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } - private final String requestId; private final double[][][] inference; - private final long timeMs; - private final boolean cacheHit; - public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, long timeMs, boolean cacheHit) { - this.requestId = Objects.requireNonNull(requestId); + public PyTorchInferenceResult(@Nullable double[][][] inference) { this.inference = inference; - this.timeMs = timeMs; - this.cacheHit = cacheHit; - } - - public String getRequestId() { - return requestId; } public double[][][] getInferenceResult() { return inference; } - public long getTimeMs() { - return timeMs; - } - - public boolean isCacheHit() { - return cacheHit; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId); if (inference != null) { builder.startArray(INFERENCE.getPreferredName()); for (double[][] doubles : inference) { @@ -95,15 +70,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } - builder.field(TIME_MS.getPreferredName(), timeMs); - builder.field(CACHE_HIT.getPreferredName(), cacheHit); builder.endObject(); return builder; } @Override public int hashCode() { - return Objects.hash(requestId, timeMs, Arrays.deepHashCode(inference), cacheHit); + return Arrays.deepHashCode(inference); } @Override @@ -112,9 +85,6 @@ public boolean equals(Object other) { if (other == null || getClass() != other.getClass()) return false; PyTorchInferenceResult that = (PyTorchInferenceResult) other; - return Objects.equals(requestId, that.requestId) - && Arrays.deepEquals(inference, that.inference) - && timeMs == that.timeMs - && cacheHit == that.cacheHit; + return Arrays.deepEquals(inference, that.inference); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java index f27037e38617d..11340d0bf542d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java @@ -19,24 +19,43 @@ * The top level object capturing output from the pytorch process. */ public record PyTorchResult( + String requestId, + Boolean isCacheHit, + Long timeMs, @Nullable PyTorchInferenceResult inferenceResult, @Nullable ThreadSettings threadSettings, + @Nullable AckResult ackResult, @Nullable ErrorResult errorResult ) implements ToXContentObject { - static final ParseField REQUEST_ID = new ParseField("request_id"); + private static final ParseField REQUEST_ID = new ParseField("request_id"); + private static final ParseField CACHE_HIT = new ParseField("cache_hit"); + private static final ParseField TIME_MS = new ParseField("time_ms"); private static final ParseField RESULT = new ParseField("result"); private static final ParseField THREAD_SETTINGS = new ParseField("thread_settings"); + private static final ParseField ACK = new ParseField("ack"); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "pytorch_result", - a -> new PyTorchResult((PyTorchInferenceResult) a[0], (ThreadSettings) a[1], (ErrorResult) a[2]) + a -> new PyTorchResult( + (String) a[0], + (Boolean) a[1], + (Long) a[2], + (PyTorchInferenceResult) a[3], + (ThreadSettings) a[4], + (AckResult) a[5], + (ErrorResult) a[6] + ) ); static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID); + PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CACHE_HIT); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), TIME_MS); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PyTorchInferenceResult.PARSER, RESULT); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ThreadSettings.PARSER, THREAD_SETTINGS); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), AckResult.PARSER, ACK); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ErrorResult.PARSER, ErrorResult.ERROR); } @@ -47,12 +66,24 @@ public boolean isError() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + if (requestId != null) { + builder.field(REQUEST_ID.getPreferredName(), requestId); + } + if (isCacheHit != null) { + builder.field(CACHE_HIT.getPreferredName(), isCacheHit); + } + if (timeMs != null) { + builder.field(TIME_MS.getPreferredName(), timeMs); + } if (inferenceResult != null) { builder.field(RESULT.getPreferredName(), inferenceResult); } if (threadSettings != null) { builder.field(THREAD_SETTINGS.getPreferredName(), threadSettings); } + if (ackResult != null) { + builder.field(ACK.getPreferredName(), ackResult); + } if (errorResult != null) { builder.field(ErrorResult.ERROR.getPreferredName(), errorResult); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java index 3d2ad6997545d..9154d33c04574 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java @@ -14,20 +14,19 @@ import java.io.IOException; -public record ThreadSettings(int numThreadsPerAllocation, int numAllocations, String requestId) implements ToXContentObject { +public record ThreadSettings(int numThreadsPerAllocation, int numAllocations) implements ToXContentObject { private static final ParseField NUM_ALLOCATIONS = new ParseField("num_allocations"); private static final ParseField NUM_THREADS_PER_ALLOCATION = new ParseField("num_threads_per_allocation"); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "thread_settings", - a -> new ThreadSettings((int) a[0], (int) a[1], (String) a[2]) + a -> new ThreadSettings((int) a[0], (int) a[1]) ); static { PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_THREADS_PER_ALLOCATION); PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_ALLOCATIONS); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID); } @Override @@ -35,9 +34,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(NUM_THREADS_PER_ALLOCATION.getPreferredName(), numThreadsPerAllocation); builder.field(NUM_ALLOCATIONS.getPreferredName(), numAllocations); - if (requestId != null) { - builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId); - } builder.endObject(); return builder; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index b36ce41c5c49d..f3afb0286f076 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -64,7 +64,7 @@ public void testProcessResults() { String resultsField = randomAlphaOfLength(10); FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult( tokenization, - new PyTorchInferenceResult("1", scores, 0L, false), + new PyTorchInferenceResult(scores), tokenizer, 4, resultsField @@ -91,7 +91,7 @@ public void testProcessResults_GivenMissingTokens() { 0 ); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(new double[][][] { { {} } }); expectThrows( ElasticsearchStatusException.class, () -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index 416beaee9d3db..389a4fab802a0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -72,7 +72,7 @@ public void testProcessResults_GivenNoTokens() { var e = expectThrows( ElasticsearchStatusException.class, - () -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, false)) + () -> processor.processResult(tokenization, new PyTorchInferenceResult(null)) ); assertThat(e, instanceOf(ElasticsearchStatusException.class)); } @@ -113,7 +113,7 @@ public void testProcessResultsWithSpecialTokens() { { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -141,7 +141,7 @@ public void testProcessResults() { { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -178,7 +178,7 @@ public void testProcessResults_withIobMap() { { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -211,7 +211,7 @@ public void testProcessResults_withCustomIobMap() { { 0, 0, 0, 0, 5 }, // in { 6, 0, 0, 0, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java index f988da404bdb3..ab8bdf4870973 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java @@ -87,7 +87,7 @@ public void testProcessor() throws IOException { assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7)); double[][][] scores = { { START_TOKEN_SCORES }, { END_TOKEN_SCORES } }; NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores); QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult( tokenizationResult, pyTorchResult diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java index 357a0bd1bd611..3b48e75846243 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java @@ -32,7 +32,7 @@ public class TextClassificationProcessorTests extends ESTestCase { public void testInvalidResult() { { - PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, false); + PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] {}); var e = expectThrows( ElasticsearchStatusException.class, () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) @@ -41,7 +41,7 @@ public void testInvalidResult() { assertThat(e.getMessage(), containsString("Text classification result has no data")); } { - PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, false); + PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0 } } }); var e = expectThrows( ElasticsearchStatusException.class, () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java index 5601fd6b8baa8..10be6225163b6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java @@ -51,7 +51,7 @@ public void testProcessor() throws IOException { assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7)); double[][][] scores = { { { 42 } } }; NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(textSimilarityConfig); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores); TextSimilarityInferenceResults result = (TextSimilarityInferenceResults) resultProcessor.processResult( tokenizationResult, pyTorchResult @@ -74,7 +74,7 @@ public void testResultFunctions() { TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer); NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(textSimilarityConfig); double[][][] scores = { { { 42 }, { 12 }, { 100 } } }; - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores); TextSimilarityInferenceResults result = (TextSimilarityInferenceResults) resultProcessor.processResult( new BertTokenizationResult(List.of(), List.of(), 1), pyTorchResult diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java index 11eb75ac91bdf..98da8da4b686a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult; @@ -24,7 +25,6 @@ import static org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor.REPORTING_PERIOD_MS; import static org.hamcrest.Matchers.closeTo; -import static org.hamcrest.Matchers.comparesEqualTo; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.nullValue; @@ -37,40 +37,47 @@ public void testsThreadSettings() { var settingsHolder = new AtomicReference(); var processor = new PyTorchResultProcessor("deployment-foo", settingsHolder::set); - var settings = new ThreadSettings(1, 1, "thread-setting"); + var settings = new ThreadSettings(1, 1); processor.registerRequest("thread-setting", new AssertingResultListener(r -> assertEquals(settings, r.threadSettings()))); - processor.process(mockNativeProcess(List.of(new PyTorchResult(null, settings, null)).iterator())); + processor.process( + mockNativeProcess(List.of(new PyTorchResult("thread-setting", null, null, null, settings, null, null)).iterator()) + ); assertEquals(settings, settingsHolder.get()); } public void testResultsProcessing() { - var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false); - var threadSettings = new ThreadSettings(1, 1, "b"); - var errorResult = new ErrorResult("c", "a bad thing has happened"); + var inferenceResult = new PyTorchInferenceResult(null); + var threadSettings = new ThreadSettings(1, 1); + var ack = new AckResult(true); + var errorResult = new ErrorResult("a bad thing has happened"); var inferenceListener = new AssertingResultListener(r -> assertEquals(inferenceResult, r.inferenceResult())); var threadSettingsListener = new AssertingResultListener(r -> assertEquals(threadSettings, r.threadSettings())); + var ackListener = new AssertingResultListener(r -> assertEquals(ack, r.ackResult())); var errorListener = new AssertingResultListener(r -> assertEquals(errorResult, r.errorResult())); var processor = new PyTorchResultProcessor("foo", s -> {}); processor.registerRequest("a", inferenceListener); processor.registerRequest("b", threadSettingsListener); - processor.registerRequest("c", errorListener); + processor.registerRequest("c", ackListener); + processor.registerRequest("d", errorListener); processor.process( mockNativeProcess( List.of( - new PyTorchResult(inferenceResult, null, null), - new PyTorchResult(null, threadSettings, null), - new PyTorchResult(null, null, errorResult) + new PyTorchResult("a", true, 1000L, inferenceResult, null, null, null), + new PyTorchResult("b", null, null, null, threadSettings, null, null), + new PyTorchResult("c", null, null, null, null, ack, null), + new PyTorchResult("d", null, null, null, null, null, errorResult) ).iterator() ) ); assertTrue(inferenceListener.hasResponse); assertTrue(threadSettingsListener.hasResponse); + assertTrue(ackListener.hasResponse); assertTrue(errorListener.hasResponse); } @@ -86,9 +93,9 @@ public void testPendingRequest() { ); processor.registerRequest("b", calledOnShutdown); - var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false); + var inferenceResult = new PyTorchInferenceResult(null); - processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator())); + processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator())); assertSame(inferenceResult, resultHolder.get()); assertTrue(calledOnShutdown.hasResponse); } @@ -100,8 +107,8 @@ public void testCancelPendingRequest() { processor.ignoreResponseWithoutNotifying("a"); - var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false); - processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator())); + var inferenceResult = new PyTorchInferenceResult(null); + processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator())); } public void testPendingRequestAreCalledAtShutdown() { @@ -146,8 +153,8 @@ public void onFailure(Exception e) { } } - private PyTorchResult wrapInferenceResult(PyTorchInferenceResult result) { - return new PyTorchResult(result, null, null); + private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, long timeMs, PyTorchInferenceResult result) { + return new PyTorchResult(requestId, isCacheHit, timeMs, result, null, null, null); } public void testsStats() { @@ -161,33 +168,33 @@ public void testsStats() { processor.registerRequest("b", pendingB); processor.registerRequest("c", pendingC); - var a = wrapInferenceResult(new PyTorchInferenceResult("a", null, 1000L, false)); - var b = wrapInferenceResult(new PyTorchInferenceResult("b", null, 900L, false)); - var c = wrapInferenceResult(new PyTorchInferenceResult("c", null, 200L, true)); + var a = wrapInferenceResult("a", false, 1000L, new PyTorchInferenceResult(null)); + var b = wrapInferenceResult("b", false, 900L, new PyTorchInferenceResult(null)); + var c = wrapInferenceResult("c", true, 200L, new PyTorchInferenceResult(null)); processor.processInferenceResult(a); var stats = processor.getResultStats(); - assertThat(stats.errorCount(), comparesEqualTo(0)); + assertThat(stats.errorCount(), equalTo(0)); assertThat(stats.cacheHitCount(), equalTo(0L)); - assertThat(stats.numberOfPendingResults(), comparesEqualTo(2)); - assertThat(stats.timingStats().getCount(), comparesEqualTo(1L)); - assertThat(stats.timingStats().getSum(), comparesEqualTo(1000L)); + assertThat(stats.numberOfPendingResults(), equalTo(2)); + assertThat(stats.timingStats().getCount(), equalTo(1L)); + assertThat(stats.timingStats().getSum(), equalTo(1000L)); processor.processInferenceResult(b); stats = processor.getResultStats(); - assertThat(stats.errorCount(), comparesEqualTo(0)); + assertThat(stats.errorCount(), equalTo(0)); assertThat(stats.cacheHitCount(), equalTo(0L)); - assertThat(stats.numberOfPendingResults(), comparesEqualTo(1)); - assertThat(stats.timingStats().getCount(), comparesEqualTo(2L)); - assertThat(stats.timingStats().getSum(), comparesEqualTo(1900L)); + assertThat(stats.numberOfPendingResults(), equalTo(1)); + assertThat(stats.timingStats().getCount(), equalTo(2L)); + assertThat(stats.timingStats().getSum(), equalTo(1900L)); processor.processInferenceResult(c); stats = processor.getResultStats(); - assertThat(stats.errorCount(), comparesEqualTo(0)); + assertThat(stats.errorCount(), equalTo(0)); assertThat(stats.cacheHitCount(), equalTo(1L)); - assertThat(stats.numberOfPendingResults(), comparesEqualTo(0)); - assertThat(stats.timingStats().getCount(), comparesEqualTo(3L)); - assertThat(stats.timingStats().getSum(), comparesEqualTo(2100L)); + assertThat(stats.numberOfPendingResults(), equalTo(0)); + assertThat(stats.timingStats().getCount(), equalTo(3L)); + assertThat(stats.timingStats().getSum(), equalTo(2100L)); } public void testsTimeDependentStats() { @@ -227,9 +234,9 @@ public void testsTimeDependentStats() { var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier); // 1st period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null))); // first call has no results as is in the same period var stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); @@ -243,7 +250,7 @@ public void testsTimeDependentStats() { assertThat(stats.peakThroughput(), equalTo(3L)); // 2nd period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 100L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 100L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -255,7 +262,7 @@ public void testsTimeDependentStats() { assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); // 4th period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 300L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 300L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -263,8 +270,8 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9]))); // 7th period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 410L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 390L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 410L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 390L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); assertThat(stats.recentStats().avgInferenceTime(), nullValue()); @@ -275,9 +282,9 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12]))); // 8th period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 510L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 500L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 490L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 510L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 500L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 490L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(3L)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java new file mode 100644 index 0000000000000..b1b83e4d18851 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.results; + +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +public class AckResultTests extends AbstractXContentTestCase { + + public static AckResult createRandom() { + return new AckResult(randomBoolean()); + } + + @Override + protected AckResult createTestInstance() { + return createRandom(); + } + + @Override + protected AckResult doParseInstance(XContentParser parser) throws IOException { + return AckResult.PARSER.parse(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java index 3c7dacd84afb4..ac197c898fdc7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java @@ -15,7 +15,7 @@ public class ErrorResultTests extends AbstractXContentTestCase { public static ErrorResult createRandom() { - return new ErrorResult(randomBoolean() ? null : randomAlphaOfLength(5), randomAlphaOfLength(5)); + return new ErrorResult(randomAlphaOfLength(50)); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java index 005271739dfc4..f7370f25e2e84 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java @@ -30,8 +30,6 @@ protected PyTorchInferenceResult createTestInstance() { } public static PyTorchInferenceResult createRandom() { - String id = randomAlphaOfLength(6); - int rows = randomIntBetween(1, 10); int columns = randomIntBetween(1, 10); int depth = randomIntBetween(1, 10); @@ -43,6 +41,6 @@ public static PyTorchInferenceResult createRandom() { } } } - return new PyTorchInferenceResult(id, arr, randomLong(), randomBoolean()); + return new PyTorchInferenceResult(arr); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java index 9325dbb5d3ebe..9281fbfc54d13 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java @@ -16,11 +16,21 @@ public class PyTorchResultTests extends AbstractXContentTestCase @Override protected PyTorchResult createTestInstance() { - int type = randomIntBetween(0, 2); + String requestId = randomAlphaOfLength(5); + int type = randomIntBetween(0, 3); return switch (type) { - case 0 -> new PyTorchResult(PyTorchInferenceResultTests.createRandom(), null, null); - case 1 -> new PyTorchResult(null, ThreadSettingsTests.createRandom(), null); - default -> new PyTorchResult(null, null, ErrorResultTests.createRandom()); + case 0 -> new PyTorchResult( + requestId, + randomBoolean(), + randomNonNegativeLong(), + PyTorchInferenceResultTests.createRandom(), + null, + null, + null + ); + case 1 -> new PyTorchResult(requestId, null, null, null, ThreadSettingsTests.createRandom(), null, null); + case 2 -> new PyTorchResult(requestId, null, null, null, null, AckResultTests.createRandom(), null); + default -> new PyTorchResult(requestId, null, null, null, null, null, ErrorResultTests.createRandom()); }; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java index 62ec2a4da27f9..ce3b9d9fa07f3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java @@ -15,11 +15,7 @@ public class ThreadSettingsTests extends AbstractXContentTestCase { public static ThreadSettings createRandom() { - return new ThreadSettings( - randomIntBetween(1, Integer.MAX_VALUE), - randomIntBetween(1, Integer.MAX_VALUE), - randomBoolean() ? null : randomAlphaOfLength(5) - ); + return new ThreadSettings(randomIntBetween(1, Integer.MAX_VALUE), randomIntBetween(1, Integer.MAX_VALUE)); } @Override diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index bc4a36cef9ddd..6d0348b1fba92 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -76,8 +76,6 @@ setup: --- "Test start and stop deployment with cache": - skip: - version: all - reason: "@AwaitsFix https://github.com/elastic/ml-cpp/pull/2376" features: allowed_warnings - do: diff --git a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java index b0e624b470d0b..f1c7c04905bea 100644 --- a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java @@ -31,7 +31,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -@AbstractFullClusterRestartTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class MLModelDeploymentFullClusterRestartIT extends AbstractFullClusterRestartTestCase { // See PyTorchModelIT for how this model was created diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java index 8109ce0f7d0f3..682875ae5a2e5 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java @@ -29,7 +29,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -@AbstractUpgradeTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class MLModelDeploymentsUpgradeIT extends AbstractUpgradeTestCase { // See PyTorchModelIT for how this model was created