Skip to content

Commit be91e99

Browse files
committed
oshmem: Add symmetric remote key storage structure
After unpacking, we try to find an already existing equivalent remote key. When found, we destroy the new one and use the old one instead. Signed-off-by: Thomas Vegas <[email protected]>
1 parent 7d734ac commit be91e99

File tree

4 files changed

+207
-27
lines changed

4 files changed

+207
-27
lines changed

config/ompi_check_ucx.m4

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[
125125
[#include <ucp/api/ucp.h>])
126126
AC_CHECK_DECLS([ucp_tag_send_nbx,
127127
ucp_tag_send_sync_nbx,
128-
ucp_tag_recv_nbx],
128+
ucp_tag_recv_nbx,
129+
ucp_rkey_compare],
129130
[], [],
130131
[#include <ucp/api/ucp.h>])
131132
AC_CHECK_TYPES([ucp_request_param_t],

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "opal/datatype/opal_convertor.h"
2323
#include "opal/mca/common/ucx/common_ucx.h"
2424
#include "opal/util/opal_environ.h"
25+
#include "opal/util/minmax.h"
2526
#include "ompi/datatype/ompi_datatype.h"
2627
#include "ompi/mca/pml/pml.h"
2728

@@ -138,6 +139,159 @@ mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx)
138139
return 0;
139140
}
140141

142+
void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store)
143+
{
144+
store->array = NULL;
145+
store->count = 0;
146+
store->size = 0;
147+
}
148+
149+
void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store)
150+
{
151+
int i;
152+
153+
for (i = 0; i < store->count; i++) {
154+
if (store->array[i].refcnt != 0) {
155+
SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0",
156+
i, store->count, store->array[i].refcnt);
157+
}
158+
159+
ucp_rkey_destroy(store->array[i].rkey);
160+
}
161+
162+
free(store->array);
163+
}
164+
165+
/**
166+
* Find position in sorted array for existing or future entry
167+
*
168+
* @param[in] store Store of the entries
169+
* @param[in] worker Common worker for rkeys used
170+
* @param[in] rkey Remote key to search for
171+
* @param[out] index Index of entry
172+
*
173+
* @return
174+
* OSHMEM_ERR_NOT_FOUND: index contains the position where future element
175+
* should be inserted to keep array sorted
176+
* OSHMEM_SUCCESS : index contains the position of the element
177+
* Other error : index is not valid
178+
*/
179+
static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store,
180+
const ucp_worker_h worker,
181+
const ucp_rkey_h rkey,
182+
int *index)
183+
{
184+
#if HAVE_DECL_UCP_RKEY_COMPARE
185+
ucp_rkey_compare_params_t params;
186+
int i, result, m, end;
187+
ucs_status_t status;
188+
189+
for (i = 0, end = store->count; i < end;) {
190+
m = (i + end) / 2;
191+
192+
params.field_mask = 0;
193+
status = ucp_rkey_compare(worker, store->array[m].rkey,
194+
rkey, &params, &result);
195+
if (status != UCS_OK) {
196+
return OSHMEM_ERROR;
197+
} else if (result == 0) {
198+
*index = m;
199+
return OSHMEM_SUCCESS;
200+
} else if (result > 0) {
201+
end = m;
202+
} else {
203+
i = m + 1;
204+
}
205+
}
206+
207+
*index = i;
208+
return OSHMEM_ERR_NOT_FOUND;
209+
#else
210+
return OSHMEM_ERROR;
211+
#endif
212+
}
213+
214+
static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store,
215+
int i, ucp_rkey_h rkey)
216+
{
217+
int size;
218+
mca_spml_ucx_rkey_t *tmp;
219+
220+
if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) {
221+
return;
222+
}
223+
224+
if (store->count >= store->size) {
225+
size = opal_min(opal_max(store->size, 8) * 2,
226+
mca_spml_ucx.symmetric_rkey_max_count);
227+
tmp = realloc(store->array, size * sizeof(*store->array));
228+
if (tmp == NULL) {
229+
return;
230+
}
231+
232+
store->array = tmp;
233+
store->size = size;
234+
}
235+
236+
memmove(&store->array[i + 1], &store->array[i],
237+
(store->count - i) * sizeof(*store->array));
238+
store->array[i].rkey = rkey;
239+
store->array[i].refcnt = 1;
240+
store->count++;
241+
return;
242+
}
243+
244+
/* Takes ownership of input ucp remote key */
245+
static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store,
246+
ucp_worker_h worker,
247+
ucp_rkey_h rkey)
248+
{
249+
int ret, i;
250+
251+
if (mca_spml_ucx.symmetric_rkey_max_count == 0) {
252+
return rkey;
253+
}
254+
255+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
256+
if (ret == OSHMEM_SUCCESS) {
257+
ucp_rkey_destroy(rkey);
258+
store->array[i].refcnt++;
259+
return store->array[i].rkey;
260+
}
261+
262+
if (ret == OSHMEM_ERR_NOT_FOUND) {
263+
mca_spml_ucx_rkey_store_insert(store, i, rkey);
264+
}
265+
266+
return rkey;
267+
}
268+
269+
static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store,
270+
ucp_worker_h worker,
271+
ucp_rkey_h rkey)
272+
{
273+
mca_spml_ucx_rkey_t *entry;
274+
int ret, i;
275+
276+
ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i);
277+
if (ret != OSHMEM_SUCCESS) {
278+
goto out;
279+
}
280+
281+
entry = &store->array[i];
282+
assert(entry->rkey == rkey);
283+
if (--entry->refcnt > 0) {
284+
return;
285+
}
286+
287+
memmove(&store->array[i], &store->array[i + 1],
288+
(store->count - (i + 1)) * sizeof(*store->array));
289+
store->count--;
290+
291+
out:
292+
ucp_rkey_destroy(rkey);
293+
}
294+
141295
int mca_spml_ucx_enable(bool enable)
142296
{
143297
SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****");
@@ -930,6 +1084,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
9301084
}
9311085
}
9321086

