Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dtls/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use tokio::time::Duration;
Expand Down Expand Up @@ -131,7 +133,8 @@ pub(crate) const DEFAULT_MTU: usize = 1200; // bytes

// PSKCallback is called once we have the remote's psk_identity_hint.
// If the remote provided none it will be nil
pub(crate) type PskCallback = Arc<dyn (Fn(&[u8]) -> Result<Vec<u8>>) + Send + Sync>;
pub(crate) type PskCallback =
Arc<dyn (Fn(&[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>>) + Send + Sync>;

// ClientAuthType declares the policy the server will follow for
// TLS Client Authentication.
Expand Down
24 changes: 15 additions & 9 deletions dtls/src/conn/conn_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::future::Future;
use std::pin::Pin;
use std::time::SystemTime;

use rand::Rng;
use rustls::pki_types::CertificateDer;
use tokio::time::sleep;
use util::conn::conn_pipe::*;
use util::KeyingMaterialExporter;

Expand Down Expand Up @@ -79,24 +82,27 @@ async fn pipe_conn(
Ok((client, sever))
}

fn psk_callback_client(hint: &[u8]) -> Result<Vec<u8>> {
fn psk_callback_client(hint: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
trace!(
"Server's hint: {}",
String::from_utf8(hint.to_vec()).unwrap()
);
Ok(vec![0xAB, 0xC1, 0x23])
Box::pin(async move { Ok(vec![0xAB, 0xC1, 0x23]) })
}

fn psk_callback_server(hint: &[u8]) -> Result<Vec<u8>> {
fn psk_callback_server(hint: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
trace!(
"Client's hint: {}",
String::from_utf8(hint.to_vec()).unwrap()
);
Ok(vec![0xAB, 0xC1, 0x23])
Box::pin(async move {
sleep(Duration::from_millis(1)).await; // Now it's possible to await in the psk callback
Ok(vec![0xAB, 0xC1, 0x23])
})
}

fn psk_callback_hint_fail(_hint: &[u8]) -> Result<Vec<u8>> {
Err(Error::Other(ERR_PSK_REJECTED.to_owned()))
fn psk_callback_hint_fail(_hint: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
Box::pin(async move { Err(Error::Other(ERR_PSK_REJECTED.to_owned())) })
}

async fn create_test_client(
Expand Down Expand Up @@ -1617,7 +1623,7 @@ async fn test_cipher_suite_configuration() -> Result<()> {
assert!(cipher_suite.is_some(), "{name} expected some, but got none");
if let Some(cs) = &*cipher_suite {
assert_eq!(cs.id(), want_cs,
"test_cipher_suite_configuration: Server Selected Bad Cipher Suite '{}': expected({}) actual({})",
"test_cipher_suite_configuration: Server Selected Bad Cipher Suite '{}': expected({}) actual({})",
name, want_cs, cs.id());
}
}
Expand All @@ -1630,8 +1636,8 @@ async fn test_cipher_suite_configuration() -> Result<()> {
Ok(())
}

fn psk_callback(_b: &[u8]) -> Result<Vec<u8>> {
Ok(vec![0x00, 0x01, 0x02])
fn psk_callback(_b: &[u8]) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send>> {
Box::pin(async move { Ok(vec![0x00, 0x01, 0x02]) })
}

#[tokio::test]
Expand Down
6 changes: 3 additions & 3 deletions dtls/src/flight/flight3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ impl Flight for Flight3 {
}
};

if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h) {
if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h).await {
return Err((alert, err));
}
}
Expand Down Expand Up @@ -411,13 +411,13 @@ impl Flight for Flight3 {
}
}

pub(crate) fn handle_server_key_exchange(
pub(crate) async fn handle_server_key_exchange(
state: &mut State,
cfg: &HandshakeConfig,
h: &HandshakeMessageServerKeyExchange,
) -> Result<(), (Option<Alert>, Option<Error>)> {
if let Some(local_psk_callback) = &cfg.local_psk_callback {
let psk = match local_psk_callback(&h.identity_hint) {
let psk = match local_psk_callback(&h.identity_hint).await {
Ok(psk) => psk,
Err(err) => {
return Err((
Expand Down
3 changes: 2 additions & 1 deletion dtls/src/flight/flight4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,8 @@ impl Flight for Flight4 {

let mut pre_master_secret = vec![];
if let Some(local_psk_callback) = &cfg.local_psk_callback {
let psk = match local_psk_callback(&client_key_exchange.identity_hint) {
let psk = match local_psk_callback(&client_key_exchange.identity_hint).await
{
Ok(psk) => psk,
Err(err) => {
return Err((
Expand Down
3 changes: 2 additions & 1 deletion dtls/src/flight/flight5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ impl Flight for Flight5 {

// handshakeMessageServerKeyExchange is optional for PSK
if server_key_exchange_data.is_empty() {
if let Err((alert, err)) = handle_server_key_exchange(state, cfg, &server_key_exchange)
if let Err((alert, err)) =
handle_server_key_exchange(state, cfg, &server_key_exchange).await
{
return Err((alert, err));
}
Expand Down
9 changes: 6 additions & 3 deletions examples/examples/dtls/dial/psk/dial_psk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ async fn main() -> Result<(), Error> {
println!("connecting {server}..");

let config = Config {
psk: Some(Arc::new(|hint: &[u8]| -> Result<Vec<u8>, Error> {
println!("Server's hint: {}", String::from_utf8(hint.to_vec())?);
Ok(vec![0xAB, 0xC1, 0x23])
psk: Some(Arc::new(|hint: &[u8]| {
let hint = hint.to_owned();
Box::pin(async move {
println!("Server's hint: {}", String::from_utf8(hint.to_vec())?);
Ok(vec![0xAB, 0xC1, 0x23])
})
})),
psk_identity_hint: Some("webrtc-rs DTLS Server".as_bytes().to_vec()),
cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8],
Expand Down
9 changes: 6 additions & 3 deletions examples/examples/dtls/listen/psk/listen_psk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ async fn main() -> Result<(), Error> {
let host = matches.value_of("host").unwrap().to_owned();

let cfg = Config {
psk: Some(Arc::new(|hint: &[u8]| -> Result<Vec<u8>, Error> {
println!("Client's hint: {}", String::from_utf8(hint.to_vec())?);
Ok(vec![0xAB, 0xC1, 0x23])
psk: Some(Arc::new(|hint: &[u8]| {
let hint = hint.to_owned();
Box::pin(async move {
println!("Client's hint: {}", String::from_utf8(hint.to_vec())?);
Ok(vec![0xAB, 0xC1, 0x23])
})
})),
psk_identity_hint: Some("webrtc-rs DTLS Client".as_bytes().to_vec()),
cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8],
Expand Down
Loading