Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
* torch.jit.save(traced_model, "simplemodel.pt")
* ## End Python
*/
@ESRestTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/ml-cpp/pull/2376")
public class PyTorchModelIT extends ESRestTestCase {

private static final String BASIC_AUTH_VALUE_SUPER_USER = UsernamePasswordToken.basicAuthHeaderValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.xpack.core.ml.utils.Intervals;
import org.elasticsearch.xpack.ml.inference.pytorch.results.AckResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.ErrorResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchResult;
Expand Down Expand Up @@ -105,10 +106,12 @@ public void process(PyTorchProcess process) {
threadSettingsConsumer.accept(threadSettings);
processThreadSettings(result);
}
if (result.ackResult() != null) {
processAcknowledgement(result);
}
if (result.errorResult() != null) {
processErrorResult(result);
}

}
} catch (Exception e) {
// No need to report error as we're stopping
Expand All @@ -118,10 +121,13 @@ public void process(PyTorchProcess process) {
pendingResults.forEach(
(id, pendingResult) -> pendingResult.listener.onResponse(
new PyTorchResult(
id,
null,
null,
null,
null,
null,
new ErrorResult(
id,
isStopping
? "inference canceled as process is stopping"
: "inference native process died unexpectedly with failure [" + e.getMessage() + "]"
Expand All @@ -133,7 +139,7 @@ public void process(PyTorchProcess process) {
} finally {
pendingResults.forEach(
(id, pendingResult) -> pendingResult.listener.onResponse(
new PyTorchResult(null, null, new ErrorResult(id, "inference canceled as process is stopping"))
new PyTorchResult(id, false, null, null, null, null, new ErrorResult("inference canceled as process is stopping"))
)
);
pendingResults.clear();
Expand All @@ -144,12 +150,17 @@ public void process(PyTorchProcess process) {
void processInferenceResult(PyTorchResult result) {
PyTorchInferenceResult inferenceResult = result.inferenceResult();
assert inferenceResult != null;
Long timeMs = result.timeMs();
if (timeMs == null) {
assert false : "time_ms should be set for an inference result";
timeMs = 0L;
}

logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, inferenceResult.getRequestId()));
processResult(inferenceResult);
PendingResult pendingResult = pendingResults.remove(inferenceResult.getRequestId());
logger.trace(() -> format("[%s] Parsed inference result with id [%s]", deploymentId, result.requestId()));
processResult(inferenceResult, timeMs, Boolean.TRUE.equals(result.isCacheHit()));
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, inferenceResult.getRequestId()));
logger.debug(() -> format("[%s] no pending result for inference [%s]", deploymentId, result.requestId()));
} else {
pendingResult.listener.onResponse(result);
}
Expand All @@ -159,10 +170,23 @@ void processThreadSettings(PyTorchResult result) {
ThreadSettings threadSettings = result.threadSettings();
assert threadSettings != null;

logger.trace(() -> format("[%s] Parsed result with id [%s]", deploymentId, threadSettings.requestId()));
PendingResult pendingResult = pendingResults.remove(threadSettings.requestId());
logger.trace(() -> format("[%s] Parsed thread settings result with id [%s]", deploymentId, result.requestId()));
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for thread settings [%s]", deploymentId, result.requestId()));
} else {
pendingResult.listener.onResponse(result);
}
}

void processAcknowledgement(PyTorchResult result) {
AckResult ack = result.ackResult();
assert ack != null;

logger.trace(() -> format("[%s] Parsed ack result with id [%s]", deploymentId, result.requestId()));
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, threadSettings.requestId()));
logger.debug(() -> format("[%s] no pending result for ack [%s]", deploymentId, result.requestId()));
} else {
pendingResult.listener.onResponse(result);
}
Expand All @@ -172,12 +196,15 @@ void processErrorResult(PyTorchResult result) {
ErrorResult errorResult = result.errorResult();
assert errorResult != null;

errorCount++;
// Only one result is processed at any time, but we need to stop this happening part way through another thread getting stats
synchronized (this) {
errorCount++;
}

logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, errorResult.requestId()));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seeing how processResult is synchronized and processErrorResult isn't, errorCount can be woefully incorrect due to race conditions.

That can be fixed in a different commit. This particular change looks good.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's quite a simple change I addressed it in 1bdf3a0

