Skip to content

[NFC][SYCL] Change context_impl::getDevices to return devices_range #19456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions sycl/source/detail/context_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ class context_impl : public std::enable_shared_from_this<context_impl> {
/// \return an instance of raw UR context handle.
const ur_context_handle_t &getHandleRef() const;

/// Unlike `get_info<info::context::devices>', this function returns a
/// reference.
const std::vector<device> &getDevices() const { return MDevices; }
devices_range getDevices() const { return MDevices; }

using CachedLibProgramsT =
std::map<std::pair<DeviceLibExt, ur_device_handle_t>,
Expand Down
7 changes: 3 additions & 4 deletions sycl/source/detail/device_global_map_entry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
"USM allocations should not be acquired for device_global with "
"device_image_scope property.");
context_impl &CtxImpl = *getSyclObjImpl(Context);
device_impl &DevImpl = *getSyclObjImpl(CtxImpl.getDevices().front());
device_impl &DevImpl = CtxImpl.getDevices().front();
std::lock_guard<std::mutex> Lock(MDeviceToUSMPtrMapMutex);

auto DGUSMPtr = MDeviceToUSMPtrMap.find({&DevImpl, &CtxImpl});
Expand Down Expand Up @@ -153,9 +153,8 @@ DeviceGlobalMapEntry::getOrAllocateDeviceGlobalUSM(const context &Context) {
void DeviceGlobalMapEntry::removeAssociatedResources(
const context_impl *CtxImpl) {
std::lock_guard<std::mutex> Lock{MDeviceToUSMPtrMapMutex};
for (device Device : CtxImpl->getDevices()) {
auto USMPtrIt =
MDeviceToUSMPtrMap.find({getSyclObjImpl(Device).get(), CtxImpl});
for (device_impl &Device : CtxImpl->getDevices()) {
auto USMPtrIt = MDeviceToUSMPtrMap.find({&Device, CtxImpl});
if (USMPtrIt != MDeviceToUSMPtrMap.end()) {
DeviceGlobalUSMMem &USMMem = USMPtrIt->second;
detail::usm::freeInternal(USMMem.MPtr, CtxImpl);
Expand Down
4 changes: 4 additions & 0 deletions sycl/source/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ template <typename iterator> class iterator_range {
iterator_range(IterTy Begin, IterTy End, size_t Size)
: Begin(Begin), End(End), Size(Size) {}

iterator_range()
: iterator_range(static_cast<value_type *>(nullptr),
static_cast<value_type *>(nullptr), 0) {}

template <typename ContainerTy>
iterator_range(const ContainerTy &Container)
: iterator_range(Container.begin(), Container.end(), Container.size()) {}
Expand Down
42 changes: 18 additions & 24 deletions sycl/source/detail/image_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ uint8_t GImageStreamID;
#endif

template <typename Param>
static bool checkImageValueRange(const std::vector<device> &Devices,
const size_t Value) {
return Value >= 1 && std::all_of(Devices.cbegin(), Devices.cend(),
[Value](const device &Dev) {
return Value <= Dev.get_info<Param>();
});
static bool checkImageValueRange(devices_range Devices, const size_t Value) {
return Value >= 1 &&
std::all_of(Devices.begin(), Devices.end(), [Value](device_impl &Dev) {
return Value <= Dev.get_info<Param>();
});
}

template <typename T, typename... Args> static bool checkAnyImpl(T) {
Expand Down Expand Up @@ -345,46 +344,47 @@ void *image_impl::allocateMem(context_impl *Context, bool InitFromUserData,

bool image_impl::checkImageDesc(const ur_image_desc_t &Desc,
context_impl *Context, void *UserPtr) {
devices_range Devices = Context ? Context->getDevices() : devices_range{};
if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE1D, UR_MEM_TYPE_IMAGE1D_ARRAY,
UR_MEM_TYPE_IMAGE2D_ARRAY, UR_MEM_TYPE_IMAGE2D) &&
!checkImageValueRange<info::device::image2d_max_width>(
getDevices(Context), Desc.width))
!checkImageValueRange<info::device::image2d_max_width>(Devices,
Desc.width))
throw exception(make_error_code(errc::invalid),
"For a 1D/2D image/image array, the width must be a Value "
">= 1 and <= info::device::image2d_max_width");

if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE3D) &&
!checkImageValueRange<info::device::image3d_max_width>(
getDevices(Context), Desc.width))
!checkImageValueRange<info::device::image3d_max_width>(Devices,
Desc.width))
throw exception(make_error_code(errc::invalid),
"For a 3D image, the width must be a Value >= 1 and <= "
"info::device::image3d_max_width");

if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE2D, UR_MEM_TYPE_IMAGE2D_ARRAY) &&
!checkImageValueRange<info::device::image2d_max_height>(
getDevices(Context), Desc.height))
!checkImageValueRange<info::device::image2d_max_height>(Devices,
Desc.height))
throw exception(make_error_code(errc::invalid),
"For a 2D image or image array, the height must be a Value "
">= 1 and <= info::device::image2d_max_height");

if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE3D) &&
!checkImageValueRange<info::device::image3d_max_height>(
getDevices(Context), Desc.height))
!checkImageValueRange<info::device::image3d_max_height>(Devices,
Desc.height))
throw exception(make_error_code(errc::invalid),
"For a 3D image, the heightmust be a Value >= 1 and <= "
"info::device::image3d_max_height");

