Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.springframework.ai.retry.autoconfigure;

import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;

import org.slf4j.Logger;
Expand All @@ -30,6 +31,7 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.lang.NonNull;
import org.springframework.retry.RetryCallback;
Expand Down Expand Up @@ -87,6 +89,12 @@ public boolean hasError(@NonNull ClientHttpResponse response) throws IOException
}

@Override
public void handleError(@NonNull URI url, @NonNull HttpMethod method, @NonNull ClientHttpResponse response)
throws IOException {
handleError(response);
}

@SuppressWarnings("removal")
public void handleError(@NonNull ClientHttpResponse response) throws IOException {
if (!response.getStatusCode().isError()) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,20 @@ private AnthropicApi(String baseUrl, String completionsPath, ApiKey anthropicApi
.build();
}

/**
* Create a new client api.
* @param completionsPath path to append to the base URL.
* @param restClient RestClient instance.
* @param webClient WebClient instance.
* @param apiKey Anthropic api Key.
*/
public AnthropicApi(String completionsPath, RestClient restClient, WebClient webClient, ApiKey apiKey) {
this.completionsPath = completionsPath;
this.restClient = restClient;
this.webClient = webClient;
this.apiKey = apiKey;
}

/**
* Creates a model response for the given chat conversation.
* @param chatRequest The chat completion request.
Expand Down Expand Up @@ -176,7 +190,7 @@ public ResponseEntity<ChatCompletionResponse> chatCompletionEntity(ChatCompletio
return this.restClient.post()
.uri(this.completionsPath)
.headers(headers -> {
headers.addAll(additionalHttpHeader);
headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader));
addDefaultHeadersIfMissing(headers);
})
.body(chatRequest)
Expand Down Expand Up @@ -217,7 +231,7 @@ public Flux<ChatCompletionResponse> chatCompletionStream(ChatCompletionRequest c
return this.webClient.post()
.uri(this.completionsPath)
.headers(headers -> {
headers.addAll(additionalHttpHeader);
headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader));
addDefaultHeadersIfMissing(headers);
}) // @formatter:off
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
Expand Down Expand Up @@ -256,7 +270,7 @@ public Flux<ChatCompletionResponse> chatCompletionStream(ChatCompletionRequest c
}

private void addDefaultHeadersIfMissing(HttpHeaders headers) {
if (!headers.containsKey(HEADER_X_API_KEY)) {
if (null == headers.getFirst(HEADER_X_API_KEY)) {
String apiKeyValue = this.apiKey.getValue();
if (StringUtils.hasText(apiKeyValue)) {
headers.add(HEADER_X_API_KEY, apiKeyValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public class DeepSeekApi {

private final WebClient webClient;

private DeepSeekStreamFunctionCallingHelper chunkMerger = new DeepSeekStreamFunctionCallingHelper();
private final DeepSeekStreamFunctionCallingHelper chunkMerger = new DeepSeekStreamFunctionCallingHelper();

/**
* Create a new chat completion api.
Expand All @@ -90,21 +90,39 @@ public DeepSeekApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String>

this.completionsPath = completionsPath;
this.betaPrefixPath = betaPrefixPath;
// @formatter:off

Consumer<HttpHeaders> finalHeaders = h -> {
h.setBearerAuth(apiKey.getValue());
h.setContentType(MediaType.APPLICATION_JSON);
h.addAll(headers);
h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));
};
this.restClient = restClientBuilder.baseUrl(baseUrl)
.defaultHeaders(finalHeaders)
.defaultStatusHandler(responseErrorHandler)
.build();

this.webClient = webClientBuilder
.baseUrl(baseUrl)
.defaultHeaders(finalHeaders)
.build(); // @formatter:on
this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(finalHeaders).build();

}

/**
* Create a new chat completion api.
* @param completionsPath the path to the chat completions endpoint.
* @param betaPrefixPath the prefix path to the beta feature endpoint.
* @param restClient RestClient instance.
* @param webClient WebClient instance.
*/
public DeepSeekApi(String completionsPath, String betaPrefixPath, RestClient restClient, WebClient webClient) {

Assert.hasText(completionsPath, "Completions Path must not be null");
Assert.hasText(betaPrefixPath, "Beta feature path must not be null");
Assert.notNull(restClient, "RestClient must not be null");
Assert.notNull(webClient, "WebClient must not be null");

this.completionsPath = completionsPath;
this.betaPrefixPath = betaPrefixPath;
this.restClient = restClient;
this.webClient = webClient;
}

/**
Expand Down Expand Up @@ -153,7 +171,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat

return this.webClient.post()
.uri(this.getEndpoint(chatRequest))
.headers(headers -> headers.addAll(additionalHttpHeader))
.headers(headers -> headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader)))
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
.retrieve()
.bodyToFlux(String.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private ElevenLabsApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
if (!(apiKey instanceof NoopApiKey)) {
h.set("xi-api-key", apiKey.getValue());
}
h.addAll(headers);
h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));
h.setContentType(MediaType.APPLICATION_JSON);
};

Expand All @@ -82,6 +82,16 @@ private ElevenLabsApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
this.webClient = webClientBuilder.baseUrl(baseUrl).defaultHeaders(jsonContentHeaders).build();
}

/**
* Create a new ElevenLabs API client.
* @param restClient Spring RestClient instance.
* @param webClient Spring WebClient instance.
*/
public ElevenLabsApi(RestClient restClient, WebClient webClient) {
this.restClient = restClient;
this.webClient = webClient;
}

public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ public ElevenLabsVoicesApi(String baseUrl, ApiKey apiKey, MultiValueMap<String,
if (!(apiKey instanceof NoopApiKey)) {
h.set("xi-api-key", apiKey.getValue());
}
h.addAll(headers);
h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));
h.setContentType(MediaType.APPLICATION_JSON);
};

