|
17 | 17 | package org.springframework.ai.openaiofficial.setup; |
18 | 18 |
|
19 | 19 | import com.openai.azure.AzureOpenAIServiceVersion; |
20 | | -import com.openai.azure.credential.AzureApiKeyCredential; |
21 | 20 | import com.openai.client.OpenAIClient; |
22 | 21 | import com.openai.client.OpenAIClientAsync; |
23 | 22 | import com.openai.client.okhttp.OpenAIOkHttpClient; |
|
45 | 44 | public class OpenAiOfficialSetup { |
46 | 45 |
|
47 | 46 | static final String OPENAI_URL = "https://api.openai.com/v1"; |
| 47 | + static final String OPENAI_API_KEY = "OPENAI_API_KEY"; |
| 48 | + static final String AZURE_OPENAI_KEY = "AZURE_OPENAI_KEY"; |
48 | 49 | static final String GITHUB_MODELS_URL = "https://models.inference.ai.azure.com"; |
49 | 50 | static final String GITHUB_TOKEN = "GITHUB_TOKEN"; |
50 | 51 | static final String DEFAULT_USER_AGENT = "spring-ai-openai-official"; |
@@ -75,21 +76,23 @@ public static OpenAIClient setupSyncClient(String baseUrl, String apiKey, Creden |
75 | 76 | if (maxRetries == null) { |
76 | 77 | maxRetries = DEFAULT_MAX_RETRIES; |
77 | 78 | } |
78 | | - |
79 | 79 | OpenAIOkHttpClient.Builder builder = OpenAIOkHttpClient.builder(); |
80 | 80 | builder |
81 | 81 | .baseUrl(calculateBaseUrl(baseUrl, modelHost, modelName, azureDeploymentName, azureOpenAiServiceVersion)); |
82 | 82 |
|
83 | | - Credential calculatedCredential = calculateCredential(modelHost, apiKey, credential); |
84 | | - String calculatedApiKey = calculateApiKey(modelHost, apiKey); |
85 | | - if (calculatedCredential == null && calculatedApiKey == null) { |
86 | | - throw new IllegalArgumentException("Either apiKey or credential must be set to authenticate"); |
87 | | - } |
88 | | - else if (calculatedCredential != null) { |
89 | | - builder.credential(calculatedCredential); |
| 83 | + String calculatedApiKey = apiKey != null ? apiKey : detectApiKey(modelHost); |
| 84 | + if (calculatedApiKey != null) { |
| 85 | + builder.apiKey(calculatedApiKey); |
90 | 86 | } |
91 | 87 | else { |
92 | | - builder.apiKey(calculatedApiKey); |
| 88 | + if (credential != null) { |
| 89 | + builder.credential(credential); |
| 90 | + } |
| 91 | + else if (modelHost == ModelHost.AZURE_OPENAI) { |
| 92 | + // If no API key is provided for Azure OpenAI, we try to use passwordless |
| 93 | + // authentication |
| 94 | + builder.credential(azureAuthentication()); |
| 95 | + } |
93 | 96 | } |
94 | 97 | builder.organization(organizationId); |
95 | 98 |
|
@@ -131,21 +134,23 @@ public static OpenAIClientAsync setupAsyncClient(String baseUrl, String apiKey, |
131 | 134 | if (maxRetries == null) { |
132 | 135 | maxRetries = DEFAULT_MAX_RETRIES; |
133 | 136 | } |
134 | | - |
135 | 137 | OpenAIOkHttpClientAsync.Builder builder = OpenAIOkHttpClientAsync.builder(); |
136 | 138 | builder |
137 | 139 | .baseUrl(calculateBaseUrl(baseUrl, modelHost, modelName, azureDeploymentName, azureOpenAiServiceVersion)); |
138 | 140 |
|
139 | | - Credential calculatedCredential = calculateCredential(modelHost, apiKey, credential); |
140 | | - String calculatedApiKey = calculateApiKey(modelHost, apiKey); |
141 | | - if (calculatedCredential == null && calculatedApiKey == null) { |
142 | | - throw new IllegalArgumentException("Either apiKey or credential must be set to authenticate"); |
143 | | - } |
144 | | - else if (calculatedCredential != null) { |
145 | | - builder.credential(calculatedCredential); |
| 141 | + String calculatedApiKey = apiKey != null ? apiKey : detectApiKey(modelHost); |
| 142 | + if (calculatedApiKey != null) { |
| 143 | + builder.apiKey(calculatedApiKey); |
146 | 144 | } |
147 | 145 | else { |
148 | | - builder.apiKey(calculatedApiKey); |
| 146 | + if (credential != null) { |
| 147 | + builder.credential(credential); |
| 148 | + } |
| 149 | + else if (modelHost == ModelHost.AZURE_OPENAI) { |
| 150 | + // If no API key is provided for Azure OpenAI, we try to use passwordless |
| 151 | + // authentication |
| 152 | + builder.credential(azureAuthentication()); |
| 153 | + } |
149 | 154 | } |
150 | 155 | builder.organization(organizationId); |
151 | 156 |
|
@@ -241,42 +246,25 @@ else if (modelHost == ModelHost.AZURE_OPENAI) { |
241 | 246 | } |
242 | 247 | } |
243 | 248 |
|
244 | | - static Credential calculateCredential(ModelHost modelHost, String apiKey, Credential credential) { |
245 | | - if (apiKey != null) { |
246 | | - if (modelHost == ModelHost.AZURE_OPENAI) { |
247 | | - return AzureApiKeyCredential.create(apiKey); |
248 | | - } |
249 | | - } |
250 | | - else if (credential != null) { |
251 | | - return credential; |
| 249 | + static Credential azureAuthentication() { |
| 250 | + try { |
| 251 | + return AzureInternalOpenAiOfficialHelper.getAzureCredential(); |
252 | 252 | } |
253 | | - else if (modelHost == ModelHost.AZURE_OPENAI) { |
254 | | - try { |
255 | | - return AzureInternalOpenAiOfficialHelper.getAzureCredential(); |
256 | | - } |
257 | | - catch (NoClassDefFoundError e) { |
258 | | - throw new IllegalArgumentException("Azure OpenAI was detected, but no credential was provided. " |
259 | | - + "If you want to use passwordless authentication, you need to add the Azure Identity library (groupId=`com.azure`, artifactId=`azure-identity`) to your classpath."); |
260 | | - } |
| 253 | + catch (NoClassDefFoundError e) { |
| 254 | + throw new IllegalArgumentException("Azure OpenAI was detected, but no credential was provided. " |
| 255 | + + "If you want to use passwordless authentication, you need to add the Azure Identity library (groupId=`com.azure`, artifactId=`azure-identity`) to your classpath."); |
261 | 256 | } |
262 | | - return null; |
263 | 257 | } |
264 | 258 |
|
265 | | - static String calculateApiKey(ModelHost modelHost, String apiKey) { |
266 | | - if (apiKey == null) { |
267 | | - var openAiKey = System.getenv("OPENAI_API_KEY"); |
268 | | - if (openAiKey != null) { |
269 | | - apiKey = openAiKey; |
270 | | - logger.debug("OpenAI API Key detected from environment variable OPENAI_API_KEY."); |
271 | | - } |
272 | | - var azureOpenAiKey = System.getenv("AZURE_OPENAI_KEY"); |
273 | | - if (azureOpenAiKey != null) { |
274 | | - apiKey = azureOpenAiKey; |
275 | | - logger.debug("Azure OpenAI Key detected from environment variable AZURE_OPENAI_KEY."); |
276 | | - } |
| 259 | + static String detectApiKey(ModelHost modelHost) { |
| 260 | + if (modelHost == ModelHost.OPENAI && System.getenv(OPENAI_API_KEY) != null) { |
| 261 | + return System.getenv(OPENAI_API_KEY); |
| 262 | + } |
| 263 | + else if (modelHost == ModelHost.AZURE_OPENAI && System.getenv(AZURE_OPENAI_KEY) != null) { |
| 264 | + return System.getenv(AZURE_OPENAI_KEY); |
277 | 265 | } |
278 | | - if (modelHost != ModelHost.AZURE_OPENAI && apiKey != null) { |
279 | | - return apiKey; |
| 266 | + else if (modelHost == ModelHost.AZURE_OPENAI && System.getenv(OPENAI_API_KEY) != null) { |
| 267 | + return System.getenv(OPENAI_API_KEY); |
280 | 268 | } |
281 | 269 | else if (modelHost == ModelHost.GITHUB_MODELS && System.getenv(GITHUB_TOKEN) != null) { |
282 | 270 | return System.getenv(GITHUB_TOKEN); |
|
0 commit comments