Skip to content

Commit 435aa76

Browse files
authored
[Libomptarget] Rework device initialization and image registration (#93844)
Summary: Currently, we register images into a linear table according to the logical OpenMP device identifier. We then initialize all of these images as one block. This logic requires that images are compatible with *all* devices instead of just the one that it can run on. This prevents us from running on systems with heterogeneous devices (i.e. image 1 runs on device 0 image 0 runs on device 1). This patch reworks the logic by instead making the compatibility check a per-device query. We then scan every device to see if it's compatible and do it as they come.
1 parent e5c93ed commit 435aa76

File tree

8 files changed

+232
-218
lines changed

8 files changed

+232
-218
lines changed

offload/include/PluginManager.h

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ struct PluginManager {
6464
std::make_unique<DeviceImageTy>(TgtBinDesc, TgtDeviceImage));
6565
}
6666

67-
/// Initialize as many devices as possible for this plugin. Devices that fail
68-
/// to initialize are ignored.
69-
void initDevices(GenericPluginTy &RTL);
70-
7167
/// Return the device presented to the user as device \p DeviceNo if it is
7268
/// initialized and ready. Otherwise return an error explaining the problem.
7369
llvm::Expected<DeviceTy &> getDevice(uint32_t DeviceNo);
@@ -117,32 +113,41 @@ struct PluginManager {
117113
return Devices.getExclusiveAccessor();
118114
}
119115

120-
int getNumUsedPlugins() const { return DeviceOffsets.size(); }
121-
122116
// Initialize all plugins.
123117
void initAllPlugins();
124118

125119
/// Iterator range for all plugins (in use or not, but always valid).
126120
auto plugins() { return llvm::make_pointee_range(Plugins); }
127121

122+
/// Iterator range for all plugins (in use or not, but always valid).
123+
auto plugins() const { return llvm::make_pointee_range(Plugins); }
124+
128125
/// Return the user provided requirements.
129126
int64_t getRequirements() const { return Requirements.getRequirements(); }
130127

131128
/// Add \p Flags to the user provided requirements.
132129
void addRequirements(int64_t Flags) { Requirements.addRequirements(Flags); }
133130

131+
/// Returns the number of plugins that are active.
132+
int getNumActivePlugins() const {
133+
int count = 0;
134+
for (auto &R : plugins())
135+
if (R.is_initialized())
136+
++count;
137+
138+
return count;
139+
}
140+
134141
private:
135142
bool RTLsLoaded = false;
136143
llvm::SmallVector<__tgt_bin_desc *> DelayedBinDesc;
137144

138145
// List of all plugins, in use or not.
139146
llvm::SmallVector<std::unique_ptr<GenericPluginTy>> Plugins;
140147

141-
// Mapping of plugins to offsets in the device table.
142-
llvm::DenseMap<const GenericPluginTy *, int32_t> DeviceOffsets;
143-
144-
// Mapping of plugins to the number of used devices.
145-
llvm::DenseMap<const GenericPluginTy *, int32_t> DeviceUsed;
148+
// Mapping of plugins to the OpenMP device identifier.
149+
llvm::DenseMap<std::pair<const GenericPluginTy *, int32_t>, int32_t>
150+
DeviceIds;
146151

