1919//! partitions to M output partitions based on a partitioning scheme, optionally
2020//! maintaining the order of the input rows in the output.
2121
22+ use std:: fmt:: { Debug , Formatter } ;
2223use std:: pin:: Pin ;
2324use std:: sync:: Arc ;
2425use std:: task:: { Context , Poll } ;
@@ -45,7 +46,7 @@ use arrow::compute::take_arrays;
4546use arrow:: datatypes:: { SchemaRef , UInt32Type } ;
4647use datafusion_common:: config:: ConfigOptions ;
4748use datafusion_common:: utils:: transpose;
48- use datafusion_common:: HashMap ;
49+ use datafusion_common:: { internal_err , HashMap } ;
4950use datafusion_common:: { not_impl_err, DataFusionError , Result } ;
5051use datafusion_common_runtime:: SpawnedTask ;
5152use datafusion_execution:: memory_pool:: MemoryConsumer ;
@@ -67,9 +68,8 @@ type MaybeBatch = Option<Result<RecordBatch>>;
6768type InputPartitionsToCurrentPartitionSender = Vec < DistributionSender < MaybeBatch > > ;
6869type InputPartitionsToCurrentPartitionReceiver = Vec < DistributionReceiver < MaybeBatch > > ;
6970
70- /// Inner state of [`RepartitionExec`].
7171#[ derive( Debug ) ]
72- struct RepartitionExecState {
72+ struct ConsumingInputStreamsState {
7373 /// Channels for sending batches from input partitions to output partitions.
7474 /// Key is the partition number.
7575 channels : HashMap <
@@ -85,16 +85,97 @@ struct RepartitionExecState {
8585 abort_helper : Arc < Vec < SpawnedTask < ( ) > > > ,
8686}
8787
88+ /// Inner state of [`RepartitionExec`].
89+ enum RepartitionExecState {
90+ /// Not initialized yet. This is the default state stored in the RepartitionExec node
91+ /// upon instantiation.
92+ NotInitialized ,
93+ /// Input streams are initialized, but they are still not being consumed. The node
94+ /// transitions to this state when the arrow's RecordBatch stream is created in
95+ /// RepartitionExec::execute(), but before any message is polled.
96+ InputStreamsInitialized ( Vec < ( SendableRecordBatchStream , RepartitionMetrics ) > ) ,
97+ /// The input streams are being consumed. The node transitions to this state when
98+ /// the first message in the arrow's RecordBatch stream is consumed.
99+ ConsumingInputStreams ( ConsumingInputStreamsState ) ,
100+ }
101+
102+ impl Default for RepartitionExecState {
103+ fn default ( ) -> Self {
104+ Self :: NotInitialized
105+ }
106+ }
107+
108+ impl Debug for RepartitionExecState {
109+ fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
110+ match self {
111+ RepartitionExecState :: NotInitialized => write ! ( f, "NotInitialized" ) ,
112+ RepartitionExecState :: InputStreamsInitialized ( v) => {
113+ write ! ( f, "InputStreamsInitialized({:?})" , v. len( ) )
114+ }
115+ RepartitionExecState :: ConsumingInputStreams ( v) => {
116+ write ! ( f, "ConsumingInputStreams({v:?})" )
117+ }
118+ }
119+ }
120+ }
121+
88122impl RepartitionExecState {
89- fn new (
123+ fn ensure_input_streams_initialized (
124+ & mut self ,
125+ input : Arc < dyn ExecutionPlan > ,
126+ metrics : ExecutionPlanMetricsSet ,
127+ output_partitions : usize ,
128+ ctx : Arc < TaskContext > ,
129+ ) -> Result < ( ) > {
130+ if !matches ! ( self , RepartitionExecState :: NotInitialized ) {
131+ return Ok ( ( ) ) ;
132+ }
133+
134+ let num_input_partitions = input. output_partitioning ( ) . partition_count ( ) ;
135+ let mut streams_and_metrics = Vec :: with_capacity ( num_input_partitions) ;
136+
137+ for i in 0 ..num_input_partitions {
138+ let metrics = RepartitionMetrics :: new ( i, output_partitions, & metrics) ;
139+
140+ let timer = metrics. fetch_time . timer ( ) ;
141+ let stream = input. execute ( i, Arc :: clone ( & ctx) ) ?;
142+ timer. done ( ) ;
143+
144+ streams_and_metrics. push ( ( stream, metrics) ) ;
145+ }
146+ * self = RepartitionExecState :: InputStreamsInitialized ( streams_and_metrics) ;
147+ Ok ( ( ) )
148+ }
149+
150+ fn consume_input_streams (
151+ & mut self ,
90152 input : Arc < dyn ExecutionPlan > ,
91- partitioning : Partitioning ,
92153 metrics : ExecutionPlanMetricsSet ,
154+ partitioning : Partitioning ,
93155 preserve_order : bool ,
94156 name : String ,
95157 context : Arc < TaskContext > ,
96- ) -> Self {
97- let num_input_partitions = input. output_partitioning ( ) . partition_count ( ) ;
158+ ) -> Result < & mut ConsumingInputStreamsState > {
159+ let streams_and_metrics = match self {
160+ RepartitionExecState :: NotInitialized => {
161+ self . ensure_input_streams_initialized (
162+ input,
163+ metrics,
164+ partitioning. partition_count ( ) ,
165+ Arc :: clone ( & context) ,
166+ ) ?;
167+ let RepartitionExecState :: InputStreamsInitialized ( value) = self else {
168+ // This cannot happen, as ensure_input_streams_initialized() was just called,
169+ // but the compiler does not know.
170+ return internal_err ! ( "Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized" ) ;
171+ } ;
172+ value
173+ }
174+ RepartitionExecState :: ConsumingInputStreams ( value) => return Ok ( value) ,
175+ RepartitionExecState :: InputStreamsInitialized ( value) => value,
176+ } ;
177+
178+ let num_input_partitions = streams_and_metrics. len ( ) ;
98179 let num_output_partitions = partitioning. partition_count ( ) ;
99180
100181 let ( txs, rxs) = if preserve_order {
@@ -129,23 +210,21 @@ impl RepartitionExecState {
129210
130211 // launch one async task per *input* partition
131212 let mut spawned_tasks = Vec :: with_capacity ( num_input_partitions) ;
132- for i in 0 ..num_input_partitions {
213+ for ( i, ( stream, metrics) ) in
214+ std:: mem:: take ( streams_and_metrics) . into_iter ( ) . enumerate ( )
215+ {
133216 let txs: HashMap < _ , _ > = channels
134217 . iter ( )
135218 . map ( |( partition, ( tx, _rx, reservation) ) | {
136219 ( * partition, ( tx[ i] . clone ( ) , Arc :: clone ( reservation) ) )
137220 } )
138221 . collect ( ) ;
139222
140- let r_metrics = RepartitionMetrics :: new ( i, num_output_partitions, & metrics) ;
141-
142223 let input_task = SpawnedTask :: spawn ( RepartitionExec :: pull_from_input (
143- Arc :: clone ( & input) ,
144- i,
224+ stream,
145225 txs. clone ( ) ,
146226 partitioning. clone ( ) ,
147- r_metrics,
148- Arc :: clone ( & context) ,
227+ metrics,
149228 ) ) ;
150229
151230 // In a separate task, wait for each input to be done
@@ -158,28 +237,17 @@ impl RepartitionExecState {
158237 ) ) ;
159238 spawned_tasks. push ( wait_for_task) ;
160239 }
161-
162- Self {
240+ * self = Self :: ConsumingInputStreams ( ConsumingInputStreamsState {
163241 channels,
164242 abort_helper : Arc :: new ( spawned_tasks) ,
243+ } ) ;
244+ match self {
245+ RepartitionExecState :: ConsumingInputStreams ( value) => Ok ( value) ,
246+ _ => unreachable ! ( ) ,
165247 }
166248 }
167249}
168250
169- /// Lazily initialized state
170- ///
171- /// Note that the state is initialized ONCE for all partitions by a single task(thread).
172- /// This may take a short while. It is also like that multiple threads
173- /// call execute at the same time, because we have just started "target partitions" tasks
174- /// which is commonly set to the number of CPU cores and all call execute at the same time.
175- ///
176- /// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles
177- /// in a mutex lock but instead allow other threads to do something useful.
178- ///
179- /// Uses a parking_lot `Mutex` to control other accesses as they are very short duration
180- /// (e.g. removing channels on completion) where the overhead of `await` is not warranted.
181- type LazyState = Arc < tokio:: sync:: OnceCell < Mutex < RepartitionExecState > > > ;
182-
183251/// A utility that can be used to partition batches based on [`Partitioning`]
184252pub struct BatchPartitioner {
185253 state : BatchPartitionerState ,
@@ -406,8 +474,9 @@ impl BatchPartitioner {
406474pub struct RepartitionExec {
407475 /// Input execution plan
408476 input : Arc < dyn ExecutionPlan > ,
409- /// Inner state that is initialized when the first output stream is created.
410- state : LazyState ,
477+ /// Inner state that is initialized when the parent calls .execute() on this node
478+ /// and consumed as soon as the parent starts consuming this node.
479+ state : Arc < Mutex < RepartitionExecState > > ,
411480 /// Execution metrics
412481 metrics : ExecutionPlanMetricsSet ,
413482 /// Boolean flag to decide whether to preserve ordering. If true means
@@ -486,11 +555,7 @@ impl RepartitionExec {
486555}
487556
488557impl DisplayAs for RepartitionExec {
489- fn fmt_as (
490- & self ,
491- t : DisplayFormatType ,
492- f : & mut std:: fmt:: Formatter ,
493- ) -> std:: fmt:: Result {
558+ fn fmt_as ( & self , t : DisplayFormatType , f : & mut Formatter ) -> std:: fmt:: Result {
494559 match t {
495560 DisplayFormatType :: Default | DisplayFormatType :: Verbose => {
496561 write ! (
@@ -583,7 +648,6 @@ impl ExecutionPlan for RepartitionExec {
583648 partition
584649 ) ;
585650
586- let lazy_state = Arc :: clone ( & self . state ) ;
587651 let input = Arc :: clone ( & self . input ) ;
588652 let partitioning = self . partitioning ( ) . clone ( ) ;
589653 let metrics = self . metrics . clone ( ) ;
@@ -595,30 +659,31 @@ impl ExecutionPlan for RepartitionExec {
595659 // Get existing ordering to use for merging
596660 let sort_exprs = self . sort_exprs ( ) . cloned ( ) . unwrap_or_default ( ) ;
597661
662+ let state = Arc :: clone ( & self . state ) ;
663+ if let Some ( mut state) = state. try_lock ( ) {
664+ state. ensure_input_streams_initialized (
665+ Arc :: clone ( & input) ,
666+ metrics. clone ( ) ,
667+ partitioning. partition_count ( ) ,
668+ Arc :: clone ( & context) ,
669+ ) ?;
670+ }
671+
598672 let stream = futures:: stream:: once ( async move {
599673 let num_input_partitions = input. output_partitioning ( ) . partition_count ( ) ;
600674
601- let input_captured = Arc :: clone ( & input) ;
602- let metrics_captured = metrics. clone ( ) ;
603- let name_captured = name. clone ( ) ;
604- let context_captured = Arc :: clone ( & context) ;
605- let state = lazy_state
606- . get_or_init ( || async move {
607- Mutex :: new ( RepartitionExecState :: new (
608- input_captured,
609- partitioning,
610- metrics_captured,
611- preserve_order,
612- name_captured,
613- context_captured,
614- ) )
615- } )
616- . await ;
617-
618675 // lock scope
619676 let ( mut rx, reservation, abort_helper) = {
620677 // lock mutexes
621678 let mut state = state. lock ( ) ;
679+ let state = state. consume_input_streams (
680+ Arc :: clone ( & input) ,
681+ metrics. clone ( ) ,
682+ partitioning,
683+ preserve_order,
684+ name. clone ( ) ,
685+ Arc :: clone ( & context) ,
686+ ) ?;
622687
623688 // now return stream for the specified *output* partition which will
624689 // read from the channel
@@ -853,24 +918,17 @@ impl RepartitionExec {
853918 ///
854919 /// txs hold the output sending channels for each output partition
855920 async fn pull_from_input (
856- input : Arc < dyn ExecutionPlan > ,
857- partition : usize ,
921+ mut stream : SendableRecordBatchStream ,
858922 mut output_channels : HashMap <
859923 usize ,
860924 ( DistributionSender < MaybeBatch > , SharedMemoryReservation ) ,
861925 > ,
862926 partitioning : Partitioning ,
863927 metrics : RepartitionMetrics ,
864- context : Arc < TaskContext > ,
865928 ) -> Result < ( ) > {
866929 let mut partitioner =
867930 BatchPartitioner :: try_new ( partitioning, metrics. repartition_time . clone ( ) ) ?;
868931
869- // execute the child operator
870- let timer = metrics. fetch_time . timer ( ) ;
871- let mut stream = input. execute ( partition, context) ?;
872- timer. done ( ) ;
873-
874932 // While there are still outputs to send to, keep pulling inputs
875933 let mut batches_until_yield = partitioner. num_partitions ( ) ;
876934 while !output_channels. is_empty ( ) {
@@ -1118,6 +1176,7 @@ mod tests {
11181176 use datafusion_common_runtime:: JoinSet ;
11191177 use datafusion_execution:: runtime_env:: RuntimeEnvBuilder ;
11201178 use insta:: assert_snapshot;
1179+ use itertools:: Itertools ;
11211180
11221181 #[ tokio:: test]
11231182 async fn one_to_many_round_robin ( ) -> Result < ( ) > {
@@ -1298,15 +1357,9 @@ mod tests {
12981357 let partitioning = Partitioning :: RoundRobinBatch ( 1 ) ;
12991358 let exec = RepartitionExec :: try_new ( Arc :: new ( input) , partitioning) . unwrap ( ) ;
13001359
1301- // Note: this should pass (the stream can be created) but the
1302- // error when the input is executed should get passed back
1303- let output_stream = exec. execute ( 0 , task_ctx) . unwrap ( ) ;
1304-
13051360 // Expect that an error is returned
1306- let result_string = crate :: common:: collect ( output_stream)
1307- . await
1308- . unwrap_err ( )
1309- . to_string ( ) ;
1361+ let result_string = exec. execute ( 0 , task_ctx) . err ( ) . unwrap ( ) . to_string ( ) ;
1362+
13101363 assert ! (
13111364 result_string. contains( "ErrorExec, unsurprisingly, errored in partition 0" ) ,
13121365 "actual: {result_string}"
@@ -1496,7 +1549,14 @@ mod tests {
14961549 } ) ;
14971550 let batches_with_drop = crate :: common:: collect ( output_stream1) . await . unwrap ( ) ;
14981551
1499- assert_eq ! ( batches_without_drop, batches_with_drop) ;
1552+ fn sort ( batch : Vec < RecordBatch > ) -> Vec < RecordBatch > {
1553+ batch
1554+ . into_iter ( )
1555+ . sorted_by_key ( |b| format ! ( "{b:?}" ) )
1556+ . collect ( )
1557+ }
1558+
1559+ assert_eq ! ( sort( batches_without_drop) , sort( batches_with_drop) ) ;
15001560 }
15011561
15021562 fn str_batches_to_vec ( batches : & [ RecordBatch ] ) -> Vec < & str > {
0 commit comments