@@ -149,6 +149,8 @@ pi_result _pi_event::start() {
149
149
}
150
150
151
151
isStarted_ = true ;
152
+ // let observers know that the event is "submitted"
153
+ trigger_callback (get_execution_status ());
152
154
return result;
153
155
}
154
156
@@ -195,6 +197,22 @@ pi_result _pi_event::record() {
195
197
196
198
try {
197
199
result = PI_CHECK_ERROR (cuEventRecord (evEnd_, cuStream));
200
+
201
+ result = cuda_piEventRetain (this );
202
+ try {
203
+ result = PI_CHECK_ERROR (cuLaunchHostFunc (
204
+ cuStream,
205
+ [](void *userData) {
206
+ pi_event event = reinterpret_cast <pi_event>(userData);
207
+ event->set_event_complete ();
208
+ cuda_piEventRelease (event);
209
+ },
210
+ this ));
211
+ } catch (...) {
212
+ // If host function fails to enqueue we must release the event here
213
+ result = cuda_piEventRelease (this );
214
+ throw ;
215
+ }
198
216
} catch (pi_result error) {
199
217
result = error;
200
218
}
@@ -215,6 +233,7 @@ pi_result _pi_event::wait() {
215
233
if (is_native_event ()) {
216
234
try {
217
235
retErr = PI_CHECK_ERROR (cuEventSynchronize (evEnd_));
236
+ isCompleted_ = true ;
218
237
} catch (pi_result error) {
219
238
retErr = error;
220
239
}
@@ -226,30 +245,12 @@ pi_result _pi_event::wait() {
226
245
retErr = PI_SUCCESS;
227
246
}
228
247
229
- return retErr;
230
- }
231
-
232
- pi_event_status _pi_event::get_execution_status () const noexcept {
248
+ auto is_success = retErr == PI_SUCCESS;
249
+ auto status = is_success ? get_execution_status () : pi_int32 (retErr);
233
250
234
- if (!is_recorded ()) {
235
- return PI_EVENT_SUBMITTED;
236
- }
237
-
238
- if (is_native_event ()) {
239
- // native event status
240
-
241
- auto status = cuEventQuery (get ());
242
- if (status == CUDA_ERROR_NOT_READY) {
243
- return PI_EVENT_RUNNING;
244
- } else if (status != CUDA_SUCCESS) {
245
- cl::sycl::detail::pi::die (" Invalid CUDA event status" );
246
- }
247
- return PI_EVENT_COMPLETE;
248
- } else {
249
- // user event status
251
+ trigger_callback (status);
250
252
251
- return is_completed () ? PI_EVENT_COMPLETE : PI_EVENT_RUNNING;
252
- }
253
+ return retErr;
253
254
}
254
255
255
256
// iterates over the event wait list, returns correct pi_result error codes.
@@ -2516,24 +2517,21 @@ pi_result cuda_piEventGetInfo(pi_event event, pi_event_info param_name,
2516
2517
2517
2518
switch (param_name) {
2518
2519
case PI_EVENT_INFO_COMMAND_QUEUE:
2519
- return getInfo<pi_queue> (param_value_size, param_value,
2520
- param_value_size_ret, event->get_queue ());
2520
+ return getInfo (param_value_size, param_value, param_value_size_ret ,
2521
+ event->get_queue ());
2521
2522
case PI_EVENT_INFO_COMMAND_TYPE:
2522
- return getInfo<pi_command_type>(param_value_size, param_value,
2523
- param_value_size_ret,
2524
- event->get_command_type ());
2523
+ return getInfo (param_value_size, param_value, param_value_size_ret,
2524
+ event->get_command_type ());
2525
2525
case PI_EVENT_INFO_REFERENCE_COUNT:
2526
- return getInfo<pi_uint32>(param_value_size, param_value,
2527
- param_value_size_ret,
2528
- event->get_reference_count ());
2526
+ return getInfo (param_value_size, param_value, param_value_size_ret,
2527
+ event->get_reference_count ());
2529
2528
case PI_EVENT_INFO_COMMAND_EXECUTION_STATUS: {
2530
- return getInfo<pi_event_status>(param_value_size, param_value,
2531
- param_value_size_ret,
2532
- event->get_execution_status ());
2529
+ return getInfo (param_value_size, param_value, param_value_size_ret,
2530
+ static_cast <pi_event_status>(event->get_execution_status ()));
2533
2531
}
2534
2532
case PI_EVENT_INFO_CONTEXT:
2535
- return getInfo<pi_context> (param_value_size, param_value,
2536
- param_value_size_ret, event->get_context ());
2533
+ return getInfo (param_value_size, param_value, param_value_size_ret ,
2534
+ event->get_context ());
2537
2535
default :
2538
2536
PI_HANDLE_UNKNOWN_PARAM_NAME (param_name);
2539
2537
}
@@ -2568,13 +2566,21 @@ pi_result cuda_piEventGetProfilingInfo(
2568
2566
return {};
2569
2567
}
2570
2568
2571
- pi_result cuda_piEventSetCallback (
2572
- pi_event event, pi_int32 command_exec_callback_type,
2573
- void (*pfn_notify)(pi_event event, pi_int32 event_command_status,
2574
- void *user_data),
2575
- void *user_data) {
2576
- cl::sycl::detail::pi::die (" cuda_piEventSetCallback not implemented" );
2577
- return {};
2569
+ pi_result cuda_piEventSetCallback (pi_event event,
2570
+ pi_int32 command_exec_callback_type,
2571
+ pfn_notify notify, void *user_data) {
2572
+
2573
+ assert (event);
2574
+ assert (notify);
2575
+ assert (command_exec_callback_type == PI_EVENT_SUBMITTED ||
2576
+ command_exec_callback_type == PI_EVENT_RUNNING ||
2577
+ command_exec_callback_type == PI_EVENT_COMPLETE);
2578
+ event_callback callback (pi_event_status (command_exec_callback_type), notify,
2579
+ user_data);
2580
+
2581
+ event->set_event_callback (callback);
2582
+
2583
+ return PI_SUCCESS;
2578
2584
}
2579
2585
2580
2586
pi_result cuda_piEventSetStatus (pi_event event, pi_int32 execution_status) {
@@ -2587,7 +2593,7 @@ pi_result cuda_piEventSetStatus(pi_event event, pi_int32 execution_status) {
2587
2593
}
2588
2594
2589
2595
if (execution_status == PI_EVENT_COMPLETE) {
2590
- return event->set_user_event_complete ();
2596
+ return event->set_event_complete ();
2591
2597
} else if (execution_status < 0 ) {
2592
2598
// TODO: A negative integer value causes all enqueued commands that wait
2593
2599
// on this user event to be terminated.
0 commit comments