@@ -70,8 +70,18 @@ class WaitInfo {
7070 }
7171};
7272
73+ template <class T >
7374inline static WaitInfo getWaitInfo (uint32_t numEventsInWaitList,
74- const ur_event_handle_t *phEventWaitList) {
75+ const ur_event_handle_t *phEventWaitList,
76+ const T &scheduler) {
77+ if (numEventsInWaitList && !scheduler.CanWaitInThread ()) {
78+ // Waiting for dependent events in threads launched by the enqueue may
79+ // not work correctly for some backend/schedulers, so we have the safe
80+ // option here to wait in the main thread instead (potentially at the
81+ // expense of performance).
82+ urEventWait (numEventsInWaitList, phEventWaitList);
83+ numEventsInWaitList = 0 ;
84+ }
7585 return native_cpu::WaitInfo (numEventsInWaitList, phEventWaitList);
7686}
7787
@@ -151,7 +161,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
151161
152162 auto &tp = hQueue->getDevice ()->tp ;
153163 const size_t numParallelThreads = tp.num_threads ();
154- std::vector<std::future< void >> futures ;
164+ auto Tasks = native_cpu::getScheduler (tp) ;
155165 auto numWG0 = ndr.GlobalSize [0 ] / ndr.LocalSize [0 ];
156166 auto numWG1 = ndr.GlobalSize [1 ] / ndr.LocalSize [1 ];
157167 auto numWG2 = ndr.GlobalSize [2 ] / ndr.LocalSize [2 ];
@@ -162,7 +172,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
162172 auto kernel = std::make_unique<ur_kernel_handle_t_>(*hKernel);
163173 kernel->updateMemPool (numParallelThreads);
164174
165- auto InEvents = native_cpu::getWaitInfo (numEventsInWaitList, phEventWaitList);
175+ auto InEvents =
176+ native_cpu::getWaitInfo (numEventsInWaitList, phEventWaitList, Tasks);
166177
167178 const size_t numWG = numWG0 * numWG1 * numWG2;
168179 const size_t numWGPerThread = numWG / numParallelThreads;
@@ -177,42 +188,41 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
177188 rangeEnd[0 ] = rangeEnd[3 ] % numWG0;
178189 rangeEnd[1 ] = (rangeEnd[3 ] / numWG0) % numWG1;
179190 rangeEnd[2 ] = rangeEnd[3 ] / (numWG0 * numWG1);
180- futures. emplace_back (tp. schedule_task (
181- [ndr, InEvents, &kernel = *kernel, rangeStart, rangeEnd = rangeEnd[3 ],
182- numWG0, numWG1, numParallelThreads](size_t threadId) {
183- auto state = getState (ndr);
184- InEvents.wait ();
185- for (size_t g0 = rangeStart[0 ], g1 = rangeStart[1 ],
186- g2 = rangeStart[ 2 ], g3 = rangeStart[3 ];
187- g3 < rangeEnd; ++g3) {
191+ Tasks. schedule ([ndr, InEvents, &kernel = *kernel, rangeStart,
192+ rangeEnd = rangeEnd[3 ], numWG0, numWG1 ,
193+ numParallelThreads](size_t threadId) {
194+ auto state = getState (ndr);
195+ InEvents.wait ();
196+ for (size_t g0 = rangeStart[0 ], g1 = rangeStart[1 ], g2 = rangeStart[ 2 ],
197+ g3 = rangeStart[3 ];
198+ g3 < rangeEnd; ++g3) {
188199#ifdef NATIVECPU_USE_OCK
189- state.update (g0, g1, g2);
190- kernel._subhandler (
191- kernel. getArgs (numParallelThreads, threadId). data (), &state);
200+ state.update (g0, g1, g2);
201+ kernel._subhandler (kernel. getArgs (numParallelThreads, threadId). data (),
202+ &state);
192203#else
193- for (size_t local2 = 0 ; local2 < ndr.LocalSize [2 ]; ++local2) {
194- for (size_t local1 = 0 ; local1 < ndr.LocalSize [1 ]; ++local1) {
195- for (size_t local0 = 0 ; local0 < ndr.LocalSize [0 ]; ++local0) {
196- state.update (g0, g1, g2, local0, local1, local2);
197- kernel._subhandler (
198- kernel.getArgs (numParallelThreads, threadId).data (),
199- &state);
200- }
201- }
204+ for (size_t local2 = 0 ; local2 < ndr.LocalSize [2 ]; ++local2) {
205+ for (size_t local1 = 0 ; local1 < ndr.LocalSize [1 ]; ++local1) {
206+ for (size_t local0 = 0 ; local0 < ndr.LocalSize [0 ]; ++local0) {
207+ state.update (g0, g1, g2, local0, local1, local2);
208+ kernel._subhandler (
209+ kernel.getArgs (numParallelThreads, threadId).data (), &state);
202210 }
211+ }
212+ }
203213#endif
204- if (++g0 == numWG0) {
205- g0 = 0 ;
206- if (++g1 == numWG1) {
207- g1 = 0 ;
208- ++g2;
209- }
210- }
214+ if (++g0 == numWG0) {
215+ g0 = 0 ;
216+ if (++g1 == numWG1) {
217+ g1 = 0 ;
218+ ++g2;
211219 }
212- }));
220+ }
221+ }
222+ });
213223 rangeStart = rangeEnd;
214224 }
215- event->set_futures (futures );
225+ event->set_tasksinfo (Tasks. getMovedTaskInfo () );
216226
217227 if (phEvent) {
218228 *phEvent = event;
@@ -248,14 +258,14 @@ withTimingEvent(ur_command_t command_type, ur_queue_handle_t hQueue,
248258 return result;
249259 }
250260 auto &tp = hQueue->getDevice ()->tp ;
251- std::vector<std::future< void >> futures ;
261+ auto Tasks = native_cpu::getScheduler (tp) ;
252262 auto InEvents =
253- native_cpu::getWaitInfo (numEventsInWaitList, phEventWaitList);
254- futures. emplace_back (tp. schedule_task ([f, InEvents](size_t ) {
263+ native_cpu::getWaitInfo (numEventsInWaitList, phEventWaitList, Tasks );
264+ Tasks. schedule ([f, InEvents](size_t ) {
255265 InEvents.wait ();
256266 f ();
257- })) ;
258- event->set_futures (futures );
267+ });
268+ event->set_tasksinfo (Tasks. getMovedTaskInfo () );
259269 event->set_callback (
260270 [event, InEvents = InEvents.getUniquePtr ()]() { event->tick_end (); });
261271 return UR_RESULT_SUCCESS;
@@ -465,7 +475,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
465475 // TODO: error checking
466476 // TODO: handle async
467477 void *startingPtr = hBuffer->_mem + offset;
468- unsigned steps = size / patternSize;
478+ size_t steps = size / patternSize;
469479 for (unsigned i = 0 ; i < steps; i++) {
470480 memcpy (static_cast <int8_t *>(startingPtr) + i * patternSize, pPattern,
471481 patternSize);
@@ -575,7 +585,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
575585 break ;
576586 }
577587 default : {
578- for (unsigned int step{0 }; step < size; step += patternSize) {
588+ for (size_t step{0 }; step < size; step += patternSize) {
579589 auto *dest = reinterpret_cast <void *>(
580590 reinterpret_cast <uint8_t *>(ptr) + step);
581591 memcpy (dest, pPattern, patternSize);
0 commit comments