From b47dbba4bc2dd1f9f43d9798def44c669d255a8a Mon Sep 17 00:00:00 2001 From: Eshed Schacham Date: Wed, 9 Apr 2025 11:10:15 +0200 Subject: [PATCH] Implement Future for SpawnedTask. It allows polling a SpawnedTask, instead of just joining it. The implementation is changed from `JoinSet` to a `JoinHandle` to simplify the code, as `JoinSet` doesn't provide any additional benefits. --- datafusion/common-runtime/src/common.rs | 86 ++++++++++++++++++++----- 1 file changed, 70 insertions(+), 16 deletions(-) diff --git a/datafusion/common-runtime/src/common.rs b/datafusion/common-runtime/src/common.rs index 361f6af95cf1..e7aba1d455ee 100644 --- a/datafusion/common-runtime/src/common.rs +++ b/datafusion/common-runtime/src/common.rs @@ -15,18 +15,25 @@ // specific language governing permissions and limitations // under the License. -use std::future::Future; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; -use crate::JoinSet; -use tokio::task::JoinError; +use tokio::task::{JoinError, JoinHandle}; + +use crate::trace_utils::{trace_block, trace_future}; /// Helper that provides a simple API to spawn a single task and join it. /// Provides guarantees of aborting on `Drop` to keep it cancel-safe. +/// Note that if the task was spawned with `spawn_blocking`, it will only be +/// aborted if it hasn't started yet. /// -/// Technically, it's just a wrapper of `JoinSet` (with size=1). +/// Technically, it's just a wrapper of a `JoinHandle` overriding drop. #[derive(Debug)] pub struct SpawnedTask { - inner: JoinSet, + inner: JoinHandle, } impl SpawnedTask { @@ -36,8 +43,9 @@ impl SpawnedTask { T: Send + 'static, R: Send, { - let mut inner = JoinSet::new(); - inner.spawn(task); + // Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop + #[allow(clippy::disallowed_methods)] + let inner = tokio::task::spawn(trace_future(task)); Self { inner } } @@ -47,22 +55,21 @@ impl SpawnedTask { T: Send + 'static, R: Send, { - let mut inner = JoinSet::new(); - inner.spawn_blocking(task); + // Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop + #[allow(clippy::disallowed_methods)] + let inner = tokio::task::spawn_blocking(trace_block(task)); Self { inner } } /// Joins the task, returning the result of join (`Result`). - pub async fn join(mut self) -> Result { - self.inner - .join_next() - .await - .expect("`SpawnedTask` instance always contains exactly 1 task") + /// Same as awaiting the spawned task, but left for backwards compatibility. + pub async fn join(self) -> Result { + self.await } /// Joins the task and unwinds the panic if it happens. pub async fn join_unwind(self) -> Result { - self.join().await.map_err(|e| { + self.await.map_err(|e| { // `JoinError` can be caused either by panic or cancellation. We have to handle panics: if e.is_panic() { std::panic::resume_unwind(e.into_panic()); @@ -77,17 +84,32 @@ impl SpawnedTask { } } +impl Future for SpawnedTask { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.inner).poll(cx) + } +} + +impl Drop for SpawnedTask { + fn drop(&mut self) { + self.inner.abort(); + } +} + #[cfg(test)] mod tests { use super::*; use std::future::{pending, Pending}; - use tokio::runtime::Runtime; + use tokio::{runtime::Runtime, sync::oneshot}; #[tokio::test] async fn runtime_shutdown() { let rt = Runtime::new().unwrap(); + #[allow(clippy::async_yields_async)] let task = rt .spawn(async { SpawnedTask::spawn(async { @@ -119,4 +141,36 @@ mod tests { .await .ok(); } + + #[tokio::test] + async fn cancel_not_started_task() { + let (sender, receiver) = oneshot::channel::(); + let task = SpawnedTask::spawn(async { + // Shouldn't be reached. + sender.send(42).unwrap(); + }); + + drop(task); + + // If the task was cancelled, the sender was also dropped, + // and awaiting the receiver should result in an error. + assert!(receiver.await.is_err()); + } + + #[tokio::test] + async fn cancel_ongoing_task() { + let (sender, mut receiver) = tokio::sync::mpsc::channel(1); + let task = SpawnedTask::spawn(async move { + sender.send(1).await.unwrap(); + // This line will never be reached because the channel has a buffer + // of 1. + sender.send(2).await.unwrap(); + }); + // Let the task start. + assert_eq!(receiver.recv().await.unwrap(), 1); + drop(task); + + // The sender was dropped so we receive `None`. + assert!(receiver.recv().await.is_none()); + } }