Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::pin::Pin;
use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest::multipart::Form;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use reqwest_eventsource::{retry::ExponentialBackoff, Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

use crate::{
Expand All @@ -12,6 +12,7 @@ use crate::{
file::Files,
image::Images,
moderation::Moderations,
streaming_backoff::StreamingBackoff,
traits::AsyncTryFrom,
Assistants, Audio, AuditLogs, Batches, Chat, Completions, Embeddings, FineTuning, Invites,
Models, Projects, Responses, Threads, Uploads, Users, VectorStores,
Expand Down Expand Up @@ -414,7 +415,7 @@ impl<C: Config> Client<C> {
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
let mut event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
Expand All @@ -423,6 +424,9 @@ impl<C: Config> Client<C> {
.eventsource()
.unwrap();

let retry_policy: StreamingBackoff = self.backoff.clone().into();
event_source.set_retry_policy(Box::new(retry_policy));

stream(event_source).await
}

Expand All @@ -436,7 +440,7 @@ impl<C: Config> Client<C> {
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = self
let mut event_source = self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
Expand All @@ -445,6 +449,9 @@ impl<C: Config> Client<C> {
.eventsource()
.unwrap();

let retry_policy: StreamingBackoff = self.backoff.clone().into();
event_source.set_retry_policy(Box::new(retry_policy));

stream_mapped_raw_events(event_source, event_mapper).await
}

Expand Down
1 change: 1 addition & 0 deletions async-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ mod projects;
mod responses;
mod runs;
mod steps;
mod streaming_backoff;
mod threads;
pub mod traits;
pub mod types;
Expand Down
64 changes: 64 additions & 0 deletions async-openai/src/streaming_backoff.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::time::Duration;

use reqwest::StatusCode;
use reqwest_eventsource::retry::RetryPolicy;

/// Wraps `backoff::ExponentialBackoff` to provide a custom backoff suitable for
/// reqwest_eventsource
pub struct StreamingBackoff(backoff::ExponentialBackoff);

impl StreamingBackoff {
fn should_retry(&self, error: &reqwest_eventsource::Error) -> bool {
// Errors at the connection level only
if let reqwest_eventsource::Error::Transport(error) = error {
// TODO: We can't inspect the response body as reading consumes it.
// This is problematic because quota exceeded errors are also 429.
return error
.status()
.as_ref()
.is_some_and(StatusCode::is_server_error)
|| error.status() == Some(reqwest::StatusCode::TOO_MANY_REQUESTS);
}

true
}
}

impl From<backoff::ExponentialBackoff> for StreamingBackoff {
fn from(backoff: backoff::ExponentialBackoff) -> Self {
Self(backoff)
}
}

impl RetryPolicy for StreamingBackoff {
fn retry(
&self,
error: &reqwest_eventsource::Error,
last_retry: Option<(usize, Duration)>,
) -> Option<Duration> {
if !self.should_retry(error) {
return None;
};

// Ignoring backoff randomization factor for simplicity
// Basically reimplements the retry policy from eventsource
if let Some((_retry_num, last_duration)) = last_retry {
let duration = last_duration.mul_f64(self.0.multiplier);

if let Some(max_duration) = self.0.max_elapsed_time {
Some(duration.min(max_duration))
} else {
Some(duration)
}
} else {
Some(self.0.initial_interval)
}
}

fn set_reconnection_time(&mut self, duration: Duration) {
self.0.initial_interval = duration;
if let Some(max_elapsed_time) = self.0.max_elapsed_time {
self.0.max_elapsed_time = Some(max_elapsed_time.max(duration))
}
}
}