@@ -746,19 +746,28 @@ class device_image_impl {
746746 MRTCBinInfo->MIncludePairs , BuildOptions, LogPtr);
747747
748748 auto &PM = detail::ProgramManager::getInstance ();
749- std::vector<std::shared_ptr<device_image_impl>> Result;
750- Result.reserve (Binaries->NumDeviceBinaries );
749+
750+ // Add all binaries and keep the images for processing.
751+ std::vector<std::pair<RTDeviceBinaryImage *,
752+ std::shared_ptr<std::vector<kernel_id>>>>
753+ NewImages;
754+ NewImages.reserve (Binaries->NumDeviceBinaries );
751755 for (int I = 0 ; I < Binaries->NumDeviceBinaries ; I++) {
752756 sycl_device_binary Binary = &(Binaries->DeviceBinaries [I]);
753-
754757 RTDeviceBinaryImage *NewImage = nullptr ;
755758 auto KernelIDs = std::make_shared<std::vector<kernel_id>>();
756759 PM.addImage (Binary, &NewImage, KernelIDs.get ());
760+ if (NewImage)
761+ NewImages.push_back (
762+ std::make_pair (std::move (NewImage), std::move (KernelIDs)));
763+ }
757764
758- // If the image is empty, we can skip it.
759- if (!NewImage)
760- continue ;
761-
765+ // Now bring all images into the proper state. Note that we do this in a
766+ // separate pass over NewImages to make sure dependency images have been
767+ // registered beforehand.
768+ std::vector<std::shared_ptr<device_image_impl>> Result;
769+ Result.reserve (NewImages.size ());
770+ for (auto &[NewImage, KernelIDs] : NewImages) {
762771 std::set<std::string> KernelNames;
763772 std::unordered_map<std::string, std::string> MangledKernelNames;
764773 std::unordered_set<std::string> DeviceGlobalIDSet;
@@ -843,7 +852,26 @@ class device_image_impl {
843852 std::move (KernelNames), std::move (MangledKernelNames),
844853 std::string{Prefix}, std::move (DGRegs));
845854
846- DevImgPlainWithDeps ImgWithDeps{DevImgImpl};
855+ // Resolve dependencies.
856+ // TODO: Consider making a collectDeviceImageDeps variant that takes a
857+ // set reference and inserts into that instead.
858+ std::set<RTDeviceBinaryImage *> ImgDeps;
859+ for (const device &Device : Devices) {
860+ std::set<RTDeviceBinaryImage *> DevImgDeps =
861+ PM.collectDeviceImageDeps (*NewImage, Device);
862+ ImgDeps.insert (DevImgDeps.begin (), DevImgDeps.end ());
863+ }
864+
865+ // Pack main image and dependencies together.
866+ std::vector<device_image_plain> NewImageAndDeps;
867+ NewImageAndDeps.reserve (1 + ImgDeps.size ());
868+ NewImageAndDeps.push_back (std::move (
869+ createSyclObjFromImpl<device_image_plain>(std::move (DevImgImpl))));
870+ for (RTDeviceBinaryImage *ImgDep : ImgDeps)
871+ NewImageAndDeps.push_back (PM.createDependencyImage (
872+ MContext, Devices, ImgDep, bundle_state::input));
873+
874+ DevImgPlainWithDeps ImgWithDeps (std::move (NewImageAndDeps));
847875 PM.bringSYCLDeviceImageToState (ImgWithDeps, bundle_state::executable);
848876 Result.push_back (getSyclObjImpl (ImgWithDeps.getMain ()));
849877 }
0 commit comments