Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,10 @@ public void process(PyTorchProcess process) {
pendingResults.forEach(
(id, pendingResult) -> pendingResult.listener.onResponse(
new PyTorchResult(
id,
null,
null,
new ErrorResult(
id,
isStopping
? "inference canceled as process is stopping"
: "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
Expand All @@ -132,7 +132,7 @@ public void process(PyTorchProcess process) {
} finally {
pendingResults.forEach(
(id, pendingResult) -> pendingResult.listener.onResponse(
new PyTorchResult(null, null, new ErrorResult(id, "inference canceled as process is stopping"))
new PyTorchResult(id, null, null, new ErrorResult("inference canceled as process is stopping"))
)
);
pendingResults.clear();
Expand All @@ -144,11 +144,11 @@ void processInferenceResult(PyTorchResult result) {
PyTorchInferenceResult inferenceResult = result.inferenceResult();
assert inferenceResult != null;

logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, inferenceResult.getRequestId()));
logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, result.requestId()));
processResult(inferenceResult);
PendingResult pendingResult = pendingResults.remove(inferenceResult.getRequestId());
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, inferenceResult.getRequestId()));
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, result.requestId()));
} else {
pendingResult.listener.onResponse(result);
}
Expand All @@ -158,10 +158,10 @@ void processThreadSettings(PyTorchResult result) {
ThreadSettings threadSettings = result.threadSettings();
assert threadSettings != null;

logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, threadSettings.requestId()));
PendingResult pendingResult = pendingResults.remove(threadSettings.requestId());
logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, result.requestId()));
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, threadSettings.requestId()));
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, result.requestId()));
} else {
pendingResult.listener.onResponse(result);
}
Expand All @@ -173,10 +173,10 @@ void processErrorResult(PyTorchResult result) {

errorCount++;

logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, errorResult.requestId()));
PendingResult pendingResult = pendingResults.remove(errorResult.requestId());
logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, result.requestId()));
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, errorResult.requestId()));
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, result.requestId()));
} else {
pendingResult.listener.onResponse(result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public record ErrorResult(String requestId, String error) implements ToXContentObject {
public class ErrorResult implements ToXContentObject {

public static final ParseField ERROR = new ParseField("error");

Expand All @@ -28,6 +30,27 @@ public record ErrorResult(String requestId, String error) implements ToXContentO
PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR);
}

private final String requestId;
private final String error;

ErrorResult(String requestId, String error) {
this.requestId = requestId;
this.error = error;
}

public ErrorResult(String error) {
this.requestId = null;
this.error = error;
}

public String error() {
return error;
}

Optional<String> requestId() {
return Optional.ofNullable(requestId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -38,4 +61,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ErrorResult that = (ErrorResult) o;
return Objects.equals(requestId, that.requestId) && Objects.equals(error, that.error);
}

@Override
public int hashCode() {
return Objects.hash(requestId, error);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class PyTorchInferenceResult implements ToXContentObject {
);

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), PyTorchResult.REQUEST_ID);
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID);
PARSER.declareField(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> MlParserUtils.parse3DArrayOfDoubles(INFERENCE.getPreferredName(), p),
Expand All @@ -57,14 +57,26 @@ public static PyTorchInferenceResult fromXContent(XContentParser parser) throws
private final Long timeMs;
private final boolean cacheHit;

public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, @Nullable Long timeMs, @Nullable Boolean cacheHit) {
this.requestId = Objects.requireNonNull(requestId);
PyTorchInferenceResult(
@Nullable String requestId,
@Nullable double[][][] inference,
@Nullable Long timeMs,
@Nullable Boolean cacheHit
) {
this.requestId = requestId;
this.inference = inference;
this.timeMs = timeMs;
this.cacheHit = cacheHit != null && cacheHit;
}

public String getRequestId() {
public PyTorchInferenceResult(@Nullable double[][][] inference, @Nullable Long timeMs, @Nullable Boolean cacheHit) {
this.requestId = null;
this.inference = inference;
this.timeMs = timeMs;
this.cacheHit = cacheHit != null && cacheHit;
}

String getRequestId() {
return requestId;
}

Expand All @@ -83,7 +95,9 @@ public boolean isCacheHit() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId);
if (requestId != null) {
builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId);
}
if (inference != null) {
builder.startArray(INFERENCE.getPreferredName());
for (double[][] doubles : inference) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* The top level object capturing output from the pytorch process.
*/
public record PyTorchResult(
@Nullable String requestId,
@Nullable PyTorchInferenceResult inferenceResult,
@Nullable ThreadSettings threadSettings,
@Nullable ErrorResult errorResult
Expand All @@ -29,12 +30,27 @@ public record PyTorchResult(
private static final ParseField RESULT = new ParseField("result");
private static final ParseField THREAD_SETTINGS = new ParseField("thread_settings");

public static ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>(
"pytorch_result",
a -> new PyTorchResult((PyTorchInferenceResult) a[0], (ThreadSettings) a[1], (ErrorResult) a[2])
);
public static ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>("pytorch_result", a -> {
String outerId = (String) a[0];
PyTorchInferenceResult inferenceResult = (PyTorchInferenceResult) a[1];
ThreadSettings threadSettings = (ThreadSettings) a[2];
ErrorResult errorResult = (ErrorResult) a[3];
if (outerId == null) {
if (inferenceResult != null) {
outerId = inferenceResult.getRequestId();
}
if (threadSettings != null) {
outerId = threadSettings.requestId().orElse(null);
}
if (errorResult != null) {
outerId = errorResult.requestId().orElse(null);
}
}
return new PyTorchResult(outerId, inferenceResult, threadSettings, errorResult);
});

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), REQUEST_ID);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PyTorchInferenceResult.PARSER, RESULT);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ThreadSettings.PARSER, THREAD_SETTINGS);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ErrorResult.PARSER, ErrorResult.ERROR);
Expand All @@ -56,6 +72,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (errorResult != null) {
builder.field(ErrorResult.ERROR.getPreferredName(), errorResult);
}
if (requestId != null) {
builder.field(REQUEST_ID.getPreferredName(), requestId);
}

builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public record ThreadSettings(int numThreadsPerAllocation, int numAllocations, String requestId) implements ToXContentObject {
public class ThreadSettings implements ToXContentObject {

private static final ParseField NUM_ALLOCATIONS = new ParseField("num_allocations");
private static final ParseField NUM_THREADS_PER_ALLOCATION = new ParseField("num_threads_per_allocation");
Expand All @@ -30,6 +32,34 @@ public record ThreadSettings(int numThreadsPerAllocation, int numAllocations, St
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID);
}

private final int numThreadsPerAllocation;
private final int numAllocations;
private final String requestId;

ThreadSettings(int numThreadsPerAllocation, int numAllocations, String requestId) {
this.numThreadsPerAllocation = numThreadsPerAllocation;
this.numAllocations = numAllocations;
this.requestId = requestId;
}

public ThreadSettings(int numThreadsPerAllocation, int numAllocations) {
this.numThreadsPerAllocation = numThreadsPerAllocation;
this.numAllocations = numAllocations;
this.requestId = null;
}

public int numThreadsPerAllocation() {
return numThreadsPerAllocation;
}

public int numAllocations() {
return numAllocations;
}

Optional<String> requestId() {
return Optional.ofNullable(requestId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
Expand All @@ -41,4 +71,19 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.endObject();
return builder;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ThreadSettings that = (ThreadSettings) o;
return numThreadsPerAllocation == that.numThreadsPerAllocation
&& numAllocations == that.numAllocations
&& Objects.equals(requestId, that.requestId);
}

@Override
public int hashCode() {
return Objects.hash(numThreadsPerAllocation, numAllocations, requestId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void testProcessResults() {
String resultsField = randomAlphaOfLength(10);
FillMaskResults result = (FillMaskResults) FillMaskProcessor.processResult(
tokenization,
new PyTorchInferenceResult("1", scores, 0L, null),
new PyTorchInferenceResult(scores, 0L, null),
tokenizer,
4,
resultsField
Expand All @@ -93,7 +93,7 @@ public void testProcessResults_GivenMissingTokens() {
0
);

PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null);
PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(new double[][][] { { {} } }, 0L, null);
expectThrows(
ElasticsearchStatusException.class,
() -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public void testProcessResults_GivenNoTokens() {

var e = expectThrows(
ElasticsearchStatusException.class,
() -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null))
() -> processor.processResult(tokenization, new PyTorchInferenceResult(null, 0L, null))
);
assertThat(e, instanceOf(ElasticsearchStatusException.class));
}
Expand Down Expand Up @@ -113,7 +113,7 @@ public void testProcessResultsWithSpecialTokens() {
{ 0, 0, 0, 0, 0, 0, 0, 6, 0 }, // london
{ 7, 0, 0, 0, 0, 0, 0, 0, 0 } // sep
} };
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L, null));

assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
assertThat(result.getEntityGroups().size(), equalTo(2));
Expand Down Expand Up @@ -141,7 +141,7 @@ public void testProcessResults() {
{ 0, 0, 0, 0, 0, 0, 0, 0, 0 }, // in
{ 0, 0, 0, 0, 0, 0, 0, 6, 0 } // london
} };
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L, null));