Expand All @@ -73,6 +73,14 @@ public ElevenLabsVoicesApi(String baseUrl, ApiKey apiKey, MultiValueMap<String,

}

/**
* Create a new ElevenLabs Voices API client.
* @param restClient Spring RestClient instance.
*/
public ElevenLabsVoicesApi(RestClient restClient) {
this.restClient = restClient;
}

public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public static Builder builder() {

private final WebClient webClient;

private OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();
private final OpenAiStreamFunctionCallingHelper chunkMerger = new OpenAiStreamFunctionCallingHelper();

/**
* Create a new chat completion api.
Expand Down Expand Up @@ -145,7 +145,7 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> he
Consumer<HttpHeaders> finalHeaders = h -> {
h.setContentType(MediaType.APPLICATION_JSON);
h.set(HTTP_USER_AGENT_HEADER, SPRING_AI_USER_AGENT);
h.addAll(headers);
h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));
};
this.restClient = restClientBuilder.clone()
.baseUrl(baseUrl)
Expand All @@ -159,6 +159,30 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> he
.build(); // @formatter:on
}

/**
* Create a new chat completion api.
* @param baseUrl api base URL.
* @param apiKey OpenAI apiKey.
* @param headers the http headers to use.
* @param completionsPath the path to the chat completions endpoint.
* @param embeddingsPath the path to the embeddings endpoint.
* @param restClient RestClient instance.
* @param webClient WebClient instance.
* @param responseErrorHandler Response error handler.
*/
public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers, String completionsPath,
String embeddingsPath, ResponseErrorHandler responseErrorHandler, RestClient restClient,
WebClient webClient) {
this.baseUrl = baseUrl;
this.apiKey = apiKey;
this.headers = headers;
this.completionsPath = completionsPath;
this.embeddingsPath = embeddingsPath;
this.responseErrorHandler = responseErrorHandler;
this.restClient = restClient;
this.webClient = webClient;
}

