@@ -272,21 +272,36 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
272272 : SYCLBINContentCopy{ContentCopy (SYCLBINContent, SYCLBINSize)},
273273 SYCLBINContentCopySize{SYCLBINSize},
274274 ParsedSYCLBIN (SYCLBIN{SYCLBINContentCopy.get (), SYCLBINSize}) {
275- size_t NumJITBinaries = 0 , NumNativeBinaries = 0 ;
276- for (const SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules ) {
277- NumJITBinaries += AM.IRModules .size ();
278- NumNativeBinaries += AM.NativeDeviceCodeImages .size ();
279- }
280- DeviceBinaries.reserve (NumJITBinaries + NumNativeBinaries);
281- JITDeviceBinaryImages.reserve (NumJITBinaries);
282- NativeDeviceBinaryImages.reserve (NumNativeBinaries);
275+ AbstractModuleDescriptors = std::unique_ptr<AbstractModuleDesc[]>(
276+ new AbstractModuleDesc[ParsedSYCLBIN.AbstractModules .size ()]);
277+
278+ size_t NumBinaries = 0 ;
279+ for (const SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules )
280+ NumBinaries += AM.IRModules .size () + AM.NativeDeviceCodeImages .size ();
281+ DeviceBinaries.reserve (NumBinaries);
282+ BinaryImages = std::unique_ptr<RTDeviceBinaryImage[]>(
283+ new RTDeviceBinaryImage[NumBinaries]);
284+
285+ RTDeviceBinaryImage *CurrentBinaryImagesStart = BinaryImages.get ();
286+ for (size_t I = 0 ; I < getNumAbstractModules (); ++I) {
287+ SYCLBIN::AbstractModule &AM = ParsedSYCLBIN.AbstractModules [I];
288+ AbstractModuleDesc &AMDesc = AbstractModuleDescriptors[I];
289+
290+ // Set up the abstract module descriptor.
291+ AMDesc.NumJITBinaries = AM.IRModules .size ();
292+ AMDesc.NumNativeBinaries = AM.NativeDeviceCodeImages .size ();
293+ AMDesc.JITBinaries = CurrentBinaryImagesStart;
294+ AMDesc.NativeBinaries = CurrentBinaryImagesStart + AMDesc.NumJITBinaries ;
295+ CurrentBinaryImagesStart +=
296+ AMDesc.NumJITBinaries + AM.NativeDeviceCodeImages .size ();
283297
284- for (SYCLBIN::AbstractModule &AM : ParsedSYCLBIN.AbstractModules ) {
285298 // Construct properties from SYCLBIN metadata.
286299 std::vector<_sycl_device_binary_property_set_struct> &BinPropertySets =
287300 convertAbstractModuleProperties (AM);
288301
289- for (SYCLBIN::IRModule &IRM : AM.IRModules ) {
302+ for (size_t J = 0 ; J < AM.IRModules .size (); ++J) {
303+ SYCLBIN::IRModule &IRM = AM.IRModules [J];
304+
290305 sycl_device_binary_struct &DeviceBinary = DeviceBinaries.emplace_back ();
291306 DeviceBinary.Version = SYCL_DEVICE_BINARY_VERSION;
292307 DeviceBinary.Kind = 4 ;
@@ -309,11 +324,12 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
309324 DeviceBinary.PropertySetsEnd =
310325 BinPropertySets.data () + BinPropertySets.size ();
311326 // Create an image from it.
312- JITDeviceBinaryImages. emplace_back ( &DeviceBinary) ;
327+ AMDesc. JITBinaries [J] = RTDeviceBinaryImage{ &DeviceBinary} ;
313328 }
314329
315- for (const SYCLBIN::NativeDeviceCodeImage &NDCI :
316- AM.NativeDeviceCodeImages ) {
330+ for (size_t J = 0 ; J < AM.NativeDeviceCodeImages .size (); ++J) {
331+ const SYCLBIN::NativeDeviceCodeImage &NDCI = AM.NativeDeviceCodeImages [J];
332+
317333 assert (NDCI.Metadata != nullptr );
318334 PropertySet &NDCIMetadataProps = (*NDCI.Metadata )
319335 [PropertySetRegistry::SYCLBIN_NATIVE_DEVICE_CODE_IMAGE_METADATA];
@@ -346,7 +362,7 @@ SYCLBINBinaries::SYCLBINBinaries(const char *SYCLBINContent, size_t SYCLBINSize)
346362 DeviceBinary.PropertySetsEnd =
347363 BinPropertySets.data () + BinPropertySets.size ();
348364 // Create an image from it.
349- NativeDeviceBinaryImages. emplace_back ( &DeviceBinary) ;
365+ AMDesc. NativeBinaries [J] = RTDeviceBinaryImage{ &DeviceBinary} ;
350366 }
351367 }
352368}
@@ -394,33 +410,44 @@ SYCLBINBinaries::convertAbstractModuleProperties(SYCLBIN::AbstractModule &AM) {
394410}
395411
396412std::vector<const RTDeviceBinaryImage *>
397- SYCLBINBinaries::getBestCompatibleImages (device_impl &Dev) {
398- auto SelectCompatibleImages =
399- [&](const std::vector<RTDeviceBinaryImage> &Imgs) {
400- std::vector<const RTDeviceBinaryImage *> CompatImgs;
401- for (const RTDeviceBinaryImage &Img : Imgs)
402- if (doesDevSupportDeviceRequirements (Dev, Img) &&
403- doesImageTargetMatchDevice (Img, Dev))
404- CompatImgs.push_back (&Img);
405- return CompatImgs;
406- };
407-
408- // Try with native images first.
409- std::vector<const RTDeviceBinaryImage *> NativeImgs =
410- SelectCompatibleImages (NativeDeviceBinaryImages);
411- if (!NativeImgs.empty ())
412- return NativeImgs;
413-
414- // If there were no native images, pick JIT images.
415- return SelectCompatibleImages (JITDeviceBinaryImages);
413+ SYCLBINBinaries::getBestCompatibleImages (device_impl &Dev, bundle_state State) {
414+ auto GetCompatibleImage = [&](const RTDeviceBinaryImage *Imgs,
415+ size_t NumImgs) {
416+ const RTDeviceBinaryImage *CompatImagePtr =
417+ std::find_if (Imgs, Imgs + NumImgs, [&](const RTDeviceBinaryImage &Img) {
418+ return doesDevSupportDeviceRequirements (Dev, Img) &&
419+ doesImageTargetMatchDevice (Img, Dev);
420+ });
421+ return (CompatImagePtr != Imgs + NumImgs) ? CompatImagePtr : nullptr ;
422+ };
423+
424+ std::vector<const RTDeviceBinaryImage *> Images;
425+ for (size_t I = 0 ; I < getNumAbstractModules (); ++I) {
426+ const AbstractModuleDesc &AMDesc = AbstractModuleDescriptors[I];
427+ // If the target state is executable, try with native images first.
428+ if (State == bundle_state::executable) {
429+ if (const RTDeviceBinaryImage *CompatImagePtr = GetCompatibleImage (
430+ AMDesc.NativeBinaries , AMDesc.NumNativeBinaries )) {
431+ Images.push_back (CompatImagePtr);
432+ continue ;
433+ }
434+ }
435+
436+ // Otherwise, select the first compatible JIT binary.
437+ if (const RTDeviceBinaryImage *CompatImagePtr =
438+ GetCompatibleImage (AMDesc.JITBinaries , AMDesc.NumJITBinaries ))
439+ Images.push_back (CompatImagePtr);
440+ }
441+ return Images;
416442}
417443
418444std::vector<const RTDeviceBinaryImage *>
419- SYCLBINBinaries::getBestCompatibleImages (devices_range Devs) {
445+ SYCLBINBinaries::getBestCompatibleImages (devices_range Devs,
446+ bundle_state State) {
420447 std::set<const RTDeviceBinaryImage *> Images;
421448 for (device_impl &Dev : Devs) {
422449 std::vector<const RTDeviceBinaryImage *> BestImagesForDev =
423- getBestCompatibleImages (Dev);
450+ getBestCompatibleImages (Dev, State );
424451 Images.insert (BestImagesForDev.cbegin (), BestImagesForDev.cend ());
425452 }
426453 return {Images.cbegin (), Images.cend ()};
0 commit comments