Skip to content

Commit 5c2fdbf

Browse files
committed
Implementation of the OpenAI Java SDK
- Improve authentication, mostly with Azure OpenAI Signed-off-by: Julien Dubois <[email protected]>
1 parent b97a535 commit 5c2fdbf

File tree

3 files changed

+38
-90
lines changed

3 files changed

+38
-90
lines changed

models/spring-ai-openai-official/src/main/java/org/springframework/ai/openaiofficial/OpenAiOfficialChatModel.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
import org.springframework.ai.tool.definition.ToolDefinition;
6161
import org.springframework.util.Assert;
6262
import org.springframework.util.CollectionUtils;
63-
import org.springframework.util.StringUtils;
6463
import reactor.core.publisher.Flux;
6564
import reactor.core.publisher.Mono;
6665
import reactor.core.scheduler.Schedulers;

models/spring-ai-openai-official/src/main/java/org/springframework/ai/openaiofficial/setup/OpenAiOfficialSetup.java

Lines changed: 38 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.springframework.ai.openaiofficial.setup;
1818

1919
import com.openai.azure.AzureOpenAIServiceVersion;
20-
import com.openai.azure.credential.AzureApiKeyCredential;
2120
import com.openai.client.OpenAIClient;
2221
import com.openai.client.OpenAIClientAsync;
2322
import com.openai.client.okhttp.OpenAIOkHttpClient;
@@ -45,6 +44,8 @@
4544
public class OpenAiOfficialSetup {
4645

4746
static final String OPENAI_URL = "https://api.openai.com/v1";
47+
static final String OPENAI_API_KEY = "OPENAI_API_KEY";
48+
static final String AZURE_OPENAI_KEY = "AZURE_OPENAI_KEY";
4849
static final String GITHUB_MODELS_URL = "https://models.inference.ai.azure.com";
4950
static final String GITHUB_TOKEN = "GITHUB_TOKEN";
5051
static final String DEFAULT_USER_AGENT = "spring-ai-openai-official";
@@ -75,21 +76,23 @@ public static OpenAIClient setupSyncClient(String baseUrl, String apiKey, Creden
7576
if (maxRetries == null) {
7677
maxRetries = DEFAULT_MAX_RETRIES;
7778
}
78-
7979
OpenAIOkHttpClient.Builder builder = OpenAIOkHttpClient.builder();
8080
builder
8181
.baseUrl(calculateBaseUrl(baseUrl, modelHost, modelName, azureDeploymentName, azureOpenAiServiceVersion));
8282

83-
Credential calculatedCredential = calculateCredential(modelHost, apiKey, credential);
84-
String calculatedApiKey = calculateApiKey(modelHost, apiKey);
85-
if (calculatedCredential == null && calculatedApiKey == null) {
86-
throw new IllegalArgumentException("Either apiKey or credential must be set to authenticate");
87-
}
88-
else if (calculatedCredential != null) {
89-
builder.credential(calculatedCredential);
83+
String calculatedApiKey = apiKey != null ? apiKey : detectApiKey(modelHost);
84+
if (calculatedApiKey != null) {
85+
builder.apiKey(calculatedApiKey);
9086
}
9187
else {
92-
builder.apiKey(calculatedApiKey);
88+
if (credential != null) {
89+
builder.credential(credential);
90+
}
91+
else if (modelHost == ModelHost.AZURE_OPENAI) {
92+
// If no API key is provided for Azure OpenAI, we try to use passwordless
93+
// authentication
94+
builder.credential(azureAuthentication());
95+
}
9396
}
9497
builder.organization(organizationId);
9598

@@ -131,21 +134,23 @@ public static OpenAIClientAsync setupAsyncClient(String baseUrl, String apiKey,
131134
if (maxRetries == null) {
132135
maxRetries = DEFAULT_MAX_RETRIES;
133136
}
134-
135137
OpenAIOkHttpClientAsync.Builder builder = OpenAIOkHttpClientAsync.builder();
136138
builder
137139
.baseUrl(calculateBaseUrl(baseUrl, modelHost, modelName, azureDeploymentName, azureOpenAiServiceVersion));
138140

139-
Credential calculatedCredential = calculateCredential(modelHost, apiKey, credential);
140-
String calculatedApiKey = calculateApiKey(modelHost, apiKey);
141-
if (calculatedCredential == null && calculatedApiKey == null) {
142-
throw new IllegalArgumentException("Either apiKey or credential must be set to authenticate");
143-
}
144-
else if (calculatedCredential != null) {
145-
builder.credential(calculatedCredential);
141+
String calculatedApiKey = apiKey != null ? apiKey : detectApiKey(modelHost);
142+
if (calculatedApiKey != null) {
143+
builder.apiKey(calculatedApiKey);
146144
}
147145
else {
148-
builder.apiKey(calculatedApiKey);
146+
if (credential != null) {
147+
builder.credential(credential);
148+
}
149+
else if (modelHost == ModelHost.AZURE_OPENAI) {
150+
// If no API key is provided for Azure OpenAI, we try to use passwordless
151+
// authentication
152+
builder.credential(azureAuthentication());
153+
}
149154
}
150155
builder.organization(organizationId);
151156

@@ -241,42 +246,25 @@ else if (modelHost == ModelHost.AZURE_OPENAI) {
241246
}
242247
}
243248

244-
static Credential calculateCredential(ModelHost modelHost, String apiKey, Credential credential) {
245-
if (apiKey != null) {
246-
if (modelHost == ModelHost.AZURE_OPENAI) {
247-
return AzureApiKeyCredential.create(apiKey);
248-
}
249-
}
250-
else if (credential != null) {
251-
return credential;
249+
static Credential azureAuthentication() {
250+
try {
251+
return AzureInternalOpenAiOfficialHelper.getAzureCredential();
252252
}
253-
else if (modelHost == ModelHost.AZURE_OPENAI) {
254-
try {
255-
return AzureInternalOpenAiOfficialHelper.getAzureCredential();
256-
}
257-
catch (NoClassDefFoundError e) {
258-
throw new IllegalArgumentException("Azure OpenAI was detected, but no credential was provided. "
259-
+ "If you want to use passwordless authentication, you need to add the Azure Identity library (groupId=`com.azure`, artifactId=`azure-identity`) to your classpath.");
260-
}
253+
catch (NoClassDefFoundError e) {
254+
throw new IllegalArgumentException("Azure OpenAI was detected, but no credential was provided. "
255+
+ "If you want to use passwordless authentication, you need to add the Azure Identity library (groupId=`com.azure`, artifactId=`azure-identity`) to your classpath.");
261256
}
262-
return null;
263257
}
264258

265-
static String calculateApiKey(ModelHost modelHost, String apiKey) {
266-
if (apiKey == null) {
267-
var openAiKey = System.getenv("OPENAI_API_KEY");
268-
if (openAiKey != null) {
269-
apiKey = openAiKey;
270-
logger.debug("OpenAI API Key detected from environment variable OPENAI_API_KEY.");
271-
}
272-
var azureOpenAiKey = System.getenv("AZURE_OPENAI_KEY");
273-
if (azureOpenAiKey != null) {
274-
apiKey = azureOpenAiKey;
275-
logger.debug("Azure OpenAI Key detected from environment variable AZURE_OPENAI_KEY.");
276-
}
259+
static String detectApiKey(ModelHost modelHost) {
260+
if (modelHost == ModelHost.OPENAI && System.getenv(OPENAI_API_KEY) != null) {
261+
return System.getenv(OPENAI_API_KEY);
262+
}
263+
else if (modelHost == ModelHost.AZURE_OPENAI && System.getenv(AZURE_OPENAI_KEY) != null) {
264+
return System.getenv(AZURE_OPENAI_KEY);
277265
}
278-
if (modelHost != ModelHost.AZURE_OPENAI && apiKey != null) {
279-
return apiKey;
266+
else if (modelHost == ModelHost.AZURE_OPENAI && System.getenv(OPENAI_API_KEY) != null) {
267+
return System.getenv(OPENAI_API_KEY);
280268
}
281269
else if (modelHost == ModelHost.GITHUB_MODELS && System.getenv(GITHUB_TOKEN) != null) {
282270
return System.getenv(GITHUB_TOKEN);

models/spring-ai-openai-official/src/test/java/org/springframework/ai/openaiofficial/chat/OpenAiOfficialChatModelIT.java

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,65 +16,26 @@
1616

1717
package org.springframework.ai.openaiofficial.chat;
1818

19-
import org.assertj.core.data.Percentage;
2019
import org.junit.jupiter.api.Test;
2120
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
22-
import org.junit.jupiter.params.ParameterizedTest;
23-
import org.junit.jupiter.params.provider.ValueSource;
2421
import org.slf4j.Logger;
2522
import org.slf4j.LoggerFactory;
26-
import org.springframework.ai.chat.client.ChatClient;
27-
import org.springframework.ai.chat.memory.ChatMemory;
28-
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
29-
import org.springframework.ai.chat.messages.AssistantMessage;
3023
import org.springframework.ai.chat.messages.Message;
31-
import org.springframework.ai.chat.messages.SystemMessage;
3224
import org.springframework.ai.chat.messages.UserMessage;
33-
import org.springframework.ai.chat.metadata.DefaultUsage;
34-
import org.springframework.ai.chat.metadata.EmptyUsage;
35-
import org.springframework.ai.chat.metadata.Usage;
3625
import org.springframework.ai.chat.model.ChatResponse;
37-
import org.springframework.ai.chat.model.Generation;
38-
import org.springframework.ai.chat.prompt.ChatOptions;
3926
import org.springframework.ai.chat.prompt.Prompt;
40-
import org.springframework.ai.chat.prompt.PromptTemplate;
4127
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
42-
import org.springframework.ai.content.Media;
43-
import org.springframework.ai.converter.BeanOutputConverter;
44-
import org.springframework.ai.converter.ListOutputConverter;
45-
import org.springframework.ai.converter.MapOutputConverter;
46-
import org.springframework.ai.model.tool.DefaultToolCallingManager;
47-
import org.springframework.ai.model.tool.ToolCallingChatOptions;
48-
import org.springframework.ai.model.tool.ToolCallingManager;
49-
import org.springframework.ai.model.tool.ToolExecutionResult;
5028
import org.springframework.ai.openaiofficial.OpenAiOfficialChatModel;
5129
import org.springframework.ai.openaiofficial.OpenAiOfficialTestConfiguration;
52-
import org.springframework.ai.support.ToolCallbacks;
53-
import org.springframework.ai.tool.annotation.Tool;
54-
import org.springframework.ai.tool.function.FunctionToolCallback;
5530
import org.springframework.beans.factory.annotation.Autowired;
5631
import org.springframework.beans.factory.annotation.Value;
5732
import org.springframework.boot.test.context.SpringBootTest;
58-
import org.springframework.core.convert.support.DefaultConversionService;
59-
import org.springframework.core.io.ClassPathResource;
6033
import org.springframework.core.io.Resource;
61-
import org.springframework.util.MimeTypeUtils;
62-
import reactor.core.publisher.Flux;
6334

64-
import java.io.IOException;
65-
import java.net.URI;
66-
import java.util.ArrayList;
67-
import java.util.Arrays;
6835
import java.util.List;
6936
import java.util.Map;
70-
import java.util.UUID;
71-
import java.util.concurrent.CountDownLatch;
72-
import java.util.concurrent.TimeUnit;
73-
import java.util.stream.Collectors;
74-
import java.util.stream.IntStream;
7537

7638
import static org.assertj.core.api.Assertions.assertThat;
77-
import static org.assertj.core.api.Assertions.assertThatThrownBy;
7839

7940
/**
8041
* Integration tests for {@link OpenAiOfficialChatModel}.

0 commit comments

Comments
 (0)