assertThat(result.getAnnotatedResult(), equalTo("Many use [Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
assertThat(result.getEntityGroups().size(), equalTo(2));
Expand Down Expand Up @@ -178,7 +178,7 @@ public void testProcessResults_withIobMap() {
{ 0, 0, 0, 0, 0, 0, 0, 0, 5 }, // in
{ 6, 0, 0, 0, 0, 0, 0, 0, 0 } // london
} };
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L, null));

assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](ORG&Elasticsearch) in [London](LOC&London)"));
assertThat(result.getEntityGroups().size(), equalTo(2));
Expand Down Expand Up @@ -211,7 +211,7 @@ public void testProcessResults_withCustomIobMap() {
{ 0, 0, 0, 0, 5 }, // in
{ 6, 0, 0, 0, 0 } // london
} };
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult("1", scores, 1L, null));
NerResults result = (NerResults) processor.processResult(tokenization, new PyTorchInferenceResult(scores, 1L, null));

assertThat(result.getAnnotatedResult(), equalTo("[Elasticsearch](SOFTWARE&Elasticsearch) in [London](LOC&London)"));
assertThat(result.getEntityGroups().size(), equalTo(2));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public void testProcessor() throws IOException {
assertThat(tokenizationResult.getTokenization(0).seqPairOffset(), equalTo(7));
double[][][] scores = { { START_TOKEN_SCORES }, { END_TOKEN_SCORES } };
NlpTask.ResultProcessor resultProcessor = processor.getResultProcessor(config);
PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", scores, 1L, null);
PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult(scores, 1L, null);
QuestionAnsweringInferenceResults result = (QuestionAnsweringInferenceResults) resultProcessor.processResult(
tokenizationResult,
pyTorchResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class TextClassificationProcessorTests extends ESTestCase {

public void testInvalidResult() {
{
PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, null);
PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] {}, 0L, null);
var e = expectThrows(
ElasticsearchStatusException.class,
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
Expand All @@ -41,7 +41,7 @@ public void testInvalidResult() {
assertThat(e.getMessage(), containsString("Text classification result has no data"));
}
{
PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
PyTorchInferenceResult torchResult = new PyTorchInferenceResult(new double[][][] { { { 1.0 } } }, 0L, null);
var e = expectThrows(
ElasticsearchStatusException.class,
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
Expand Down
Loading