PendingResult pendingResult = pendingResults.remove(errorResult.requestId());
logger.trace(() -> format("[%s] Parsed error with id [%s]", deploymentId, result.requestId()));
PendingResult pendingResult = pendingResults.remove(result.requestId());
if (pendingResult == null) {
logger.debug(() -> format("[%s] no pending result for [%s]", deploymentId, errorResult.requestId()));
logger.debug(() -> format("[%s] no pending result for error [%s]", deploymentId, result.requestId()));
} else {
pendingResult.listener.onResponse(result);
}
Expand Down Expand Up @@ -218,8 +245,8 @@ public synchronized ResultStats getResultStats() {
);
}

private synchronized void processResult(PyTorchInferenceResult result) {
timingStats.accept(result.getTimeMs());
private synchronized void processResult(PyTorchInferenceResult result, long timeMs, boolean isCacheHit) {
timingStats.accept(timeMs);

lastResultTimeMs = currentTimeMsSupplier.getAsLong();
if (lastResultTimeMs > currentPeriodEndTimeMs) {
Expand All @@ -240,15 +267,15 @@ private synchronized void processResult(PyTorchInferenceResult result) {

lastPeriodCacheHitCount = 0;
lastPeriodSummaryStats = new LongSummaryStatistics();
lastPeriodSummaryStats.accept(result.getTimeMs());
lastPeriodSummaryStats.accept(timeMs);

// set to the end of the current bucket
currentPeriodEndTimeMs = startTime + Intervals.alignToCeil(lastResultTimeMs - startTime, REPORTING_PERIOD_MS);
} else {
lastPeriodSummaryStats.accept(result.getTimeMs());
lastPeriodSummaryStats.accept(timeMs);
}

if (result.isCacheHit()) {
if (isCacheHit) {
cacheHitCount++;
lastPeriodCacheHitCount++;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.inference.pytorch.results;

import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;

public record AckResult(boolean acknowledged) implements ToXContentObject {

public static final ParseField ACKNOWLEDGED = new ParseField("acknowledged");

public static ConstructingObjectParser<AckResult, Void> PARSER = new ConstructingObjectParser<>(
"ack",
a -> new AckResult((Boolean) a[0])
);

static {
PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), ACKNOWLEDGED);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(ACKNOWLEDGED.getPreferredName(), acknowledged);
builder.endObject();
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,22 @@

import java.io.IOException;

public record ErrorResult(String requestId, String error) implements ToXContentObject {
public record ErrorResult(String error) implements ToXContentObject {

public static final ParseField ERROR = new ParseField("error");

public static ConstructingObjectParser<ErrorResult, Void> PARSER = new ConstructingObjectParser<>(
"error",
a -> new ErrorResult((String) a[0], (String) a[1])
a -> new ErrorResult((String) a[0])
);

static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), PyTorchResult.REQUEST_ID);
PARSER.declareString(ConstructingObjectParser.constructorArg(), ERROR);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (requestId != null) {
builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId);
}
builder.field(ERROR.getPreferredName(), error);
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
* All results must have a request_id.
Expand All @@ -28,62 +27,38 @@
public class PyTorchInferenceResult implements ToXContentObject {

private static final ParseField INFERENCE = new ParseField("inference");
private static final ParseField TIME_MS = new ParseField("time_ms");
private static final ParseField CACHE_HIT = new ParseField("cache_hit");

public static final ConstructingObjectParser<PyTorchInferenceResult, Void> PARSER = new ConstructingObjectParser<>(
"pytorch_inference_result",
a -> new PyTorchInferenceResult((String) a[0], (double[][][]) a[1], (Long) a[2], (Boolean) a[3])
a -> new PyTorchInferenceResult((double[][][]) a[0])
);

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), PyTorchResult.REQUEST_ID);
PARSER.declareField(
ConstructingObjectParser.optionalConstructorArg(),
(p, c) -> MlParserUtils.parse3DArrayOfDoubles(INFERENCE.getPreferredName(), p),
INFERENCE,
ObjectParser.ValueType.VALUE_ARRAY
);
PARSER.declareLong(ConstructingObjectParser.constructorArg(), TIME_MS);
PARSER.declareBoolean(ConstructingObjectParser.constructorArg(), CACHE_HIT);
}

public static PyTorchInferenceResult fromXContent(XContentParser parser) throws IOException {
return PARSER.parse(parser, null);
}

private final String requestId;
private final double[][][] inference;
private final long timeMs;
private final boolean cacheHit;

public PyTorchInferenceResult(String requestId, @Nullable double[][][] inference, long timeMs, boolean cacheHit) {
this.requestId = Objects.requireNonNull(requestId);
public PyTorchInferenceResult(@Nullable double[][][] inference) {
this.inference = inference;
this.timeMs = timeMs;
this.cacheHit = cacheHit;
}

public String getRequestId() {
return requestId;
}

public double[][][] getInferenceResult() {
return inference;
}

public long getTimeMs() {
return timeMs;
}

public boolean isCacheHit() {
return cacheHit;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PyTorchResult.REQUEST_ID.getPreferredName(), requestId);
if (inference != null) {
builder.startArray(INFERENCE.getPreferredName());
for (double[][] doubles : inference) {
Expand All @@ -95,15 +70,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
}
builder.endArray();
}
builder.field(TIME_MS.getPreferredName(), timeMs);
builder.field(CACHE_HIT.getPreferredName(), cacheHit);
builder.endObject();
return builder;
}

@Override
public int hashCode() {
return Objects.hash(requestId, timeMs, Arrays.deepHashCode(inference), cacheHit);
return Arrays.deepHashCode(inference);
}

@Override
Expand All @@ -112,9 +85,6 @@ public boolean equals(Object other) {
if (other == null || getClass() != other.getClass()) return false;

PyTorchInferenceResult that = (PyTorchInferenceResult) other;
return Objects.equals(requestId, that.requestId)
&& Arrays.deepEquals(inference, that.inference)
&& timeMs == that.timeMs
&& cacheHit == that.cacheHit;
return Arrays.deepEquals(inference, that.inference);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,43 @@
* The top level object capturing output from the pytorch process.
*/
public record PyTorchResult(
String requestId,
Boolean isCacheHit,
Long timeMs,
@Nullable PyTorchInferenceResult inferenceResult,
@Nullable ThreadSettings threadSettings,
@Nullable AckResult ackResult,
@Nullable ErrorResult errorResult
) implements ToXContentObject {

static final ParseField REQUEST_ID = new ParseField("request_id");
private static final ParseField REQUEST_ID = new ParseField("request_id");
private static final ParseField CACHE_HIT = new ParseField("cache_hit");
private static final ParseField TIME_MS = new ParseField("time_ms");

private static final ParseField RESULT = new ParseField("result");
private static final ParseField THREAD_SETTINGS = new ParseField("thread_settings");
private static final ParseField ACK = new ParseField("ack");

public static ConstructingObjectParser<PyTorchResult, Void> PARSER = new ConstructingObjectParser<>(
"pytorch_result",
a -> new PyTorchResult((PyTorchInferenceResult) a[0], (ThreadSettings) a[1], (ErrorResult) a[2])
a -> new PyTorchResult(
(String) a[0],
(Boolean) a[1],
(Long) a[2],
(PyTorchInferenceResult) a[3],
(ThreadSettings) a[4],
(AckResult) a[5],
(ErrorResult) a[6]
)
);

static {
PARSER.declareString(ConstructingObjectParser.constructorArg(), REQUEST_ID);
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CACHE_HIT);
PARSER.declareLong(ConstructingObjectParser.optionalConstructorArg(), TIME_MS);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), PyTorchInferenceResult.PARSER, RESULT);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ThreadSettings.PARSER, THREAD_SETTINGS);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), AckResult.PARSER, ACK);
PARSER.declareObject(ConstructingObjectParser.optionalConstructorArg(), ErrorResult.PARSER, ErrorResult.ERROR);
}

Expand All @@ -47,12 +66,24 @@ public boolean isError() {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (requestId != null) {
builder.field(REQUEST_ID.getPreferredName(), requestId);
}
if (isCacheHit != null) {
builder.field(CACHE_HIT.getPreferredName(), isCacheHit);
}
if (timeMs != null) {
builder.field(TIME_MS.getPreferredName(), timeMs);
}
if (inferenceResult != null) {
builder.field(RESULT.getPreferredName(), inferenceResult);
}
if (threadSettings != null) {
builder.field(THREAD_SETTINGS.getPreferredName(), threadSettings);
}
if (ackResult != null) {
builder.field(ACK.getPreferredName(), ackResult);
}
if (errorResult != null) {
builder.field(ErrorResult.ERROR.getPreferredName(), errorResult);
}
Expand Down
Loading