1087+
mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store);
1088+
9331089
*ucx_ctx_p = ucx_ctx;
9341090

9351091
return OSHMEM_SUCCESS;

oshmem/mca/spml/ucx/spml_ucx.h

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,31 @@ struct ucp_peer {
7676
size_t mkeys_cnt;
7777
};
7878
typedef struct ucp_peer ucp_peer_t;
79-
79+
80+
/* An rkey_store entry */
81+
typedef struct mca_spml_ucx_rkey {
82+
ucp_rkey_h rkey;
83+
int refcnt;
84+
} mca_spml_ucx_rkey_t;
85+
86+
typedef struct mca_spml_ucx_rkey_store {
87+
mca_spml_ucx_rkey_t *array;
88+
int size;
89+
int count;
90+
} mca_spml_ucx_rkey_store_t;
91+
8092
struct mca_spml_ucx_ctx {
81-
ucp_worker_h *ucp_worker;
82-
ucp_peer_t *ucp_peers;
83-
long options;
84-
opal_bitmap_t put_op_bitmap;
85-
unsigned long nb_progress_cnt;
86-
unsigned int ucp_workers;
87-
int *put_proc_indexes;
88-
unsigned put_proc_count;
89-
bool synchronized_quiet;
90-
int strong_sync;
93+
ucp_worker_h *ucp_worker;
94+
ucp_peer_t *ucp_peers;
95+
long options;
96+
opal_bitmap_t put_op_bitmap;
97+
unsigned long nb_progress_cnt;
98+
unsigned int ucp_workers;
99+
int *put_proc_indexes;
100+
unsigned put_proc_count;
101+
bool synchronized_quiet;
102+
int strong_sync;
103+
mca_spml_ucx_rkey_store_t rkey_store;
91104
};
92105
typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t;
93106

@@ -284,6 +297,9 @@ extern int mca_spml_ucx_team_reduce(shmem_team_t team, void
284297
extern unsigned
285298
mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx);
286299

