Skip to content

Commit 88823a5

Browse files
committed
Introduce header provider trait
1 parent 63387fa commit 88823a5

File tree

5 files changed

+83
-2
lines changed

5 files changed

+83
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/target
22
/src/proto/
3+
/Cargo.lock

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ prost = "0.11.6"
1616
reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] }
1717
tokio = { version = "1", default-features = false, features = ["time"] }
1818
rand = "0.8.5"
19+
async-trait = "0.1.77"
1920

2021
[target.'cfg(genproto)'.build-dependencies]
2122
prost-build = { version = "0.11.3" }

src/client.rs

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
use prost::Message;
2-
use reqwest;
32
use reqwest::header::CONTENT_TYPE;
43
use reqwest::Client;
4+
use std::collections::HashMap;
55
use std::default::Default;
6+
use std::sync::Arc;
67

78
use crate::error::VssError;
9+
use crate::headers::get_headermap;
10+
use crate::headers::FixedHeaders;
11+
use crate::headers::VssHeaderProvider;
812
use crate::types::{
913
DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, ListKeyVersionsRequest,
1014
ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
@@ -23,6 +27,7 @@ where
2327
base_url: String,
2428
client: Client,
2529
retry_policy: R,
30+
header_provider: Arc<dyn VssHeaderProvider>,
2631
}
2732

2833
impl<R: RetryPolicy<E = VssError>> VssClient<R> {
@@ -34,7 +39,19 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
3439

3540
/// Constructs a [`VssClient`] from a given [`reqwest::Client`], using `base_url` as the VSS server endpoint.
3641
pub fn from_client(base_url: &str, client: Client, retry_policy: R) -> Self {
37-
Self { base_url: String::from(base_url), client, retry_policy }
42+
Self {
43+
base_url: String::from(base_url),
44+
client,
45+
retry_policy,
46+
header_provider: Arc::new(FixedHeaders::new(HashMap::new())),
47+
}
48+
}
49+
50+
/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
51+
/// HTTP headers will be provided by the given `header_provider`.
52+
pub fn new_with_headers(base_url: &str, retry_policy: R, header_provider: Arc<dyn VssHeaderProvider>) -> Self {
53+
let client = Client::new();
54+
Self { base_url: String::from(base_url), client, retry_policy, header_provider }
3855
}
3956

4057
/// Returns the underlying base URL.
@@ -111,10 +128,20 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
111128

112129
async fn post_request<Rq: Message, Rs: Message + Default>(&self, request: &Rq, url: &str) -> Result<Rs, VssError> {
113130
let request_body = request.encode_to_vec();
131+
let headermap = self
132+
.header_provider
133+
.get_headers(&request_body)
134+
.await
135+
.and_then(get_headermap)
136+
.map_err(|e| match e {
137+
e @ VssError::AuthError(_) => e,
138+
e => VssError::AuthError(e.to_string()),
139+
})?;
114140
let response_raw = self
115141
.client
116142
.post(url)
117143
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
144+
.headers(headermap)
118145
.body(request_body)
119146
.send()
120147
.await?;

src/headers/mod.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
use crate::error::VssError;
2+
use async_trait::async_trait;
3+
use reqwest::header::HeaderMap;
4+
use std::collections::HashMap;
5+
use std::str::FromStr;
6+
7+
/// Defines a trait around how headers are provided for each VSS request.
8+
#[async_trait]
9+
pub trait VssHeaderProvider {
10+
/// Returns the HTTP headers to be used for a VSS request.
11+
/// This method is called on each request, and should likely perform some form of caching.
12+
///
13+
/// A reference to the serialized request body is given as `request`.
14+
/// It can be used to perform operations such as request signing.
15+
///
16+
/// Any returned errors should be of the `VssError::AuthError` variant,
17+
/// or will be mapped to it.
18+
async fn get_headers(&self, request: &[u8]) -> Result<HashMap<String, String>, VssError>;
19+
}
20+
21+
/// A header provider returning an given, fixed set of headers.
22+
pub struct FixedHeaders {
23+
headers: HashMap<String, String>,
24+
}
25+
26+
impl FixedHeaders {
27+
/// Creates a new header provider returning the given, fixed set of headers.
28+
pub fn new(headers: HashMap<String, String>) -> FixedHeaders {
29+
FixedHeaders { headers }
30+
}
31+
}
32+
33+
#[async_trait]
34+
impl VssHeaderProvider for FixedHeaders {
35+
async fn get_headers(&self, _request: &[u8]) -> Result<HashMap<String, String>, VssError> {
36+
Ok(self.headers.clone())
37+
}
38+
}
39+
40+
pub(crate) fn get_headermap(headers: HashMap<String, String>) -> Result<HeaderMap, VssError> {
41+
let mut headermap = HeaderMap::new();
42+
for (name, value) in headers {
43+
headermap.insert(
44+
reqwest::header::HeaderName::from_str(&name).map_err(|e| VssError::AuthError(e.to_string()))?,
45+
reqwest::header::HeaderValue::from_str(&value).map_err(|e| VssError::AuthError(e.to_string()))?,
46+
);
47+
}
48+
Ok(headermap)
49+
}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ pub mod util;
2525

2626
// Encryption-Decryption related crate-only helpers.
2727
pub(crate) mod crypto;
28+
29+
/// A collection of header providers.
30+
pub mod headers;

0 commit comments

Comments
 (0)