From fb516c99e8f44d1f2e4084a81181e60fd7eb4ee2 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Thu, 28 Jul 2022 14:01:17 +0100 Subject: [PATCH 1/7] [ML] Move PyTorch request ID and cache hit indicator to top level This change will facilitate a performance improvement on the C++ side. The request ID and cache hit indicator are the parts that need to be changed when the C++ process responds to an inference request. Having them at the top level means we do not need to parse and manipulate the original response - we can simply cache the inner object of the response and add the outer fields around it when serializing it. --- .../process/PyTorchResultProcessor.java | 48 ++++++---- .../inference/pytorch/results/AckResult.java | 37 ++++++++ .../pytorch/results/ErrorResult.java | 8 +- .../results/PyTorchInferenceResult.java | 28 +----- .../pytorch/results/PyTorchResult.java | 28 +++++- .../pytorch/results/ThreadSettings.java | 8 +- .../inference/nlp/FillMaskProcessorTests.java | 4 +- .../ml/inference/nlp/NerProcessorTests.java | 10 +-- .../nlp/QuestionAnsweringProcessorTests.java | 2 +- .../nlp/TextClassificationProcessorTests.java | 4 +- .../process/PyTorchResultProcessorTests.java | 87 ++++++++++--------- .../pytorch/results/AckResultTests.java | 35 ++++++++ .../pytorch/results/ErrorResultTests.java | 2 +- .../results/PyTorchInferenceResultTests.java | 4 +- .../pytorch/results/PyTorchResultTests.java | 10 ++- .../pytorch/results/ThreadSettingsTests.java | 6 +- 16 files changed, 204 insertions(+), 117 deletions(-) create mode 100644 x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java create mode 100644 x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index c430d2a873a6f..47fb66011cb03 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.core.ml.utils.Intervals; +import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult; @@ -105,10 +106,12 @@ public void process(PyTorchProcess process) { threadSettingsConsumer.accept(threadSettings); processThreadSettings(result); } + if (result.ackResult() != null) { + processAcknowledgement(result); + } if (result.errorResult() != null) { processErrorResult(result); } - } } catch (Exception e) { // No need to report error as we're stopping @@ -118,10 +121,12 @@ public void process(PyTorchProcess process) { pendingResults.forEach( (id, pendingResult) -> pendingResult.listener.onResponse( new PyTorchResult( + id, + false, + null, null, null, new ErrorResult( - id, isStopping ? "inference canceled as process is stopping" : "inference native process died unexpectedly with failure [" + e.getMessage() + "]" @@ -133,7 +138,7 @@ public void process(PyTorchProcess process) { } finally { pendingResults.forEach( (id, pendingResult) -> pendingResult.listener.onResponse( - new PyTorchResult(null, null, new ErrorResult(id, "inference canceled as process is stopping")) + new PyTorchResult(id, false, null, null, null, new ErrorResult("inference canceled as process is stopping")) ) ); pendingResults.clear(); @@ -145,11 +150,11 @@ void processInferenceResult(PyTorchResult result) { PyTorchInferenceResult inferenceResult = result.inferenceResult(); assert inferenceResult != null; - logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, inferenceResult.getRequestId())); - processResult(inferenceResult); - PendingResult pendingResult = pendingResults.remove(inferenceResult.getRequestId()); + logger.trace(() -> format("[%s] Parsed inference result with id [%s]", deploymentId, result.requestId())); + processResult(inferenceResult, result.isCacheHit()); + PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { - logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, inferenceResult.getRequestId())); + logger.debug(() -> format("[%s] no pending result for inference [%s]", deploymentId, result.requestId())); } else { pendingResult.listener.onResponse(result); } @@ -159,10 +164,23 @@ void processThreadSettings(PyTorchResult result) { ThreadSettings threadSettings = result.threadSettings(); assert threadSettings != null; - logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, threadSettings.requestId())); - PendingResult pendingResult = pendingResults.remove(threadSettings.requestId()); + logger.trace(() -> format("[%s] Parsed thread settings result with id [%s]", deploymentId, result.requestId())); + PendingResult pendingResult = pendingResults.remove(result.requestId()); + if (pendingResult == null) { + logger.debug(() -> format("[%s] no pending result for thread settings [%s]", deploymentId, result.requestId())); + } else { + pendingResult.listener.onResponse(result); + } + } + + void processAcknowledgement(PyTorchResult result) { + AckResult ack = result.ackResult(); + assert ack != null; + + logger.trace(() -> format("[%s] Parsed ack result with id [%s]", deploymentId, result.requestId())); + PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { - logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, threadSettings.requestId())); + logger.debug(() -> format("[%s] no pending result for ack [%s]", deploymentId, result.requestId())); } else { pendingResult.listener.onResponse(result); } @@ -174,10 +192,10 @@ void processErrorResult(PyTorchResult result) { errorCount++; - logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, errorResult.requestId())); - PendingResult pendingResult = pendingResults.remove(errorResult.requestId()); + logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, result.requestId())); + PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { - logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, errorResult.requestId())); + logger.debug(() -> format("[%s] no pending result for error [%s]", deploymentId, result.requestId())); } else { pendingResult.listener.onResponse(result); } @@ -218,7 +236,7 @@ public synchronized ResultStats getResultStats() { ); } - private synchronized void processResult(PyTorchInferenceResult result) { + private synchronized void processResult(PyTorchInferenceResult result, Boolean isCacheHit) { timingStats.accept(result.getTimeMs()); lastResultTimeMs = currentTimeMsSupplier.getAsLong(); @@ -248,7 +266,7 @@ private synchronized void processResult(PyTorchInferenceResult result) { lastPeriodSummaryStats.accept(result.getTimeMs()); } - if (result.isCacheHit()) { + if (Boolean.TRUE.equals(isCacheHit)) { cacheHitCount++; lastPeriodCacheHitCount++; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java new file mode 100644 index 0000000000000..9a9325b065d7c --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.results; + +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; + +public record AckResult(boolean acknowledged) implements ToXContentObject { + + public static final ParseField ACKNOWLEDGED = new ParseField("acknowledged"); + + public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "error", + a -> new AckResult((Boolean) a[0]) + ); + + static { + PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ACKNOWLEDGED); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(ACKNOWLEDGED.getPreferredName(), acknowledged); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java index 20e0855a50b3e..68fc5cc589231 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResult.java @@ -14,26 +14,22 @@ import java.io.IOException; -public record ErrorResult(String requestId, String error) implements ToXContentObject { +public record ErrorResult(String error) implements ToXContentObject { public static final ParseField ERROR = new ParseField("error"); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "error", - a -> new ErrorResult((String) a[0], (String) a[1]) + a -> new ErrorResult((String) a[0]) ); static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID); PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - if (requestId != null) { - builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId); - } builder.field(ERROR.getPreferredName(), error); builder.endObject(); return builder; diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java index c4636f3110f4a..a38dbf720eab9 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java @@ -29,15 +29,13 @@ public class PyTorchInferenceResult implements ToXContentObject { private static final ParseField INFERENCE = new ParseField("inference"); private static final ParseField TIME_MS = new ParseField("time_ms"); - private static final ParseField CACHE_HIT = new ParseField("cache_hit"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "pytorch_inference_result", - a -> new PyTorchInferenceResult((String) a[0], (double[][][]) a[1], (Long) a[2], (Boolean) a[3]) + a -> new PyTorchInferenceResult((double[][][]) a[0], (Long) a[1]) ); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), PyTorchResult.REQUEST_ID); PARSER.declareField( ConstructingObjectParser.optionalConstructorArg(), (p, c) -> MlParserUtils.parse3DArrayOfDoubles(INFERENCE.getPreferredName(), p), @@ -45,27 +43,18 @@ public class PyTorchInferenceResult implements ToXContentObject { ObjectParser.ValueType.VALUE_ARRAY ); PARSER.declareLong(ConstructingObjectParser.constructorArg(), TIME_MS); - PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), CACHE_HIT); } public static PyTorchInferenceResult fromXContent(XContentParser parser) throws IOException { return PARSER.parse(parser, null); } - private final String requestId; private final double[][][] inference; private final long timeMs; - private final boolean cacheHit; - public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, long timeMs, boolean cacheHit) { - this.requestId = Objects.requireNonNull(requestId); + public PyTorchInferenceResult(@Nullable double[][][] inference, long timeMs) { this.inference = inference; this.timeMs = timeMs; - this.cacheHit = cacheHit; - } - - public String getRequestId() { - return requestId; } public double[][][] getInferenceResult() { @@ -76,14 +65,9 @@ public long getTimeMs() { return timeMs; } - public boolean isCacheHit() { - return cacheHit; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId); if (inference != null) { builder.startArray(INFERENCE.getPreferredName()); for (double[][] doubles : inference) { @@ -96,14 +80,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endArray(); } builder.field(TIME_MS.getPreferredName(), timeMs); - builder.field(CACHE_HIT.getPreferredName(), cacheHit); builder.endObject(); return builder; } @Override public int hashCode() { - return Objects.hash(requestId, timeMs, Arrays.deepHashCode(inference), cacheHit); + return Objects.hash(timeMs, Arrays.deepHashCode(inference)); } @Override @@ -112,9 +95,6 @@ public boolean equals(Object other) { if (other == null || getClass() != other.getClass()) return false; PyTorchInferenceResult that = (PyTorchInferenceResult) other; - return Objects.equals(requestId, that.requestId) - && Arrays.deepEquals(inference, that.inference) - && timeMs == that.timeMs - && cacheHit == that.cacheHit; + return Arrays.deepEquals(inference, that.inference) && timeMs == that.timeMs; } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java index f27037e38617d..b9a319b9eba3f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java @@ -19,24 +19,39 @@ * The top level object capturing output from the pytorch process. */ public record PyTorchResult( + String requestId, + Boolean isCacheHit, @Nullable PyTorchInferenceResult inferenceResult, @Nullable ThreadSettings threadSettings, + @Nullable AckResult ackResult, @Nullable ErrorResult errorResult ) implements ToXContentObject { - static final ParseField REQUEST_ID = new ParseField("request_id"); + private static final ParseField REQUEST_ID = new ParseField("request_id"); + private static final ParseField CACHE_HIT = new ParseField("cache_hit"); private static final ParseField RESULT = new ParseField("result"); private static final ParseField THREAD_SETTINGS = new ParseField("thread_settings"); + private static final ParseField ACK = new ParseField("ack"); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "pytorch_result", - a -> new PyTorchResult((PyTorchInferenceResult) a[0], (ThreadSettings) a[1], (ErrorResult) a[2]) + a -> new PyTorchResult( + (String) a[0], + (Boolean) a[1], + (PyTorchInferenceResult) a[2], + (ThreadSettings) a[3], + (AckResult) a[4], + (ErrorResult) a[5] + ) ); static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID); + PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CACHE_HIT); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PyTorchInferenceResult.PARSER, RESULT); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ThreadSettings.PARSER, THREAD_SETTINGS); + PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), AckResult.PARSER, ACK); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ErrorResult.PARSER, ErrorResult.ERROR); } @@ -47,12 +62,21 @@ public boolean isError() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); + if (requestId != null) { + builder.field(REQUEST_ID.getPreferredName(), requestId); + } + if (isCacheHit != null) { + builder.field(CACHE_HIT.getPreferredName(), isCacheHit); + } if (inferenceResult != null) { builder.field(RESULT.getPreferredName(), inferenceResult); } if (threadSettings != null) { builder.field(THREAD_SETTINGS.getPreferredName(), threadSettings); } + if (ackResult != null) { + builder.field(ACK.getPreferredName(), ackResult); + } if (errorResult != null) { builder.field(ErrorResult.ERROR.getPreferredName(), errorResult); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java index 3d2ad6997545d..9154d33c04574 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettings.java @@ -14,20 +14,19 @@ import java.io.IOException; -public record ThreadSettings(int numThreadsPerAllocation, int numAllocations, String requestId) implements ToXContentObject { +public record ThreadSettings(int numThreadsPerAllocation, int numAllocations) implements ToXContentObject { private static final ParseField NUM_ALLOCATIONS = new ParseField("num_allocations"); private static final ParseField NUM_THREADS_PER_ALLOCATION = new ParseField("num_threads_per_allocation"); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "thread_settings", - a -> new ThreadSettings((int) a[0], (int) a[1], (String) a[2]) + a -> new ThreadSettings((int) a[0], (int) a[1]) ); static { PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_THREADS_PER_ALLOCATION); PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_ALLOCATIONS); - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID); } @Override @@ -35,9 +34,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field(NUM_THREADS_PER_ALLOCATION.getPreferredName(), numThreadsPerAllocation); builder.field(NUM_ALLOCATIONS.getPreferredName(), numAllocations); - if (requestId != null) { - builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId); - } builder.endObject(); return builder; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index 61c037b712406..b924f8891a8ef 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -66,7 +66,7 @@ public void testProcessResults() { String resultsField = randomAlphaOfLength(10); FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult( tokenization, - new PyTorchInferenceResult("1", scores, 0L, false), + new PyTorchInferenceResult(scores, 0L), tokenizer, 4, resultsField @@ -93,7 +93,7 @@ public void testProcessResults_GivenMissingTokens() { 0 ); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(new double[][][] { { {} } }, 0L); expectThrows( ElasticsearchStatusException.class, () -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index 416beaee9d3db..6ee10f2596567 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -72,7 +72,7 @@ public void testProcessResults_GivenNoTokens() { var e = expectThrows( ElasticsearchStatusException.class, - () -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, false)) + () -> processor.processResult(tokenization, new PyTorchInferenceResult(null, 0L)) ); assertThat(e, instanceOf(ElasticsearchStatusException.class)); } @@ -113,7 +113,7 @@ public void testProcessResultsWithSpecialTokens() { { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -141,7 +141,7 @@ public void testProcessResults() { { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -178,7 +178,7 @@ public void testProcessResults_withIobMap() { { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -211,7 +211,7 @@ public void testProcessResults_withCustomIobMap() { { 0, 0, 0, 0, 5 }, // in { 6, 0, 0, 0, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, false)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java index 3fd07c65f25d0..61344f5dec6a5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java @@ -87,7 +87,7 @@ public void testProcessor() throws IOException { assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7)); double[][][] scores = { { START_TOKEN_SCORES }, { END_TOKEN_SCORES } }; NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores, 1L); QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult( tokenizationResult, pyTorchResult diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java index 357a0bd1bd611..dea51cf06be1c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java @@ -32,7 +32,7 @@ public class TextClassificationProcessorTests extends ESTestCase { public void testInvalidResult() { { - PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, false); + PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] {}, 0L); var e = expectThrows( ElasticsearchStatusException.class, () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) @@ -41,7 +41,7 @@ public void testInvalidResult() { assertThat(e.getMessage(), containsString("Text classification result has no data")); } { - PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, false); + PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0 } } }, 0L); var e = expectThrows( ElasticsearchStatusException.class, () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java index 11eb75ac91bdf..976881f6010b8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult; import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult; @@ -24,7 +25,6 @@ import static org.elasticsearch.xpack.ml.inference.pytorch.process.PyTorchResultProcessor.REPORTING_PERIOD_MS; import static org.hamcrest.Matchers.closeTo; -import static org.hamcrest.Matchers.comparesEqualTo; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.nullValue; @@ -37,40 +37,45 @@ public void testsThreadSettings() { var settingsHolder = new AtomicReference(); var processor = new PyTorchResultProcessor("deployment-foo", settingsHolder::set); - var settings = new ThreadSettings(1, 1, "thread-setting"); + var settings = new ThreadSettings(1, 1); processor.registerRequest("thread-setting", new AssertingResultListener(r -> assertEquals(settings, r.threadSettings()))); - processor.process(mockNativeProcess(List.of(new PyTorchResult(null, settings, null)).iterator())); + processor.process(mockNativeProcess(List.of(new PyTorchResult("thread-setting", null, null, settings, null, null)).iterator())); assertEquals(settings, settingsHolder.get()); } public void testResultsProcessing() { - var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false); - var threadSettings = new ThreadSettings(1, 1, "b"); - var errorResult = new ErrorResult("c", "a bad thing has happened"); + var inferenceResult = new PyTorchInferenceResult(null, 1000L); + var threadSettings = new ThreadSettings(1, 1); + var ack = new AckResult(true); + var errorResult = new ErrorResult("a bad thing has happened"); var inferenceListener = new AssertingResultListener(r -> assertEquals(inferenceResult, r.inferenceResult())); var threadSettingsListener = new AssertingResultListener(r -> assertEquals(threadSettings, r.threadSettings())); + var ackListener = new AssertingResultListener(r -> assertEquals(ack, r.ackResult())); var errorListener = new AssertingResultListener(r -> assertEquals(errorResult, r.errorResult())); var processor = new PyTorchResultProcessor("foo", s -> {}); processor.registerRequest("a", inferenceListener); processor.registerRequest("b", threadSettingsListener); - processor.registerRequest("c", errorListener); + processor.registerRequest("c", ackListener); + processor.registerRequest("d", errorListener); processor.process( mockNativeProcess( List.of( - new PyTorchResult(inferenceResult, null, null), - new PyTorchResult(null, threadSettings, null), - new PyTorchResult(null, null, errorResult) + new PyTorchResult("a", true, inferenceResult, null, null, null), + new PyTorchResult("b", null, null, threadSettings, null, null), + new PyTorchResult("c", null, null, null, ack, null), + new PyTorchResult("d", null, null, null, null, errorResult) ).iterator() ) ); assertTrue(inferenceListener.hasResponse); assertTrue(threadSettingsListener.hasResponse); + assertTrue(ackListener.hasResponse); assertTrue(errorListener.hasResponse); } @@ -86,9 +91,9 @@ public void testPendingRequest() { ); processor.registerRequest("b", calledOnShutdown); - var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false); + var inferenceResult = new PyTorchInferenceResult(null, 1000L); - processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator())); + processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, inferenceResult, null, null, null)).iterator())); assertSame(inferenceResult, resultHolder.get()); assertTrue(calledOnShutdown.hasResponse); } @@ -100,8 +105,8 @@ public void testCancelPendingRequest() { processor.ignoreResponseWithoutNotifying("a"); - var inferenceResult = new PyTorchInferenceResult("a", null, 1000L, false); - processor.process(mockNativeProcess(List.of(new PyTorchResult(inferenceResult, null, null)).iterator())); + var inferenceResult = new PyTorchInferenceResult(null, 1000L); + processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, inferenceResult, null, null, null)).iterator())); } public void testPendingRequestAreCalledAtShutdown() { @@ -146,8 +151,8 @@ public void onFailure(Exception e) { } } - private PyTorchResult wrapInferenceResult(PyTorchInferenceResult result) { - return new PyTorchResult(result, null, null); + private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, PyTorchInferenceResult result) { + return new PyTorchResult(requestId, isCacheHit, result, null, null, null); } public void testsStats() { @@ -161,33 +166,33 @@ public void testsStats() { processor.registerRequest("b", pendingB); processor.registerRequest("c", pendingC); - var a = wrapInferenceResult(new PyTorchInferenceResult("a", null, 1000L, false)); - var b = wrapInferenceResult(new PyTorchInferenceResult("b", null, 900L, false)); - var c = wrapInferenceResult(new PyTorchInferenceResult("c", null, 200L, true)); + var a = wrapInferenceResult("a", false, new PyTorchInferenceResult(null, 1000L)); + var b = wrapInferenceResult("b", false, new PyTorchInferenceResult(null, 900L)); + var c = wrapInferenceResult("c", true, new PyTorchInferenceResult(null, 200L)); processor.processInferenceResult(a); var stats = processor.getResultStats(); - assertThat(stats.errorCount(), comparesEqualTo(0)); + assertThat(stats.errorCount(), equalTo(0)); assertThat(stats.cacheHitCount(), equalTo(0L)); - assertThat(stats.numberOfPendingResults(), comparesEqualTo(2)); - assertThat(stats.timingStats().getCount(), comparesEqualTo(1L)); - assertThat(stats.timingStats().getSum(), comparesEqualTo(1000L)); + assertThat(stats.numberOfPendingResults(), equalTo(2)); + assertThat(stats.timingStats().getCount(), equalTo(1L)); + assertThat(stats.timingStats().getSum(), equalTo(1000L)); processor.processInferenceResult(b); stats = processor.getResultStats(); - assertThat(stats.errorCount(), comparesEqualTo(0)); + assertThat(stats.errorCount(), equalTo(0)); assertThat(stats.cacheHitCount(), equalTo(0L)); - assertThat(stats.numberOfPendingResults(), comparesEqualTo(1)); - assertThat(stats.timingStats().getCount(), comparesEqualTo(2L)); - assertThat(stats.timingStats().getSum(), comparesEqualTo(1900L)); + assertThat(stats.numberOfPendingResults(), equalTo(1)); + assertThat(stats.timingStats().getCount(), equalTo(2L)); + assertThat(stats.timingStats().getSum(), equalTo(1900L)); processor.processInferenceResult(c); stats = processor.getResultStats(); - assertThat(stats.errorCount(), comparesEqualTo(0)); + assertThat(stats.errorCount(), equalTo(0)); assertThat(stats.cacheHitCount(), equalTo(1L)); - assertThat(stats.numberOfPendingResults(), comparesEqualTo(0)); - assertThat(stats.timingStats().getCount(), comparesEqualTo(3L)); - assertThat(stats.timingStats().getSum(), comparesEqualTo(2100L)); + assertThat(stats.numberOfPendingResults(), equalTo(0)); + assertThat(stats.timingStats().getCount(), equalTo(3L)); + assertThat(stats.timingStats().getSum(), equalTo(2100L)); } public void testsTimeDependentStats() { @@ -227,9 +232,9 @@ public void testsTimeDependentStats() { var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier); // 1st period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 200L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 200L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 200L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 200L))); // first call has no results as is in the same period var stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); @@ -243,7 +248,7 @@ public void testsTimeDependentStats() { assertThat(stats.peakThroughput(), equalTo(3L)); // 2nd period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 100L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 100L))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -255,7 +260,7 @@ public void testsTimeDependentStats() { assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); // 4th period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 300L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 300L))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -263,8 +268,8 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9]))); // 7th period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 410L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 390L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 410L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 390L))); stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); assertThat(stats.recentStats().avgInferenceTime(), nullValue()); @@ -275,9 +280,9 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12]))); // 8th period - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 510L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 500L, false))); - processor.processInferenceResult(wrapInferenceResult(new PyTorchInferenceResult("foo", null, 490L, false))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 510L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 500L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 490L))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(3L)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java new file mode 100644 index 0000000000000..b1b83e4d18851 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResultTests.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.pytorch.results; + +import org.elasticsearch.test.AbstractXContentTestCase; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; + +public class AckResultTests extends AbstractXContentTestCase { + + public static AckResult createRandom() { + return new AckResult(randomBoolean()); + } + + @Override + protected AckResult createTestInstance() { + return createRandom(); + } + + @Override + protected AckResult doParseInstance(XContentParser parser) throws IOException { + return AckResult.PARSER.parse(parser, null); + } + + @Override + protected boolean supportsUnknownFields() { + return false; + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java index 3c7dacd84afb4..ac197c898fdc7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ErrorResultTests.java @@ -15,7 +15,7 @@ public class ErrorResultTests extends AbstractXContentTestCase { public static ErrorResult createRandom() { - return new ErrorResult(randomBoolean() ? null : randomAlphaOfLength(5), randomAlphaOfLength(5)); + return new ErrorResult(randomAlphaOfLength(50)); } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java index 005271739dfc4..fe2a5bab07a92 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java @@ -30,8 +30,6 @@ protected PyTorchInferenceResult createTestInstance() { } public static PyTorchInferenceResult createRandom() { - String id = randomAlphaOfLength(6); - int rows = randomIntBetween(1, 10); int columns = randomIntBetween(1, 10); int depth = randomIntBetween(1, 10); @@ -43,6 +41,6 @@ public static PyTorchInferenceResult createRandom() { } } } - return new PyTorchInferenceResult(id, arr, randomLong(), randomBoolean()); + return new PyTorchInferenceResult(arr, randomNonNegativeLong()); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java index 9325dbb5d3ebe..13137505ef5a0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java @@ -16,11 +16,13 @@ public class PyTorchResultTests extends AbstractXContentTestCase @Override protected PyTorchResult createTestInstance() { - int type = randomIntBetween(0, 2); + String requestId = randomAlphaOfLength(5); + int type = randomIntBetween(0, 3); return switch (type) { - case 0 -> new PyTorchResult(PyTorchInferenceResultTests.createRandom(), null, null); - case 1 -> new PyTorchResult(null, ThreadSettingsTests.createRandom(), null); - default -> new PyTorchResult(null, null, ErrorResultTests.createRandom()); + case 0 -> new PyTorchResult(requestId, randomBoolean(), PyTorchInferenceResultTests.createRandom(), null, null, null); + case 1 -> new PyTorchResult(requestId, null, null, ThreadSettingsTests.createRandom(), null, null); + case 2 -> new PyTorchResult(requestId, null, null, null, AckResultTests.createRandom(), null); + default -> new PyTorchResult(requestId, null, null, null, null, ErrorResultTests.createRandom()); }; } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java index 62ec2a4da27f9..ce3b9d9fa07f3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/ThreadSettingsTests.java @@ -15,11 +15,7 @@ public class ThreadSettingsTests extends AbstractXContentTestCase { public static ThreadSettings createRandom() { - return new ThreadSettings( - randomIntBetween(1, Integer.MAX_VALUE), - randomIntBetween(1, Integer.MAX_VALUE), - randomBoolean() ? null : randomAlphaOfLength(5) - ); + return new ThreadSettings(randomIntBetween(1, Integer.MAX_VALUE), randomIntBetween(1, Integer.MAX_VALUE)); } @Override From c7aa1a29d8ed964119f992201708256d0c5c3cbd Mon Sep 17 00:00:00 2001 From: David Roberts Date: Thu, 28 Jul 2022 17:25:51 +0100 Subject: [PATCH 2/7] Move time_ms up a level too --- .../process/PyTorchResultProcessor.java | 22 +++++--- .../results/PyTorchInferenceResult.java | 18 ++----- .../pytorch/results/PyTorchResult.java | 15 ++++-- .../inference/nlp/FillMaskProcessorTests.java | 4 +- .../ml/inference/nlp/NerProcessorTests.java | 10 ++-- .../nlp/QuestionAnsweringProcessorTests.java | 2 +- .../nlp/TextClassificationProcessorTests.java | 4 +- .../process/PyTorchResultProcessorTests.java | 52 ++++++++++--------- .../results/PyTorchInferenceResultTests.java | 2 +- .../pytorch/results/PyTorchResultTests.java | 16 ++++-- 10 files changed, 79 insertions(+), 66 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index 47fb66011cb03..858809d82e482 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -122,7 +122,8 @@ public void process(PyTorchProcess process) { (id, pendingResult) -> pendingResult.listener.onResponse( new PyTorchResult( id, - false, + null, + null, null, null, null, @@ -138,7 +139,7 @@ public void process(PyTorchProcess process) { } finally { pendingResults.forEach( (id, pendingResult) -> pendingResult.listener.onResponse( - new PyTorchResult(id, false, null, null, null, new ErrorResult("inference canceled as process is stopping")) + new PyTorchResult(id, false, null, null, null, null, new ErrorResult("inference canceled as process is stopping")) ) ); pendingResults.clear(); @@ -149,9 +150,14 @@ public void process(PyTorchProcess process) { void processInferenceResult(PyTorchResult result) { PyTorchInferenceResult inferenceResult = result.inferenceResult(); assert inferenceResult != null; + Long timeMs = result.timeMs(); + if (timeMs == null) { + assert false : "time_ms should be set for an inference result"; + timeMs = 0L; + } logger.trace(() -> format("[%s] Parsed inference result with id [%s]", deploymentId, result.requestId())); - processResult(inferenceResult, result.isCacheHit()); + processResult(inferenceResult, timeMs, Boolean.TRUE.equals(result.isCacheHit())); PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { logger.debug(() -> format("[%s] no pending result for inference [%s]", deploymentId, result.requestId())); @@ -236,8 +242,8 @@ public synchronized ResultStats getResultStats() { ); } - private synchronized void processResult(PyTorchInferenceResult result, Boolean isCacheHit) { - timingStats.accept(result.getTimeMs()); + private synchronized void processResult(PyTorchInferenceResult result, long timeMs, boolean isCacheHit) { + timingStats.accept(timeMs); lastResultTimeMs = currentTimeMsSupplier.getAsLong(); if (lastResultTimeMs > currentPeriodEndTimeMs) { @@ -258,15 +264,15 @@ private synchronized void processResult(PyTorchInferenceResult result, Boolean i lastPeriodCacheHitCount = 0; lastPeriodSummaryStats = new LongSummaryStatistics(); - lastPeriodSummaryStats.accept(result.getTimeMs()); + lastPeriodSummaryStats.accept(timeMs); // set to the end of the current bucket currentPeriodEndTimeMs = startTime + Intervals.alignToCeil(lastResultTimeMs - startTime, REPORTING_PERIOD_MS); } else { - lastPeriodSummaryStats.accept(result.getTimeMs()); + lastPeriodSummaryStats.accept(timeMs); } - if (Boolean.TRUE.equals(isCacheHit)) { + if (isCacheHit) { cacheHitCount++; lastPeriodCacheHitCount++; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java index a38dbf720eab9..a1482851fc21d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResult.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Arrays; -import java.util.Objects; /** * All results must have a request_id. @@ -28,11 +27,10 @@ public class PyTorchInferenceResult implements ToXContentObject { private static final ParseField INFERENCE = new ParseField("inference"); - private static final ParseField TIME_MS = new ParseField("time_ms"); public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "pytorch_inference_result", - a -> new PyTorchInferenceResult((double[][][]) a[0], (Long) a[1]) + a -> new PyTorchInferenceResult((double[][][]) a[0]) ); static { @@ -42,7 +40,6 @@ public class PyTorchInferenceResult implements ToXContentObject { INFERENCE, ObjectParser.ValueType.VALUE_ARRAY ); - PARSER.declareLong(ConstructingObjectParser.constructorArg(), TIME_MS); } public static PyTorchInferenceResult fromXContent(XContentParser parser) throws IOException { @@ -50,21 +47,15 @@ public static PyTorchInferenceResult fromXContent(XContentParser parser) throws } private final double[][][] inference; - private final long timeMs; - public PyTorchInferenceResult(@Nullable double[][][] inference, long timeMs) { + public PyTorchInferenceResult(@Nullable double[][][] inference) { this.inference = inference; - this.timeMs = timeMs; } public double[][][] getInferenceResult() { return inference; } - public long getTimeMs() { - return timeMs; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); @@ -79,14 +70,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.endArray(); } - builder.field(TIME_MS.getPreferredName(), timeMs); builder.endObject(); return builder; } @Override public int hashCode() { - return Objects.hash(timeMs, Arrays.deepHashCode(inference)); + return Arrays.deepHashCode(inference); } @Override @@ -95,6 +85,6 @@ public boolean equals(Object other) { if (other == null || getClass() != other.getClass()) return false; PyTorchInferenceResult that = (PyTorchInferenceResult) other; - return Arrays.deepEquals(inference, that.inference) && timeMs == that.timeMs; + return Arrays.deepEquals(inference, that.inference); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java index b9a319b9eba3f..11340d0bf542d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResult.java @@ -21,6 +21,7 @@ public record PyTorchResult( String requestId, Boolean isCacheHit, + Long timeMs, @Nullable PyTorchInferenceResult inferenceResult, @Nullable ThreadSettings threadSettings, @Nullable AckResult ackResult, @@ -29,6 +30,7 @@ public record PyTorchResult( private static final ParseField REQUEST_ID = new ParseField("request_id"); private static final ParseField CACHE_HIT = new ParseField("cache_hit"); + private static final ParseField TIME_MS = new ParseField("time_ms"); private static final ParseField RESULT = new ParseField("result"); private static final ParseField THREAD_SETTINGS = new ParseField("thread_settings"); @@ -39,16 +41,18 @@ public record PyTorchResult( a -> new PyTorchResult( (String) a[0], (Boolean) a[1], - (PyTorchInferenceResult) a[2], - (ThreadSettings) a[3], - (AckResult) a[4], - (ErrorResult) a[5] + (Long) a[2], + (PyTorchInferenceResult) a[3], + (ThreadSettings) a[4], + (AckResult) a[5], + (ErrorResult) a[6] ) ); static { PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID); PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CACHE_HIT); + PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), TIME_MS); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PyTorchInferenceResult.PARSER, RESULT); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ThreadSettings.PARSER, THREAD_SETTINGS); PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), AckResult.PARSER, ACK); @@ -68,6 +72,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (isCacheHit != null) { builder.field(CACHE_HIT.getPreferredName(), isCacheHit); } + if (timeMs != null) { + builder.field(TIME_MS.getPreferredName(), timeMs); + } if (inferenceResult != null) { builder.field(RESULT.getPreferredName(), inferenceResult); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java index b924f8891a8ef..3ea5f8634a0de 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java @@ -66,7 +66,7 @@ public void testProcessResults() { String resultsField = randomAlphaOfLength(10); FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult( tokenization, - new PyTorchInferenceResult(scores, 0L), + new PyTorchInferenceResult(scores), tokenizer, 4, resultsField @@ -93,7 +93,7 @@ public void testProcessResults_GivenMissingTokens() { 0 ); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(new double[][][] { { {} } }, 0L); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(new double[][][] { { {} } }); expectThrows( ElasticsearchStatusException.class, () -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java index 6ee10f2596567..389a4fab802a0 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java @@ -72,7 +72,7 @@ public void testProcessResults_GivenNoTokens() { var e = expectThrows( ElasticsearchStatusException.class, - () -> processor.processResult(tokenization, new PyTorchInferenceResult(null, 0L)) + () -> processor.processResult(tokenization, new PyTorchInferenceResult(null)) ); assertThat(e, instanceOf(ElasticsearchStatusException.class)); } @@ -113,7 +113,7 @@ public void testProcessResultsWithSpecialTokens() { { 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london { 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -141,7 +141,7 @@ public void testProcessResults() { { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in { 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -178,7 +178,7 @@ public void testProcessResults_withIobMap() { { 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in { 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); @@ -211,7 +211,7 @@ public void testProcessResults_withCustomIobMap() { { 0, 0, 0, 0, 5 }, // in { 6, 0, 0, 0, 0 } // london } }; - NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L)); + NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores)); assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)")); assertThat(result.getEntityGroups().size(), equalTo(2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java index 61344f5dec6a5..b53c688918e2c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/QuestionAnsweringProcessorTests.java @@ -87,7 +87,7 @@ public void testProcessor() throws IOException { assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7)); double[][][] scores = { { START_TOKEN_SCORES }, { END_TOKEN_SCORES } }; NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores, 1L); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores); QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult( tokenizationResult, pyTorchResult diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java index dea51cf06be1c..3b48e75846243 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java @@ -32,7 +32,7 @@ public class TextClassificationProcessorTests extends ESTestCase { public void testInvalidResult() { { - PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] {}, 0L); + PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] {}); var e = expectThrows( ElasticsearchStatusException.class, () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) @@ -41,7 +41,7 @@ public void testInvalidResult() { assertThat(e.getMessage(), containsString("Text classification result has no data")); } { - PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0 } } }, 0L); + PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0 } } }); var e = expectThrows( ElasticsearchStatusException.class, () -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java index 976881f6010b8..98da8da4b686a 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java @@ -40,13 +40,15 @@ public void testsThreadSettings() { var settings = new ThreadSettings(1, 1); processor.registerRequest("thread-setting", new AssertingResultListener(r -> assertEquals(settings, r.threadSettings()))); - processor.process(mockNativeProcess(List.of(new PyTorchResult("thread-setting", null, null, settings, null, null)).iterator())); + processor.process( + mockNativeProcess(List.of(new PyTorchResult("thread-setting", null, null, null, settings, null, null)).iterator()) + ); assertEquals(settings, settingsHolder.get()); } public void testResultsProcessing() { - var inferenceResult = new PyTorchInferenceResult(null, 1000L); + var inferenceResult = new PyTorchInferenceResult(null); var threadSettings = new ThreadSettings(1, 1); var ack = new AckResult(true); var errorResult = new ErrorResult("a bad thing has happened"); @@ -65,10 +67,10 @@ public void testResultsProcessing() { processor.process( mockNativeProcess( List.of( - new PyTorchResult("a", true, inferenceResult, null, null, null), - new PyTorchResult("b", null, null, threadSettings, null, null), - new PyTorchResult("c", null, null, null, ack, null), - new PyTorchResult("d", null, null, null, null, errorResult) + new PyTorchResult("a", true, 1000L, inferenceResult, null, null, null), + new PyTorchResult("b", null, null, null, threadSettings, null, null), + new PyTorchResult("c", null, null, null, null, ack, null), + new PyTorchResult("d", null, null, null, null, null, errorResult) ).iterator() ) ); @@ -91,9 +93,9 @@ public void testPendingRequest() { ); processor.registerRequest("b", calledOnShutdown); - var inferenceResult = new PyTorchInferenceResult(null, 1000L); + var inferenceResult = new PyTorchInferenceResult(null); - processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, inferenceResult, null, null, null)).iterator())); + processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator())); assertSame(inferenceResult, resultHolder.get()); assertTrue(calledOnShutdown.hasResponse); } @@ -105,8 +107,8 @@ public void testCancelPendingRequest() { processor.ignoreResponseWithoutNotifying("a"); - var inferenceResult = new PyTorchInferenceResult(null, 1000L); - processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, inferenceResult, null, null, null)).iterator())); + var inferenceResult = new PyTorchInferenceResult(null); + processor.process(mockNativeProcess(List.of(new PyTorchResult("a", false, 1000L, inferenceResult, null, null, null)).iterator())); } public void testPendingRequestAreCalledAtShutdown() { @@ -151,8 +153,8 @@ public void onFailure(Exception e) { } } - private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, PyTorchInferenceResult result) { - return new PyTorchResult(requestId, isCacheHit, result, null, null, null); + private PyTorchResult wrapInferenceResult(String requestId, boolean isCacheHit, long timeMs, PyTorchInferenceResult result) { + return new PyTorchResult(requestId, isCacheHit, timeMs, result, null, null, null); } public void testsStats() { @@ -166,9 +168,9 @@ public void testsStats() { processor.registerRequest("b", pendingB); processor.registerRequest("c", pendingC); - var a = wrapInferenceResult("a", false, new PyTorchInferenceResult(null, 1000L)); - var b = wrapInferenceResult("b", false, new PyTorchInferenceResult(null, 900L)); - var c = wrapInferenceResult("c", true, new PyTorchInferenceResult(null, 200L)); + var a = wrapInferenceResult("a", false, 1000L, new PyTorchInferenceResult(null)); + var b = wrapInferenceResult("b", false, 900L, new PyTorchInferenceResult(null)); + var c = wrapInferenceResult("c", true, 200L, new PyTorchInferenceResult(null)); processor.processInferenceResult(a); var stats = processor.getResultStats(); @@ -232,9 +234,9 @@ public void testsTimeDependentStats() { var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier); // 1st period - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 200L))); - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 200L))); - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 200L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 200L, new PyTorchInferenceResult(null))); // first call has no results as is in the same period var stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); @@ -248,7 +250,7 @@ public void testsTimeDependentStats() { assertThat(stats.peakThroughput(), equalTo(3L)); // 2nd period - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 100L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 100L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -260,7 +262,7 @@ public void testsTimeDependentStats() { assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); // 4th period - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 300L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 300L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -268,8 +270,8 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9]))); // 7th period - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 410L))); - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 390L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 410L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 390L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); assertThat(stats.recentStats().avgInferenceTime(), nullValue()); @@ -280,9 +282,9 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12]))); // 8th period - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 510L))); - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 500L))); - processor.processInferenceResult(wrapInferenceResult("foo", false, new PyTorchInferenceResult(null, 490L))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 510L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 500L, new PyTorchInferenceResult(null))); + processor.processInferenceResult(wrapInferenceResult("foo", false, 490L, new PyTorchInferenceResult(null))); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(3L)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java index fe2a5bab07a92..f7370f25e2e84 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchInferenceResultTests.java @@ -41,6 +41,6 @@ public static PyTorchInferenceResult createRandom() { } } } - return new PyTorchInferenceResult(arr, randomNonNegativeLong()); + return new PyTorchInferenceResult(arr); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java index 13137505ef5a0..9281fbfc54d13 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/results/PyTorchResultTests.java @@ -19,10 +19,18 @@ protected PyTorchResult createTestInstance() { String requestId = randomAlphaOfLength(5); int type = randomIntBetween(0, 3); return switch (type) { - case 0 -> new PyTorchResult(requestId, randomBoolean(), PyTorchInferenceResultTests.createRandom(), null, null, null); - case 1 -> new PyTorchResult(requestId, null, null, ThreadSettingsTests.createRandom(), null, null); - case 2 -> new PyTorchResult(requestId, null, null, null, AckResultTests.createRandom(), null); - default -> new PyTorchResult(requestId, null, null, null, null, ErrorResultTests.createRandom()); + case 0 -> new PyTorchResult( + requestId, + randomBoolean(), + randomNonNegativeLong(), + PyTorchInferenceResultTests.createRandom(), + null, + null, + null + ); + case 1 -> new PyTorchResult(requestId, null, null, null, ThreadSettingsTests.createRandom(), null, null); + case 2 -> new PyTorchResult(requestId, null, null, null, null, AckResultTests.createRandom(), null); + default -> new PyTorchResult(requestId, null, null, null, null, null, ErrorResultTests.createRandom()); }; } From 1bdf3a0365005c8d65f597b847c986716fd52c45 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Tue, 2 Aug 2022 15:16:33 +0100 Subject: [PATCH 3/7] Address review comment --- .../ml/inference/pytorch/process/PyTorchResultProcessor.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index 858809d82e482..89bc976dbb60b 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -196,7 +196,10 @@ void processErrorResult(PyTorchResult result) { ErrorResult errorResult = result.errorResult(); assert errorResult != null; - errorCount++; + // Only one result is processed at any time, but we need to stop this happening part way through another thread getting stats + synchronized (this) { + errorCount++; + } logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, result.requestId())); PendingResult pendingResult = pendingResults.remove(result.requestId()); From 5a5442a08f4b134b3051663ad1c2ea68e6d20664 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Tue, 2 Aug 2022 15:17:02 +0100 Subject: [PATCH 4/7] Fix compilation --- .../xpack/ml/inference/nlp/TextSimilarityProcessorTests.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java index 5601fd6b8baa8..10be6225163b6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextSimilarityProcessorTests.java @@ -51,7 +51,7 @@ public void testProcessor() throws IOException { assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7)); double[][][] scores = { { { 42 } } }; NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(textSimilarityConfig); - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores); TextSimilarityInferenceResults result = (TextSimilarityInferenceResults) resultProcessor.processResult( tokenizationResult, pyTorchResult @@ -74,7 +74,7 @@ public void testResultFunctions() { TextSimilarityProcessor processor = new TextSimilarityProcessor(tokenizer); NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(textSimilarityConfig); double[][][] scores = { { { 42 }, { 12 }, { 100 } } }; - PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, false); + PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores); TextSimilarityInferenceResults result = (TextSimilarityInferenceResults) resultProcessor.processResult( new BertTokenizationResult(List.of(), List.of(), 1), pyTorchResult From bb7fd516ed4435f85b698403e57c9518ea7a8c45 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Wed, 3 Aug 2022 11:07:41 +0100 Subject: [PATCH 5/7] Mute tests that will fail without the corresponding C++ change --- .../org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java | 1 + .../resources/rest-api-spec/test/ml/3rd_party_deployment.yml | 2 ++ .../xpack/restart/MLModelDeploymentFullClusterRestartIT.java | 1 + .../org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java | 1 + 4 files changed, 5 insertions(+) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 40eb8a77913b0..9498b58bb5b22 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -75,6 +75,7 @@ * torch.jit.save(traced_model, "simplemodel.pt") * ## End Python */ +@ESRestTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class PyTorchModelIT extends ESRestTestCase { private static final String BASIC_AUTH_VALUE_SUPER_USER = UsernamePasswordToken.basicAuthHeaderValue( diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index 6d0348b1fba92..bc4a36cef9ddd 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -76,6 +76,8 @@ setup: --- "Test start and stop deployment with cache": - skip: + version: all + reason: "@AwaitsFix https://github.com/elastic/ml-cpp/pull/2376" features: allowed_warnings - do: diff --git a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java index f1c7c04905bea..b0e624b470d0b 100644 --- a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java @@ -31,6 +31,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +@AbstractFullClusterRestartTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class MLModelDeploymentFullClusterRestartIT extends AbstractFullClusterRestartTestCase { // See PyTorchModelIT for how this model was created diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java index 682875ae5a2e5..8109ce0f7d0f3 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java @@ -29,6 +29,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; +@AbstractUpgradeTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class MLModelDeploymentsUpgradeIT extends AbstractUpgradeTestCase { // See PyTorchModelIT for how this model was created From 813bd1ab8359a4985ae42e854a258bf9d2635de8 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Wed, 3 Aug 2022 12:39:12 +0100 Subject: [PATCH 6/7] Fix typo --- .../xpack/ml/inference/pytorch/results/AckResult.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java index 9a9325b065d7c..06f2679c56c2d 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/results/AckResult.java @@ -19,7 +19,7 @@ public record AckResult(boolean acknowledged) implements ToXContentObject { public static final ParseField ACKNOWLEDGED = new ParseField("acknowledged"); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( - "error", + "ack", a -> new AckResult((Boolean) a[0]) ); From 805bb8478f36ea36081a94d47418d01e599c73d9 Mon Sep 17 00:00:00 2001 From: David Roberts Date: Wed, 3 Aug 2022 12:40:43 +0100 Subject: [PATCH 7/7] Revert "Mute tests that will fail without the corresponding C++ change" This reverts commit bb7fd516ed4435f85b698403e57c9518ea7a8c45. --- .../org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java | 1 - .../resources/rest-api-spec/test/ml/3rd_party_deployment.yml | 2 -- .../xpack/restart/MLModelDeploymentFullClusterRestartIT.java | 1 - .../org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java | 1 - 4 files changed, 5 deletions(-) diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java index 9498b58bb5b22..40eb8a77913b0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java @@ -75,7 +75,6 @@ * torch.jit.save(traced_model, "simplemodel.pt") * ## End Python */ -@ESRestTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class PyTorchModelIT extends ESRestTestCase { private static final String BASIC_AUTH_VALUE_SUPER_USER = UsernamePasswordToken.basicAuthHeaderValue( diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml index bc4a36cef9ddd..6d0348b1fba92 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/3rd_party_deployment.yml @@ -76,8 +76,6 @@ setup: --- "Test start and stop deployment with cache": - skip: - version: all - reason: "@AwaitsFix https://github.com/elastic/ml-cpp/pull/2376" features: allowed_warnings - do: diff --git a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java index b0e624b470d0b..f1c7c04905bea 100644 --- a/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java +++ b/x-pack/qa/full-cluster-restart/src/test/java/org/elasticsearch/xpack/restart/MLModelDeploymentFullClusterRestartIT.java @@ -31,7 +31,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -@AbstractFullClusterRestartTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class MLModelDeploymentFullClusterRestartIT extends AbstractFullClusterRestartTestCase { // See PyTorchModelIT for how this model was created diff --git a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java index 8109ce0f7d0f3..682875ae5a2e5 100644 --- a/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java +++ b/x-pack/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/MLModelDeploymentsUpgradeIT.java @@ -29,7 +29,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; -@AbstractUpgradeTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376") public class MLModelDeploymentsUpgradeIT extends AbstractUpgradeTestCase { // See PyTorchModelIT for how this model was created