@@ -98,6 +98,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = {
9898};
9999#endif
100100
101+ unsigned
102+ mca_spml_ucx_mem_map_flags_symmetric_rkey (struct mca_spml_ucx * spml_ucx )
103+ {
104+ #if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY
105+ if (spml_ucx -> symmetric_rkey_max_count > 0 ) {
106+ return UCP_MEM_MAP_SYMMETRIC_RKEY ;
107+ }
108+ #endif
109+
110+ return 0 ;
111+ }
112+
113+ void mca_spml_ucx_rkey_store_init (mca_spml_ucx_rkey_store_t * store )
114+ {
115+ store -> array = NULL ;
116+ store -> count = 0 ;
117+ store -> size = 0 ;
118+ }
119+
120+ void mca_spml_ucx_rkey_store_cleanup (mca_spml_ucx_rkey_store_t * store )
121+ {
122+ int i ;
123+
124+ for (i = 0 ; i < store -> count ; i ++ ) {
125+ if (store -> array [i ].refcnt != 0 ) {
126+ SPML_UCX_ERROR ("rkey store destroy: %d/%d has refcnt %d > 0" ,
127+ i , store -> count , store -> array [i ].refcnt );
128+ }
129+
130+ ucp_rkey_destroy (store -> array [i ].rkey );
131+ }
132+
133+ free (store -> array );
134+ }
135+
136+ /**
137+ * Find position in sorted array for existing or future entry
138+ *
139+ * @param[in] store Store of the entries
140+ * @param[in] worker Common worker for rkeys used
141+ * @param[in] rkey Remote key to search for
142+ * @param[out] index Index of entry
143+ *
144+ * @return
145+ * OSHMEM_ERR_NOT_FOUND: index contains the position where future element
146+ * should be inserted to keep array sorted
147+ * OSHMEM_SUCCESS : index contains the position of the element
148+ * Other error : index is not valid
149+ */
150+ static int mca_spml_ucx_rkey_store_find (const mca_spml_ucx_rkey_store_t * store ,
151+ const ucp_worker_h worker ,
152+ const ucp_rkey_h rkey ,
153+ int * index )
154+ {
155+ #if HAVE_DECL_UCP_RKEY_COMPARE
156+ ucp_rkey_compare_params_t params ;
157+ int i , result , m , end ;
158+ ucs_status_t status ;
159+
160+ for (i = 0 , end = store -> count ; i < end ;) {
161+ m = (i + end ) / 2 ;
162+
163+ params .field_mask = 0 ;
164+ status = ucp_rkey_compare (worker , store -> array [m ].rkey ,
165+ rkey , & params , & result );
166+ if (status != UCS_OK ) {
167+ return OSHMEM_ERROR ;
168+ } else if (result == 0 ) {
169+ * index = m ;
170+ return OSHMEM_SUCCESS ;
171+ } else if (result > 0 ) {
172+ end = m ;
173+ } else {
174+ i = m + 1 ;
175+ }
176+ }
177+
178+ * index = i ;
179+ return OSHMEM_ERR_NOT_FOUND ;
180+ #else
181+ return OSHMEM_ERROR ;
182+ #endif
183+ }
184+
185+ static void mca_spml_ucx_rkey_store_insert (mca_spml_ucx_rkey_store_t * store ,
186+ int i , ucp_rkey_h rkey )
187+ {
188+ int size ;
189+ mca_spml_ucx_rkey_t * tmp ;
190+
191+ if (store -> count >= mca_spml_ucx .symmetric_rkey_max_count ) {
192+ return ;
193+ }
194+
195+ if (store -> count >= store -> size ) {
196+ size = sshmem_ucx_min (sshmem_ucx_max (store -> size , 8 ) * 2 ,
197+ mca_spml_ucx .symmetric_rkey_max_count );
198+ tmp = realloc (store -> array , size * sizeof (* store -> array ));
199+ if (tmp == NULL ) {
200+ return ;
201+ }
202+
203+ store -> array = tmp ;
204+ store -> size = size ;
205+ }
206+
207+ memmove (& store -> array [i + 1 ], & store -> array [i ],
208+ (store -> count - i ) * sizeof (* store -> array ));
209+ store -> array [i ].rkey = rkey ;
210+ store -> array [i ].refcnt = 1 ;
211+ store -> count ++ ;
212+ return ;
213+ }
214+
215+ /* Takes ownership of input ucp remote key */
216+ static ucp_rkey_h mca_spml_ucx_rkey_store_get (mca_spml_ucx_rkey_store_t * store ,
217+ ucp_worker_h worker ,
218+ ucp_rkey_h rkey )
219+ {
220+ int ret , i ;
221+
222+ if (mca_spml_ucx .symmetric_rkey_max_count == 0 ) {
223+ return rkey ;
224+ }
225+
226+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
227+ if (ret == OSHMEM_SUCCESS ) {
228+ ucp_rkey_destroy (rkey );
229+ store -> array [i ].refcnt ++ ;
230+ return store -> array [i ].rkey ;
231+ }
232+
233+ if (ret == OSHMEM_ERR_NOT_FOUND ) {
234+ mca_spml_ucx_rkey_store_insert (store , i , rkey );
235+ }
236+
237+ return rkey ;
238+ }
239+
240+ static void mca_spml_ucx_rkey_store_put (mca_spml_ucx_rkey_store_t * store ,
241+ ucp_worker_h worker ,
242+ ucp_rkey_h rkey )
243+ {
244+ mca_spml_ucx_rkey_t * entry ;
245+ int ret , i ;
246+
247+ ret = mca_spml_ucx_rkey_store_find (store , worker , rkey , & i );
248+ if (ret != OSHMEM_SUCCESS ) {
249+ goto out ;
250+ }
251+
252+ entry = & store -> array [i ];
253+ assert (entry -> rkey == rkey );
254+ if (-- entry -> refcnt > 0 ) {
255+ return ;
256+ }
257+
258+ memmove (& store -> array [i ], & store -> array [i + 1 ],
259+ (store -> count - (i + 1 )) * sizeof (* store -> array ));
260+ store -> count -- ;
261+
262+ out :
263+ ucp_rkey_destroy (rkey );
264+ }
265+
101266int mca_spml_ucx_enable (bool enable )
102267{
103268 SPML_UCX_VERBOSE (50 , "*** ucx ENABLED ****" );
@@ -212,6 +377,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
212377{
213378 int rc ;
214379 ucs_status_t err ;
380+ ucp_rkey_h rkey ;
215381
216382 rc = mca_spml_ucx_ctx_mkey_new (ucx_ctx , pe , segno , ucx_mkey );
217383 if (OSHMEM_SUCCESS != rc ) {
@@ -220,11 +386,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
220386 }
221387
222388 if (mkey -> u .data ) {
223- err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & (( * ucx_mkey ) -> rkey ) );
389+ err = ucp_ep_rkey_unpack (ucx_ctx -> ucp_peers [pe ].ucp_conn , mkey -> u .data , & rkey );
224390 if (UCS_OK != err ) {
225391 SPML_UCX_ERROR ("failed to unpack rkey: %s" , ucs_status_string (err ));
226392 return OSHMEM_ERROR ;
227393 }
394+
395+ if (!oshmem_proc_on_local_node (pe )) {
396+ rkey = mca_spml_ucx_rkey_store_get (& ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [0 ], rkey );
397+ }
398+
399+ (* ucx_mkey )-> rkey = rkey ;
400+
228401 rc = mca_spml_ucx_ctx_mkey_cache (ucx_ctx , mkey , segno , pe );
229402 if (OSHMEM_SUCCESS != rc ) {
230403 SPML_UCX_ERROR ("mca_spml_ucx_ctx_mkey_cache failed" );
@@ -239,7 +412,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn
239412 ucp_peer_t * ucp_peer ;
240413 int rc ;
241414 ucp_peer = & (ucx_ctx -> ucp_peers [pe ]);
242- ucp_rkey_destroy ( ucx_mkey -> rkey );
415+ mca_spml_ucx_rkey_store_put ( & ucx_ctx -> rkey_store , ucx_ctx -> ucp_worker [ 0 ], ucx_mkey -> rkey );
243416 ucx_mkey -> rkey = NULL ;
244417 rc = mca_spml_ucx_peer_mkey_cache_del (ucp_peer , segno );
245418 if (OSHMEM_SUCCESS != rc ){
@@ -697,7 +870,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr,
697870 UCP_MEM_MAP_PARAM_FIELD_FLAGS ;
698871 mem_map_params .address = addr ;
699872 mem_map_params .length = size ;
700- mem_map_params .flags = flags ;
873+ mem_map_params .flags = flags |
874+ mca_spml_ucx_mem_map_flags_symmetric_rkey (& mca_spml_ucx );
701875
702876 status = ucp_mem_map (mca_spml_ucx .ucp_context , & mem_map_params , & mem_h );
703877 if (UCS_OK != status ) {
@@ -887,6 +1061,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
8871061 }
8881062 }
8891063
1064+ mca_spml_ucx_rkey_store_init (& ucx_ctx -> rkey_store );
1065+
8901066 * ucx_ctx_p = ucx_ctx ;
8911067
8921068 return OSHMEM_SUCCESS ;
0 commit comments