Skip to content

Commit 9f516d6

Browse files
[SYCL] Avoid additional allocations in has_extension (#19765)
#19264 fixed an issue where has_extension would allow partial matches to return true. However, the proposed solution makes additional allocations to pad the strings used. To avoid this, this commit changes the implementation to explicitly check for the partial match conditions without padding either strings. --------- Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 66429d5 commit 9f516d6

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

sycl/source/detail/device_impl.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,34 @@ device_impl::get_backend_info<info::device::backend_version>() const {
106106
#endif
107107

108108
bool device_impl::has_extension(const std::string &ExtensionName) const {
109+
if (ExtensionName.empty())
110+
return false;
111+
109112
const std::string AllExtensionNames{
110113
get_info_impl<UR_DEVICE_INFO_EXTENSIONS>()};
111114

112-
// We add a space to both sides of both the extension string and the query
113-
// string. This prevents to lookup from finding partial extension matches.
114-
return ((" " + AllExtensionNames + " ").find(" " + ExtensionName + " ") !=
115-
std::string::npos);
115+
size_t FoundExtPos = AllExtensionNames.find(ExtensionName);
116+
while (FoundExtPos != std::string::npos) {
117+
// If the extension name was found, we need to ensure it is not a partial
118+
// match. That is, the following must hold:
119+
// * The match must be at the start of the list of names or have a
120+
// whitespace before it and
121+
// * the match must end at the end of the list of names or have a
122+
// whitespace after it.
123+
bool IsStartOrTerminated =
124+
FoundExtPos == 0 || AllExtensionNames[FoundExtPos - 1] == ' ';
125+
bool IsEndOrTerminated =
126+
FoundExtPos + ExtensionName.size() == AllExtensionNames.size() ||
127+
AllExtensionNames[FoundExtPos + ExtensionName.size()] == ' ';
128+
if (IsStartOrTerminated && IsEndOrTerminated)
129+
return true;
130+
131+
// If the match was partial, the extension name could still be later in the
132+
// list. As such, search for the next match and recheck.
133+
FoundExtPos = AllExtensionNames.find(ExtensionName,
134+
FoundExtPos + ExtensionName.size());
135+
}
136+
return false;
116137
}
117138

118139
bool device_impl::is_partition_supported(info::partition_property Prop) const {

sycl/unittests/context_device/HasExtensionWordBoundary.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ TEST_F(HasExtensionWordBoundaryTest, SingleExtension) {
9191

9292
sycl::platform Plt{sycl::platform()};
9393
sycl::device Dev = Plt.get_devices()[0];
94-
auto DevImpl = detail::getSyclObjImpl(Dev);
9594

9695
EXPECT_TRUE(Dev.has_extension("cl_khr_fp64"));
9796
EXPECT_FALSE(Dev.has_extension("cl_khr_fp6"));
@@ -102,13 +101,28 @@ TEST_F(HasExtensionWordBoundaryTest, FirstMiddleLastExtensions) {
102101

103102
sycl::platform Plt{sycl::platform()};
104103
sycl::device Dev = Plt.get_devices()[0];
105-
auto DevImpl = detail::getSyclObjImpl(Dev);
106104

107105
EXPECT_TRUE(Dev.has_extension("cl_first_ext"));
108106
EXPECT_TRUE(Dev.has_extension("cl_middle_ext"));
109107
EXPECT_TRUE(Dev.has_extension("cl_last_ext"));
110108
}
111109

110+
TEST_F(HasExtensionWordBoundaryTest, MatchAfterPartialMatch) {
111+
MockExtensions = "cl_khr_fp64_with_more cl_khr_fp64";
112+
113+
sycl::platform Plt{sycl::platform()};
114+
sycl::device Dev = Plt.get_devices()[0];
115+
116+
EXPECT_TRUE(Dev.has_extension("cl_khr_fp64"));
117+
}
118+
119+
TEST_F(HasExtensionWordBoundaryTest, MatchEmptyString) {
120+
sycl::platform Plt{sycl::platform()};
121+
sycl::device Dev = Plt.get_devices()[0];
122+
123+
EXPECT_FALSE(Dev.has_extension(""));
124+
}
125+
112126
TEST_F(HasExtensionWordBoundaryTest, NonUniformGroupExtensions) {
113127
MockExtensions = "cl_khr_subgroup_non_uniform_vote "
114128
"cl_khr_subgroup_ballot "
@@ -118,7 +132,6 @@ TEST_F(HasExtensionWordBoundaryTest, NonUniformGroupExtensions) {
118132

119133
sycl::platform Plt{sycl::platform()};
120134
sycl::device Dev = Plt.get_devices()[0];
121-
auto DevImpl = detail::getSyclObjImpl(Dev);
122135

123136
EXPECT_TRUE(Dev.has_extension("cl_khr_subgroup_non_uniform_vote"));
124137
EXPECT_TRUE(Dev.has_extension("cl_khr_subgroup_ballot"));

0 commit comments

Comments
 (0)