if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE3D) &&
!checkImageValueRange<info::device::image3d_max_depth>(
getDevices(Context), Desc.depth))
!checkImageValueRange<info::device::image3d_max_depth>(Devices,
Desc.depth))
throw exception(make_error_code(errc::invalid),
"For a 3D image, the depth must be a Value >= 1 and <= "
"info::device::image2d_max_depth");

if (checkAny(Desc.type, UR_MEM_TYPE_IMAGE1D_ARRAY,
UR_MEM_TYPE_IMAGE2D_ARRAY) &&
!checkImageValueRange<info::device::image_max_array_size>(
getDevices(Context), Desc.arraySize))
!checkImageValueRange<info::device::image_max_array_size>(Devices,
Desc.arraySize))
throw exception(make_error_code(errc::invalid),
"For a 1D and 2D image array, the array_size must be a "
"Value >= 1 and <= info::device::image_max_array_size.");
Expand Down Expand Up @@ -451,12 +451,6 @@ bool image_impl::checkImageFormat(const ur_image_format_t &Format,
return true;
}

std::vector<device> image_impl::getDevices(context_impl *Context) {
if (!Context)
return {};
return Context->get_info<info::context::devices>();
}

void image_impl::sampledImageConstructorNotification(
const detail::code_location &CodeLoc, void *UserObj, const void *HostObj,
uint32_t Dim, size_t Range[3], image_format Format,
Expand Down
3 changes: 1 addition & 2 deletions sycl/source/detail/image_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class accessor;
class handler;

namespace detail {
class devices_range;

// utility functions and typedefs for image_impl
using image_allocator = aligned_allocator<byte>;
Expand Down Expand Up @@ -297,8 +298,6 @@ class image_impl final : public SYCLMemObjT {
void unsampledImageDestructorNotification(void *UserObj);

private:
std::vector<device> getDevices(context_impl *Context);

ur_mem_type_t getImageType() {
if (MDimensions == 1)
return (MIsArrayImage ? UR_MEM_TYPE_IMAGE1D_ARRAY : UR_MEM_TYPE_IMAGE1D);
Expand Down
6 changes: 3 additions & 3 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ static bool isDeviceBinaryTypeSupported(context_impl &ContextImpl,
if (ContextBackend == backend::ext_oneapi_cuda)
return false;

const std::vector<device> &Devices = ContextImpl.getDevices();
devices_range Devices = ContextImpl.getDevices();

// Program type is SPIR-V, so we need a device compiler to do JIT.
for (const device &D : Devices) {
for (device_impl &D : Devices) {
if (!D.get_info<info::device::is_compiler_available>())
return false;
}
Expand All @@ -143,7 +143,7 @@ static bool isDeviceBinaryTypeSupported(context_impl &ContextImpl,
return true;
}

for (const device &D : Devices) {
for (device_impl &D : Devices) {
// We need cl_khr_il_program extension to be present
// and we can call clCreateProgramWithILKHR using the extension
std::vector<std::string> Extensions =
Expand Down
9 changes: 4 additions & 5 deletions sycl/source/detail/scheduler/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,9 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(queue_impl *Queue,
// which means that there is already an allocation(cl_mem) in some context.
// Registering this allocation in the SYCL graph.

std::vector<sycl::device> Devices =
InteropCtxPtr->get_info<info::context::devices>();
assert(Devices.size() != 0);
device_impl &Dev = *detail::getSyclObjImpl(Devices[0]);
devices_range Devices = InteropCtxPtr->getDevices();
assert(!Devices.empty());
device_impl &Dev = Devices.front();

// Since all the Scheduler commands require queue but we have only context
// here, we need to create a dummy queue bound to the context and one of the
Expand Down Expand Up @@ -675,7 +674,7 @@ static bool checkHostUnifiedMemory(context_impl *Ctx) {
if (Ctx == nullptr)
return true;

for (const device &Device : Ctx->getDevices()) {
for (device_impl &Device : Ctx->getDevices()) {
if (!Device.get_info<info::device::host_unified_memory>())
return false;
}
Expand Down
4 changes: 2 additions & 2 deletions sycl/source/detail/usm/usm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,13 +581,13 @@ device get_pointer_device(const void *Ptr, const context &Ctxt) {

// Check if ptr is a host allocation
if (get_pointer_type(Ptr, Ctxt) == alloc::host) {
auto Devs = detail::getSyclObjImpl(Ctxt)->getDevices();
detail::devices_range Devs = detail::getSyclObjImpl(Ctxt)->getDevices();
if (Devs.size() == 0)
throw exception(make_error_code(errc::invalid),
"No devices in passed context!");

// Just return the first device in the context
return Devs[0];
return detail::createSyclObjFromImpl<device>(Devs.front());
}

ur_device_handle_t DeviceId;
Expand Down