Skip to content

Commit 149aa9c

Browse files
committed
Implementation of the OpenAI Java SDK
- More tests passing for ChatModel Signed-off-by: Julien Dubois <[email protected]>
1 parent 4000dd9 commit 149aa9c

File tree

6 files changed

+863
-99
lines changed

6 files changed

+863
-99
lines changed

models/spring-ai-openai-official/src/main/java/org/springframework/ai/openaiofficial/OpenAiOfficialChatModel.java

Lines changed: 133 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import com.openai.client.OpenAIClient;
2020
import com.openai.client.OpenAIClientAsync;
2121
import com.openai.core.JsonArray;
22+
import com.openai.core.JsonField;
2223
import com.openai.core.JsonValue;
2324
import com.openai.models.FunctionDefinition;
2425
import com.openai.models.FunctionParameters;
@@ -61,7 +62,6 @@
6162
import org.springframework.util.Assert;
6263
import org.springframework.util.CollectionUtils;
6364
import reactor.core.publisher.Flux;
64-
import reactor.core.publisher.Mono;
6565
import reactor.core.scheduler.Schedulers;
6666

6767
import java.util.ArrayList;
@@ -92,9 +92,9 @@ public class OpenAiOfficialChatModel implements ChatModel {
9292

9393
private final Logger logger = LoggerFactory.getLogger(OpenAiOfficialChatModel.class);
9494

95-
private final OpenAIClient openAiClient;
95+
private OpenAIClient openAiClient;
9696

97-
private final OpenAIClientAsync openAiClientAsync;
97+
private OpenAIClientAsync openAiClientAsync;
9898

9999
private final OpenAiOfficialChatOptions options;
100100

@@ -154,13 +154,15 @@ public OpenAiOfficialChatModel(OpenAIClient openAiClient, OpenAIClientAsync open
154154
this.options.getOrganizationId(), this.options.isAzure(), this.options.isGitHubModels(),
155155
this.options.getModel(), this.options.getTimeout(), this.options.getMaxRetries(),
156156
this.options.getProxy(), this.options.getCustomHeaders()));
157+
157158
this.openAiClientAsync = Objects.requireNonNullElseGet(openAiClientAsync,
158159
() -> setupAsyncClient(this.options.getBaseUrl(), this.options.getApiKey(),
159160
this.options.getCredential(), this.options.getAzureDeploymentName(),
160161
this.options.getAzureOpenAIServiceVersion(), this.options.getOrganizationId(),
161162
this.options.isAzure(), this.options.isGitHubModels(), this.options.getModel(),
162163
this.options.getTimeout(), this.options.getMaxRetries(), this.options.getProxy(),
163164
this.options.getCustomHeaders()));
165+
164166
this.observationRegistry = Objects.requireNonNullElse(observationRegistry, ObservationRegistry.NOOP);
165167
this.toolCallingManager = Objects.requireNonNullElse(toolCallingManager, DEFAULT_TOOL_CALLING_MANAGER);
166168
this.toolExecutionEligibilityPredicate = Objects.requireNonNullElse(toolExecutionEligibilityPredicate,
@@ -173,8 +175,10 @@ public OpenAiOfficialChatOptions getOptions() {
173175

174176
@Override
175177
public ChatResponse call(Prompt prompt) {
176-
// Before moving any further, build the final request Prompt,
177-
// merging runtime and default options.
178+
if (this.openAiClient == null) {
179+
throw new IllegalStateException(
180+
"OpenAI sync client is not configured. Have you set the 'streamUsage' option to false?");
181+
}
178182
Prompt requestPrompt = buildRequestPrompt(prompt);
179183
return this.internalCall(requestPrompt, null);
180184
}
@@ -248,8 +252,10 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
248252

249253
@Override
250254
public Flux<ChatResponse> stream(Prompt prompt) {
251-
// Before moving any further, build the final request Prompt,
252-
// merging runtime and default options.
255+
if (this.openAiClientAsync == null) {
256+
throw new IllegalStateException(
257+
"OpenAI async client is not configured. Streaming is not supported with the current configuration. Have you set the 'streamUsage' option to true?");
258+
}
253259
Prompt requestPrompt = buildRequestPrompt(prompt);
254260
return internalStream(requestPrompt, null);
255261
}
@@ -273,72 +279,68 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
273279

274280
observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
275281

276-
Flux<ChatResponse> chatResponse = Flux.empty();
277-
// Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
278-
// the function call handling logic.
279-
this.openAiClientAsync.chat().completions().createStreaming(request).subscribe(chunk -> {
280-
ChatCompletion chatCompletion = chunkToChatCompletion(chunk);
281-
Mono.just(chatCompletion).map(chatCompletion2 -> {
282+
Flux<ChatResponse> chatResponses = Flux.create(sink -> {
283+
this.openAiClientAsync.chat().completions().createStreaming(request).subscribe(chunk -> {
282284
try {
285+
ChatCompletion chatCompletion = chunkToChatCompletion(chunk);
283286
// If an id is not provided, set to "NO_ID" (for compatible APIs).
284-
chatCompletion2.id();
285-
String id = chatCompletion2.id();
286-
287-
List<Generation> generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off
288-
roleMap.putIfAbsent(id, choice.message()._role().asString().isPresent() ? choice.message()._role().asStringOrThrow() : "");
289-
Map<String, Object> metadata = Map.of(
290-
"id", id,
291-
"role", roleMap.getOrDefault(id, ""),
292-
"index", choice.index(),
293-
"finishReason", choice.finishReason().asString(),
294-
"refusal", choice.message().refusal().isPresent() ? choice.message().refusal() : "",
295-
"annotations", choice.message().annotations().isPresent() ? choice.message().annotations() : List.of());
296-
return buildGeneration(choice, metadata);
297-
}).toList();
298-
299-
Optional<CompletionUsage> usage = chatCompletion2.usage();
300-
Usage currentChatResponseUsage = usage.isPresent()? getDefaultUsage(usage.get()) : new EmptyUsage();
301-
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
302-
previousChatResponse);
303-
return new ChatResponse(generations, from(chatCompletion2, accumulatedUsage));
304-
}
305-
catch (Exception e) {
306-
logger.error("Error processing chat completion", e);
307-
return new ChatResponse(List.of());
308-
}
309-
})
310-
.flux()
311-
.buffer(2, 1)
312-
.map(bufferList -> {
313-
ChatResponse firstResponse = bufferList.get(0);
314-
if (request.streamOptions().isPresent()) {
315-
if (bufferList.size() == 2) {
316-
ChatResponse secondResponse = bufferList.get(1);
317-
if (secondResponse!=null) {
318-
// This is the usage from the final Chat response for a
319-
// given Chat request.
320-
Usage usage = secondResponse.getMetadata().getUsage();
321-
if (!UsageCalculator.isEmpty(usage)) {
322-
// Store the usage from the final response to the
323-
// penultimate response for accumulation.
324-
return new ChatResponse(firstResponse.getResults(),
325-
from(firstResponse.getMetadata(), usage));
326-
}
287+
chatCompletion.id();
288+
String id = chatCompletion.id();
289+
290+
List<Generation> generations = chatCompletion.choices().stream().map(choice -> { // @formatter:off
291+
roleMap.putIfAbsent(id, choice.message()._role().asString().isPresent() ? choice.message()._role().asStringOrThrow() : "");
292+
Map<String, Object> metadata = Map.of(
293+
"id", id,
294+
"role", roleMap.getOrDefault(id, ""),
295+
"index", choice.index(),
296+
"finishReason", choice.finishReason().asString(),
297+
"refusal", choice.message().refusal().isPresent() ? choice.message().refusal() : "",
298+
"annotations", choice.message().annotations().isPresent() ? choice.message().annotations() : List.of());
299+
return buildGeneration(choice, metadata);
300+
}).toList();
301+
302+
Optional<CompletionUsage> usage = chatCompletion.usage();
303+
Usage currentChatResponseUsage = usage.isPresent()? getDefaultUsage(usage.get()) : new EmptyUsage();
304+
Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage,
305+
previousChatResponse);
306+
ChatResponse response = new ChatResponse(generations, from(chatCompletion, accumulatedUsage));
307+
sink.next(response);
308+
}
309+
catch (Exception e) {
310+
logger.error("Error processing chat completion", e);
311+
sink.error(e);
312+
}
313+
}).onCompleteFuture().whenComplete((unused, throwable) -> {
314+
if (throwable != null) {
315+
sink.error(throwable);
316+
} else {
317+
sink.complete();
318+
}
319+
});
320+
})
321+
.buffer(2, 1)
322+
.map(bufferList -> {
323+
ChatResponse firstResponse = (ChatResponse) bufferList.get(0);
324+
if (request.streamOptions().isPresent()) {
325+
if (bufferList.size() == 2) {
326+
ChatResponse secondResponse = (ChatResponse) bufferList.get(1);
327+
if (secondResponse!=null) {
328+
// This is the usage from the final Chat response for a
329+
// given Chat request.
330+
Usage usage = secondResponse.getMetadata().getUsage();
331+
if (!UsageCalculator.isEmpty(usage)) {
332+
// Store the usage from the final response to the
333+
// penultimate response for accumulation.
334+
return new ChatResponse(firstResponse.getResults(),
335+
from(firstResponse.getMetadata(), usage));
327336
}
328337
}
329338
}
330-
return firstResponse;
331-
});
332-
})
333-
.onCompleteFuture()
334-
.whenComplete((unused, error) -> {
335-
if (error != null) {
336-
logger.error(error.getMessage(), error);
337-
throw new RuntimeException(error);
338-
}
339+
}
340+
return firstResponse;
339341
});
340342

341-
Flux<ChatResponse> flux = chatResponse.flatMap(response -> {
343+
Flux<ChatResponse> flux = chatResponses.flatMap(response -> {
342344
assert prompt.getOptions() != null;
343345
if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
344346
// FIXME: bounded elastic needs to be used since tool calling
@@ -432,19 +434,40 @@ private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usa
432434
private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
433435
List<ChatCompletion.Choice> choices = chunk.choices()
434436
.stream()
435-
.map(chunkChoice -> ChatCompletion.Choice.builder()
437+
.map(chunkChoice -> {
438+
ChatCompletion.Choice.Builder choiceBuilder = ChatCompletion.Choice.builder()
436439
.finishReason(ChatCompletion.Choice.FinishReason.of(chunkChoice.finishReason().toString()))
437-
.index(chunkChoice.index())
438-
.message(ChatCompletionMessage.builder().content(chunkChoice.delta().content()).build())
439-
.build())
440+
.index(chunkChoice.index())
441+
.message(ChatCompletionMessage.builder()
442+
.content(chunkChoice.delta().content())
443+
.refusal(chunkChoice.delta().refusal())
444+
.build());
445+
446+
// Handle optional logprobs
447+
if (chunkChoice.logprobs().isPresent()) {
448+
var logprobs = chunkChoice.logprobs().get();
449+
choiceBuilder.logprobs(ChatCompletion.Choice.Logprobs.builder()
450+
.content(logprobs.content())
451+
.refusal(logprobs.refusal())
452+
.build());
453+
} else {
454+
// Provide empty logprobs when not present
455+
choiceBuilder.logprobs(ChatCompletion.Choice.Logprobs.builder()
456+
.content(List.of())
457+
.refusal(List.of())
458+
.build());
459+
}
460+
461+
return choiceBuilder.build();
462+
})
440463
.toList();
441464

442465
return ChatCompletion.builder()
443466
.id(chunk.id())
444467
.choices(choices)
445468
.created(chunk.created())
446469
.model(chunk.model())
447-
.usage(Objects.requireNonNull(chunk.usage().orElse(null)))
470+
.usage(chunk.usage().orElse(CompletionUsage.builder().promptTokens(0).completionTokens(0).totalTokens(0).build()))
448471
.build();
449472
}
450473

@@ -606,7 +629,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
606629
.stream()
607630
.map(toolResponse -> ChatCompletionMessage.builder()
608631
.role(JsonValue.from(MessageType.TOOL))
609-
.content(ChatCompletionMessage.builder().content(toolResponse.responseData()).build().content())
632+
.content(ChatCompletionMessage.builder().content(toolResponse.responseData()).refusal(Optional.ofNullable(message.getMetadata().get("refusal")).map(Object::toString).orElse("")).build().content())
610633
.refusal(JsonValue.from(Optional.ofNullable(message.getMetadata().get("refusal")).map(Object::toString).orElse("")))
611634
.build())
612635
.toList();
@@ -710,7 +733,12 @@ else if (requestOptions.getModel() != null) {
710733
streamOptionsBuilder.includeObfuscation(requestOptions.getStreamOptions().includeObfuscation().get());
711734
}
712735
streamOptionsBuilder.additionalProperties(requestOptions.getStreamOptions()._additionalProperties());
736+
streamOptionsBuilder.includeUsage(requestOptions.getStreamUsage());
713737
builder.streamOptions(streamOptionsBuilder.build());
738+
} else {
739+
builder.streamOptions(ChatCompletionStreamOptions.builder()
740+
.includeUsage(true) // Include usage by default for streaming
741+
.build());
714742
}
715743
}
716744

@@ -752,22 +780,39 @@ else if (mediaContentData instanceof String text) {
752780

753781
private List<ChatCompletionTool> getChatCompletionTools(List<ToolDefinition> toolDefinitions) {
754782
return toolDefinitions.stream()
755-
.map(toolDefinition -> {
756-
FunctionParameters.Builder parametersBuilder = FunctionParameters.builder();
757-
parametersBuilder.putAdditionalProperty("type", JsonValue.from("object"));
758-
if (!toolDefinition.inputSchema().isEmpty()) {
759-
parametersBuilder.putAdditionalProperty("strict", JsonValue.from(true)); // TODO allow to have non-strict schemas
760-
parametersBuilder.putAdditionalProperty("json_schema", JsonValue.from(toolDefinition.inputSchema()));
761-
}
762-
FunctionDefinition functionDefinition = FunctionDefinition.builder()
763-
.name(toolDefinition.name())
764-
.description(toolDefinition.description())
765-
.parameters(parametersBuilder.build())
766-
.build();
767-
768-
return ChatCompletionTool.ofFunction(ChatCompletionFunctionTool.builder().function(functionDefinition).build());
769-
} )
770-
.toList();
783+
.map(toolDefinition -> {
784+
FunctionParameters.Builder parametersBuilder = FunctionParameters.builder();
785+
786+
if (!toolDefinition.inputSchema().isEmpty()) {
787+
// Parse the schema and add its properties directly
788+
try {
789+
com.fasterxml.jackson.databind.ObjectMapper mapper = new com.fasterxml.jackson.databind.ObjectMapper();
790+
@SuppressWarnings("unchecked")
791+
Map<String, Object> schemaMap = mapper.readValue(toolDefinition.inputSchema(), Map.class);
792+
793+
// Add each property from the schema to the parameters
794+
schemaMap.forEach((key, value) ->
795+
parametersBuilder.putAdditionalProperty(key, JsonValue.from(value))
796+
);
797+
798+
// Add strict mode
799+
parametersBuilder.putAdditionalProperty("strict", JsonValue.from(true)); // TODO allow non-strict mode
800+
} catch (Exception e) {
801+
logger.error("Failed to parse tool schema", e);
802+
}
803+
}
804+
805+
FunctionDefinition functionDefinition = FunctionDefinition.builder()
806+
.name(toolDefinition.name())
807+
.description(toolDefinition.description())
808+
.parameters(parametersBuilder.build())
809+
.build();
810+
811+
return ChatCompletionTool.ofFunction(
812+
ChatCompletionFunctionTool.builder().function(functionDefinition).build()
813+
);
814+
})
815+
.toList();
771816
}
772817

773818
@Override

0 commit comments

Comments
 (0)