@@ -55,6 +55,8 @@ pi_result cuda_piKernelGetGroupInfo(pi_kernel kernel, pi_device device,
5555// / \endcond
5656}
5757
58+ using _pi_stream_guard = std::unique_lock<std::mutex>;
59+
5860// / A PI platform stores all known PI devices,
5961// / in the CUDA plugin this is just a vector of
6062// / available devices since initialization is done
@@ -387,6 +389,11 @@ struct _pi_queue {
387389
388390 std::vector<native_type> compute_streams_;
389391 std::vector<native_type> transfer_streams_;
392+ // delay_compute_ keeps track of which streams have been recently reused and
393+ // their next use should be delayed. If a stream has been recently reused it
394+ // will be skipped the next time it would be selected round-robin style. When
395+ // skipped, its delay flag is cleared.
396+ std::vector<bool > delay_compute_;
390397 _pi_context *context_;
391398 _pi_device *device_;
392399 pi_queue_properties properties_;
@@ -399,6 +406,10 @@ struct _pi_queue {
399406 unsigned int last_sync_compute_streams_;
400407 unsigned int last_sync_transfer_streams_;
401408 unsigned int flags_;
409+ // When compute_stream_sync_mutex_ and compute_stream_mutex_ both need to be
410+ // locked at the same time, compute_stream_sync_mutex_ should be locked first
411+ // to avoid deadlocks
412+ std::mutex compute_stream_sync_mutex_;
402413 std::mutex compute_stream_mutex_;
403414 std::mutex transfer_stream_mutex_;
404415 bool has_ownership_;
@@ -408,7 +419,8 @@ struct _pi_queue {
408419 _pi_device *device, pi_queue_properties properties,
409420 unsigned int flags, bool backend_owns = true )
410421 : compute_streams_{std::move (compute_streams)},
411- transfer_streams_{std::move (transfer_streams)}, context_{context},
422+ transfer_streams_{std::move (transfer_streams)},
423+ delay_compute_ (compute_streams_.size(), false ), context_{context},
412424 device_{device}, properties_{properties}, refCount_{1 }, eventCount_{0 },
413425 compute_stream_idx_{0 }, transfer_stream_idx_{0 },
414426 num_compute_streams_{0 }, num_transfer_streams_{0 },
@@ -425,10 +437,47 @@ struct _pi_queue {
425437
426438 // get_next_compute/transfer_stream() functions return streams from
427439 // appropriate pools in round-robin fashion
428- native_type get_next_compute_stream ();
440+ native_type get_next_compute_stream (pi_uint32 *stream_token = nullptr );
441+ // this overload tries select a stream that was used by one of dependancies.
442+ // If that is not possible returns a new stream. If a stream is reused it
443+ // returns a lock that needs to remain locked as long as the stream is in use
444+ native_type get_next_compute_stream (pi_uint32 num_events_in_wait_list,
445+ const pi_event *event_wait_list,
446+ _pi_stream_guard &guard,
447+ pi_uint32 *stream_token = nullptr );
429448 native_type get_next_transfer_stream ();
430449 native_type get () { return get_next_compute_stream (); };
431450
451+ bool has_been_synchronized (pi_uint32 stream_token) {
452+ // stream token not associated with one of the compute streams
453+ if (stream_token == std::numeric_limits<pi_uint32>::max ()) {
454+ return false ;
455+ }
456+ return last_sync_compute_streams_ >= stream_token;
457+ }
458+
459+ bool can_reuse_stream (pi_uint32 stream_token) {
460+ // stream token not associated with one of the compute streams
461+ if (stream_token == std::numeric_limits<pi_uint32>::max ()) {
462+ return true ;
463+ }
464+ // If the command represented by the stream token was not the last command
465+ // enqueued to the stream we can not reuse the stream - we need to allow for
466+ // commands enqueued after it and the one we are about to enqueue to run
467+ // concurrently
468+ bool is_last_command =
469+ (compute_stream_idx_ - stream_token) <= compute_streams_.size ();
470+ // If there was a barrier enqueued to the queue after the command
471+ // represented by the stream token we should not reuse the stream, as we can
472+ // not take that stream into account for the bookkeeping for the next
473+ // barrier - such a stream would not be synchronized with. Performance-wise
474+ // it does not matter that we do not reuse the stream, as the work
475+ // represented by the stream token is guaranteed to be complete by the
476+ // barrier before any work we are about to enqueue to the stream will start,
477+ // so the event does not need to be synchronized with.
478+ return is_last_command && !has_been_synchronized (stream_token);
479+ }
480+
432481 template <typename T> void for_each_stream (T &&f) {
433482 {
434483 std::lock_guard<std::mutex> compute_guard (compute_stream_mutex_);
@@ -451,30 +500,39 @@ struct _pi_queue {
451500 }
452501
453502 template <typename T> void sync_streams (T &&f) {
454- auto sync = [&f](const std::vector<CUstream> &streams, unsigned int start,
455- unsigned int stop) {
503+ auto sync_compute = [&f, &streams = compute_streams_,
504+ &delay = delay_compute_](unsigned int start,
505+ unsigned int stop) {
506+ for (unsigned int i = start; i < stop; i++) {
507+ f (streams[i]);
508+ delay[i] = false ;
509+ }
510+ };
511+ auto sync_transfer = [&f, &streams = transfer_streams_](unsigned int start,
512+ unsigned int stop) {
456513 for (unsigned int i = start; i < stop; i++) {
457514 f (streams[i]);
458515 }
459516 };
460517 {
461518 unsigned int size = static_cast <unsigned int >(compute_streams_.size ());
519+ std::lock_guard compute_sync_guard (compute_stream_sync_mutex_);
462520 std::lock_guard<std::mutex> compute_guard (compute_stream_mutex_);
463521 unsigned int start = last_sync_compute_streams_;
464522 unsigned int end = num_compute_streams_ < size
465523 ? num_compute_streams_
466524 : compute_stream_idx_.load ();
467525 last_sync_compute_streams_ = end;
468526 if (end - start >= size) {
469- sync (compute_streams_, 0 , size);
527+ sync_compute ( 0 , size);
470528 } else {
471529 start %= size;
472530 end %= size;
473531 if (start < end) {
474- sync (compute_streams_, start, end);
532+ sync_compute ( start, end);
475533 } else {
476- sync (compute_streams_, start, size);
477- sync (compute_streams_, 0 , end);
534+ sync_compute ( start, size);
535+ sync_compute ( 0 , end);
478536 }
479537 }
480538 }
@@ -488,15 +546,15 @@ struct _pi_queue {
488546 : transfer_stream_idx_.load ();
489547 last_sync_transfer_streams_ = end;
490548 if (end - start >= size) {
491- sync (transfer_streams_, 0 , size);
549+ sync_transfer ( 0 , size);
492550 } else {
493551 start %= size;
494552 end %= size;
495553 if (start < end) {
496- sync (transfer_streams_, start, end);
554+ sync_transfer ( start, end);
497555 } else {
498- sync (transfer_streams_, start, size);
499- sync (transfer_streams_, 0 , end);
556+ sync_transfer ( start, size);
557+ sync_transfer ( 0 , end);
500558 }
501559 }
502560 }
@@ -538,6 +596,8 @@ struct _pi_event {
538596
539597 CUstream get_stream () const noexcept { return stream_; }
540598
599+ pi_uint32 get_stream_token () const noexcept { return streamToken_; }
600+
541601 pi_command_type get_command_type () const noexcept { return commandType_; }
542602
543603 pi_uint32 get_reference_count () const noexcept { return refCount_; }
@@ -581,9 +641,11 @@ struct _pi_event {
581641 pi_uint64 get_end_time () const ;
582642
583643 // construct a native CUDA. This maps closely to the underlying CUDA event.
584- static pi_event make_native (pi_command_type type, pi_queue queue,
585- CUstream stream) {
586- return new _pi_event (type, queue->get_context (), queue, stream);
644+ static pi_event
645+ make_native (pi_command_type type, pi_queue queue, CUstream stream,
646+ pi_uint32 stream_token = std::numeric_limits<pi_uint32>::max()) {
647+ return new _pi_event (type, queue->get_context (), queue, stream,
648+ stream_token);
587649 }
588650
589651 pi_result release ();
@@ -594,7 +656,7 @@ struct _pi_event {
594656 // This constructor is private to force programmers to use the make_native /
595657 // make_user static members in order to create a pi_event for CUDA.
596658 _pi_event (pi_command_type type, pi_context context, pi_queue queue,
597- CUstream stream);
659+ CUstream stream, pi_uint32 stream_token );
598660
599661 pi_command_type commandType_; // The type of command associated with event.
600662
@@ -610,6 +672,7 @@ struct _pi_event {
610672 // PI event has started or not
611673 //
612674
675+ pi_uint32 streamToken_;
613676 pi_uint32 eventId_; // Queue identifier of the event.
614677
615678 native_type evEnd_; // CUDA event handle. If this _pi_event represents a user
0 commit comments