1919import com .openai .client .OpenAIClient ;
2020import com .openai .client .OpenAIClientAsync ;
2121import com .openai .core .JsonArray ;
22+ import com .openai .core .JsonField ;
2223import com .openai .core .JsonValue ;
2324import com .openai .models .FunctionDefinition ;
2425import com .openai .models .FunctionParameters ;
6162import org .springframework .util .Assert ;
6263import org .springframework .util .CollectionUtils ;
6364import reactor .core .publisher .Flux ;
64- import reactor .core .publisher .Mono ;
6565import reactor .core .scheduler .Schedulers ;
6666
6767import java .util .ArrayList ;
@@ -92,9 +92,9 @@ public class OpenAiOfficialChatModel implements ChatModel {
9292
9393 private final Logger logger = LoggerFactory .getLogger (OpenAiOfficialChatModel .class );
9494
95- private final OpenAIClient openAiClient ;
95+ private OpenAIClient openAiClient ;
9696
97- private final OpenAIClientAsync openAiClientAsync ;
97+ private OpenAIClientAsync openAiClientAsync ;
9898
9999 private final OpenAiOfficialChatOptions options ;
100100
@@ -154,13 +154,15 @@ public OpenAiOfficialChatModel(OpenAIClient openAiClient, OpenAIClientAsync open
154154 this .options .getOrganizationId (), this .options .isAzure (), this .options .isGitHubModels (),
155155 this .options .getModel (), this .options .getTimeout (), this .options .getMaxRetries (),
156156 this .options .getProxy (), this .options .getCustomHeaders ()));
157+
157158 this .openAiClientAsync = Objects .requireNonNullElseGet (openAiClientAsync ,
158159 () -> setupAsyncClient (this .options .getBaseUrl (), this .options .getApiKey (),
159160 this .options .getCredential (), this .options .getAzureDeploymentName (),
160161 this .options .getAzureOpenAIServiceVersion (), this .options .getOrganizationId (),
161162 this .options .isAzure (), this .options .isGitHubModels (), this .options .getModel (),
162163 this .options .getTimeout (), this .options .getMaxRetries (), this .options .getProxy (),
163164 this .options .getCustomHeaders ()));
165+
164166 this .observationRegistry = Objects .requireNonNullElse (observationRegistry , ObservationRegistry .NOOP );
165167 this .toolCallingManager = Objects .requireNonNullElse (toolCallingManager , DEFAULT_TOOL_CALLING_MANAGER );
166168 this .toolExecutionEligibilityPredicate = Objects .requireNonNullElse (toolExecutionEligibilityPredicate ,
@@ -173,8 +175,10 @@ public OpenAiOfficialChatOptions getOptions() {
173175
174176 @ Override
175177 public ChatResponse call (Prompt prompt ) {
176- // Before moving any further, build the final request Prompt,
177- // merging runtime and default options.
178+ if (this .openAiClient == null ) {
179+ throw new IllegalStateException (
180+ "OpenAI sync client is not configured. Have you set the 'streamUsage' option to false?" );
181+ }
178182 Prompt requestPrompt = buildRequestPrompt (prompt );
179183 return this .internalCall (requestPrompt , null );
180184 }
@@ -248,8 +252,10 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
248252
249253 @ Override
250254 public Flux <ChatResponse > stream (Prompt prompt ) {
251- // Before moving any further, build the final request Prompt,
252- // merging runtime and default options.
255+ if (this .openAiClientAsync == null ) {
256+ throw new IllegalStateException (
257+ "OpenAI async client is not configured. Streaming is not supported with the current configuration. Have you set the 'streamUsage' option to true?" );
258+ }
253259 Prompt requestPrompt = buildRequestPrompt (prompt );
254260 return internalStream (requestPrompt , null );
255261 }
@@ -273,72 +279,68 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
273279
274280 observation .parentObservation (contextView .getOrDefault (ObservationThreadLocalAccessor .KEY , null )).start ();
275281
276- Flux <ChatResponse > chatResponse = Flux .empty ();
277- // Convert the ChatCompletionChunk into a ChatCompletion to be able to reuse
278- // the function call handling logic.
279- this .openAiClientAsync .chat ().completions ().createStreaming (request ).subscribe (chunk -> {
280- ChatCompletion chatCompletion = chunkToChatCompletion (chunk );
281- Mono .just (chatCompletion ).map (chatCompletion2 -> {
282+ Flux <ChatResponse > chatResponses = Flux .create (sink -> {
283+ this .openAiClientAsync .chat ().completions ().createStreaming (request ).subscribe (chunk -> {
282284 try {
285+ ChatCompletion chatCompletion = chunkToChatCompletion (chunk );
283286 // If an id is not provided, set to "NO_ID" (for compatible APIs).
284- chatCompletion2 .id ();
285- String id = chatCompletion2 .id ();
286-
287- List <Generation > generations = chatCompletion2 .choices ().stream ().map (choice -> { // @formatter:off
288- roleMap .putIfAbsent (id , choice .message ()._role ().asString ().isPresent () ? choice .message ()._role ().asStringOrThrow () : "" );
289- Map <String , Object > metadata = Map .of (
290- "id" , id ,
291- "role" , roleMap .getOrDefault (id , "" ),
292- "index" , choice .index (),
293- "finishReason" , choice .finishReason ().asString (),
294- "refusal" , choice .message ().refusal ().isPresent () ? choice .message ().refusal () : "" ,
295- "annotations" , choice .message ().annotations ().isPresent () ? choice .message ().annotations () : List .of ());
296- return buildGeneration (choice , metadata );
297- }).toList ();
298-
299- Optional <CompletionUsage > usage = chatCompletion2 .usage ();
300- Usage currentChatResponseUsage = usage .isPresent ()? getDefaultUsage (usage .get ()) : new EmptyUsage ();
301- Usage accumulatedUsage = UsageCalculator .getCumulativeUsage (currentChatResponseUsage ,
302- previousChatResponse );
303- return new ChatResponse (generations , from (chatCompletion2 , accumulatedUsage ));
304- }
305- catch (Exception e ) {
306- logger .error ("Error processing chat completion" , e );
307- return new ChatResponse (List .of ());
308- }
309- })
310- .flux ()
311- .buffer (2 , 1 )
312- .map (bufferList -> {
313- ChatResponse firstResponse = bufferList .get (0 );
314- if (request .streamOptions ().isPresent ()) {
315- if (bufferList .size () == 2 ) {
316- ChatResponse secondResponse = bufferList .get (1 );
317- if (secondResponse !=null ) {
318- // This is the usage from the final Chat response for a
319- // given Chat request.
320- Usage usage = secondResponse .getMetadata ().getUsage ();
321- if (!UsageCalculator .isEmpty (usage )) {
322- // Store the usage from the final response to the
323- // penultimate response for accumulation.
324- return new ChatResponse (firstResponse .getResults (),
325- from (firstResponse .getMetadata (), usage ));
326- }
287+ chatCompletion .id ();
288+ String id = chatCompletion .id ();
289+
290+ List <Generation > generations = chatCompletion .choices ().stream ().map (choice -> { // @formatter:off
291+ roleMap .putIfAbsent (id , choice .message ()._role ().asString ().isPresent () ? choice .message ()._role ().asStringOrThrow () : "" );
292+ Map <String , Object > metadata = Map .of (
293+ "id" , id ,
294+ "role" , roleMap .getOrDefault (id , "" ),
295+ "index" , choice .index (),
296+ "finishReason" , choice .finishReason ().asString (),
297+ "refusal" , choice .message ().refusal ().isPresent () ? choice .message ().refusal () : "" ,
298+ "annotations" , choice .message ().annotations ().isPresent () ? choice .message ().annotations () : List .of ());
299+ return buildGeneration (choice , metadata );
300+ }).toList ();
301+
302+ Optional <CompletionUsage > usage = chatCompletion .usage ();
303+ Usage currentChatResponseUsage = usage .isPresent ()? getDefaultUsage (usage .get ()) : new EmptyUsage ();
304+ Usage accumulatedUsage = UsageCalculator .getCumulativeUsage (currentChatResponseUsage ,
305+ previousChatResponse );
306+ ChatResponse response = new ChatResponse (generations , from (chatCompletion , accumulatedUsage ));
307+ sink .next (response );
308+ }
309+ catch (Exception e ) {
310+ logger .error ("Error processing chat completion" , e );
311+ sink .error (e );
312+ }
313+ }).onCompleteFuture ().whenComplete ((unused , throwable ) -> {
314+ if (throwable != null ) {
315+ sink .error (throwable );
316+ } else {
317+ sink .complete ();
318+ }
319+ });
320+ })
321+ .buffer (2 , 1 )
322+ .map (bufferList -> {
323+ ChatResponse firstResponse = (ChatResponse ) bufferList .get (0 );
324+ if (request .streamOptions ().isPresent ()) {
325+ if (bufferList .size () == 2 ) {
326+ ChatResponse secondResponse = (ChatResponse ) bufferList .get (1 );
327+ if (secondResponse !=null ) {
328+ // This is the usage from the final Chat response for a
329+ // given Chat request.
330+ Usage usage = secondResponse .getMetadata ().getUsage ();
331+ if (!UsageCalculator .isEmpty (usage )) {
332+ // Store the usage from the final response to the
333+ // penultimate response for accumulation.
334+ return new ChatResponse (firstResponse .getResults (),
335+ from (firstResponse .getMetadata (), usage ));
327336 }
328337 }
329338 }
330- return firstResponse ;
331- });
332- })
333- .onCompleteFuture ()
334- .whenComplete ((unused , error ) -> {
335- if (error != null ) {
336- logger .error (error .getMessage (), error );
337- throw new RuntimeException (error );
338- }
339+ }
340+ return firstResponse ;
339341 });
340342
341- Flux <ChatResponse > flux = chatResponse .flatMap (response -> {
343+ Flux <ChatResponse > flux = chatResponses .flatMap (response -> {
342344 assert prompt .getOptions () != null ;
343345 if (this .toolExecutionEligibilityPredicate .isToolExecutionRequired (prompt .getOptions (), response )) {
344346 // FIXME: bounded elastic needs to be used since tool calling
@@ -432,19 +434,40 @@ private ChatResponseMetadata from(ChatResponseMetadata chatResponseMetadata, Usa
432434 private ChatCompletion chunkToChatCompletion (ChatCompletionChunk chunk ) {
433435 List <ChatCompletion .Choice > choices = chunk .choices ()
434436 .stream ()
435- .map (chunkChoice -> ChatCompletion .Choice .builder ()
437+ .map (chunkChoice -> {
438+ ChatCompletion .Choice .Builder choiceBuilder = ChatCompletion .Choice .builder ()
436439 .finishReason (ChatCompletion .Choice .FinishReason .of (chunkChoice .finishReason ().toString ()))
437- .index (chunkChoice .index ())
438- .message (ChatCompletionMessage .builder ().content (chunkChoice .delta ().content ()).build ())
439- .build ())
440+ .index (chunkChoice .index ())
441+ .message (ChatCompletionMessage .builder ()
442+ .content (chunkChoice .delta ().content ())
443+ .refusal (chunkChoice .delta ().refusal ())
444+ .build ());
445+
446+ // Handle optional logprobs
447+ if (chunkChoice .logprobs ().isPresent ()) {
448+ var logprobs = chunkChoice .logprobs ().get ();
449+ choiceBuilder .logprobs (ChatCompletion .Choice .Logprobs .builder ()
450+ .content (logprobs .content ())
451+ .refusal (logprobs .refusal ())
452+ .build ());
453+ } else {
454+ // Provide empty logprobs when not present
455+ choiceBuilder .logprobs (ChatCompletion .Choice .Logprobs .builder ()
456+ .content (List .of ())
457+ .refusal (List .of ())
458+ .build ());
459+ }
460+
461+ return choiceBuilder .build ();
462+ })
440463 .toList ();
441464
442465 return ChatCompletion .builder ()
443466 .id (chunk .id ())
444467 .choices (choices )
445468 .created (chunk .created ())
446469 .model (chunk .model ())
447- .usage (Objects . requireNonNull ( chunk .usage ().orElse (null )))
470+ .usage (chunk .usage ().orElse (CompletionUsage . builder (). promptTokens ( 0 ). completionTokens ( 0 ). totalTokens ( 0 ). build ( )))
448471 .build ();
449472 }
450473
@@ -606,7 +629,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
606629 .stream ()
607630 .map (toolResponse -> ChatCompletionMessage .builder ()
608631 .role (JsonValue .from (MessageType .TOOL ))
609- .content (ChatCompletionMessage .builder ().content (toolResponse .responseData ()).build ().content ())
632+ .content (ChatCompletionMessage .builder ().content (toolResponse .responseData ()).refusal ( Optional . ofNullable ( message . getMetadata (). get ( "refusal" )). map ( Object :: toString ). orElse ( "" )). build ().content ())
610633 .refusal (JsonValue .from (Optional .ofNullable (message .getMetadata ().get ("refusal" )).map (Object ::toString ).orElse ("" )))
611634 .build ())
612635 .toList ();
@@ -710,7 +733,12 @@ else if (requestOptions.getModel() != null) {
710733 streamOptionsBuilder .includeObfuscation (requestOptions .getStreamOptions ().includeObfuscation ().get ());
711734 }
712735 streamOptionsBuilder .additionalProperties (requestOptions .getStreamOptions ()._additionalProperties ());
736+ streamOptionsBuilder .includeUsage (requestOptions .getStreamUsage ());
713737 builder .streamOptions (streamOptionsBuilder .build ());
738+ } else {
739+ builder .streamOptions (ChatCompletionStreamOptions .builder ()
740+ .includeUsage (true ) // Include usage by default for streaming
741+ .build ());
714742 }
715743 }
716744
@@ -752,22 +780,39 @@ else if (mediaContentData instanceof String text) {
752780
753781 private List <ChatCompletionTool > getChatCompletionTools (List <ToolDefinition > toolDefinitions ) {
754782 return toolDefinitions .stream ()
755- .map (toolDefinition -> {
756- FunctionParameters .Builder parametersBuilder = FunctionParameters .builder ();
757- parametersBuilder .putAdditionalProperty ("type" , JsonValue .from ("object" ));
758- if (!toolDefinition .inputSchema ().isEmpty ()) {
759- parametersBuilder .putAdditionalProperty ("strict" , JsonValue .from (true )); // TODO allow to have non-strict schemas
760- parametersBuilder .putAdditionalProperty ("json_schema" , JsonValue .from (toolDefinition .inputSchema ()));
761- }
762- FunctionDefinition functionDefinition = FunctionDefinition .builder ()
763- .name (toolDefinition .name ())
764- .description (toolDefinition .description ())
765- .parameters (parametersBuilder .build ())
766- .build ();
767-
768- return ChatCompletionTool .ofFunction (ChatCompletionFunctionTool .builder ().function (functionDefinition ).build ());
769- } )
770- .toList ();
783+ .map (toolDefinition -> {
784+ FunctionParameters .Builder parametersBuilder = FunctionParameters .builder ();
785+
786+ if (!toolDefinition .inputSchema ().isEmpty ()) {
787+ // Parse the schema and add its properties directly
788+ try {
789+ com .fasterxml .jackson .databind .ObjectMapper mapper = new com .fasterxml .jackson .databind .ObjectMapper ();
790+ @ SuppressWarnings ("unchecked" )
791+ Map <String , Object > schemaMap = mapper .readValue (toolDefinition .inputSchema (), Map .class );
792+
793+ // Add each property from the schema to the parameters
794+ schemaMap .forEach ((key , value ) ->
795+ parametersBuilder .putAdditionalProperty (key , JsonValue .from (value ))
796+ );
797+
798+ // Add strict mode
799+ parametersBuilder .putAdditionalProperty ("strict" , JsonValue .from (true )); // TODO allow non-strict mode
800+ } catch (Exception e ) {
801+ logger .error ("Failed to parse tool schema" , e );
802+ }
803+ }
804+
805+ FunctionDefinition functionDefinition = FunctionDefinition .builder ()
806+ .name (toolDefinition .name ())
807+ .description (toolDefinition .description ())
808+ .parameters (parametersBuilder .build ())
809+ .build ();
810+
811+ return ChatCompletionTool .ofFunction (
812+ ChatCompletionFunctionTool .builder ().function (functionDefinition ).build ()
813+ );
814+ })
815+ .toList ();
771816 }
772817
773818 @ Override
0 commit comments