10
10
11
11
#include " common.hpp"
12
12
13
+ inline cl_mem_alloc_flags_intel
14
+ hostDescToClFlags (const ur_usm_host_desc_t &desc) {
15
+ cl_mem_alloc_flags_intel allocFlags = 0 ;
16
+ if (desc.flags & UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT) {
17
+ allocFlags |= CL_MEM_ALLOC_INITIAL_PLACEMENT_HOST_INTEL;
18
+ }
19
+ return allocFlags;
20
+ }
21
+
22
+ inline cl_mem_alloc_flags_intel
23
+ deviceDescToClFlags (const ur_usm_device_desc_t &desc) {
24
+ cl_mem_alloc_flags_intel allocFlags = 0 ;
25
+ if (desc.flags & UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT) {
26
+ allocFlags |= CL_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE_INTEL;
27
+ }
28
+ if (desc.flags & UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED) {
29
+ allocFlags |= CL_MEM_ALLOC_WRITE_COMBINED_INTEL;
30
+ }
31
+ return allocFlags;
32
+ }
33
+
34
+ ur_result_t
35
+ usmDescToCLMemProperties (const ur_base_desc_t *Desc,
36
+ std::vector<cl_mem_properties_intel> &Properties) {
37
+ cl_mem_alloc_flags_intel AllocFlags = 0 ;
38
+ const auto *Next = Desc;
39
+ do {
40
+ switch (Next->stype ) {
41
+ case UR_STRUCTURE_TYPE_USM_HOST_DESC: {
42
+ auto HostDesc = reinterpret_cast <const ur_usm_host_desc_t *>(Next);
43
+ if (UR_USM_HOST_MEM_FLAGS_MASK & HostDesc->flags ) {
44
+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
45
+ }
46
+ AllocFlags |= hostDescToClFlags (*HostDesc);
47
+ break ;
48
+ }
49
+ case UR_STRUCTURE_TYPE_USM_DEVICE_DESC: {
50
+ auto DeviceDesc = reinterpret_cast <const ur_usm_device_desc_t *>(Next);
51
+ if (UR_USM_HOST_MEM_FLAGS_MASK & DeviceDesc->flags ) {
52
+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
53
+ }
54
+ AllocFlags |= deviceDescToClFlags (*DeviceDesc);
55
+ break ;
56
+ }
57
+ case UR_STRUCTURE_TYPE_USM_ALLOC_LOCATION_DESC: {
58
+ auto LocationDesc =
59
+ reinterpret_cast <const ur_usm_alloc_location_desc_t *>(Next);
60
+ Properties.push_back (CL_MEM_ALLOC_BUFFER_LOCATION_INTEL);
61
+ // CL bitfields are cl_ulong
62
+ Properties.push_back (static_cast <cl_ulong>(LocationDesc->location ));
63
+ break ;
64
+ }
65
+ default :
66
+ return UR_RESULT_ERROR_INVALID_VALUE;
67
+ }
68
+
69
+ Next = Next->pNext ? static_cast <const ur_base_desc_t *>(Next->pNext )
70
+ : nullptr ;
71
+ } while (Next);
72
+
73
+ if (AllocFlags) {
74
+ Properties.push_back (CL_MEM_ALLOC_FLAGS_INTEL);
75
+ Properties.push_back (AllocFlags);
76
+ }
77
+ Properties.push_back (0 );
78
+
79
+ return UR_RESULT_SUCCESS;
80
+ }
81
+
13
82
UR_APIEXPORT ur_result_t UR_APICALL
14
83
urUSMHostAlloc (ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
15
84
ur_usm_pool_handle_t , size_t size, void **ppMem) {
16
85
17
86
void *Ptr = nullptr ;
18
87
uint32_t Alignment = pUSMDesc ? pUSMDesc->align : 0 ;
19
88
20
- cl_mem_alloc_flags_intel Flags = 0 ;
21
- cl_mem_properties_intel Properties[3 ];
22
-
23
- if (pUSMDesc && pUSMDesc->pNext &&
24
- static_cast <const ur_base_desc_t *>(pUSMDesc->pNext )->stype ==
25
- UR_STRUCTURE_TYPE_USM_HOST_DESC) {
26
- const auto *HostDesc =
27
- static_cast <const ur_usm_host_desc_t *>(pUSMDesc->pNext );
28
-
29
- if (HostDesc->flags & UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT) {
30
- Flags |= CL_MEM_ALLOC_INITIAL_PLACEMENT_HOST_INTEL;
31
- }
32
- Properties[0 ] = CL_MEM_ALLOC_FLAGS_INTEL;
33
- Properties[1 ] = Flags;
34
- Properties[2 ] = 0 ;
35
- } else {
36
- Properties[0 ] = 0 ;
89
+ std::vector<cl_mem_properties_intel> AllocProperties;
90
+ if (pUSMDesc && pUSMDesc->pNext ) {
91
+ UR_RETURN_ON_FAILURE (usmDescToCLMemProperties (
92
+ static_cast <const ur_base_desc_t *>(pUSMDesc->pNext ), AllocProperties));
37
93
}
38
94
39
95
// First we need to look up the function pointer
@@ -47,7 +103,9 @@ urUSMHostAlloc(ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc,
47
103
48
104
if (FuncPtr) {
49
105
cl_int ClResult = CL_SUCCESS;
50
- Ptr = FuncPtr (CLContext, Properties, size, Alignment, &ClResult);
106
+ Ptr = FuncPtr (CLContext,
107
+ AllocProperties.empty () ? nullptr : AllocProperties.data (),
108
+ size, Alignment, &ClResult);
51
109
if (ClResult == CL_INVALID_BUFFER_SIZE) {
52
110
return UR_RESULT_ERROR_INVALID_USM_SIZE;
53
111
}
@@ -71,25 +129,10 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
71
129
void *Ptr = nullptr ;
72
130
uint32_t Alignment = pUSMDesc ? pUSMDesc->align : 0 ;
73
131
74
- cl_mem_alloc_flags_intel Flags = 0 ;
75
- cl_mem_properties_intel Properties[3 ];
76
- if (pUSMDesc && pUSMDesc->pNext &&
77
- static_cast <const ur_base_desc_t *>(pUSMDesc->pNext )->stype ==
78
- UR_STRUCTURE_TYPE_USM_DEVICE_DESC) {
79
- const auto *HostDesc =
80
- static_cast <const ur_usm_device_desc_t *>(pUSMDesc->pNext );
81
-
82
- if (HostDesc->flags & UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT) {
83
- Flags |= CL_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE_INTEL;
84
- }
85
- if (HostDesc->flags & UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED) {
86
- Flags |= CL_MEM_ALLOC_WRITE_COMBINED_INTEL;
87
- }
88
- Properties[0 ] = CL_MEM_ALLOC_FLAGS_INTEL;
89
- Properties[1 ] = Flags;
90
- Properties[2 ] = 0 ;
91
- } else {
92
- Properties[0 ] = 0 ;
132
+ std::vector<cl_mem_properties_intel> AllocProperties;
133
+ if (pUSMDesc && pUSMDesc->pNext ) {
134
+ UR_RETURN_ON_FAILURE (usmDescToCLMemProperties (
135
+ static_cast <const ur_base_desc_t *>(pUSMDesc->pNext ), AllocProperties));
93
136
}
94
137
95
138
// First we need to look up the function pointer
@@ -104,8 +147,8 @@ urUSMDeviceAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
104
147
if (FuncPtr) {
105
148
cl_int ClResult = CL_SUCCESS;
106
149
Ptr = FuncPtr (CLContext, cl_adapter::cast<cl_device_id>(hDevice),
107
- cl_adapter::cast<cl_mem_properties_intel *>(Properties), size ,
108
- Alignment, &ClResult);
150
+ AllocProperties. empty () ? nullptr : AllocProperties. data () ,
151
+ size, Alignment, &ClResult);
109
152
if (ClResult == CL_INVALID_BUFFER_SIZE) {
110
153
return UR_RESULT_ERROR_INVALID_USM_SIZE;
111
154
}
@@ -129,35 +172,10 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
129
172
void *Ptr = nullptr ;
130
173
uint32_t Alignment = pUSMDesc ? pUSMDesc->align : 0 ;
131
174
132
- cl_mem_alloc_flags_intel Flags = 0 ;
133
- const auto *NextStruct =
134
- (pUSMDesc ? static_cast <const ur_base_desc_t *>(pUSMDesc->pNext )
135
- : nullptr );
136
- while (NextStruct) {
137
- if (NextStruct->stype == UR_STRUCTURE_TYPE_USM_HOST_DESC) {
138
- const auto *HostDesc =
139
- reinterpret_cast <const ur_usm_host_desc_t *>(NextStruct);
140
- if (HostDesc->flags & UR_USM_HOST_MEM_FLAG_INITIAL_PLACEMENT) {
141
- Flags |= CL_MEM_ALLOC_INITIAL_PLACEMENT_HOST_INTEL;
142
- }
143
- } else if (NextStruct->stype == UR_STRUCTURE_TYPE_USM_DEVICE_DESC) {
144
- const auto *DevDesc =
145
- reinterpret_cast <const ur_usm_device_desc_t *>(NextStruct);
146
- if (DevDesc->flags & UR_USM_DEVICE_MEM_FLAG_INITIAL_PLACEMENT) {
147
- Flags |= CL_MEM_ALLOC_INITIAL_PLACEMENT_DEVICE_INTEL;
148
- }
149
- if (DevDesc->flags & UR_USM_DEVICE_MEM_FLAG_WRITE_COMBINED) {
150
- Flags |= CL_MEM_ALLOC_WRITE_COMBINED_INTEL;
151
- }
152
- }
153
- NextStruct = static_cast <const ur_base_desc_t *>(NextStruct->pNext );
154
- }
155
-
156
- cl_mem_properties_intel Properties[3 ] = {CL_MEM_ALLOC_FLAGS_INTEL, Flags, 0 };
157
-
158
- // Passing a flags value of 0 doesn't work, so truncate the properties
159
- if (Flags == 0 ) {
160
- Properties[0 ] = 0 ;
175
+ std::vector<cl_mem_properties_intel> AllocProperties;
176
+ if (pUSMDesc && pUSMDesc->pNext ) {
177
+ UR_RETURN_ON_FAILURE (usmDescToCLMemProperties (
178
+ static_cast <const ur_base_desc_t *>(pUSMDesc->pNext ), AllocProperties));
161
179
}
162
180
163
181
// First we need to look up the function pointer
@@ -172,8 +190,8 @@ urUSMSharedAlloc(ur_context_handle_t hContext, ur_device_handle_t hDevice,
172
190
if (FuncPtr) {
173
191
cl_int ClResult = CL_SUCCESS;
174
192
Ptr = FuncPtr (CLContext, cl_adapter::cast<cl_device_id>(hDevice),
175
- cl_adapter::cast<cl_mem_properties_intel *>(Properties), size ,
176
- Alignment, cl_adapter::cast<cl_int *>(&ClResult));
193
+ AllocProperties. empty () ? nullptr : AllocProperties. data () ,
194
+ size, Alignment, cl_adapter::cast<cl_int *>(&ClResult));
177
195
if (ClResult == CL_INVALID_BUFFER_SIZE) {
178
196
return UR_RESULT_ERROR_INVALID_USM_SIZE;
179
197
}
0 commit comments