147152
// Set of all device images currently in use.
148153
llvm::DenseSet<const __tgt_device_image *> UsedImages;

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3163,25 +3163,24 @@ struct AMDGPUPluginTy final : public GenericPluginTy {
31633163
uint16_t getMagicElfBits() const override { return ELF::EM_AMDGPU; }
31643164

31653165
/// Check whether the image is compatible with an AMDGPU device.
3166-
Expected<bool> isELFCompatible(StringRef Image) const override {
3166+
Expected<bool> isELFCompatible(uint32_t DeviceId,
3167+
StringRef Image) const override {
31673168
// Get the associated architecture and flags from the ELF.
31683169
auto ElfOrErr = ELF64LEObjectFile::create(
31693170
MemoryBufferRef(Image, /*Identifier=*/""), /*InitContent=*/false);
31703171
if (!ElfOrErr)
31713172
return ElfOrErr.takeError();
31723173
std::optional<StringRef> Processor = ElfOrErr->tryGetCPUName();
3174+
if (!Processor)
3175+
return false;
31733176

3174-
for (hsa_agent_t Agent : KernelAgents) {
3175-
auto TargeTripleAndFeaturesOrError =
3176-
utils::getTargetTripleAndFeatures(Agent);
3177-
if (!TargeTripleAndFeaturesOrError)
3178-
return TargeTripleAndFeaturesOrError.takeError();
3179-
if (!utils::isImageCompatibleWithEnv(Processor ? *Processor : "",
3177+
auto TargeTripleAndFeaturesOrError =
3178+
utils::getTargetTripleAndFeatures(getKernelAgent(DeviceId));
3179+
if (!TargeTripleAndFeaturesOrError)
3180+
return TargeTripleAndFeaturesOrError.takeError();
3181+
return utils::isImageCompatibleWithEnv(Processor ? *Processor : "",
31803182
ElfOrErr->getPlatformFlags(),
3181-
*TargeTripleAndFeaturesOrError))
3182-
return false;
3183-
}
3184-
return true;
3183+
*TargeTripleAndFeaturesOrError);
31853184
}
31863185

31873186
bool isDataExchangable(int32_t SrcDeviceId, int32_t DstDeviceId) override {

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -993,11 +993,11 @@ struct GenericPluginTy {
993993
/// Get the number of active devices.
994994
int32_t getNumDevices() const { return NumDevices; }
995995

996-
/// Get the plugin-specific device identifier offset.
997-
int32_t getDeviceIdStartIndex() const { return DeviceIdStartIndex; }
998-
999-
/// Set the plugin-specific device identifier offset.
1000-
void setDeviceIdStartIndex(int32_t Offset) { DeviceIdStartIndex = Offset; }
996+
/// Get the plugin-specific device identifier.
997+
int32_t getUserId(int32_t DeviceId) const {
998+
assert(UserDeviceIds.contains(DeviceId) && "No user-id registered");
999+
return UserDeviceIds.at(DeviceId);
1000+
}
10011001

10021002
/// Get the ELF code to recognize the binary image of this plugin.
10031003
virtual uint16_t getMagicElfBits() const = 0;
@@ -1059,7 +1059,8 @@ struct GenericPluginTy {
10591059
/// Indicate if an image is compatible with the plugin devices. Notice that
10601060
/// this function may be called before actually initializing the devices. So
10611061
/// we could not move this function into GenericDeviceTy.
1062-
virtual Expected<bool> isELFCompatible(StringRef Image) const = 0;
1062+
virtual Expected<bool> isELFCompatible(uint32_t DeviceID,
1063+
StringRef Image) const = 0;
10631064

10641065
protected:
10651066
/// Indicate whether a device id is valid.
@@ -1070,11 +1071,18 @@ struct GenericPluginTy {
10701071
public:
10711072
// TODO: This plugin interface needs to be cleaned up.
10721073

1073-
/// Returns true if the plugin has been initialized.
1074+
/// Returns non-zero if the plugin runtime has been initialized.
10741075
int32_t is_initialized() const;
10751076

1076-
/// Returns non-zero if the provided \p Image can be executed by the runtime.
1077-
int32_t is_valid_binary(__tgt_device_image *Image, bool Initialized = true);
1077+
/// Returns non-zero if the \p Image is compatible with the plugin. This
1078+
/// function does not require the plugin to be initialized before use.
1079+
int32_t is_plugin_compatible(__tgt_device_image *Image);
1080+
1081+
/// Returns non-zero if the \p Image is compatible with the device.
1082+
int32_t is_device_compatible(int32_t DeviceId, __tgt_device_image *Image);
1083+
1084+
/// Returns non-zero if the plugin device has been initialized.
1085+
int32_t is_device_initialized(int32_t DeviceId) const;
10781086

10791087
/// Initialize the device inside of the plugin.
10801088
int32_t init_device(int32_t DeviceId);
@@ -1180,7 +1188,7 @@ struct GenericPluginTy {
11801188
const char **ErrStr);
11811189

11821190
/// Sets the offset into the devices for use by OMPT.
1183-
int32_t set_device_offset(int32_t DeviceIdOffset);
1191+
int32_t set_device_identifier(int32_t UserId, int32_t DeviceId);
11841192

11851193
/// Returns if the plugin can support auotmatic copy.
11861194
int32_t use_auto_zero_copy(int32_t DeviceId);
@@ -1200,10 +1208,8 @@ struct GenericPluginTy {
12001208
/// Number of devices available for the plugin.
12011209
int32_t NumDevices = 0;
12021210

1203-
/// Index offset, which when added to a DeviceId, will yield a unique
1204-
/// user-observable device identifier. This is especially important when
1205-
/// DeviceIds of multiple plugins / RTLs need to be distinguishable.
1206-
int32_t DeviceIdStartIndex = 0;
1211+
/// Map of plugin device identifiers to the user device identifier.
1212+
llvm::DenseMap<int32_t, int32_t> UserDeviceIds;
12071213

12081214
/// Array of pointers to the devices. Initially, they are all set to nullptr.
12091215
/// Once a device is initialized, the pointer is stored in the position given

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -748,8 +748,7 @@ Error GenericDeviceTy::init(GenericPluginTy &Plugin) {
748748
if (ompt::Initialized) {
749749
bool ExpectedStatus = false;
750750
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, true))
751-
performOmptCallback(device_initialize, /*device_num=*/DeviceId +
752-
Plugin.getDeviceIdStartIndex(),
751+
performOmptCallback(device_initialize, Plugin.getUserId(DeviceId),
753752
/*type=*/getComputeUnitKind().c_str(),
754753
/*device=*/reinterpret_cast<ompt_device_t *>(this),
755754
/*lookup=*/ompt::lookupCallbackByName,
@@ -847,9 +846,7 @@ Error GenericDeviceTy::deinit(GenericPluginTy &Plugin) {
847846
if (ompt::Initialized) {
848847
bool ExpectedStatus = true;
849848
if (OmptInitialized.compare_exchange_strong(ExpectedStatus, false))
850-
performOmptCallback(device_finalize,
851-
/*device_num=*/DeviceId +
852-
Plugin.getDeviceIdStartIndex());
849+
performOmptCallback(device_finalize, Plugin.getUserId(DeviceId));
853850
}
854851
#endif
855852

@@ -908,7 +905,7 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
908905
size_t Bytes =
909906
getPtrDiff(InputTgtImage->ImageEnd, InputTgtImage->ImageStart);
910907
performOmptCallback(
911-
device_load, /*device_num=*/DeviceId + Plugin.getDeviceIdStartIndex(),
908+
device_load, Plugin.getUserId(DeviceId),
912909
/*FileName=*/nullptr, /*FileOffset=*/0, /*VmaInFile=*/nullptr,
913910
/*ImgSize=*/Bytes, /*HostAddr=*/InputTgtImage->ImageStart,
914911
/*DeviceAddr=*/nullptr, /* FIXME: ModuleId */ 0);
@@ -1492,11 +1489,14 @@ Error GenericDeviceTy::syncEvent(void *EventPtr) {
14921489
bool GenericDeviceTy::useAutoZeroCopy() { return useAutoZeroCopyImpl(); }
14931490

14941491
Error GenericPluginTy::init() {
1492+
if (Initialized)
1493+
return Plugin::success();
1494+
14951495
auto NumDevicesOrErr = initImpl();
14961496
if (!NumDevicesOrErr)
14971497
return NumDevicesOrErr.takeError();
1498-
14991498
Initialized = true;
1499+
15001500
NumDevices = *NumDevicesOrErr;
15011501
if (NumDevices == 0)
15021502
return Plugin::success();
@@ -1517,6 +1517,8 @@ Error GenericPluginTy::init() {
15171517
}
15181518

15191519
Error GenericPluginTy::deinit() {
1520+
assert(Initialized && "Plugin was not initialized!");
1521+
15201522
// Deinitialize all active devices.
15211523
for (int32_t DeviceId = 0; DeviceId < NumDevices; ++DeviceId) {
15221524
if (Devices[DeviceId]) {
@@ -1537,7 +1539,11 @@ Error GenericPluginTy::deinit() {
15371539
delete RecordReplay;
15381540

15391541
// Perform last deinitializations on the plugin.
1540-
return deinitImpl();
1542+
if (Error Err = deinitImpl())
1543+
return Err;
1544+
Initialized = false;
1545+
1546+
return Plugin::success();
15411547
}
15421548

15431549
Error GenericPluginTy::initDevice(int32_t DeviceId) {
@@ -1599,8 +1605,7 @@ Expected<bool> GenericPluginTy::checkBitcodeImage(StringRef Image) const {
15991605

16001606
int32_t GenericPluginTy::is_initialized() const { return Initialized; }
16011607

1602-
int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
1603-
bool Initialized) {
1608+
int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) {
16041609
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
16051610
target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
16061611

@@ -1618,11 +1623,43 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
16181623
auto MatchOrErr = checkELFImage(Buffer);
16191624
if (Error Err = MatchOrErr.takeError())
16201625
return HandleError(std::move(Err));
1621-
if (!Initialized || !*MatchOrErr)
1622-
return *MatchOrErr;
1626+
return *MatchOrErr;
1627+
}
1628+
case file_magic::bitcode: {
1629+
auto MatchOrErr = checkBitcodeImage(Buffer);
1630+
if (Error Err = MatchOrErr.takeError())
1631+
return HandleError(std::move(Err));
1632+
return *MatchOrErr;
1633+
}
1634+
default:
1635+
return false;
1636+
}
1637+
}
1638+
1639+
int32_t GenericPluginTy::is_device_compatible(int32_t DeviceId,
1640+
__tgt_device_image *Image) {
1641+
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
1642+
target::getPtrDiff(Image->ImageEnd, Image->ImageStart));
1643+
1644+
auto HandleError = [&](Error Err) -> bool {
1645+
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
1646+
DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
1647+
return false;
1648+
};
1649+
switch (identify_magic(Buffer)) {
1650+
case file_magic::elf:
1651+
case file_magic::elf_relocatable:
1652+
case file_magic::elf_executable:
1653+
case file_magic::elf_shared_object:
1654+
case file_magic::elf_core: {
1655+
auto MatchOrErr = checkELFImage(Buffer);
1656+
if (Error Err = MatchOrErr.takeError())
1657+
return HandleError(std::move(Err));
1658+
if (!*MatchOrErr)
1659+
return false;
16231660

16241661
// Perform plugin-dependent checks for the specific architecture if needed.
1625-
auto CompatibleOrErr = isELFCompatible(Buffer);
1662+
auto CompatibleOrErr = isELFCompatible(DeviceId, Buffer);
16261663
if (Error Err = CompatibleOrErr.takeError())
16271664
return HandleError(std::move(Err));
16281665
return *CompatibleOrErr;
@@ -1638,6 +1675,10 @@ int32_t GenericPluginTy::is_valid_binary(__tgt_device_image *Image,
16381675
}
16391676
}
16401677

1678+
int32_t GenericPluginTy::is_device_initialized(int32_t DeviceId) const {
1679+
return isValidDeviceId(DeviceId) && Devices[DeviceId] != nullptr;
1680+
}
1681+
16411682
int32_t GenericPluginTy::init_device(int32_t DeviceId) {
16421683
auto Err = initDevice(DeviceId);
16431684
if (Err) {
@@ -1985,8 +2026,9 @@ int32_t GenericPluginTy::init_device_info(int32_t DeviceId,
19852026
return OFFLOAD_SUCCESS;
19862027
}
19872028

1988-
int32_t GenericPluginTy::set_device_offset(int32_t DeviceIdOffset) {
1989-
setDeviceIdStartIndex(DeviceIdOffset);
2029+
int32_t GenericPluginTy::set_device_identifier(int32_t UserId,
2030+
int32_t DeviceId) {
2031+
UserDeviceIds[DeviceId] = UserId;
19902032

19912033
return OFFLOAD_SUCCESS;
19922034
}

offload/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,8 +1388,9 @@ struct CUDAPluginTy final : public GenericPluginTy {
13881388

13891389
const char *getName() const override { return GETNAME(TARGET_NAME); }
13901390

1391-
/// Check whether the image is compatible with the available CUDA devices.
1392-
Expected<bool> isELFCompatible(StringRef Image) const override {
1391+
/// Check whether the image is compatible with a CUDA device.
1392+
Expected<bool> isELFCompatible(uint32_t DeviceId,
1393+
StringRef Image) const override {
13931394
auto ElfOrErr =
13941395
ELF64LEObjectFile::create(MemoryBufferRef(Image, /*Identifier=*/""),
13951396
/*InitContent=*/false);
@@ -1399,33 +1400,29 @@ struct CUDAPluginTy final : public GenericPluginTy {
13991400
// Get the numeric value for the image's `sm_` value.
14001401
auto SM = ElfOrErr->getPlatformFlags() & ELF::EF_CUDA_SM;
14011402

1402-
for (int32_t DevId = 0; DevId < getNumDevices(); ++DevId) {
1403-
CUdevice Device;
1404-
CUresult Res = cuDeviceGet(&Device, DevId);
1405-
if (auto Err = Plugin::check(Res, "Error in cuDeviceGet: %s"))
1406-
return std::move(Err);
1407-
1408-
int32_t Major, Minor;
1409-
Res = cuDeviceGetAttribute(
1410-
&Major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, Device);
1411-
if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
1412-
return std::move(Err);
1413-
1414-
Res = cuDeviceGetAttribute(
1415-
&Minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, Device);
1416-
if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
1417-
return std::move(Err);
1418-
1419-
int32_t ImageMajor = SM / 10;
1420-
int32_t ImageMinor = SM % 10;
1421-
1422-
// A cubin generated for a certain compute capability is supported to
1423-
// run on any GPU with the same major revision and same or higher minor
1424-
// revision.
1425-
if (Major != ImageMajor || Minor < ImageMinor)
1426-
return false;
1427-
}
1428-
return true;
1403+
CUdevice Device;
1404+
CUresult Res = cuDeviceGet(&Device, DeviceId);
1405+
if (auto Err = Plugin::check(Res, "Error in cuDeviceGet: %s"))
1406+
return std::move(Err);
1407+
1408+
int32_t Major, Minor;
1409+
Res = cuDeviceGetAttribute(
1410+
&Major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, Device);
1411+
if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
1412+
return std::move(Err);
1413+
1414+
Res = cuDeviceGetAttribute(
1415+
&Minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, Device);
1416+
if (auto Err = Plugin::check(Res, "Error in cuDeviceGetAttribute: %s"))
1417+
return std::move(Err);
1418+
1419+
int32_t ImageMajor = SM / 10;
1420+
int32_t ImageMinor = SM % 10;
1421+
1422+
// A cubin generated for a certain compute capability is supported to
1423+
// run on any GPU with the same major revision and same or higher minor
1424+
// revision.
1425+
return Major == ImageMajor && Minor >= ImageMinor;
14291426
}
14301427
};
14311428

offload/plugins-nextgen/host/src/rtl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,9 @@ struct GenELF64PluginTy final : public GenericPluginTy {
418418
}
419419

420420
/// All images (ELF-compatible) should be compatible with this plugin.
421-
Expected<bool> isELFCompatible(StringRef) const override { return true; }
421+
Expected<bool> isELFCompatible(uint32_t, StringRef) const override {
422+
return true;
423+
}
422424

423425
Triple::ArchType getTripleArch() const override {
424426
#if defined(__x86_64__)

0 commit comments

Comments
 (0)