Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 70 additions & 16 deletions datafusion/common-runtime/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@
// specific language governing permissions and limitations
// under the License.

use std::future::Future;
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};

use crate::JoinSet;
use tokio::task::JoinError;
use tokio::task::{JoinError, JoinHandle};
Comment on lines -20 to +24
Copy link
Contributor

@gabotechs gabotechs Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think not using crate::JoinSet is going to render the work being done by @geoffreyclaude in #14547 pretty much useless. Is this right @geoffreyclaude?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can just add something similar to the spawned task

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch @gabotechs! Yes, pretty much. For both to be compatible, the same wrapping I implemented over the tokio::task::JoinSet probably needs to be done over tokio::task::spawn and tokio::task::spawn_blocking.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashdnazg I'll try to add unit tests for the instrumentation feature tomorrow which you could use to validate your change doesn't introduce any regression.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@geoffreyclaude I updated the PR with the tracing functions wrapping the spawned tasks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashdnazg Looks good to me! I've opened #15673 to add the required regression tests on my feature.


use crate::trace_utils::{trace_block, trace_future};

/// Helper that provides a simple API to spawn a single task and join it.
/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
/// Note that if the task was spawned with `spawn_blocking`, it will only be
/// aborted if it hasn't started yet.
///
/// Technically, it's just a wrapper of `JoinSet` (with size=1).
/// Technically, it's just a wrapper of a `JoinHandle` overriding drop.
#[derive(Debug)]
pub struct SpawnedTask<R> {
inner: JoinSet<R>,
inner: JoinHandle<R>,
}

impl<R: 'static> SpawnedTask<R> {
Expand All @@ -36,8 +43,9 @@ impl<R: 'static> SpawnedTask<R> {
T: Send + 'static,
R: Send,
{
let mut inner = JoinSet::new();
inner.spawn(task);
// Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop
#[allow(clippy::disallowed_methods)]
let inner = tokio::task::spawn(trace_future(task));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think some rationale about why it is ok to use the disallowed methods would help. Something like:

Suggested change
let inner = tokio::task::spawn(trace_future(task));
// Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop
let inner = tokio::task::spawn(trace_future(task));

Self { inner }
}

Expand All @@ -47,22 +55,21 @@ impl<R: 'static> SpawnedTask<R> {
T: Send + 'static,
R: Send,
{
let mut inner = JoinSet::new();
inner.spawn_blocking(task);
// Ok to use spawn_blocking here as SpawnedTask handles aborting/cancelling the task on Drop
#[allow(clippy::disallowed_methods)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#[allow(clippy::disallowed_methods)]
// Ok to use spawn here as SpawnedTask handles aborting/cancelling the task on Drop
#[allow(clippy::disallowed_methods)]

let inner = tokio::task::spawn_blocking(trace_block(task));
Self { inner }
}

/// Joins the task, returning the result of join (`Result<R, JoinError>`).
pub async fn join(mut self) -> Result<R, JoinError> {
self.inner
.join_next()
.await
.expect("`SpawnedTask` instance always contains exactly 1 task")
/// Same as awaiting the spawned task, but left for backwards compatibility.
pub async fn join(self) -> Result<R, JoinError> {
self.await
}

/// Joins the task and unwinds the panic if it happens.
pub async fn join_unwind(self) -> Result<R, JoinError> {
self.join().await.map_err(|e| {
self.await.map_err(|e| {
// `JoinError` can be caused either by panic or cancellation. We have to handle panics:
if e.is_panic() {
std::panic::resume_unwind(e.into_panic());
Expand All @@ -77,17 +84,32 @@ impl<R: 'static> SpawnedTask<R> {
}
}

impl<R> Future for SpawnedTask<R> {
type Output = Result<R, JoinError>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.inner).poll(cx)
}
}

impl<R> Drop for SpawnedTask<R> {
fn drop(&mut self) {
self.inner.abort();
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::future::{pending, Pending};

use tokio::runtime::Runtime;
use tokio::{runtime::Runtime, sync::oneshot};

#[tokio::test]
async fn runtime_shutdown() {
let rt = Runtime::new().unwrap();
#[allow(clippy::async_yields_async)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is ok -- the only reason clippy picks this up now is that SpawnedTask is actually a future where previously one had to call join().await -- so TLDR this change makes sense to me (though the test is perhaps somewhat suspect)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want, I can change the rt.spawn to rt.spawn_blocking and then we get rid of both the unnecessary surrounding async and the clippy thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a second test -- I think we should leave the existing test as is as part of ensuring this PR doesn't introduce any regressiosn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A second test wouldn't test anything new IMO so keeping as is.

let task = rt
.spawn(async {
SpawnedTask::spawn(async {
Expand Down Expand Up @@ -119,4 +141,36 @@ mod tests {
.await
.ok();
}

#[tokio::test]
async fn cancel_not_started_task() {
let (sender, receiver) = oneshot::channel::<i32>();
let task = SpawnedTask::spawn(async {
// Shouldn't be reached.
sender.send(42).unwrap();
});

drop(task);

// If the task was cancelled, the sender was also dropped,
// and awaiting the receiver should result in an error.
assert!(receiver.await.is_err());
}

#[tokio::test]
async fn cancel_ongoing_task() {
let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
let task = SpawnedTask::spawn(async move {
sender.send(1).await.unwrap();
// This line will never be reached because the channel has a buffer
// of 1.
sender.send(2).await.unwrap();
});
// Let the task start.
assert_eq!(receiver.recv().await.unwrap(), 1);
drop(task);

// The sender was dropped so we receive `None`.
assert!(receiver.recv().await.is_none());
}
}