@@ -615,22 +615,25 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
615615}
616616
617617// Quick check to see whether BinImage is a compiler-generated device image.
618- static bool isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
618+ bool ProgramManager:: isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
619619 // SYCL devicelib image.
620- if (BinImage->getDeviceLibMetadata ().isAvailable ())
620+ if ((m_Bfloat16DeviceLibImages[0 ].get () == BinImage) ||
621+ m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
621622 return true ;
622623
623624 return false ;
624625}
625626
626- static bool isSpecialDeviceImageShouldBeUsed (RTDeviceBinaryImage *BinImage,
627- const device &Dev) {
627+ bool ProgramManager:: isSpecialDeviceImageShouldBeUsed (
628+ RTDeviceBinaryImage *BinImage, const device &Dev) {
628629 // Decide whether a devicelib image should be used.
629- if (BinImage->getDeviceLibMetadata ().isAvailable ()) {
630- const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
631- BinImage->getDeviceLibMetadata ();
632- uint32_t DeviceLibMeta =
633- DeviceBinaryProperty (*(DeviceLibMetaProp.begin ())).asUint32 ();
630+ int Bfloat16DeviceLibVersion = -1 ;
631+ if (m_Bfloat16DeviceLibImages[0 ].get () == BinImage)
632+ Bfloat16DeviceLibVersion = 0 ;
633+ else if (m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
634+ Bfloat16DeviceLibVersion = 1 ;
635+
636+ if (Bfloat16DeviceLibVersion != -1 ) {
634637 // Currently, only bfloat conversion devicelib are supported, so the prop
635638 // DeviceLibMeta are only used to represent fallback or native version.
636639 // For bfloat16 conversion devicelib, we have fallback and native version.
@@ -644,7 +647,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
644647 detail::getSyclObjImpl (Dev);
645648 std::string NativeBF16ExtName = " cl_intel_bfloat16_conversions" ;
646649 bool NativeBF16Supported = (DeviceImpl->has_extension (NativeBF16ExtName));
647- return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
650+ return NativeBF16Supported ==
651+ (Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
648652 }
649653
650654 return false ;
@@ -1838,87 +1842,69 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
18381842 return {};
18391843}
18401844
1841- static bool shouldSkipEmptyImage (sycl_device_binary RawImg, bool IsRTC) {
1842- // For bfloat16 device library image, we should keep it. However, in some
1843- // scenario, __sycl_register_lib can be called multiple times and the same
1844- // bfloat16 device library image may be handled multiple times which is not
1845- // needed. 2 static bool variables are created to record whether native or
1846- // fallback bfloat16 device library image has been handled, if yes, we just
1847- // need to skip it.
1848- // We cannot prevent redundant loads of device library images if they are part
1849- // of a runtime-compiled device binary, as these will be freed when the
1850- // corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
1851- // on the presence of RTC device library images.
1845+ static bool isBfloat16DeviceLibImage (sycl_device_binary RawImg,
1846+ uint32_t *LibVersion = nullptr ) {
18521847 sycl_device_binary_property_set ImgPS;
1853- static bool IsNativeBF16DeviceLibHandled = false ;
1854- static bool IsFallbackBF16DeviceLibHandled = false ;
18551848 for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
18561849 ++ImgPS) {
18571850 if (ImgPS->Name &&
18581851 !strcmp (__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name )) {
1852+ if (!LibVersion)
1853+ return true ;
1854+
1855+ // Valid version for bfloat16 device library is 0(fallback), 1(native).
1856+ *LibVersion = 2 ;
18591857 sycl_device_binary_property ImgP;
18601858 for (ImgP = ImgPS->PropertiesBegin ; ImgP != ImgPS->PropertiesEnd ;
18611859 ++ImgP) {
18621860 if (ImgP->Name && !strcmp (" bfloat16" , ImgP->Name ) &&
18631861 (ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
18641862 break ;
18651863 }
1866- if (ImgP == ImgPS->PropertiesEnd )
1867- return true ;
1868-
1869- // A valid bfloat16 device library image is found here.
1870- // If it originated from RTC, we cannot skip it, but do not mark it as
1871- // being present.
1872- if (IsRTC)
1873- return false ;
1874-
1875- // Otherwise, we need to check whether it has been handled already.
1876- uint32_t BF16NativeVal = DeviceBinaryProperty (ImgP).asUint32 ();
1877- if (((BF16NativeVal == 0 ) && IsFallbackBF16DeviceLibHandled) ||
1878- ((BF16NativeVal == 1 ) && IsNativeBF16DeviceLibHandled))
1879- return true ;
1880-
1881- if (BF16NativeVal == 0 )
1882- IsFallbackBF16DeviceLibHandled = true ;
1883- else
1884- IsNativeBF16DeviceLibHandled = true ;
1885-
1886- return false ;
1864+ if (ImgP != ImgPS->PropertiesEnd )
1865+ *LibVersion = DeviceBinaryProperty (ImgP).asUint32 ();
1866+ return true ;
18871867 }
18881868 }
1889- return true ;
1869+
1870+ return false ;
18901871}
18911872
1892- static bool isCompiledAtRuntime (sycl_device_binaries DeviceBinary) {
1893- // Check whether the first device binary contains a legacy format offload
1894- // entry with a `$` in its name.
1895- if (DeviceBinary->NumDeviceBinaries > 0 ) {
1896- sycl_device_binary Binary = DeviceBinary->DeviceBinaries ;
1897- if (Binary->EntriesBegin != Binary->EntriesEnd ) {
1898- sycl_offload_entry Entry = Binary->EntriesBegin ;
1899- if (!Entry->IsNewOffloadEntryType () &&
1900- std::string_view{Entry->name }.find (' $' ) != std::string_view::npos) {
1901- return true ;
1902- }
1903- }
1873+ static sycl_device_binary_property_set
1874+ getExportedSymbolPS (sycl_device_binary RawImg) {
1875+ sycl_device_binary_property_set ImgPS;
1876+ for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
1877+ ++ImgPS) {
1878+ if (ImgPS->Name &&
1879+ !strcmp (__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name ))
1880+ return ImgPS;
19041881 }
1905- return false ;
1882+
1883+ return nullptr ;
1884+ }
1885+
1886+ static bool shouldSkipEmptyImage (sycl_device_binary RawImg) {
1887+ // For bfloat16 device library image, we should keep it although it doesn't
1888+ // include any kernel.
1889+ if (isBfloat16DeviceLibImage (RawImg))
1890+ return false ;
1891+
1892+ // We may extend the logic here other than bfloat16 device library image.
1893+ return true ;
19061894}
19071895
19081896void ProgramManager::addImages (sycl_device_binaries DeviceBinary) {
19091897 const bool DumpImages = std::getenv (" SYCL_DUMP_IMAGES" ) && !m_UseSpvFile;
1910- const bool IsRTC = isCompiledAtRuntime (DeviceBinary);
19111898 for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
19121899 sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
19131900 const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
19141901 const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
1915- // If the image does not contain kernels, skip it unless it is one of the
1916- // bfloat16 device libraries, and it wasn't loaded before or resulted from
1917- // runtime compilation.
1918- if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg, IsRTC))
1902+ if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg))
19191903 continue ;
19201904
19211905 std::unique_ptr<RTDeviceBinaryImage> Img;
1906+ bool IsBfloat16DeviceLib = false ;
1907+ uint32_t Bfloat16DeviceLibVersion = 0 ;
19221908 if (isDeviceImageCompressed (RawImg))
19231909#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
19241910 Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1928,25 +1914,63 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
19281914 " SYCL RT was built without ZSTD support."
19291915 " Aborting. " );
19301916#endif
1931- else
1932- Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1917+ else {
1918+ IsBfloat16DeviceLib =
1919+ isBfloat16DeviceLibImage (RawImg, &Bfloat16DeviceLibVersion);
1920+ if (!IsBfloat16DeviceLib)
1921+ Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1922+ }
19331923
19341924 static uint32_t SequenceID = 0 ;
19351925
1936- // Fill the kernel argument mask map
1937- const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1938- Img->getKernelParamOptInfo ();
1939- if (KPOIRange.isAvailable ()) {
1940- KernelNameToArgMaskMap &ArgMaskMap =
1941- m_EliminatedKernelArgMasks[Img.get ()];
1942- for (const auto &Info : KPOIRange)
1943- ArgMaskMap[Info->Name ] =
1944- createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1926+ // Fill the kernel argument mask map, no need to do this for bfloat16
1927+ // device library image since it doesn't include any kernel.
1928+ if (!IsBfloat16DeviceLib) {
1929+ const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1930+ Img->getKernelParamOptInfo ();
1931+ if (KPOIRange.isAvailable ()) {
1932+ KernelNameToArgMaskMap &ArgMaskMap =
1933+ m_EliminatedKernelArgMasks[Img.get ()];
1934+ for (const auto &Info : KPOIRange)
1935+ ArgMaskMap[Info->Name ] =
1936+ createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1937+ }
19451938 }
19461939
19471940 // Fill maps for kernel bundles
19481941 std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
19491942
1943+ // For bfloat16 device library image, it doesn't include any kernel, device
1944+ // global, virtual function, so just skip adding it to any related maps.
1945+ // The bfloat16 device library are provided by compiler and may be used by
1946+ // different sycl device images, program manager will own single copy for
1947+ // native and fallback version bfloat16 device library, these device
1948+ // library images will not be erased unless program manager is destroyed.
1949+ {
1950+ if (IsBfloat16DeviceLib) {
1951+ assert ((Bfloat16DeviceLibVersion < 2 ) &&
1952+ " Invalid Bfloat16 Device Library Index." );
1953+ if (m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion].get ())
1954+ continue ;
1955+ size_t ImgSize =
1956+ static_cast <size_t >(RawImg->BinaryEnd - RawImg->BinaryStart );
1957+ std::unique_ptr<char []> Data (new char [ImgSize]);
1958+ std::memcpy (Data.get (), RawImg->BinaryStart , ImgSize);
1959+ auto DynBfloat16DeviceLibImg =
1960+ std::make_unique<DynRTDeviceBinaryImage>(std::move (Data), ImgSize);
1961+ auto ESPropSet = getExportedSymbolPS (RawImg);
1962+ sycl_device_binary_property ESProp;
1963+ for (ESProp = ESPropSet->PropertiesBegin ;
1964+ ESProp != ESPropSet->PropertiesEnd ; ++ESProp) {
1965+ m_ExportedSymbolImages.insert (
1966+ {ESProp->Name , DynBfloat16DeviceLibImg.get ()});
1967+ }
1968+ m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion] =
1969+ std::move (DynBfloat16DeviceLibImg);
1970+ continue ;
1971+ }
1972+ }
1973+
19501974 // Register all exported symbols
19511975 for (const sycl_device_binary_property &ESProp :
19521976 Img->getExportedSymbols ()) {
@@ -2111,19 +2135,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
21112135}
21122136
21132137void ProgramManager::removeImages (sycl_device_binaries DeviceBinary) {
2114- bool IsRTC = isCompiledAtRuntime (DeviceBinary);
21152138 for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
21162139 sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
21172140 auto DevImgIt = m_DeviceImages.find (RawImg);
21182141 if (DevImgIt == m_DeviceImages.end ())
21192142 continue ;
21202143 const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
21212144 const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
2122- // Skip clean up if there are no offload entries, unless `DeviceBinary`
2123- // resulted from runtime compilation: Then, this is one of the `bfloat16`
2124- // device libraries, so we want to make sure that the image and its exported
2125- // symbols are removed from the program manager's maps.
2126- if (EntriesB == EntriesE && !IsRTC)
2145+ if (EntriesB == EntriesE)
21272146 continue ;
21282147
21292148 RTDeviceBinaryImage *Img = DevImgIt->second .get ();
@@ -2651,7 +2670,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26512670 std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
26522671 {
26532672 std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
2654- DepKernelIDs = m_BinImg2KernelIDs[Dep];
2673+ // For device library images, they are not in m_BinImg2KernelIDs since
2674+ // no kernel is included.
2675+ auto DepIt = m_BinImg2KernelIDs.find (Dep);
2676+ if (DepIt != m_BinImg2KernelIDs.end ())
2677+ DepKernelIDs = DepIt->second ;
26552678 }
26562679
26572680 assert (ImgInfoPair.second .State == getBinImageState (Dep) &&
@@ -2865,9 +2888,10 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
28652888 const std::shared_ptr<device_image_impl> &DeviceImageImpl =
28662889 getSyclObjImpl (Img);
28672890 // Duplicates are not expected here, otherwise urProgramLink should fail
2868- KernelIDs.insert (KernelIDs.end (),
2869- DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2870- DeviceImageImpl->get_kernel_ids_ptr ()->end ());
2891+ if (DeviceImageImpl->get_kernel_ids_ptr ())
2892+ KernelIDs.insert (KernelIDs.end (),
2893+ DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2894+ DeviceImageImpl->get_kernel_ids_ptr ()->end ());
28712895 // To be able to answer queries about specialziation constants, the new
28722896 // device image should have the specialization constants from all the linked
28732897 // images.
0 commit comments