1515#include " umf_pools/disjoint_pool_config_parser.hpp"
1616#include " usm.hpp"
1717
18- #include < umf/pools/pool_disjoint.h>
19- #include < umf/pools/pool_proxy.h>
2018#include < umf/providers/provider_level_zero.h>
2119
2220namespace umf {
@@ -34,7 +32,17 @@ ur_result_t getProviderNativeError(const char *providerName,
3432}
3533} // namespace umf
3634
37- static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig () {
35+ static std::optional<usm::DisjointPoolAllConfigs>
36+ initializeDisjointPoolConfig () {
37+ const char *UrRetDisable = std::getenv (" UR_L0_DISABLE_USM_ALLOCATOR" );
38+ const char *PiRetDisable =
39+ std::getenv (" SYCL_PI_LEVEL_ZERO_DISABLE_USM_ALLOCATOR" );
40+ const char *Disable =
41+ UrRetDisable ? UrRetDisable : (PiRetDisable ? PiRetDisable : nullptr );
42+ if (Disable != nullptr && Disable != std::string (" " )) {
43+ return std::nullopt ;
44+ }
45+
3846 const char *PoolUrTraceVal = std::getenv (" UR_L0_USM_ALLOCATOR_TRACE" );
3947
4048 int PoolTrace = 0 ;
@@ -47,7 +55,14 @@ static usm::DisjointPoolAllConfigs initializeDisjointPoolConfig() {
4755 return usm::DisjointPoolAllConfigs (PoolTrace);
4856 }
4957
50- return usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
58+ // TODO: rework parseDisjointPoolConfig to return optional,
59+ // once EnableBuffers is no longer used (by legacy L0)
60+ auto configs = usm::parseDisjointPoolConfig (PoolUrConfigVal, PoolTrace);
61+ if (configs.EnableBuffers ) {
62+ return configs;
63+ }
64+
65+ return std::nullopt ;
5166}
5267
5368inline umf_usm_memory_type_t urToUmfMemoryType (ur_usm_type_t type) {
@@ -81,32 +96,35 @@ descToDisjoinPoolMemType(const usm::pool_descriptor &desc) {
8196 }
8297}
8398
84- static umf::pool_unique_handle_t
85- makePool (usm::umf_disjoint_pool_config_t *poolParams,
86- usm::pool_descriptor poolDescriptor) {
87- umf_level_zero_memory_provider_params_handle_t params = NULL ;
88- umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (¶ms);
99+ static umf::provider_unique_handle_t
100+ makeProvider (usm::pool_descriptor poolDescriptor) {
101+ umf_level_zero_memory_provider_params_handle_t hParams;
102+ umf_result_t umf_ret = umfLevelZeroMemoryProviderParamsCreate (&hParams);
89103 if (umf_ret != UMF_RESULT_SUCCESS) {
90104 throw umf::umf2urResult (umf_ret);
91105 }
92106
107+ std::unique_ptr<umf_level_zero_memory_provider_params_t ,
108+ decltype (&umfLevelZeroMemoryProviderParamsDestroy)>
109+ params (hParams, &umfLevelZeroMemoryProviderParamsDestroy);
110+
93111 umf_ret = umfLevelZeroMemoryProviderParamsSetContext (
94- params , poolDescriptor.hContext ->getZeHandle ());
112+ hParams , poolDescriptor.hContext ->getZeHandle ());
95113 if (umf_ret != UMF_RESULT_SUCCESS) {
96114 throw umf::umf2urResult (umf_ret);
97115 };
98116
99117 ze_device_handle_t level_zero_device_handle =
100118 poolDescriptor.hDevice ? poolDescriptor.hDevice ->ZeDevice : nullptr ;
101119
102- umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (params ,
120+ umf_ret = umfLevelZeroMemoryProviderParamsSetDevice (hParams ,
103121 level_zero_device_handle);
104122 if (umf_ret != UMF_RESULT_SUCCESS) {
105123 throw umf::umf2urResult (umf_ret);
106124 }
107125
108126 umf_ret = umfLevelZeroMemoryProviderParamsSetMemoryType (
109- params , urToUmfMemoryType (poolDescriptor.type ));
127+ hParams , urToUmfMemoryType (poolDescriptor.type ));
110128 if (umf_ret != UMF_RESULT_SUCCESS) {
111129 throw umf::umf2urResult (umf_ret);
112130 }
@@ -123,46 +141,37 @@ makePool(usm::umf_disjoint_pool_config_t *poolParams,
123141 }
124142
125143 umf_ret = umfLevelZeroMemoryProviderParamsSetResidentDevices (
126- params , residentZeHandles.data (), residentZeHandles.size ());
144+ hParams , residentZeHandles.data (), residentZeHandles.size ());
127145 if (umf_ret != UMF_RESULT_SUCCESS) {
128146 throw umf::umf2urResult (umf_ret);
129147 }
130148 }
131149
132150 auto [ret, provider] =
133- umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), params );
151+ umf::providerMakeUniqueFromOps (umfLevelZeroMemoryProviderOps (), hParams );
134152 if (ret != UMF_RESULT_SUCCESS) {
135153 throw umf::umf2urResult (ret);
136154 }
137155
138- if (!poolParams) {
139- auto [ret, poolHandle] = umf::poolMakeUniqueFromOps (
140- umfProxyPoolOps (), std::move (provider), nullptr );
141- if (ret != UMF_RESULT_SUCCESS)
142- throw umf::umf2urResult (ret);
143- return std::move (poolHandle);
144- } else {
145- auto umfParams = getUmfParamsHandle (*poolParams);
146-
147- auto [ret, poolHandle] =
148- umf::poolMakeUniqueFromOps (umfDisjointPoolOps (), std::move (provider),
149- static_cast <void *>(umfParams.get ()));
150- if (ret != UMF_RESULT_SUCCESS)
151- throw umf::umf2urResult (ret);
152- return std::move (poolHandle);
153- }
156+ return std::move (provider);
154157}
155158
156159ur_usm_pool_handle_t_::ur_usm_pool_handle_t_ (ur_context_handle_t hContext,
157160 ur_usm_pool_desc_t *pPoolDesc)
158161 : hContext(hContext) {
159162 // TODO: handle UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK from pPoolDesc
160163 auto disjointPoolConfigs = initializeDisjointPoolConfig ();
161- if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
162- for (auto &config : disjointPoolConfigs.Configs ) {
163- config.MaxPoolableSize = limits->maxPoolableSize ;
164- config.SlabMinSize = limits->minDriverAllocSize ;
164+
165+ if (disjointPoolConfigs.has_value ()) {
166+ if (auto limits = find_stype_node<ur_usm_pool_limits_desc_t >(pPoolDesc)) {
167+ for (auto &config : disjointPoolConfigs.value ().Configs ) {
168+ config.MaxPoolableSize = limits->maxPoolableSize ;
169+ config.SlabMinSize = limits->minDriverAllocSize ;
170+ }
165171 }
172+ } else {
173+ // If pooling is disabled, do nothing.
174+ logger::info (" USM pooling is disabled. Skiping pool limits adjustment." );
166175 }
167176
168177 auto [result, descriptors] = usm::pool_descriptor::create (this , hContext);
@@ -171,12 +180,13 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t hContext,
171180 }
172181
173182 for (auto &desc : descriptors) {
174- if (disjointPoolConfigs.EnableBuffers ) {
183+ if (disjointPoolConfigs.has_value () ) {
175184 auto &poolConfig =
176- disjointPoolConfigs.Configs [descToDisjoinPoolMemType (desc)];
177- poolManager.addPool (desc, makePool (&poolConfig, desc));
185+ disjointPoolConfigs.value ().Configs [descToDisjoinPoolMemType (desc)];
186+ poolManager.addPool (
187+ desc, usm::makeDisjointPool (makeProvider (desc), poolConfig));
178188 } else {
179- poolManager.addPool (desc, makePool ( nullptr , desc));
189+ poolManager.addPool (desc, usm::makeProxyPool ( makeProvider ( desc) ));
180190 }
181191 }
182192}
0 commit comments