From 718247a100ee3773851d0bf342d3a1d8df86c74f Mon Sep 17 00:00:00 2001 From: Marc Delling Date: Mon, 3 Nov 2025 12:50:31 +0100 Subject: [PATCH 1/2] Make psk callback async-capable --- dtls/src/config.rs | 5 ++++- dtls/src/conn/conn_test.rs | 24 +++++++++++++++--------- dtls/src/flight/flight3.rs | 6 +++--- dtls/src/flight/flight4.rs | 3 ++- dtls/src/flight/flight5.rs | 3 ++- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/dtls/src/config.rs b/dtls/src/config.rs index 0e1e23c6f..17edd608e 100644 --- a/dtls/src/config.rs +++ b/dtls/src/config.rs @@ -1,3 +1,5 @@ +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use tokio::time::Duration; @@ -131,7 +133,8 @@ pub(crate) const DEFAULT_MTU: usize = 1200; // bytes // PSKCallback is called once we have the remote's psk_identity_hint. // If the remote provided none it will be nil -pub(crate) type PskCallback = Arc Result>) + Send + Sync>; +pub(crate) type PskCallback = + Arc Pin>> + Send>>) + Send + Sync>; // ClientAuthType declares the policy the server will follow for // TLS Client Authentication. diff --git a/dtls/src/conn/conn_test.rs b/dtls/src/conn/conn_test.rs index af0c8c811..23fe43993 100644 --- a/dtls/src/conn/conn_test.rs +++ b/dtls/src/conn/conn_test.rs @@ -1,7 +1,10 @@ +use std::future::Future; +use std::pin::Pin; use std::time::SystemTime; use rand::Rng; use rustls::pki_types::CertificateDer; +use tokio::time::sleep; use util::conn::conn_pipe::*; use util::KeyingMaterialExporter; @@ -79,24 +82,27 @@ async fn pipe_conn( Ok((client, sever)) } -fn psk_callback_client(hint: &[u8]) -> Result> { +fn psk_callback_client(hint: &[u8]) -> Pin>> + Send>> { trace!( "Server's hint: {}", String::from_utf8(hint.to_vec()).unwrap() ); - Ok(vec![0xAB, 0xC1, 0x23]) + Box::pin(async move { Ok(vec![0xAB, 0xC1, 0x23]) }) } -fn psk_callback_server(hint: &[u8]) -> Result> { +fn psk_callback_server(hint: &[u8]) -> Pin>> + Send>> { trace!( "Client's hint: {}", String::from_utf8(hint.to_vec()).unwrap() ); - Ok(vec![0xAB, 0xC1, 0x23]) + Box::pin(async move { + sleep(Duration::from_millis(1)).await; // Now it's possible to await in the psk callback + Ok(vec![0xAB, 0xC1, 0x23]) + }) } -fn psk_callback_hint_fail(_hint: &[u8]) -> Result> { - Err(Error::Other(ERR_PSK_REJECTED.to_owned())) +fn psk_callback_hint_fail(_hint: &[u8]) -> Pin>> + Send>> { + Box::pin(async move { Err(Error::Other(ERR_PSK_REJECTED.to_owned())) }) } async fn create_test_client( @@ -1617,7 +1623,7 @@ async fn test_cipher_suite_configuration() -> Result<()> { assert!(cipher_suite.is_some(), "{name} expected some, but got none"); if let Some(cs) = &*cipher_suite { assert_eq!(cs.id(), want_cs, - "test_cipher_suite_configuration: Server Selected Bad Cipher Suite '{}': expected({}) actual({})", + "test_cipher_suite_configuration: Server Selected Bad Cipher Suite '{}': expected({}) actual({})", name, want_cs, cs.id()); } } @@ -1630,8 +1636,8 @@ async fn test_cipher_suite_configuration() -> Result<()> { Ok(()) } -fn psk_callback(_b: &[u8]) -> Result> { - Ok(vec![0x00, 0x01, 0x02]) +fn psk_callback(_b: &[u8]) -> Pin>> + Send>> { + Box::pin(async move { Ok(vec![0x00, 0x01, 0x02]) }) } #[tokio::test] diff --git a/dtls/src/flight/flight3.rs b/dtls/src/flight/flight3.rs index d84e3bdd0..a4e2725cd 100644 --- a/dtls/src/flight/flight3.rs +++ b/dtls/src/flight/flight3.rs @@ -319,7 +319,7 @@ impl Flight for Flight3 { } }; - if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h) { + if let Err((alert, err)) = handle_server_key_exchange(state, cfg, h).await { return Err((alert, err)); } } @@ -411,13 +411,13 @@ impl Flight for Flight3 { } } -pub(crate) fn handle_server_key_exchange( +pub(crate) async fn handle_server_key_exchange( state: &mut State, cfg: &HandshakeConfig, h: &HandshakeMessageServerKeyExchange, ) -> Result<(), (Option, Option)> { if let Some(local_psk_callback) = &cfg.local_psk_callback { - let psk = match local_psk_callback(&h.identity_hint) { + let psk = match local_psk_callback(&h.identity_hint).await { Ok(psk) => psk, Err(err) => { return Err(( diff --git a/dtls/src/flight/flight4.rs b/dtls/src/flight/flight4.rs index 11a77ffe1..b93e78e47 100644 --- a/dtls/src/flight/flight4.rs +++ b/dtls/src/flight/flight4.rs @@ -290,7 +290,8 @@ impl Flight for Flight4 { let mut pre_master_secret = vec![]; if let Some(local_psk_callback) = &cfg.local_psk_callback { - let psk = match local_psk_callback(&client_key_exchange.identity_hint) { + let psk = match local_psk_callback(&client_key_exchange.identity_hint).await + { Ok(psk) => psk, Err(err) => { return Err(( diff --git a/dtls/src/flight/flight5.rs b/dtls/src/flight/flight5.rs index 264cd9e0a..c794680f9 100644 --- a/dtls/src/flight/flight5.rs +++ b/dtls/src/flight/flight5.rs @@ -273,7 +273,8 @@ impl Flight for Flight5 { // handshakeMessageServerKeyExchange is optional for PSK if server_key_exchange_data.is_empty() { - if let Err((alert, err)) = handle_server_key_exchange(state, cfg, &server_key_exchange) + if let Err((alert, err)) = + handle_server_key_exchange(state, cfg, &server_key_exchange).await { return Err((alert, err)); } From ef3b2e42e01861be24b440a63aa02ef6a09a2823 Mon Sep 17 00:00:00 2001 From: Marc Delling Date: Mon, 3 Nov 2025 13:22:55 +0100 Subject: [PATCH 2/2] Amend examples --- examples/examples/dtls/dial/psk/dial_psk.rs | 9 ++++++--- examples/examples/dtls/listen/psk/listen_psk.rs | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/examples/examples/dtls/dial/psk/dial_psk.rs b/examples/examples/dtls/dial/psk/dial_psk.rs index 05e7d16e9..dc9803771 100644 --- a/examples/examples/dtls/dial/psk/dial_psk.rs +++ b/examples/examples/dtls/dial/psk/dial_psk.rs @@ -62,9 +62,12 @@ async fn main() -> Result<(), Error> { println!("connecting {server}.."); let config = Config { - psk: Some(Arc::new(|hint: &[u8]| -> Result, Error> { - println!("Server's hint: {}", String::from_utf8(hint.to_vec())?); - Ok(vec![0xAB, 0xC1, 0x23]) + psk: Some(Arc::new(|hint: &[u8]| { + let hint = hint.to_owned(); + Box::pin(async move { + println!("Server's hint: {}", String::from_utf8(hint.to_vec())?); + Ok(vec![0xAB, 0xC1, 0x23]) + }) })), psk_identity_hint: Some("webrtc-rs DTLS Server".as_bytes().to_vec()), cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8], diff --git a/examples/examples/dtls/listen/psk/listen_psk.rs b/examples/examples/dtls/listen/psk/listen_psk.rs index 8846b33a0..4f7699dbf 100644 --- a/examples/examples/dtls/listen/psk/listen_psk.rs +++ b/examples/examples/dtls/listen/psk/listen_psk.rs @@ -57,9 +57,12 @@ async fn main() -> Result<(), Error> { let host = matches.value_of("host").unwrap().to_owned(); let cfg = Config { - psk: Some(Arc::new(|hint: &[u8]| -> Result, Error> { - println!("Client's hint: {}", String::from_utf8(hint.to_vec())?); - Ok(vec![0xAB, 0xC1, 0x23]) + psk: Some(Arc::new(|hint: &[u8]| { + let hint = hint.to_owned(); + Box::pin(async move { + println!("Client's hint: {}", String::from_utf8(hint.to_vec())?); + Ok(vec![0xAB, 0xC1, 0x23]) + }) })), psk_identity_hint: Some("webrtc-rs DTLS Client".as_bytes().to_vec()), cipher_suites: vec![CipherSuiteId::Tls_Psk_With_Aes_128_Ccm_8],