Skip to content

Commit 56a2af7

Browse files
gabotechsalamb
andauthored
Propagate .execute() calls immediately in RepartitionExec (#16093)
* Propagate .execute() calls immediately instead of lazily on the first RecordBatch poll * Address race condition: make consume_input_streams lazily initialize the RepartitionExecState if it was not initialized * Remove atomic bool for checking if the state was initialized --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 33a2531 commit 56a2af7

File tree

1 file changed

+132
-72
lines changed
  • datafusion/physical-plan/src/repartition

1 file changed

+132
-72
lines changed

datafusion/physical-plan/src/repartition/mod.rs

Lines changed: 132 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
//! partitions to M output partitions based on a partitioning scheme, optionally
2020
//! maintaining the order of the input rows in the output.
2121
22+
use std::fmt::{Debug, Formatter};
2223
use std::pin::Pin;
2324
use std::sync::Arc;
2425
use std::task::{Context, Poll};
@@ -45,7 +46,7 @@ use arrow::compute::take_arrays;
4546
use arrow::datatypes::{SchemaRef, UInt32Type};
4647
use datafusion_common::config::ConfigOptions;
4748
use datafusion_common::utils::transpose;
48-
use datafusion_common::HashMap;
49+
use datafusion_common::{internal_err, HashMap};
4950
use datafusion_common::{not_impl_err, DataFusionError, Result};
5051
use datafusion_common_runtime::SpawnedTask;
5152
use datafusion_execution::memory_pool::MemoryConsumer;
@@ -67,9 +68,8 @@ type MaybeBatch = Option<Result<RecordBatch>>;
6768
type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>;
6869
type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>;
6970

70-
/// Inner state of [`RepartitionExec`].
7171
#[derive(Debug)]
72-
struct RepartitionExecState {
72+
struct ConsumingInputStreamsState {
7373
/// Channels for sending batches from input partitions to output partitions.
7474
/// Key is the partition number.
7575
channels: HashMap<
@@ -85,16 +85,97 @@ struct RepartitionExecState {
8585
abort_helper: Arc<Vec<SpawnedTask<()>>>,
8686
}
8787

88+
/// Inner state of [`RepartitionExec`].
89+
enum RepartitionExecState {
90+
/// Not initialized yet. This is the default state stored in the RepartitionExec node
91+
/// upon instantiation.
92+
NotInitialized,
93+
/// Input streams are initialized, but they are still not being consumed. The node
94+
/// transitions to this state when the arrow's RecordBatch stream is created in
95+
/// RepartitionExec::execute(), but before any message is polled.
96+
InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>),
97+
/// The input streams are being consumed. The node transitions to this state when
98+
/// the first message in the arrow's RecordBatch stream is consumed.
99+
ConsumingInputStreams(ConsumingInputStreamsState),
100+
}
101+
102+
impl Default for RepartitionExecState {
103+
fn default() -> Self {
104+
Self::NotInitialized
105+
}
106+
}
107+
108+
impl Debug for RepartitionExecState {
109+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
110+
match self {
111+
RepartitionExecState::NotInitialized => write!(f, "NotInitialized"),
112+
RepartitionExecState::InputStreamsInitialized(v) => {
113+
write!(f, "InputStreamsInitialized({:?})", v.len())
114+
}
115+
RepartitionExecState::ConsumingInputStreams(v) => {
116+
write!(f, "ConsumingInputStreams({v:?})")
117+
}
118+
}
119+
}
120+
}
121+
88122
impl RepartitionExecState {
89-
fn new(
123+
fn ensure_input_streams_initialized(
124+
&mut self,
125+
input: Arc<dyn ExecutionPlan>,
126+
metrics: ExecutionPlanMetricsSet,
127+
output_partitions: usize,
128+
ctx: Arc<TaskContext>,
129+
) -> Result<()> {
130+
if !matches!(self, RepartitionExecState::NotInitialized) {
131+
return Ok(());
132+
}
133+
134+
let num_input_partitions = input.output_partitioning().partition_count();
135+
let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
136+
137+
for i in 0..num_input_partitions {
138+
let metrics = RepartitionMetrics::new(i, output_partitions, &metrics);
139+
140+
let timer = metrics.fetch_time.timer();
141+
let stream = input.execute(i, Arc::clone(&ctx))?;
142+
timer.done();
143+
144+
streams_and_metrics.push((stream, metrics));
145+
}
146+
*self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
147+
Ok(())
148+
}
149+
150+
fn consume_input_streams(
151+
&mut self,
90152
input: Arc<dyn ExecutionPlan>,
91-
partitioning: Partitioning,
92153
metrics: ExecutionPlanMetricsSet,
154+
partitioning: Partitioning,
93155
preserve_order: bool,
94156
name: String,
95157
context: Arc<TaskContext>,
96-
) -> Self {
97-
let num_input_partitions = input.output_partitioning().partition_count();
158+
) -> Result<&mut ConsumingInputStreamsState> {
159+
let streams_and_metrics = match self {
160+
RepartitionExecState::NotInitialized => {
161+
self.ensure_input_streams_initialized(
162+
input,
163+
metrics,
164+
partitioning.partition_count(),
165+
Arc::clone(&context),
166+
)?;
167+
let RepartitionExecState::InputStreamsInitialized(value) = self else {
168+
// This cannot happen, as ensure_input_streams_initialized() was just called,
169+
// but the compiler does not know.
170+
return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized");
171+
};
172+
value
173+
}
174+
RepartitionExecState::ConsumingInputStreams(value) => return Ok(value),
175+
RepartitionExecState::InputStreamsInitialized(value) => value,
176+
};
177+
178+
let num_input_partitions = streams_and_metrics.len();
98179
let num_output_partitions = partitioning.partition_count();
99180

100181
let (txs, rxs) = if preserve_order {
@@ -129,23 +210,21 @@ impl RepartitionExecState {
129210

130211
// launch one async task per *input* partition
131212
let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
132-
for i in 0..num_input_partitions {
213+
for (i, (stream, metrics)) in
214+
std::mem::take(streams_and_metrics).into_iter().enumerate()
215+
{
133216
let txs: HashMap<_, _> = channels
134217
.iter()
135218
.map(|(partition, (tx, _rx, reservation))| {
136219
(*partition, (tx[i].clone(), Arc::clone(reservation)))
137220
})
138221
.collect();
139222

140-
let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics);
141-
142223
let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input(
143-
Arc::clone(&input),
144-
i,
224+
stream,
145225
txs.clone(),
146226
partitioning.clone(),
147-
r_metrics,
148-
Arc::clone(&context),
227+
metrics,
149228
));
150229

151230
// In a separate task, wait for each input to be done
@@ -158,28 +237,17 @@ impl RepartitionExecState {
158237
));
159238
spawned_tasks.push(wait_for_task);
160239
}
161-
162-
Self {
240+
*self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
163241
channels,
164242
abort_helper: Arc::new(spawned_tasks),
243+
});
244+
match self {
245+
RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
246+
_ => unreachable!(),
165247
}
166248
}
167249
}
168250

169-
/// Lazily initialized state
170-
///
171-
/// Note that the state is initialized ONCE for all partitions by a single task(thread).
172-
/// This may take a short while. It is also like that multiple threads
173-
/// call execute at the same time, because we have just started "target partitions" tasks
174-
/// which is commonly set to the number of CPU cores and all call execute at the same time.
175-
///
176-
/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles
177-
/// in a mutex lock but instead allow other threads to do something useful.
178-
///
179-
/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration
180-
/// (e.g. removing channels on completion) where the overhead of `await` is not warranted.
181-
type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>;
182-
183251
/// A utility that can be used to partition batches based on [`Partitioning`]
184252
pub struct BatchPartitioner {
185253
state: BatchPartitionerState,
@@ -406,8 +474,9 @@ impl BatchPartitioner {
406474
pub struct RepartitionExec {
407475
/// Input execution plan
408476
input: Arc<dyn ExecutionPlan>,
409-
/// Inner state that is initialized when the first output stream is created.
410-
state: LazyState,
477+
/// Inner state that is initialized when the parent calls .execute() on this node
478+
/// and consumed as soon as the parent starts consuming this node.
479+
state: Arc<Mutex<RepartitionExecState>>,
411480
/// Execution metrics
412481
metrics: ExecutionPlanMetricsSet,
413482
/// Boolean flag to decide whether to preserve ordering. If true means
@@ -486,11 +555,7 @@ impl RepartitionExec {
486555
}
487556

488557
impl DisplayAs for RepartitionExec {
489-
fn fmt_as(
490-
&self,
491-
t: DisplayFormatType,
492-
f: &mut std::fmt::Formatter,
493-
) -> std::fmt::Result {
558+
fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result {
494559
match t {
495560
DisplayFormatType::Default | DisplayFormatType::Verbose => {
496561
write!(
@@ -583,7 +648,6 @@ impl ExecutionPlan for RepartitionExec {
583648
partition
584649
);
585650

586-
let lazy_state = Arc::clone(&self.state);
587651
let input = Arc::clone(&self.input);
588652
let partitioning = self.partitioning().clone();
589653
let metrics = self.metrics.clone();
@@ -595,30 +659,31 @@ impl ExecutionPlan for RepartitionExec {
595659
// Get existing ordering to use for merging
596660
let sort_exprs = self.sort_exprs().cloned().unwrap_or_default();
597661

662+
let state = Arc::clone(&self.state);
663+
if let Some(mut state) = state.try_lock() {
664+
state.ensure_input_streams_initialized(
665+
Arc::clone(&input),
666+
metrics.clone(),
667+
partitioning.partition_count(),
668+
Arc::clone(&context),
669+
)?;
670+
}
671+
598672
let stream = futures::stream::once(async move {
599673
let num_input_partitions = input.output_partitioning().partition_count();
600674

601-
let input_captured = Arc::clone(&input);
602-
let metrics_captured = metrics.clone();
603-
let name_captured = name.clone();
604-
let context_captured = Arc::clone(&context);
605-
let state = lazy_state
606-
.get_or_init(|| async move {
607-
Mutex::new(RepartitionExecState::new(
608-
input_captured,
609-
partitioning,
610-
metrics_captured,
611-
preserve_order,
612-
name_captured,
613-
context_captured,
614-
))
615-
})
616-
.await;
617-
618675
// lock scope
619676
let (mut rx, reservation, abort_helper) = {
620677
// lock mutexes
621678
let mut state = state.lock();
679+
let state = state.consume_input_streams(
680+
Arc::clone(&input),
681+
metrics.clone(),
682+
partitioning,
683+
preserve_order,
684+
name.clone(),
685+
Arc::clone(&context),
686+
)?;
622687

623688
// now return stream for the specified *output* partition which will
624689
// read from the channel
@@ -853,24 +918,17 @@ impl RepartitionExec {
853918
///
854919
/// txs hold the output sending channels for each output partition
855920
async fn pull_from_input(
856-
input: Arc<dyn ExecutionPlan>,
857-
partition: usize,
921+
mut stream: SendableRecordBatchStream,
858922
mut output_channels: HashMap<
859923
usize,
860924
(DistributionSender<MaybeBatch>, SharedMemoryReservation),
861925
>,
862926
partitioning: Partitioning,
863927
metrics: RepartitionMetrics,
864-
context: Arc<TaskContext>,
865928
) -> Result<()> {
866929
let mut partitioner =
867930
BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?;
868931

869-
// execute the child operator
870-
let timer = metrics.fetch_time.timer();
871-
let mut stream = input.execute(partition, context)?;
872-
timer.done();
873-
874932
// While there are still outputs to send to, keep pulling inputs
875933
let mut batches_until_yield = partitioner.num_partitions();
876934
while !output_channels.is_empty() {
@@ -1118,6 +1176,7 @@ mod tests {
11181176
use datafusion_common_runtime::JoinSet;
11191177
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
11201178
use insta::assert_snapshot;
1179+
use itertools::Itertools;
11211180

11221181
#[tokio::test]
11231182
async fn one_to_many_round_robin() -> Result<()> {
@@ -1298,15 +1357,9 @@ mod tests {
12981357
let partitioning = Partitioning::RoundRobinBatch(1);
12991358
let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap();
13001359

1301-
// Note: this should pass (the stream can be created) but the
1302-
// error when the input is executed should get passed back
1303-
let output_stream = exec.execute(0, task_ctx).unwrap();
1304-
13051360
// Expect that an error is returned
1306-
let result_string = crate::common::collect(output_stream)
1307-
.await
1308-
.unwrap_err()
1309-
.to_string();
1361+
let result_string = exec.execute(0, task_ctx).err().unwrap().to_string();
1362+
13101363
assert!(
13111364
result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"),
13121365
"actual: {result_string}"
@@ -1496,7 +1549,14 @@ mod tests {
14961549
});
14971550
let batches_with_drop = crate::common::collect(output_stream1).await.unwrap();
14981551

1499-
assert_eq!(batches_without_drop, batches_with_drop);
1552+
fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
1553+
batch
1554+
.into_iter()
1555+
.sorted_by_key(|b| format!("{b:?}"))
1556+
.collect()
1557+
}
1558+
1559+
assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
15001560
}
15011561

15021562
fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {

0 commit comments

Comments
 (0)