Skip to content

Commit 598a83e

Browse files
committed
Implementation of the OpenAI Java SDK
- Added support for authentication with OpenAI, Azure OpenAI and GitHub models. - Added support for complex OpenAI client configuration - Refactored the existing code to use those new methods Signed-off-by: Julien Dubois <[email protected]>
1 parent d01b466 commit 598a83e

15 files changed

+684
-147
lines changed

models/spring-ai-openai-official/pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@
5353
<version>${openai-official.version}</version>
5454
</dependency>
5555

56+
<dependency>
57+
<groupId>com.azure</groupId>
58+
<artifactId>azure-identity</artifactId>
59+
<version>${azure-identity.version}</version>
60+
<optional>true</optional>
61+
</dependency>
62+
5663
<!-- Spring Framework -->
5764
<dependency>
5865
<groupId>org.springframework</groupId>
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
package org.springframework.ai.openaiofficial;
2+
3+
import com.openai.azure.AzureOpenAIServiceVersion;
4+
import com.openai.credential.Credential;
5+
6+
import java.net.Proxy;
7+
import java.time.Duration;
8+
import java.util.Map;
9+
10+
public class AbstractOpenAiOfficialOptions {
11+
12+
/**
13+
* The deployment URL to connect to OpenAI.
14+
*/
15+
private String baseUrl;
16+
17+
/**
18+
* The API key to connect to OpenAI.
19+
*/
20+
private String apiKey;
21+
22+
/**
23+
* Credentials used to connect to Azure OpenAI.
24+
*/
25+
private Credential credential;
26+
27+
/**
28+
* The model name used. When using Azure AI Foundry, this is also used as the default
29+
* deployment name.
30+
*/
31+
private String model;
32+
33+
/**
34+
* The deployment name as defined in Azure AI Foundry. On Azure AI Foundry, the
35+
* default deployment name is the same as the model name. When using OpenAI directly,
36+
* this value isn't used.
37+
*/
38+
private String azureDeploymentName;
39+
40+
/**
41+
* The Azure OpenAI Service version to use when connecting to Azure AI Foundry.
42+
*/
43+
private AzureOpenAIServiceVersion azureOpenAIServiceVersion;
44+
45+
/**
46+
* The organization ID to use when connecting to Azure OpenAI.
47+
*/
48+
private String organizationId;
49+
50+
/**
51+
* Whether Azure OpenAI is detected.
52+
*/
53+
private boolean isAzure;
54+
55+
/**
56+
* Whether GitHub Models is detected.
57+
*/
58+
private boolean isGitHubModels;
59+
60+
/**
61+
* Request timeout for OpenAI client.
62+
*/
63+
private Duration timeout;
64+
65+
/**
66+
* Maximum number of retries for OpenAI client.
67+
*/
68+
private Integer maxRetries;
69+
70+
/**
71+
* Proxy settings for OpenAI client.
72+
*/
73+
private Proxy proxy;
74+
75+
/**
76+
* Custom headers to add to OpenAI client requests.
77+
*/
78+
private Map<String, String> customHeaders;
79+
80+
public String getBaseUrl() {
81+
return baseUrl;
82+
}
83+
84+
public void setBaseUrl(String baseUrl) {
85+
this.baseUrl = baseUrl;
86+
}
87+
88+
public String getApiKey() {
89+
return apiKey;
90+
}
91+
92+
public void setApiKey(String apiKey) {
93+
this.apiKey = apiKey;
94+
}
95+
96+
public Credential getCredential() {
97+
return credential;
98+
}
99+
100+
public void setCredential(Credential credential) {
101+
this.credential = credential;
102+
}
103+
104+
public String getModel() {
105+
return model;
106+
}
107+
108+
public void setModel(String model) {
109+
this.model = model;
110+
}
111+
112+
public String getAzureDeploymentName() {
113+
return azureDeploymentName;
114+
}
115+
116+
public void setAzureDeploymentName(String azureDeploymentName) {
117+
this.azureDeploymentName = azureDeploymentName;
118+
}
119+
120+
/**
121+
* Alias for getAzureDeploymentName()
122+
*/
123+
public String getDeploymentName() {
124+
return azureDeploymentName;
125+
}
126+
127+
/**
128+
* Alias for setAzureDeploymentName()
129+
*/
130+
public void setDeploymentName(String azureDeploymentName) {
131+
this.azureDeploymentName = azureDeploymentName;
132+
}
133+
134+
public AzureOpenAIServiceVersion getAzureOpenAIServiceVersion() {
135+
return azureOpenAIServiceVersion;
136+
}
137+
138+
public void setAzureOpenAIServiceVersion(AzureOpenAIServiceVersion azureOpenAIServiceVersion) {
139+
this.azureOpenAIServiceVersion = azureOpenAIServiceVersion;
140+
}
141+
142+
public String getOrganizationId() {
143+
return organizationId;
144+
}
145+
146+
public void setOrganizationId(String organizationId) {
147+
this.organizationId = organizationId;
148+
}
149+
150+
public boolean isAzure() {
151+
return isAzure;
152+
}
153+
154+
public void setAzure(boolean azure) {
155+
isAzure = azure;
156+
}
157+
158+
public boolean isGitHubModels() {
159+
return isGitHubModels;
160+
}
161+
162+
public void setGitHubModels(boolean gitHubModels) {
163+
isGitHubModels = gitHubModels;
164+
}
165+
166+
public Duration getTimeout() {
167+
return timeout;
168+
}
169+
170+
public void setTimeout(Duration timeout) {
171+
this.timeout = timeout;
172+
}
173+
174+
public Integer getMaxRetries() {
175+
return maxRetries;
176+
}
177+
178+
public void setMaxRetries(Integer maxRetries) {
179+
this.maxRetries = maxRetries;
180+
}
181+
182+
public Proxy getProxy() {
183+
return proxy;
184+
}
185+
186+
public void setProxy(Proxy proxy) {
187+
this.proxy = proxy;
188+
}
189+
190+
public Map<String, String> getCustomHeaders() {
191+
return customHeaders;
192+
}
193+
194+
public void setCustomHeaders(Map<String, String> customHeaders) {
195+
this.customHeaders = customHeaders;
196+
}
197+
198+
}

