-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[OpenMP] Move `__omp_rtl_data_environment' handling to OpenMP #157182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: This operation is done every time we load a binary, this behavior should be moved into OpenMP since it concerns an OpenMP specific data struct. This is a little messy, because ideally we should only be using public APIs, but more can be extracted later.
|
@llvm/pr-subscribers-offload Author: Joseph Huber (jhuber6) ChangesSummary: Full diff: https://github.com/llvm/llvm-project/pull/157182.diff 5 Files Affected:
diff --git a/offload/include/device.h b/offload/include/device.h
index 1e85bb1876c83..bf93ce0460aef 100644
--- a/offload/include/device.h
+++ b/offload/include/device.h
@@ -33,7 +33,9 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
+#include "GlobalHandler.h"
#include "PluginInterface.h"
+
using GenericPluginTy = llvm::omp::target::plugin::GenericPluginTy;
// Forward declarations.
diff --git a/offload/libomptarget/device.cpp b/offload/libomptarget/device.cpp
index 6585286bf4285..71423ae0c94d9 100644
--- a/offload/libomptarget/device.cpp
+++ b/offload/libomptarget/device.cpp
@@ -37,6 +37,8 @@
using namespace llvm::omp::target::ompt;
#endif
+using namespace llvm::omp::target::plugin;
+
int HostDataToTargetTy::addEventIfNecessary(DeviceTy &Device,
AsyncInfoTy &AsyncInfo) const {
// First, check if the user disabled atomic map transfer/malloc/dealloc.
@@ -97,7 +99,55 @@ llvm::Error DeviceTy::init() {
return llvm::Error::success();
}
-// Load binary to device.
+// Extract the mapping of host function pointers to device function pointers
+// from the entry table. Functions marked as 'indirect' in OpenMP will have
+// offloading entries generated for them which map the host's function pointer
+// to a global containing the corresponding function pointer on the device.
+static llvm::Expected<std::pair<void *, uint64_t>>
+setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image,
+ __tgt_device_binary Binary) {
+ AsyncInfoTy AsyncInfo(Device);
+ llvm::ArrayRef<llvm::offloading::EntryTy> Entries(Image->EntriesBegin,
+ Image->EntriesEnd);
+ llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
+ for (const auto &Entry : Entries) {
+ if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP ||
+ Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
+ continue;
+
+ assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
+ auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
+
+ void *Ptr;
+ if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
+ return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
+ "failed to load %s", Entry.SymbolName);
+
+ HstPtr = Entry.Address;
+ if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
+ return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
+ "failed to load %s", Entry.SymbolName);
+ }
+
+ // If we do not have any indirect globals we exit early.
+ if (IndirectCallTable.empty())
+ return std::pair{nullptr, 0};
+
+ // Sort the array to allow for more efficient lookup of device pointers.
+ llvm::sort(IndirectCallTable,
+ [](const auto &x, const auto &y) { return x.first < y.first; });
+
+ uint64_t TableSize =
+ IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
+ void *DevicePtr = Device.allocData(TableSize, nullptr, TARGET_ALLOC_DEVICE);
+ if (Device.submitData(DevicePtr, IndirectCallTable.data(), TableSize,
+ AsyncInfo))
+ return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
+ "failed to copy data");
+ return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
+}
+
+// Load binary to device and perform global initialization if needed.
llvm::Expected<__tgt_device_binary>
DeviceTy::loadBinary(__tgt_device_image *Img) {
__tgt_device_binary Binary;
@@ -105,6 +155,38 @@ DeviceTy::loadBinary(__tgt_device_image *Img) {
if (RTL->load_binary(RTLDeviceID, Img, &Binary) != OFFLOAD_SUCCESS)
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load binary %p", Img);
+
+ // This symbol is optional.
+ void *DeviceEnvironmentPtr;
+ if (RTL->get_global(Binary, sizeof(DeviceEnvironmentTy),
+ "__omp_rtl_device_environment", &DeviceEnvironmentPtr))
+ return Binary;
+
+ // Obtain a table mapping host function pointers to device function pointers.
+ auto CallTablePairOrErr = setupIndirectCallTable(*this, Img, Binary);
+ if (!CallTablePairOrErr)
+ return CallTablePairOrErr.takeError();
+
+ GenericDeviceTy &GenericDevice = RTL->getDevice(RTLDeviceID);
+ DeviceEnvironmentTy DeviceEnvironment;
+ DeviceEnvironment.DeviceDebugKind = GenericDevice.getDebugKind();
+ DeviceEnvironment.NumDevices = RTL->getNumDevices();
+ // TODO: The device ID used here is not the real device ID used by OpenMP.
+ DeviceEnvironment.DeviceNum = RTLDeviceID;
+ DeviceEnvironment.DynamicMemSize = GenericDevice.getDynamicMemorySize();
+ DeviceEnvironment.ClockFrequency = GenericDevice.getClockFrequency();
+ DeviceEnvironment.IndirectCallTable =
+ reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
+ DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
+ DeviceEnvironment.HardwareParallelism =
+ GenericDevice.getHardwareParallelism();
+
+ AsyncInfoTy AsyncInfo(*this);
+ if (submitData(DeviceEnvironmentPtr, &DeviceEnvironment,
+ sizeof(DeviceEnvironment), AsyncInfo))
+ return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
+ "failed to copy data");
+
return Binary;
}
diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h
index f0c05a1b90716..6ff3ef8cda177 100644
--- a/offload/plugins-nextgen/common/include/PluginInterface.h
+++ b/offload/plugins-nextgen/common/include/PluginInterface.h
@@ -839,11 +839,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
Error unloadBinary(DeviceImageTy *Image);
virtual Error unloadBinaryImpl(DeviceImageTy *Image) = 0;
- /// Setup the device environment if needed. Notice this setup may not be run
- /// on some plugins. By default, it will be executed, but plugins can change
- /// this behavior by overriding the shouldSetupDeviceEnvironment function.
- Error setupDeviceEnvironment(GenericPluginTy &Plugin, DeviceImageTy &Image);
-
/// Setup the global device memory pool, if the plugin requires one.
Error setupDeviceMemoryPool(GenericPluginTy &Plugin, DeviceImageTy &Image,
uint64_t PoolSize);
@@ -1043,6 +1038,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
uint32_t getDefaultNumBlocks() const {
return GridValues.GV_Default_Num_Teams;
}
+ uint32_t getDebugKind() const { return OMPX_DebugKind; }
uint32_t getDynamicMemorySize() const { return OMPX_SharedMemorySize; }
virtual uint64_t getClockFrequency() const { return CLOCKS_PER_SEC; }
@@ -1183,11 +1179,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
virtual Error getDeviceHeapSize(uint64_t &V) = 0;
virtual Error setDeviceHeapSize(uint64_t V) = 0;
- /// Indicate whether the device should setup the device environment. Notice
- /// that returning false in this function will change the behavior of the
- /// setupDeviceEnvironment() function.
- virtual bool shouldSetupDeviceEnvironment() const { return true; }
-
/// Indicate whether the device should setup the global device memory pool. If
/// false is return the value on the device will be uninitialized.
virtual bool shouldSetupDeviceMemoryPool() const { return true; }
@@ -1243,7 +1234,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
enum class PeerAccessState : uint8_t { AVAILABLE, UNAVAILABLE, PENDING };
/// Array of peer access states with the rest of devices. This means that if
- /// the device I has a matrix PeerAccesses with PeerAccesses[J] == AVAILABLE,
+ /// the device I has a matrix PeerAccesses with PeerAccesses == AVAILABLE,
/// the device I can access device J's memory directly. However, notice this
/// does not mean that device J can access device I's memory directly.
llvm::SmallVector<PeerAccessState> PeerAccesses;
diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp
index e5a313d5e9bb4..36cdd6035e26d 100644
--- a/offload/plugins-nextgen/common/src/PluginInterface.cpp
+++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp
@@ -371,54 +371,6 @@ struct RecordReplayTy {
};
} // namespace llvm::omp::target::plugin
-// Extract the mapping of host function pointers to device function pointers
-// from the entry table. Functions marked as 'indirect' in OpenMP will have
-// offloading entries generated for them which map the host's function pointer
-// to a global containing the corresponding function pointer on the device.
-static Expected<std::pair<void *, uint64_t>>
-setupIndirectCallTable(GenericPluginTy &Plugin, GenericDeviceTy &Device,
- DeviceImageTy &Image) {
- GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();
-
- llvm::ArrayRef<llvm::offloading::EntryTy> Entries(
- Image.getTgtImage()->EntriesBegin, Image.getTgtImage()->EntriesEnd);
- llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
- for (const auto &Entry : Entries) {
- if (Entry.Kind != object::OffloadKind::OFK_OpenMP || Entry.Size == 0 ||
- !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
- continue;
-
- assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
- auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();
-
- GlobalTy DeviceGlobal(Entry.SymbolName, Entry.Size);
- if (auto Err =
- Handler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal))
- return std::move(Err);
-
- HstPtr = Entry.Address;
- if (auto Err = Device.dataRetrieve(&DevPtr, DeviceGlobal.getPtr(),
- Entry.Size, nullptr))
- return std::move(Err);
- }
-
- // If we do not have any indirect globals we exit early.
- if (IndirectCallTable.empty())
- return std::pair{nullptr, 0};
-
- // Sort the array to allow for more efficient lookup of device pointers.
- llvm::sort(IndirectCallTable,
- [](const auto &x, const auto &y) { return x.first < y.first; });
-
- uint64_t TableSize =
- IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
- void *DevicePtr = Device.allocate(TableSize, nullptr, TARGET_ALLOC_DEVICE);
- if (auto Err = Device.dataSubmit(DevicePtr, IndirectCallTable.data(),
- TableSize, nullptr))
- return std::move(Err);
- return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
-}
-
AsyncInfoWrapperTy::AsyncInfoWrapperTy(GenericDeviceTy &Device,
__tgt_async_info *AsyncInfoPtr)
: Device(Device),
@@ -943,10 +895,6 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
// Add the image to list.
LoadedImages.push_back(Image);
- // Setup the device environment if needed.
- if (auto Err = setupDeviceEnvironment(Plugin, *Image))
- return std::move(Err);
-
// Setup the global device memory pool if needed.
if (!Plugin.getRecordReplay().isReplaying() &&
shouldSetupDeviceMemoryPool()) {
@@ -982,43 +930,6 @@ GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
return Image;
}
-Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
- DeviceImageTy &Image) {
- // There are some plugins that do not need this step.
- if (!shouldSetupDeviceEnvironment())
- return Plugin::success();
-
- // Obtain a table mapping host function pointers to device function pointers.
- auto CallTablePairOrErr = setupIndirectCallTable(Plugin, *this, Image);
- if (!CallTablePairOrErr)
- return CallTablePairOrErr.takeError();
-
- DeviceEnvironmentTy DeviceEnvironment;
- DeviceEnvironment.DeviceDebugKind = OMPX_DebugKind;
- DeviceEnvironment.NumDevices = Plugin.getNumDevices();
- // TODO: The device ID used here is not the real device ID used by OpenMP.
- DeviceEnvironment.DeviceNum = DeviceId;
- DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
- DeviceEnvironment.ClockFrequency = getClockFrequency();
- DeviceEnvironment.IndirectCallTable =
- reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
- DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
- DeviceEnvironment.HardwareParallelism = getHardwareParallelism();
-
- // Create the metainfo of the device environment global.
- GlobalTy DevEnvGlobal("__omp_rtl_device_environment",
- sizeof(DeviceEnvironmentTy), &DeviceEnvironment);
-
- // Write device environment values to the device.
- GenericGlobalHandlerTy &GHandler = Plugin.getGlobalHandler();
- if (auto Err = GHandler.writeGlobalToDevice(*this, Image, DevEnvGlobal)) {
- DP("Missing symbol %s, continue execution anyway.\n",
- DevEnvGlobal.getName().data());
- consumeError(std::move(Err));
- }
- return Plugin::success();
-}
-
Error GenericDeviceTy::setupDeviceMemoryPool(GenericPluginTy &Plugin,
DeviceImageTy &Image,
uint64_t PoolSize) {
@@ -2259,8 +2170,7 @@ int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size,
GenericGlobalHandlerTy &GHandler = getGlobalHandler();
if (auto Err =
GHandler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal)) {
- REPORT("Failure to look up global address: %s\n",
- toString(std::move(Err)).data());
+ consumeError(std::move(Err));
return OFFLOAD_FAIL;
}
diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp
index f440ebaf17fe4..5436cae3b0293 100644
--- a/offload/plugins-nextgen/host/src/rtl.cpp
+++ b/offload/plugins-nextgen/host/src/rtl.cpp
@@ -387,7 +387,6 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
}
/// This plugin should not setup the device environment or memory pool.
- virtual bool shouldSetupDeviceEnvironment() const override { return false; };
virtual bool shouldSetupDeviceMemoryPool() const override { return false; };
/// Getters and setters for stack size and heap size not relevant.
|
| static llvm::Expected<std::pair<void *, uint64_t>> | ||
| setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image, | ||
| __tgt_device_binary Binary) { | ||
| AsyncInfoTy AsyncInfo(Device); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is new compared to the code before.
Why is that needed now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The internal plugin API doesn't need it but this one that OpenMP has does, it's just a nullptr so it's just there to tell it that it's synchronous.
jplehr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
Summary:
This operation is done every time we load a binary, this behavior should
be moved into OpenMP since it concerns an OpenMP specific data struct.
This is a little messy, because ideally we should only be using public
APIs, but more can be extracted later.