@@ -48,18 +48,17 @@ bool is_source_kernel_bundle_supported(
4848
4949namespace detail {
5050
51- static bool checkAllDevicesAreInContext (const std::vector<device> & Devices,
51+ inline bool checkAllDevicesAreInContext (devices_range Devices,
5252 const context &Context) {
53- return std::all_of (
54- Devices. begin (), Devices. end (), [&Context](const device &Dev) {
55- return getSyclObjImpl (Context)->isDeviceValid (* getSyclObjImpl ( Dev) );
56- });
53+ return std::all_of (Devices. begin (), Devices. end (),
54+ [&Context](device_impl &Dev) {
55+ return getSyclObjImpl (Context)->isDeviceValid (Dev);
56+ });
5757}
5858
59- static bool checkAllDevicesHaveAspect (const std::vector<device> &Devices,
60- aspect Aspect) {
59+ inline bool checkAllDevicesHaveAspect (devices_range Devices, aspect Aspect) {
6160 return std::all_of (Devices.begin (), Devices.end (),
62- [&Aspect](const device &Dev) { return Dev.has (Aspect); });
61+ [&Aspect](device_impl &Dev) { return Dev.has (Aspect); });
6362}
6463
6564namespace syclex = sycl::ext::oneapi::experimental;
@@ -100,9 +99,10 @@ class kernel_bundle_impl
10099 }
101100
102101public:
103- kernel_bundle_impl (context Ctx, std::vector<device> Devs, bundle_state State,
102+ kernel_bundle_impl (context Ctx, devices_range Devs, bundle_state State,
104103 private_tag)
105- : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
104+ : MContext(std::move(Ctx)),
105+ MDevices (Devs.to<std::vector<device_impl *>>()), MState(State) {
106106
107107 common_ctor_checks ();
108108
@@ -112,8 +112,9 @@ class kernel_bundle_impl
112112 }
113113
114114 // Interop constructor used by make_kernel
115- kernel_bundle_impl (context Ctx, std::vector<device> Devs, private_tag)
116- : MContext(Ctx), MDevices(Devs), MState(bundle_state::executable) {
115+ kernel_bundle_impl (context Ctx, devices_range Devs, private_tag)
116+ : MContext(Ctx), MDevices(Devs.to<std::vector<device_impl *>>()),
117+ MState(bundle_state::executable) {
117118 if (!checkAllDevicesAreInContext (Devs, Ctx))
118119 throw sycl::exception (
119120 make_error_code (errc::invalid),
@@ -122,9 +123,9 @@ class kernel_bundle_impl
122123 }
123124
124125 // Interop constructor
125- kernel_bundle_impl (context Ctx, std::vector<device> Devs,
126+ kernel_bundle_impl (context Ctx, devices_range Devs,
126127 device_image_plain &DevImage, private_tag Tag)
127- : kernel_bundle_impl(std::move(Ctx), std::move( Devs) , Tag) {
128+ : kernel_bundle_impl(std::move(Ctx), Devs, Tag) {
128129 MDeviceImages.emplace_back (DevImage);
129130 MUniqueDeviceImages.emplace_back (DevImage);
130131 }
@@ -133,22 +134,19 @@ class kernel_bundle_impl
133134 // Have one constructor because sycl::build and sycl::compile have the same
134135 // signature
135136 kernel_bundle_impl (const kernel_bundle<bundle_state::input> &InputBundle,
136- std::vector<device> Devs, const property_list &PropList,
137+ devices_range Devs, const property_list &PropList,
137138 bundle_state TargetState, private_tag)
138- : MContext(InputBundle.get_context()), MDevices(std::move(Devs)),
139- MState (TargetState) {
139+ : MContext(InputBundle.get_context()),
140+ MDevices(Devs.to<std::vector<device_impl *>>()), MState(TargetState) {
140141
141142 kernel_bundle_impl &InputBundleImpl = *getSyclObjImpl (InputBundle);
142143 MSpecConstValues = InputBundleImpl.get_spec_const_map_ref ();
143144
144- const std::vector<device> &InputBundleDevices =
145- InputBundleImpl.get_devices ();
145+ devices_range InputBundleDevices = InputBundleImpl.get_devices ();
146146 const bool AllDevsAssociatedWithInputBundle =
147- std::all_of (MDevices.begin (), MDevices.end (),
148- [&InputBundleDevices](const device &Dev) {
149- return InputBundleDevices.end () !=
150- std::find (InputBundleDevices.begin (),
151- InputBundleDevices.end (), Dev);
147+ std::all_of (get_devices ().begin (), get_devices ().end (),
148+ [&InputBundleDevices](device_impl &Dev) {
149+ return InputBundleDevices.contains (Dev);
152150 });
153151 if (MDevices.empty () || !AllDevsAssociatedWithInputBundle)
154152 throw sycl::exception (
@@ -163,8 +161,8 @@ class kernel_bundle_impl
163161 for (const DevImgPlainWithDeps &DevImgWithDeps :
164162 InputBundleImpl.MDeviceImages ) {
165163 // Skip images which are not compatible with devices provided
166- if (std::none_of (MDevices .begin (), MDevices .end (),
167- [&DevImgWithDeps](const device &Dev) {
164+ if (std::none_of (get_devices () .begin (), get_devices () .end (),
165+ [&DevImgWithDeps](device_impl &Dev) {
168166 return getSyclObjImpl (DevImgWithDeps.getMain ())
169167 ->compatible_with_device (Dev);
170168 }))
@@ -206,8 +204,9 @@ class kernel_bundle_impl
206204 // Matches sycl::link
207205 kernel_bundle_impl (
208206 const std::vector<kernel_bundle<bundle_state::object>> &ObjectBundles,
209- std::vector<device> Devs, const property_list &PropList, private_tag)
210- : MDevices(std::move(Devs)), MState(bundle_state::executable) {
207+ devices_range Devs, const property_list &PropList, private_tag)
208+ : MDevices(Devs.to<std::vector<device_impl *>>()),
209+ MState(bundle_state::executable) {
211210 if (MDevices.empty ())
212211 throw sycl::exception (make_error_code (errc::invalid),
213212 " Vector of devices is empty" );
@@ -226,16 +225,15 @@ class kernel_bundle_impl
226225 // Check if any of the devices in devs are not in the set of associated
227226 // devices for any of the bundles in ObjectBundles
228227 const bool AllDevsAssociatedWithInputBundles = std::all_of (
229- MDevices.begin (), MDevices.end (), [&ObjectBundles](const device &Dev) {
228+ get_devices ().begin (), get_devices ().end (),
229+ [&ObjectBundles](device_impl &Dev) {
230230 // Number of devices is expected to be small
231231 return std::all_of (
232232 ObjectBundles.begin (), ObjectBundles.end (),
233233 [&Dev](const kernel_bundle<bundle_state::object> &KernelBundle) {
234- const std::vector<device> & BundleDevices =
234+ devices_range BundleDevices =
235235 getSyclObjImpl (KernelBundle)->get_devices ();
236- return BundleDevices.end () != std::find (BundleDevices.begin (),
237- BundleDevices.end (),
238- Dev);
236+ return BundleDevices.contains (Dev);
239237 });
240238 });
241239 if (!AllDevsAssociatedWithInputBundles)
@@ -363,41 +361,33 @@ class kernel_bundle_impl
363361 }
364362
365363 // Create a link graph and clone it for each device.
366- device_impl &FirstDevice = *getSyclObjImpl (MDevices[0 ]);
367- std::map<std::shared_ptr<device_impl>, LinkGraph<device_image_plain>>
368- DevImageLinkGraphs;
364+ device_impl &FirstDevice = get_devices ().front ();
365+ std::map<device_impl *, LinkGraph<device_image_plain>> DevImageLinkGraphs;
369366 const auto &FirstGraph =
370367 DevImageLinkGraphs
371- .emplace (FirstDevice. shared_from_this () ,
368+ .emplace (& FirstDevice,
372369 LinkGraph<device_image_plain>{DevImages, Dependencies})
373370 .first ->second ;
374- for (size_t I = 1 ; I < MDevices.size (); ++I)
375- DevImageLinkGraphs.emplace (getSyclObjImpl (MDevices[I]),
376- FirstGraph.Clone ());
371+ for (device_impl &Dev : get_devices ())
372+ DevImageLinkGraphs.emplace (&Dev, FirstGraph.Clone ());
377373
378374 // Poison the images based on whether the corresponding device supports it.
379375 for (auto &GraphIt : DevImageLinkGraphs) {
380- device Dev = createSyclObjFromImpl<device>( GraphIt.first ) ;
376+ device_impl & Dev = * GraphIt.first ;
381377 GraphIt.second .Poison ([&Dev](const device_image_plain &DevImg) {
382378 return !getSyclObjImpl (DevImg)->compatible_with_device (Dev);
383379 });
384380 }
385381
386382 // Unify graphs after poisoning.
387- std::map<std::vector<std::shared_ptr<device_impl>>,
388- LinkGraph<device_image_plain>>
383+ std::map<std::vector<device_impl *>, LinkGraph<device_image_plain>>
389384 UnifiedGraphs = UnifyGraphs (DevImageLinkGraphs);
390385
391386 // Link based on the resulting graphs.
392387 for (auto &GraphIt : UnifiedGraphs) {
393- std::vector<device> DeviceGroup;
394- DeviceGroup.reserve (GraphIt.first .size ());
395- for (const auto &DeviceImgImpl : GraphIt.first )
396- DeviceGroup.emplace_back (createSyclObjFromImpl<device>(DeviceImgImpl));
397-
398388 std::vector<device_image_plain> LinkedResults =
399389 detail::ProgramManager::getInstance ().link (
400- GraphIt.second .GetNodeValues (), DeviceGroup , PropList);
390+ GraphIt.second .GetNodeValues (), GraphIt. first , PropList);
401391 MDeviceImages.insert (MDeviceImages.end (), LinkedResults.begin (),
402392 LinkedResults.end ());
403393 MUniqueDeviceImages.insert (MUniqueDeviceImages.end (),
@@ -410,8 +400,8 @@ class kernel_bundle_impl
410400 for (const DevImgPlainWithDeps *DeviceImageWithDeps :
411401 ImagesWithSpecConsts) {
412402 // Skip images which are not compatible with devices provided
413- if (std::none_of (MDevices .begin (), MDevices .end (),
414- [DeviceImageWithDeps](const device &Dev) {
403+ if (std::none_of (get_devices () .begin (), get_devices () .end (),
404+ [DeviceImageWithDeps](device_impl &Dev) {
415405 return getSyclObjImpl (DeviceImageWithDeps->getMain ())
416406 ->compatible_with_device (Dev);
417407 }))
@@ -438,10 +428,11 @@ class kernel_bundle_impl
438428 }
439429 }
440430
441- kernel_bundle_impl (context Ctx, std::vector<device> Devs,
431+ kernel_bundle_impl (context Ctx, devices_range Devs,
442432 const std::vector<kernel_id> &KernelIDs,
443433 bundle_state State, private_tag)
444- : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
434+ : MContext(std::move(Ctx)),
435+ MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {
445436
446437 common_ctor_checks ();
447438
@@ -450,10 +441,11 @@ class kernel_bundle_impl
450441 fillUniqueDeviceImages ();
451442 }
452443
453- kernel_bundle_impl (context Ctx, std::vector<device> Devs,
444+ kernel_bundle_impl (context Ctx, devices_range Devs,
454445 const DevImgSelectorImpl &Selector, bundle_state State,
455446 private_tag)
456- : MContext(std::move(Ctx)), MDevices(std::move(Devs)), MState(State) {
447+ : MContext(std::move(Ctx)),
448+ MDevices(Devs.to<std::vector<device_impl *>>()), MState(State) {
457449
458450 common_ctor_checks ();
459451
@@ -548,7 +540,9 @@ class kernel_bundle_impl
548540 kernel_bundle_impl (const context &Context, syclex::source_language Lang,
549541 const std::string &Src, include_pairs_t IncludePairsVec,
550542 private_tag)
551- : MContext(Context), MDevices(Context.get_devices()),
543+ : MContext(Context), MDevices(getSyclObjImpl(Context)
544+ ->getDevices()
545+ .to<std::vector<device_impl *>>()),
552546 MDeviceImages{device_image_plain{device_image_impl::create (
553547 Src, MContext, MDevices, Lang, std::move (IncludePairsVec))}},
554548 MUniqueDeviceImages{MDeviceImages[0 ].getMain ()},
@@ -560,7 +554,9 @@ class kernel_bundle_impl
560554 // construct from source bytes
561555 kernel_bundle_impl (const context &Context, syclex::source_language Lang,
562556 const std::vector<std::byte> &Bytes, private_tag)
563- : MContext(Context), MDevices(Context.get_devices()),
557+ : MContext(Context), MDevices(getSyclObjImpl(Context)
558+ ->getDevices ()
559+ .to<std::vector<device_impl *>>()),
564560 MDeviceImages{device_image_plain{
565561 device_image_impl::create (Bytes, MContext, MDevices, Lang)}},
566562 MUniqueDeviceImages{MDeviceImages[0 ].getMain ()},
@@ -571,11 +567,11 @@ class kernel_bundle_impl
571567 // oneapi_ext_kernel_compiler
572568 // construct from built source files
573569 kernel_bundle_impl (
574- const context &Context, const std::vector<device> & Devs,
570+ const context &Context, devices_range Devs,
575571 std::vector<device_image_plain> &&DevImgs,
576572 std::vector<std::shared_ptr<ManagedDeviceBinaries>> &&DevBinaries,
577573 bundle_state State, private_tag)
578- : MContext(Context), MDevices(Devs),
574+ : MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>() ),
579575 MSharedDeviceBinaries (std::move(DevBinaries)),
580576 MUniqueDeviceImages(std::move(DevImgs)), MState(State) {
581577 common_ctor_checks ();
@@ -587,10 +583,11 @@ class kernel_bundle_impl
587583 }
588584
589585 // SYCLBIN constructor
590- kernel_bundle_impl (const context &Context, const std::vector<device> & Devs,
586+ kernel_bundle_impl (const context &Context, devices_range Devs,
591587 const sycl::span<char > Bytes, bundle_state State,
592588 private_tag)
593- : MContext(Context), MDevices(Devs), MState(State) {
589+ : MContext(Context), MDevices(Devs.to<std::vector<device_impl *>>()),
590+ MState(State) {
594591 common_ctor_checks ();
595592
596593 auto &SYCLBIN = MSYCLBINs.emplace_back (
@@ -622,7 +619,7 @@ class kernel_bundle_impl
622619 }
623620
624621 std::shared_ptr<kernel_bundle_impl> build_from_source (
625- const std::vector<device> Devices,
622+ devices_range Devices,
626623 const std::vector<sycl::detail::string_view> &BuildOptions,
627624 std::string *LogPtr,
628625 const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
@@ -645,7 +642,7 @@ class kernel_bundle_impl
645642 }
646643
647644 std::shared_ptr<kernel_bundle_impl> compile_from_source (
648- const std::vector<device> Devices,
645+ devices_range Devices,
649646 const std::vector<sycl::detail::string_view> &CompileOptions,
650647 std::string *LogPtr,
651648 const std::vector<sycl::detail::string_view> &RegisteredKernelNames) {
@@ -733,8 +730,9 @@ class kernel_bundle_impl
733730 void *ext_oneapi_get_device_global_address (const std::string &Name,
734731 const device &Dev) const {
735732 DeviceGlobalMapEntry *Entry = getDeviceGlobalEntry (Name);
733+ device_impl &DeviceImpl = *getSyclObjImpl (Dev);
736734
737- if (std::find (MDevices. begin (), MDevices. end (), Dev) == MDevices. end ( )) {
735+ if (! get_devices (). contains (DeviceImpl )) {
738736 throw sycl::exception (make_error_code (errc::invalid),
739737 " kernel_bundle not built for device" );
740738 }
@@ -745,7 +743,6 @@ class kernel_bundle_impl
745743 " 'device_image_scope' property" );
746744 }
747745
748- device_impl &DeviceImpl = *getSyclObjImpl (Dev);
749746 bool SupportContextMemcpy = false ;
750747 DeviceImpl.getAdapter ().call <UrApiKind::urDeviceGetInfo>(
751748 DeviceImpl.getHandleRef (),
@@ -772,7 +769,7 @@ class kernel_bundle_impl
772769
773770 context get_context () const noexcept { return MContext; }
774771
775- const std::vector<device> & get_devices () const noexcept { return MDevices; }
772+ devices_range get_devices () const noexcept { return MDevices; }
776773
777774 std::vector<kernel_id> get_kernel_ids () const {
778775 // Collect kernel ids from all device images, then remove duplicates
@@ -1111,7 +1108,7 @@ class kernel_bundle_impl
11111108 }
11121109
11131110 context MContext;
1114- std::vector<device > MDevices;
1111+ std::vector<device_impl * > MDevices;
11151112
11161113 // For sycl_jit, building from source may have produced sycl binaries that
11171114 // the kernel_bundles now manage.
0 commit comments