models/spring-ai-openai-official/src/main/java/org/springframework/ai/openaiofficial/OpenAiOfficialEmbeddingModel.java

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
117
package org.springframework.ai.openaiofficial;
218

319
import com.openai.client.OpenAIClient;
@@ -27,6 +43,8 @@
2743
import java.util.List;
2844
import java.util.Objects;
2945

46+
import static org.springframework.ai.openaiofficial.setup.OpenAiOfficialSetup.setupSyncClient;
47+
3048
/**
3149
* Embedding Model implementation using the OpenAI Java SDK.
3250
*
@@ -42,39 +60,66 @@ public class OpenAiOfficialEmbeddingModel extends AbstractEmbeddingModel {
4260

4361
private final OpenAIClient openAiClient;
4462

45-
private final OpenAiOfficialEmbeddingOptions defaultOptions;
63+
private final OpenAiOfficialEmbeddingOptions options;
4664

4765
private final MetadataMode metadataMode;
4866

4967
private final ObservationRegistry observationRegistry;
5068

5169
private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;
5270

71+
public OpenAiOfficialEmbeddingModel() {
72+
this(null, null, null, null);
73+
}
74+
75+
public OpenAiOfficialEmbeddingModel(OpenAiOfficialEmbeddingOptions options) {
76+
this(null, null, options, null);
77+
}
78+
79+
public OpenAiOfficialEmbeddingModel(MetadataMode metadataMode, OpenAiOfficialEmbeddingOptions options) {
80+
this(null, metadataMode, options, null);
81+
}
82+
83+
public OpenAiOfficialEmbeddingModel(OpenAiOfficialEmbeddingOptions options,
84+
ObservationRegistry observationRegistry) {
85+
this(null, null, options, observationRegistry);
86+
}
87+
88+
public OpenAiOfficialEmbeddingModel(MetadataMode metadataMode, OpenAiOfficialEmbeddingOptions options,
89+
ObservationRegistry observationRegistry) {
90+
this(null, metadataMode, options, observationRegistry);
91+
}
92+
5393
public OpenAiOfficialEmbeddingModel(OpenAIClient openAiClient) {
54-
this(openAiClient, MetadataMode.EMBED);
94+
this(openAiClient, null, null, null);
5595
}
5696

5797
public OpenAiOfficialEmbeddingModel(OpenAIClient openAiClient, MetadataMode metadataMode) {
58-
this(openAiClient, metadataMode, OpenAiOfficialEmbeddingOptions.builder().model(DEFAULT_MODEL_NAME).build());
98+
this(openAiClient, metadataMode, null, null);
5999
}
60100

61101
public OpenAiOfficialEmbeddingModel(OpenAIClient openAiClient, MetadataMode metadataMode,
62102
OpenAiOfficialEmbeddingOptions options) {
63-
this(openAiClient, metadataMode, options, ObservationRegistry.NOOP);
103+
this(openAiClient, metadataMode, options, null);
64104
}
65105

66106
public OpenAiOfficialEmbeddingModel(OpenAIClient openAiClient, MetadataMode metadataMode,
67107
OpenAiOfficialEmbeddingOptions options, ObservationRegistry observationRegistry) {
68108

69-
Assert.notNull(openAiClient, "com.openai.client.OpenAIClient must not be null");
70-
Assert.notNull(metadataMode, "Metadata mode must not be null");
71-
Assert.notNull(options, "Options must not be null");
72-
Assert.notNull(options.getModel(), "Model name must not be null");
73-
Assert.notNull(observationRegistry, "Observation registry must not be null");
74-
this.openAiClient = openAiClient;
75-
this.metadataMode = metadataMode;
76-
this.defaultOptions = options;
77-
this.observationRegistry = observationRegistry;
109+
if (options == null) {
110+
this.options = OpenAiOfficialEmbeddingOptions.builder().model(DEFAULT_MODEL_NAME).build();
111+
}
112+
else {
113+
this.options = options;
114+
}
115+
this.openAiClient = Objects.requireNonNullElseGet(openAiClient,
116+
() -> setupSyncClient(this.options.getBaseUrl(), this.options.getApiKey(), this.options.getCredential(),
117+
this.options.getAzureDeploymentName(), this.options.getAzureOpenAIServiceVersion(),
118+
this.options.getOrganizationId(), this.options.isAzure(), this.options.isGitHubModels(),
119+
this.options.getModel(), this.options.getTimeout(), this.options.getMaxRetries(),
120+
this.options.getProxy(), this.options.getCustomHeaders()));
121+
this.metadataMode = Objects.requireNonNullElse(metadataMode, MetadataMode.EMBED);
122+
this.observationRegistry = Objects.requireNonNullElse(observationRegistry, ObservationRegistry.NOOP);
78123
}
79124

80125
@Override
@@ -91,7 +136,7 @@ public float[] embed(Document document) {
91136
@Override
92137
public EmbeddingResponse call(EmbeddingRequest embeddingRequest) {
93138
OpenAiOfficialEmbeddingOptions options = OpenAiOfficialEmbeddingOptions.builder()
94-
.from(this.defaultOptions)
139+
.from(this.options)
95140
.merge(embeddingRequest.getOptions())
96141
.build();
97142

@@ -150,8 +195,8 @@ private List<Embedding> generateEmbeddingList(List<com.openai.models.embeddings.
150195
return data;
151196
}
152197

153-
public OpenAiOfficialEmbeddingOptions getDefaultOptions() {
154-
return this.defaultOptions;
198+
public OpenAiOfficialEmbeddingOptions getOptions() {
199+
return this.options;
155200
}
156201

157202
/**

0 commit comments

Comments
 (0)