300+
extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store);
301+
extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store);
302+
287303
static inline int
288304
mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey)
289305
{

oshmem/mca/spml/ucx/spml_ucx_component.c

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ static int mca_spml_ucx_component_close(void)
260260

261261
static int spml_ucx_init(void)
262262
{
263-
unsigned int i;
263+
unsigned int i, ret;
264264
ucs_status_t err;
265265
ucp_config_t *ucp_config;
266266
ucp_params_t params;
@@ -336,6 +336,8 @@ static int spml_ucx_init(void)
336336
mca_spml_ucx_ctx_default.ucp_workers++;
337337
}
338338

339+
mca_spml_ucx_rkey_store_init(&mca_spml_ucx_ctx_default.rkey_store);
340+
339341
wrk_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
340342
err = ucp_worker_query(mca_spml_ucx_ctx_default.ucp_worker[0], &wrk_attr);
341343

@@ -440,10 +442,25 @@ static void _ctx_cleanup(mca_spml_ucx_ctx_t *ctx)
440442
free(ctx->ucp_peers);
441443
}
442444

445+
static void mca_spml_ucx_ctx_fini(mca_spml_ucx_ctx_t *ctx)
446+
{
447+
unsigned int i;
448+
449+
mca_spml_ucx_rkey_store_cleanup(&ctx->rkey_store);
450+
for (i = 0; i < ctx->ucp_workers; i++) {
451+
ucp_worker_destroy(ctx->ucp_worker[i]);
452+
}
453+
free(ctx->ucp_worker);
454+
if (ctx != &mca_spml_ucx_ctx_default) {
455+
free(ctx);
456+
}
457+
}
458+
443459
static int mca_spml_ucx_component_fini(void)
444460
{
445461
int fenced = 0, i;
446462
int ret = OSHMEM_SUCCESS;
463+
mca_spml_ucx_ctx_t *ctx;
447464

448465
opal_progress_unregister(spml_ucx_default_progress);
449466
if (mca_spml_ucx.active_array.ctxs_count) {
@@ -496,36 +513,26 @@ static int mca_spml_ucx_component_fini(void)
496513
}
497514
}
498515

499-
/* delete all workers */
500516
for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) {
501-
ucp_worker_destroy(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]);
502-
free(mca_spml_ucx.active_array.ctxs[i]->ucp_worker);
503-
free(mca_spml_ucx.active_array.ctxs[i]);
517+
mca_spml_ucx_ctx_fini(mca_spml_ucx.active_array.ctxs[i]);
504518
}
505519

506520
for (i = 0; i < mca_spml_ucx.idle_array.ctxs_count; i++) {
507-
ucp_worker_destroy(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker[0]);
508-
free(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker);
509-
free(mca_spml_ucx.idle_array.ctxs[i]);
521+
mca_spml_ucx_ctx_fini(mca_spml_ucx.idle_array.ctxs[i]);
510522
}
511523

512524
if (mca_spml_ucx_ctx_default.ucp_worker) {
513-
for (i = 0; i < (signed int)mca_spml_ucx.ucp_workers; i++) {
514-
ucp_worker_destroy(mca_spml_ucx_ctx_default.ucp_worker[i]);
515-
}
516-
free(mca_spml_ucx_ctx_default.ucp_worker);
525+
mca_spml_ucx_ctx_fini(&mca_spml_ucx_ctx_default);
517526
}
518527

519528
if (mca_spml_ucx.aux_ctx != NULL) {
520-
ucp_worker_destroy(mca_spml_ucx.aux_ctx->ucp_worker[0]);
521-
free(mca_spml_ucx.aux_ctx->ucp_worker);
529+
mca_spml_ucx_ctx_fini(mca_spml_ucx.aux_ctx);
522530
}
523531

524532
mca_spml_ucx.enabled = false; /* not anymore */
525533

526534
free(mca_spml_ucx.active_array.ctxs);
527535
free(mca_spml_ucx.idle_array.ctxs);
528-
free(mca_spml_ucx.aux_ctx);
529536

530537
SHMEM_MUTEX_DESTROY(mca_spml_ucx.internal_mutex);
531538
pthread_mutex_destroy(&mca_spml_ucx.ctx_create_mutex);

0 commit comments

Comments
 (0)