Skip to content

Commit 9e151a4

Browse files
committed
Introduce header provider trait
1 parent 63387fa commit 9e151a4

File tree

5 files changed

+79
-2
lines changed

5 files changed

+79
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
/target
22
/src/proto/
3+
/Cargo.lock

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ prost = "0.11.6"
1616
reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] }
1717
tokio = { version = "1", default-features = false, features = ["time"] }
1818
rand = "0.8.5"
19+
async-trait = "0.1.77"
1920

2021
[target.'cfg(genproto)'.build-dependencies]
2122
prost-build = { version = "0.11.3" }

src/client.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
use prost::Message;
2-
use reqwest;
32
use reqwest::header::CONTENT_TYPE;
43
use reqwest::Client;
4+
use std::collections::HashMap;
55
use std::default::Default;
6+
use std::sync::Arc;
67

78
use crate::error::VssError;
9+
use crate::headers::get_headermap;
10+
use crate::headers::FixedHeaders;
11+
use crate::headers::VssHeaderProvider;
812
use crate::types::{
913
DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, ListKeyVersionsRequest,
1014
ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
@@ -23,6 +27,7 @@ where
2327
base_url: String,
2428
client: Client,
2529
retry_policy: R,
30+
header_provider: Arc<dyn VssHeaderProvider>,
2631
}
2732

2833
impl<R: RetryPolicy<E = VssError>> VssClient<R> {
@@ -34,7 +39,19 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
3439

3540
/// Constructs a [`VssClient`] from a given [`reqwest::Client`], using `base_url` as the VSS server endpoint.
3641
pub fn from_client(base_url: &str, client: Client, retry_policy: R) -> Self {
37-
Self { base_url: String::from(base_url), client, retry_policy }
42+
Self {
43+
base_url: String::from(base_url),
44+
client,
45+
retry_policy,
46+
header_provider: Arc::new(FixedHeaders::new(HashMap::new())),
47+
}
48+
}
49+
50+
/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
51+
/// HTTP headers will be provided by the given `header_provider`.
52+
pub fn new_with_headers(base_url: &str, retry_policy: R, header_provider: Arc<dyn VssHeaderProvider>) -> Self {
53+
let client = Client::new();
54+
Self { base_url: String::from(base_url), client, retry_policy, header_provider }
3855
}
3956

4057
/// Returns the underlying base URL.
@@ -111,10 +128,17 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
111128

112129
async fn post_request<Rq: Message, Rs: Message + Default>(&self, request: &Rq, url: &str) -> Result<Rs, VssError> {
113130
let request_body = request.encode_to_vec();
131+
let headermap = self
132+
.header_provider
133+
.get_headers(&request_body)
134+
.await
135+
.and_then(get_headermap)
136+
.map_err(|e| VssError::AuthError(e.to_string()))?;
114137
let response_raw = self
115138
.client
116139
.post(url)
117140
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
141+
.headers(headermap)
118142
.body(request_body)
119143
.send()
120144
.await?;

src/headers/mod.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
use async_trait::async_trait;
2+
use reqwest::header::HeaderMap;
3+
use std::collections::HashMap;
4+
use std::io;
5+
use std::str::FromStr;
6+
7+
/// Defines a trait around how headers are provided for each VSS request.
8+
#[async_trait]
9+
pub trait VssHeaderProvider {
10+
/// Returns the HTTP headers to be used for a VSS request.
11+
/// This method is called on each request, and should likely perform some form of caching.
12+
///
13+
/// A reference to the serialized request body is given as `request`.
14+
/// It can be used to perform operations such as request signing.
15+
async fn get_headers(&self, request: &[u8]) -> io::Result<HashMap<String, String>>;
16+
}
17+
18+
/// A header provider returning an given, fixed set of headers.
19+
pub struct FixedHeaders {
20+
headers: HashMap<String, String>,
21+
}
22+
23+
impl FixedHeaders {
24+
/// Creates a new header provider returning the given, fixed set of headers.
25+
pub fn new(headers: HashMap<String, String>) -> FixedHeaders {
26+
FixedHeaders { headers }
27+
}
28+
}
29+
30+
#[async_trait]
31+
impl VssHeaderProvider for FixedHeaders {
32+
async fn get_headers(&self, _request: &[u8]) -> io::Result<HashMap<String, String>> {
33+
Ok(self.headers.clone())
34+
}
35+
}
36+
37+
pub(crate) fn get_headermap(headers: HashMap<String, String>) -> io::Result<HeaderMap> {
38+
let mut headermap = HeaderMap::new();
39+
for (name, value) in headers {
40+
headermap.insert(
41+
reqwest::header::HeaderName::from_str(&name)
42+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?,
43+
reqwest::header::HeaderValue::from_str(&value)
44+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?,
45+
);
46+
}
47+
Ok(headermap)
48+
}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ pub mod util;
2525

2626
// Encryption-Decryption related crate-only helpers.
2727
pub(crate) mod crypto;
28+
29+
/// A collection of header providers.
30+
pub mod headers;

0 commit comments

Comments
 (0)