- 
                Notifications
    You must be signed in to change notification settings 
- Fork 184
combine json chunks from requests #4317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -17,8 +17,8 @@ | |
| import static org.opensearch.ml.utils.RestActionUtils.isAsync; | ||
| import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; | ||
|  | ||
| import java.io.ByteArrayOutputStream; | ||
| import java.io.IOException; | ||
| import java.io.UncheckedIOException; | ||
| import java.nio.ByteBuffer; | ||
| import java.util.LinkedHashMap; | ||
| import java.util.List; | ||
|  | @@ -48,7 +48,6 @@ | |
| import org.opensearch.ml.common.MLModel; | ||
| import org.opensearch.ml.common.agent.MLAgent; | ||
| import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
| import org.opensearch.ml.common.exception.MLException; | ||
| import org.opensearch.ml.common.input.Input; | ||
| import org.opensearch.ml.common.input.MLInput; | ||
| import org.opensearch.ml.common.input.execute.agent.AgentMLInput; | ||
|  | @@ -158,10 +157,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client | |
| ); | ||
| channel.prepareResponse(RestStatus.OK, headers); | ||
|  | ||
| Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> { | ||
| final CompletableFuture<HttpChunk> future = new CompletableFuture<>(); | ||
| Flux.from(channel).ofType(HttpChunk.class).collectList().flatMap(chunks -> { | ||
| try { | ||
| MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content()); | ||
| BytesReference completeContent = combineChunks(chunks); | ||
| MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, completeContent); | ||
|  | ||
| final CompletableFuture<HttpChunk> future = new CompletableFuture<>(); | ||
| StreamTransportResponseHandler<MLTaskResponse> handler = new StreamTransportResponseHandler<MLTaskResponse>() { | ||
| @Override | ||
| public void handleStreamResponse(StreamTransportResponse<MLTaskResponse> streamResponse) { | ||
|  | @@ -214,19 +215,23 @@ public MLTaskResponse read(StreamInput in) throws IOException { | |
| handler | ||
| ); | ||
|  | ||
| } catch (IOException e) { | ||
| throw new MLException("Got an exception in flux.", e); | ||
| return Mono.fromCompletionStage(future); | ||
| } catch (Exception e) { | ||
| log.error("Failed to parse or process request", e); | ||
| return Mono.error(e); | ||
| } | ||
|  | ||
| return Mono.fromCompletionStage(future); | ||
| }).doOnNext(channel::sendChunk).onErrorComplete(ex -> { | ||
| // Error handling | ||
| }).doOnNext(channel::sendChunk).onErrorResume(ex -> { | ||
| log.error("Error occurred", ex); | ||
| try { | ||
| channel.sendResponse(new BytesRestResponse(channel, (Exception) ex)); | ||
| return true; | ||
| } catch (final IOException e) { | ||
| throw new UncheckedIOException(e); | ||
| String errorMessage = ex instanceof IOException | ||
| ? "Failed to parse request: " + ex.getMessage() | ||
| : "Error processing request: " + ex.getMessage(); | ||
| HttpChunk errorChunk = createHttpChunk("data: {\"error\": \"" + errorMessage.replace("\"", "\\\"") + "\"}\n\n", true); | ||
| channel.sendChunk(errorChunk); | ||
| } catch (Exception e) { | ||
| log.error("Failed to send error chunk", e); | ||
| } | ||
| return Mono.empty(); | ||
| }).subscribe(); | ||
| }; | ||
|  | ||
|  | @@ -402,6 +407,20 @@ private String extractTensorResult(MLTaskResponse response, String tensorName) { | |
| return Map.of(); | ||
| } | ||
|  | ||
| @VisibleForTesting | ||
| BytesReference combineChunks(List<HttpChunk> chunks) { | ||
| try { | ||
| ByteArrayOutputStream buffer = new ByteArrayOutputStream(); | ||
| for (HttpChunk chunk : chunks) { | ||
| chunk.content().writeTo(buffer); | ||
| } | ||
| return BytesReference.fromByteBuffer(ByteBuffer.wrap(buffer.toByteArray())); | ||
| } catch (IOException e) { | ||
| log.error("Failed to combine chunks", e); | ||
| throw new OpenSearchStatusException("Failed to combine request chunks", RestStatus.INTERNAL_SERVER_ERROR, e); | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We spoke about not using this OpenSearchStatusException as it will throw 500 errors. Have you tested on your end? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can define HTTP error code in OpenSearchStatusException. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this is Server-side I/O failure while handling a request. If that's the case, then 500 should be the right status code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think 500 makes sense here as the chunks are failing to be combined which is I/O failure. However when testing with invalid JSON once all the chunks are combined this 400 doesn't seem to be getting through. Will enhance error handling on this part instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was concerning if the chunks were malformed or the content is corrupted, then it can be a 400 bad request. But if the chunk is validated in earlier phase and during combine, if it's network issues or jvm issues, it makes sense to be 500 error | ||
| } | ||
| } | ||
|  | ||
| private HttpChunk createHttpChunk(String sseData, boolean isLast) { | ||
| BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes())); | ||
| return new HttpChunk() { | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this possible to add some unit tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not possible to create UT for prepareRequest as chunk collection happens async. However, I did try sanity test with chunked and non-chunked request, see my latest comment.