/**
* Returns a string containing all text values from the given media content list. Only
* elements of type "text" are processed and concatenated in order.
Expand Down Expand Up @@ -204,7 +228,7 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
return this.restClient.post()
.uri(this.completionsPath)
.headers(headers -> {
headers.addAll(additionalHttpHeader);
headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader));
addDefaultHeadersIfMissing(headers);
})
.body(chatRequest)
Expand Down Expand Up @@ -243,7 +267,7 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
return this.webClient.post()
.uri(this.completionsPath)
.headers(headers -> {
headers.addAll(additionalHttpHeader);
headers.addAll(HttpHeaders.readOnlyHttpHeaders(additionalHttpHeader));
addDefaultHeadersIfMissing(headers);
}) // @formatter:on
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
Expand Down Expand Up @@ -328,7 +352,7 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
}

private void addDefaultHeadersIfMissing(HttpHeaders headers) {
if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
if (null == headers.getFirst(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
headers.setBearerAuth(this.apiKey.getValue());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
RestClient.Builder restClientBuilder, WebClient.Builder webClientBuilder,
ResponseErrorHandler responseErrorHandler) {

Consumer<HttpHeaders> authHeaders = h -> h.addAll(headers);
Consumer<HttpHeaders> authHeaders = h -> h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));

// @formatter:off
this.restClient = restClientBuilder.clone()
Expand All @@ -98,6 +98,16 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
.build(); // @formatter:on
}

/**
* Create a new audio api.
* @param restClient RestClient instance.
* @param webClient WebClient instance.
*/
public OpenAiAudioApi(RestClient restClient, WebClient webClient) {
this.restClient = restClient;
this.webClient = webClient;
}

public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public class OpenAiFileApi {

public OpenAiFileApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
Consumer<HttpHeaders> authHeaders = h -> h.addAll(headers);
Consumer<HttpHeaders> authHeaders = h -> h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));

this.restClient = restClientBuilder.clone()
.baseUrl(baseUrl)
Expand All @@ -65,6 +65,10 @@ public OpenAiFileApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String
.build();
}

public OpenAiFileApi(RestClient restClient) {
this.restClient = restClient;
}

public static Builder builder() {
return new Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
.baseUrl(baseUrl)
.defaultHeaders(h -> {
h.setContentType(MediaType.APPLICATION_JSON);
h.addAll(headers);
h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));
})
.defaultStatusHandler(responseErrorHandler)
.defaultRequest(requestHeadersSpec -> {
Expand All @@ -82,6 +82,16 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
this.imagesPath = imagesPath;
}

/**
* Create a new OpenAI Image API with the provided rest client.
* @param restClient the rest client instance to use.
* @param imagesPath the images path to use.
*/
public OpenAiImageApi(RestClient restClient, String imagesPath) {
this.restClient = restClient;
this.imagesPath = imagesPath;
}

public ResponseEntity<OpenAiImageResponse> createImage(OpenAiImageRequest openAiImageRequest) {
Assert.notNull(openAiImageRequest, "Image request cannot be null.");
Assert.hasLength(openAiImageRequest.prompt(), "Prompt cannot be empty.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.springframework.ai.model.ApiKey;
import org.springframework.ai.model.NoopApiKey;
Expand Down Expand Up @@ -49,12 +47,8 @@ public class OpenAiModerationApi {

public static final String DEFAULT_MODERATION_MODEL = "omni-moderation-latest";

private static final String DEFAULT_BASE_URL = "https://api.openai.com";

private final RestClient restClient;

private final ObjectMapper objectMapper;

/**
* Create a new OpenAI Moderation API with the provided base URL.
* @param baseUrl the base URL for the OpenAI API.
Expand All @@ -64,14 +58,12 @@ public class OpenAiModerationApi {
public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> headers,
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {

this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);

// @formatter:off
this.restClient = restClientBuilder.clone()
.baseUrl(baseUrl)
.defaultHeaders(h -> {
h.setContentType(MediaType.APPLICATION_JSON);
h.addAll(headers);
h.addAll(HttpHeaders.readOnlyHttpHeaders(headers));
})
.defaultStatusHandler(responseErrorHandler)
.defaultRequest(requestHeadersSpec -> {
Expand All @@ -82,6 +74,14 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap<String,
.build(); // @formatter:on
}

/**
* Create a new OpenAI Moderation API with the provided rest client.
* @param restClient the rest client instance to use.
*/
public OpenAiModerationApi(RestClient restClient) {
this.restClient = restClient;
}

public ResponseEntity<OpenAiModerationResponse> createModeration(OpenAiModerationRequest openAiModerationRequest) {
Assert.notNull(openAiModerationRequest, "Moderation request cannot be null.");
Assert.hasLength(openAiModerationRequest.prompt(), "Prompt cannot be empty.");
Expand Down
Loading