Skip to content

Commit 23eddd6

Browse files
committed
Reuse VssError
1 parent bf1c8d0 commit 23eddd6

File tree

2 files changed

+34
-94
lines changed

2 files changed

+34
-94
lines changed

src/headers/lnurl_auth_jwt.rs

Lines changed: 31 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
use crate::error::VssError;
12
use crate::headers::get_headermap;
23
use crate::headers::VssHeaderProvider;
3-
use crate::headers::VssHeaderProviderError;
44
use async_trait::async_trait;
55
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
66
use base64::Engine;
@@ -68,37 +68,31 @@ impl LnurlAuthJwt {
6868
/// The JWT token will be returned in response to the signed LNURL request under a token field.
6969
/// The given set of headers will be used for LNURL requests, and will also be returned together
7070
/// with the JWT authorization header for VSS requests.
71-
pub fn new(
72-
seed: &[u8], url: String, default_headers: HashMap<String, String>,
73-
) -> Result<LnurlAuthJwt, VssHeaderProviderError> {
71+
pub fn new(seed: &[u8], url: String, default_headers: HashMap<String, String>) -> Result<LnurlAuthJwt, VssError> {
7472
let engine = Secp256k1::new();
75-
let master = ExtendedPrivKey::new_master(Network::Testnet, seed).map_err(VssHeaderProviderError::from)?;
76-
let child_number =
77-
ChildNumber::from_hardened_idx(PARENT_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?;
78-
let parent_key = master
79-
.derive_priv(&engine, &vec![child_number])
80-
.map_err(VssHeaderProviderError::from)?;
81-
let default_headermap =
82-
get_headermap(&default_headers).map_err(|error| VssHeaderProviderError::InvalidData { error })?;
73+
let master = ExtendedPrivKey::new_master(Network::Testnet, seed).map_err(VssError::from)?;
74+
let child_number = ChildNumber::from_hardened_idx(PARENT_DERIVATION_INDEX).map_err(VssError::from)?;
75+
let parent_key = master.derive_priv(&engine, &vec![child_number]).map_err(VssError::from)?;
76+
let default_headermap = get_headermap(&default_headers).map_err(|error| VssError::InternalError(error))?;
8377
let client = reqwest::Client::builder()
8478
.default_headers(default_headermap)
8579
.build()
86-
.map_err(VssHeaderProviderError::from)?;
80+
.map_err(VssError::from)?;
8781

8882
Ok(LnurlAuthJwt { engine, parent_key, url, default_headers, client, cached_jwt_token: RwLock::new(None) })
8983
}
9084

91-
async fn fetch_jwt_token(&self) -> Result<JwtToken, VssHeaderProviderError> {
85+
async fn fetch_jwt_token(&self) -> Result<JwtToken, VssError> {
9286
// Fetch the LNURL.
9387
let lnurl_str = self
9488
.client
9589
.get(&self.url)
9690
.send()
9791
.await
98-
.map_err(VssHeaderProviderError::from)?
92+
.map_err(VssError::from)?
9993
.text()
10094
.await
101-
.map_err(VssHeaderProviderError::from)?;
95+
.map_err(VssError::from)?;
10296

10397
// Sign the LNURL and perform the request.
10498
let signed_lnurl = sign_lnurl(&self.engine, &self.parent_key, &lnurl_str)?;
@@ -107,28 +101,26 @@ impl LnurlAuthJwt {
107101
.get(&signed_lnurl)
108102
.send()
109103
.await
110-
.map_err(VssHeaderProviderError::from)?
104+
.map_err(VssError::from)?
111105
.json()
112106
.await
113-
.map_err(VssHeaderProviderError::from)?;
107+
.map_err(VssError::from)?;
114108

115109
let untrusted_token = match lnurl_auth_response {
116110
LnurlAuthResponse { token: Some(token), .. } => token,
117111
LnurlAuthResponse { reason: Some(reason), .. } => {
118-
return Err(VssHeaderProviderError::ApplicationError {
119-
error: format!("LNURL Auth failed, reason is: {}", reason.escape_debug()),
120-
});
112+
return Err(VssError::AuthError(format!("LNURL Auth failed, reason is: {}", reason.escape_debug())));
121113
}
122114
_ => {
123-
return Err(VssHeaderProviderError::InvalidData {
124-
error: "LNURL Auth response did not contain a token nor an error".to_string(),
125-
});
115+
return Err(VssError::AuthError(
116+
"LNURL Auth response did not contain a token nor an error".to_string(),
117+
));
126118
}
127119
};
128120
parse_jwt_token(untrusted_token)
129121
}
130122

131-
async fn get_jwt_token(&self, force_refresh: bool) -> Result<String, VssHeaderProviderError> {
123+
async fn get_jwt_token(&self, force_refresh: bool) -> Result<String, VssError> {
132124
let cached_token_str = if force_refresh {
133125
None
134126
} else {
@@ -147,24 +139,23 @@ impl LnurlAuthJwt {
147139

148140
#[async_trait]
149141
impl VssHeaderProvider for LnurlAuthJwt {
150-
async fn get_headers(&self) -> Result<HashMap<String, String>, VssHeaderProviderError> {
142+
async fn get_headers(&self) -> Result<HashMap<String, String>, VssError> {
151143
let jwt_token = self.get_jwt_token(false).await?;
152144
let mut headers = self.default_headers.clone();
153145
headers.insert(AUTHORIZATION.to_string(), format!("Bearer {}", jwt_token));
154146
Ok(headers)
155147
}
156148
}
157149

158-
fn hashing_key(engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey) -> Result<PrivateKey, VssHeaderProviderError> {
159-
let hashing_child_number =
160-
ChildNumber::from_normal_idx(HASHING_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?;
150+
fn hashing_key(engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey) -> Result<PrivateKey, VssError> {
151+
let hashing_child_number = ChildNumber::from_normal_idx(HASHING_DERIVATION_INDEX).map_err(VssError::from)?;
161152
parent_key
162153
.derive_priv(engine, &vec![hashing_child_number])
163154
.map(|xpriv| xpriv.to_priv())
164-
.map_err(VssHeaderProviderError::from)
155+
.map_err(VssError::from)
165156
}
166157

167-
fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result<DerivationPath, VssHeaderProviderError> {
158+
fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result<DerivationPath, VssError> {
168159
let mut engine = HmacEngine::<sha256::Hash>::new(&hashing_key.inner[..]);
169160
engine.input(domain_name.as_bytes());
170161
let result = Hmac::<sha256::Hash>::from_engine(engine).to_byte_array();
@@ -174,12 +165,9 @@ fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result<Deriv
174165
Ok(DerivationPath::from_iter(children))
175166
}
176167

177-
fn sign_lnurl(
178-
engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey, lnurl_str: &str,
179-
) -> Result<String, VssHeaderProviderError> {
168+
fn sign_lnurl(engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey, lnurl_str: &str) -> Result<String, VssError> {
180169
// Parse k1 parameter to sign.
181-
let invalid_lnurl =
182-
|| VssHeaderProviderError::InvalidData { error: format!("invalid lnurl: {}", lnurl_str.escape_debug()) };
170+
let invalid_lnurl = || VssError::AuthError(format!("invalid lnurl: {}", lnurl_str.escape_debug()));
183171
let mut lnurl = Url::parse(lnurl_str).map_err(|_| invalid_lnurl())?;
184172
let domain = lnurl.domain().ok_or(invalid_lnurl())?;
185173
let k1_str = lnurl
@@ -195,11 +183,10 @@ fn sign_lnurl(
195183
let linking_key_path = linking_key_path(&hashing_private_key, domain)?;
196184
let linking_private_key = parent_key
197185
.derive_priv(engine, &linking_key_path)
198-
.map_err(VssHeaderProviderError::from)?
186+
.map_err(VssError::from)?
199187
.to_priv();
200188
let linking_public_key = linking_private_key.public_key(engine);
201-
let message = Message::from_slice(&k1)
202-
.map_err(|_| VssHeaderProviderError::InvalidData { error: format!("invalid k1: {:?}", k1) })?;
189+
let message = Message::from_slice(&k1).map_err(|_| VssError::AuthError(format!("invalid k1: {:?}", k1)))?;
203190
let sig = engine.sign_ecdsa(&message, &linking_private_key.inner);
204191

205192
// Compose LNURL with signature and linking public key.
@@ -221,10 +208,9 @@ struct ExpiryClaim {
221208
exp: Option<u64>,
222209
}
223210

224-
fn parse_jwt_token(jwt_token: String) -> Result<JwtToken, VssHeaderProviderError> {
211+
fn parse_jwt_token(jwt_token: String) -> Result<JwtToken, VssError> {
225212
let parts: Vec<&str> = jwt_token.split('.').collect();
226-
let invalid =
227-
|| VssHeaderProviderError::InvalidData { error: format!("invalid JWT token: {}", jwt_token.escape_debug()) };
213+
let invalid = || VssError::AuthError(format!("invalid JWT token: {}", jwt_token.escape_debug()));
228214
if parts.len() != 3 {
229215
return Err(invalid());
230216
}
@@ -235,15 +221,9 @@ fn parse_jwt_token(jwt_token: String) -> Result<JwtToken, VssHeaderProviderError
235221
Ok(JwtToken { token_str: jwt_token, expiry: claim.exp })
236222
}
237223

238-
impl From<bitcoin::bip32::Error> for VssHeaderProviderError {
239-
fn from(e: bitcoin::bip32::Error) -> VssHeaderProviderError {
240-
VssHeaderProviderError::InvalidData { error: e.to_string() }
241-
}
242-
}
243-
244-
impl From<reqwest::Error> for VssHeaderProviderError {
245-
fn from(e: reqwest::Error) -> VssHeaderProviderError {
246-
VssHeaderProviderError::RequestError { error: e.to_string() }
224+
impl From<bitcoin::bip32::Error> for VssError {
225+
fn from(e: bitcoin::bip32::Error) -> VssError {
226+
VssError::InternalError(e.to_string())
247227
}
248228
}
249229

src/headers/mod.rs

Lines changed: 3 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1+
use crate::error::VssError;
12
use async_trait::async_trait;
23
use reqwest::header::HeaderMap;
34
use std::collections::HashMap;
4-
use std::error::Error;
5-
use std::fmt::Display;
6-
use std::fmt::Formatter;
75
use std::str::FromStr;
86

97
#[cfg(feature = "lnurl-auth")]
@@ -12,50 +10,12 @@ mod lnurl_auth_jwt;
1210
#[cfg(feature = "lnurl-auth")]
1311
pub use lnurl_auth_jwt::LnurlAuthJwt;
1412

15-
/// Errors around providing headers for each VSS request.
16-
#[derive(Debug)]
17-
pub enum VssHeaderProviderError {
18-
/// Invalid data was encountered.
19-
InvalidData {
20-
/// The error message.
21-
error: String,
22-
},
23-
/// An external request failed.
24-
RequestError {
25-
/// The error message.
26-
error: String,
27-
},
28-
/// An application-level error occurred specific to the header provider functionality.
29-
ApplicationError {
30-
/// The error message.
31-
error: String,
32-
},
33-
}
34-
35-
impl Display for VssHeaderProviderError {
36-
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
37-
match self {
38-
Self::InvalidData { error } => {
39-
write!(f, "invalid data: {}", error)
40-
}
41-
Self::RequestError { error } => {
42-
write!(f, "error making request: {}", error)
43-
}
44-
Self::ApplicationError { error } => {
45-
write!(f, "application error providing headers: {}", error)
46-
}
47-
}
48-
}
49-
}
50-
51-
impl Error for VssHeaderProviderError {}
52-
5313
/// Defines a trait around how headers are provided for each VSS request.
5414
#[async_trait]
5515
pub trait VssHeaderProvider {
5616
/// Returns the HTTP headers to be used for a VSS request.
5717
/// This method is called on each request, and should likely perform some form of caching.
58-
async fn get_headers(&self) -> Result<HashMap<String, String>, VssHeaderProviderError>;
18+
async fn get_headers(&self) -> Result<HashMap<String, String>, VssError>;
5919
}
6020

6121
/// A header provider returning an given, fixed set of headers.
@@ -72,7 +32,7 @@ impl FixedHeaders {
7232

7333
#[async_trait]
7434
impl VssHeaderProvider for FixedHeaders {
75-
async fn get_headers(&self) -> Result<HashMap<String, String>, VssHeaderProviderError> {
35+
async fn get_headers(&self) -> Result<HashMap<String, String>, VssError> {
7636
Ok(self.headers.clone())
7737
}
7838
}

0 commit comments

Comments
 (0)