diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index c0dbf5164e19..2ed5da7ced20 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -36,6 +36,7 @@ use crate::repartition::distributor_channels::{ channels, partition_aware_channels, DistributionReceiver, DistributionSender, }; use crate::sorts::streaming_merge; +use crate::stream::RecordBatchStreamAdapter; use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, Statistics}; use arrow::array::{ArrayRef, UInt64Builder}; @@ -48,7 +49,7 @@ use datafusion_execution::TaskContext; use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr, PhysicalSortExpr}; use futures::stream::Stream; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use hashbrown::HashMap; use log::trace; use parking_lot::Mutex; @@ -77,6 +78,102 @@ struct RepartitionExecState { abort_helper: Arc>>, } +impl RepartitionExecState { + fn new( + input: Arc, + partitioning: Partitioning, + metrics: ExecutionPlanMetricsSet, + preserve_order: bool, + name: String, + context: Arc, + ) -> Self { + let num_input_partitions = input.output_partitioning().partition_count(); + let num_output_partitions = partitioning.partition_count(); + + let (txs, rxs) = if preserve_order { + let (txs, rxs) = + partition_aware_channels(num_input_partitions, num_output_partitions); + // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition + let txs = transpose(txs); + let rxs = transpose(rxs); + (txs, rxs) + } else { + // create one channel per *output* partition + // note we use a custom channel that ensures there is always data for each receiver + // but limits the amount of buffering if required. + let (txs, rxs) = channels(num_output_partitions); + // Clone sender for each input partitions + let txs = txs + .into_iter() + .map(|item| vec![item; num_input_partitions]) + .collect::>(); + let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); + (txs, rxs) + }; + + let mut channels = HashMap::with_capacity(txs.len()); + for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { + let reservation = Arc::new(Mutex::new( + MemoryConsumer::new(format!("{}[{partition}]", name)) + .register(context.memory_pool()), + )); + channels.insert(partition, (tx, rx, reservation)); + } + + // launch one async task per *input* partition + let mut spawned_tasks = Vec::with_capacity(num_input_partitions); + for i in 0..num_input_partitions { + let txs: HashMap<_, _> = channels + .iter() + .map(|(partition, (tx, _rx, reservation))| { + (*partition, (tx[i].clone(), Arc::clone(reservation))) + }) + .collect(); + + // TODO: metric input-output mapping is broken + let r_metrics = RepartitionMetrics::new(i, 0, &metrics); + + let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( + input.clone(), + i, + txs.clone(), + partitioning.clone(), + r_metrics, + context.clone(), + )); + + // In a separate task, wait for each input to be done + // (and pass along any errors, including panic!s) + let wait_for_task = SpawnedTask::spawn(RepartitionExec::wait_for_task( + input_task, + txs.into_iter() + .map(|(partition, (tx, _reservation))| (partition, tx)) + .collect(), + )); + spawned_tasks.push(wait_for_task); + } + + Self { + channels, + abort_helper: Arc::new(spawned_tasks), + } + } +} + +/// Lazily initialized state +/// +/// Note that the state is initialized ONCE for all partitions by a single task(thread). +/// This may take a short while. It is also like that multiple threads +/// call execute at the same time, because we have just started "target partitions" tasks +/// which is commonly set to the number of CPU cores and all call execute at the same time. +/// +/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles +/// in a futex lock but instead allow other threads to do something useful. +/// +/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration +/// (e.g. removing channels on completion) where the overhead of `await` is not warranted. +type LazyState = Arc>>; + /// A utility that can be used to partition batches based on [`Partitioning`] pub struct BatchPartitioner { state: BatchPartitionerState, @@ -298,7 +395,7 @@ pub struct RepartitionExec { /// Partitioning scheme to use partitioning: Partitioning, /// Inner state that is initialized when the first output stream is created. - state: Arc>, + state: LazyState, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Boolean flag to decide whether to preserve ordering. If true means @@ -453,134 +550,104 @@ impl ExecutionPlan for RepartitionExec { self.name(), partition ); - // lock mutexes - let mut state = self.state.lock(); - - let num_input_partitions = self.input.output_partitioning().partition_count(); - let num_output_partitions = self.partitioning.partition_count(); - - // if this is the first partition to be invoked then we need to set up initial state - if state.channels.is_empty() { - let (txs, rxs) = if self.preserve_order { - let (txs, rxs) = - partition_aware_channels(num_input_partitions, num_output_partitions); - // Take transpose of senders and receivers. `state.channels` keeps track of entries per output partition - let txs = transpose(txs); - let rxs = transpose(rxs); - (txs, rxs) - } else { - // create one channel per *output* partition - // note we use a custom channel that ensures there is always data for each receiver - // but limits the amount of buffering if required. - let (txs, rxs) = channels(num_output_partitions); - // Clone sender for each input partitions - let txs = txs - .into_iter() - .map(|item| vec![item; num_input_partitions]) - .collect::>(); - let rxs = rxs.into_iter().map(|item| vec![item]).collect::>(); - (txs, rxs) - }; - for (partition, (tx, rx)) in txs.into_iter().zip(rxs).enumerate() { - let reservation = Arc::new(Mutex::new( - MemoryConsumer::new(format!("{}[{partition}]", self.name())) - .register(context.memory_pool()), - )); - state.channels.insert(partition, (tx, rx, reservation)); - } - // launch one async task per *input* partition - let mut spawned_tasks = Vec::with_capacity(num_input_partitions); - for i in 0..num_input_partitions { - let txs: HashMap<_, _> = state - .channels - .iter() - .map(|(partition, (tx, _rx, reservation))| { - (*partition, (tx[i].clone(), Arc::clone(reservation))) - }) - .collect(); - - let r_metrics = RepartitionMetrics::new(i, partition, &self.metrics); - - let input_task = SpawnedTask::spawn(Self::pull_from_input( - self.input.clone(), - i, - txs.clone(), - self.partitioning.clone(), - r_metrics, - context.clone(), - )); - - // In a separate task, wait for each input to be done - // (and pass along any errors, including panic!s) - let wait_for_task = SpawnedTask::spawn(Self::wait_for_task( - input_task, - txs.into_iter() - .map(|(partition, (tx, _reservation))| (partition, tx)) - .collect(), - )); - spawned_tasks.push(wait_for_task); - } + let lazy_state = Arc::clone(&self.state); + let input = Arc::clone(&self.input); + let partitioning = self.partitioning.clone(); + let metrics = self.metrics.clone(); + let preserve_order = self.preserve_order; + let name = self.name().to_owned(); + let schema = self.schema(); + let schema_captured = Arc::clone(&schema); + + // Get existing ordering to use for merging + let sort_exprs = self.sort_exprs().unwrap_or(&[]).to_owned(); + + let stream = futures::stream::once(async move { + let num_input_partitions = input.output_partitioning().partition_count(); + + let input_captured = Arc::clone(&input); + let metrics_captured = metrics.clone(); + let name_captured = name.clone(); + let context_captured = Arc::clone(&context); + let state = lazy_state + .get_or_init(|| async move { + Mutex::new(RepartitionExecState::new( + input_captured, + partitioning, + metrics_captured, + preserve_order, + name_captured, + context_captured, + )) + }) + .await; - state.abort_helper = Arc::new(spawned_tasks) - } + // lock scope + let (mut rx, reservation, abort_helper) = { + // lock mutexes + let mut state = state.lock(); - trace!( - "Before returning stream in {}::execute for partition: {}", - self.name(), - partition - ); + // now return stream for the specified *output* partition which will + // read from the channel + let (_tx, rx, reservation) = state + .channels + .remove(&partition) + .expect("partition not used yet"); - // now return stream for the specified *output* partition which will - // read from the channel - let (_tx, mut rx, reservation) = state - .channels - .remove(&partition) - .expect("partition not used yet"); + (rx, reservation, Arc::clone(&state.abort_helper)) + }; - if self.preserve_order { - // Store streams from all the input partitions: - let input_streams = rx - .into_iter() - .map(|receiver| { - Box::pin(PerPartitionStream { - schema: self.schema(), - receiver, - drop_helper: Arc::clone(&state.abort_helper), - reservation: reservation.clone(), - }) as SendableRecordBatchStream - }) - .collect::>(); - // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. - - // Get existing ordering to use for merging - let sort_exprs = self.sort_exprs().unwrap_or(&[]); - - // Merge streams (while preserving ordering) coming from - // input partitions to this partition: - let fetch = None; - let merge_reservation = - MemoryConsumer::new(format!("{}[Merge {partition}]", self.name())) - .register(context.memory_pool()); - streaming_merge( - input_streams, - self.schema(), - sort_exprs, - BaselineMetrics::new(&self.metrics, partition), - context.session_config().batch_size(), - fetch, - merge_reservation, - ) - } else { - Ok(Box::pin(RepartitionStream { - num_input_partitions, - num_input_partitions_processed: 0, - schema: self.input.schema(), - input: rx.swap_remove(0), - drop_helper: Arc::clone(&state.abort_helper), - reservation, - })) - } + trace!( + "Before returning stream in {}::execute for partition: {}", + name, + partition + ); + + if preserve_order { + // Store streams from all the input partitions: + let input_streams = rx + .into_iter() + .map(|receiver| { + Box::pin(PerPartitionStream { + schema: Arc::clone(&schema_captured), + receiver, + drop_helper: Arc::clone(&abort_helper), + reservation: reservation.clone(), + }) as SendableRecordBatchStream + }) + .collect::>(); + // Note that receiver size (`rx.len()`) and `num_input_partitions` are same. + + // Merge streams (while preserving ordering) coming from + // input partitions to this partition: + let fetch = None; + let merge_reservation = + MemoryConsumer::new(format!("{}[Merge {partition}]", name)) + .register(context.memory_pool()); + streaming_merge( + input_streams, + schema_captured, + &sort_exprs, + BaselineMetrics::new(&metrics, partition), + context.session_config().batch_size(), + fetch, + merge_reservation, + ) + } else { + Ok(Box::pin(RepartitionStream { + num_input_partitions, + num_input_partitions_processed: 0, + schema: input.schema(), + input: rx.swap_remove(0), + drop_helper: abort_helper, + reservation, + }) as SendableRecordBatchStream) + } + }) + .try_flatten(); + let stream = RecordBatchStreamAdapter::new(schema, stream); + Ok(Box::pin(stream)) } fn metrics(&self) -> Option { @@ -606,10 +673,7 @@ impl RepartitionExec { Ok(RepartitionExec { input, partitioning, - state: Arc::new(Mutex::new(RepartitionExecState { - channels: HashMap::new(), - abort_helper: Arc::new(Vec::new()), - })), + state: Default::default(), metrics: ExecutionPlanMetricsSet::new(), preserve_order, cache, @@ -951,6 +1015,7 @@ mod tests { use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use futures::FutureExt; + use tokio::task::JoinSet; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1240,7 +1305,10 @@ mod tests { std::mem::drop(output_stream0); // Now, start sending input - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); // output stream 1 should *not* error and have one of the input batches let batches = crate::common::collect(output_stream1).await.unwrap(); @@ -1277,7 +1345,10 @@ mod tests { let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); let output_stream1 = exec.execute(1, task_ctx.clone()).unwrap(); - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); let batches_without_drop = crate::common::collect(output_stream1).await.unwrap(); // run some checks on the result @@ -1299,7 +1370,10 @@ mod tests { // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); - input.wait().await; + let mut background_task = JoinSet::new(); + background_task.spawn(async move { + input.wait().await; + }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); assert_eq!(batches_without_drop, batches_with_drop);