From 937abaf4597a5dc1f62fce712b84c2f54543fd10 Mon Sep 17 00:00:00 2001 From: Thomas Vegas Date: Tue, 19 Sep 2023 09:47:42 +0300 Subject: [PATCH] oshmem: Add symmetric remote key handling code At very high scale, having each rank storing each other rank's remote keys for each segment can lead to high memory consumption. We activate symmetric remote key option to generate remote keys that will be deduplicated and then used interchangeably. Signed-off-by: Thomas Vegas (cherry picked from commit 776e8babd6868b968d1724161a6999861723b08a) --- config/ompi_check_ucx.m4 | 6 +- oshmem/mca/spml/ucx/spml_ucx.c | 182 +++++++++++++++++++++- oshmem/mca/spml/ucx/spml_ucx.h | 42 +++-- oshmem/mca/spml/ucx/spml_ucx_component.c | 39 +++-- oshmem/mca/sshmem/ucx/sshmem_ucx.h | 2 + oshmem/mca/sshmem/ucx/sshmem_ucx_module.c | 3 +- 6 files changed, 243 insertions(+), 31 deletions(-) diff --git a/config/ompi_check_ucx.m4 b/config/ompi_check_ucx.m4 index d5bd30eb2ec..efc999c0530 100644 --- a/config/ompi_check_ucx.m4 +++ b/config/ompi_check_ucx.m4 @@ -140,7 +140,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ UCP_PARAM_FIELD_ESTIMATED_NUM_PPN, UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK, UCP_OP_ATTR_FLAG_MULTI_SEND, - UCS_MEMORY_TYPE_RDMA], + UCS_MEMORY_TYPE_RDMA, + UCP_MEM_MAP_SYMMETRIC_RKEY], [], [], [#include ]) AC_CHECK_DECLS([UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS], @@ -153,7 +154,8 @@ AC_DEFUN([OMPI_CHECK_UCX],[ [#include ]) AC_CHECK_DECLS([ucp_tag_send_nbx, ucp_tag_send_sync_nbx, - ucp_tag_recv_nbx], + ucp_tag_recv_nbx, + ucp_rkey_compare], [], [], [#include ]) AC_CHECK_TYPES([ucp_request_param_t], diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 0acdf2c592f..21a1d817bc1 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -98,6 +98,171 @@ static ucp_request_param_t mca_spml_ucx_request_param_b = { }; #endif +unsigned +mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx) +{ +#if HAVE_DECL_UCP_MEM_MAP_SYMMETRIC_RKEY + if (spml_ucx->symmetric_rkey_max_count > 0) { + return UCP_MEM_MAP_SYMMETRIC_RKEY; + } +#endif + + return 0; +} + +void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store) +{ + store->array = NULL; + store->count = 0; + store->size = 0; +} + +void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store) +{ + int i; + + for (i = 0; i < store->count; i++) { + if (store->array[i].refcnt != 0) { + SPML_UCX_ERROR("rkey store destroy: %d/%d has refcnt %d > 0", + i, store->count, store->array[i].refcnt); + } + + ucp_rkey_destroy(store->array[i].rkey); + } + + free(store->array); +} + +/** + * Find position in sorted array for existing or future entry + * + * @param[in] store Store of the entries + * @param[in] worker Common worker for rkeys used + * @param[in] rkey Remote key to search for + * @param[out] index Index of entry + * + * @return + * OSHMEM_ERR_NOT_FOUND: index contains the position where future element + * should be inserted to keep array sorted + * OSHMEM_SUCCESS : index contains the position of the element + * Other error : index is not valid + */ +static int mca_spml_ucx_rkey_store_find(const mca_spml_ucx_rkey_store_t *store, + const ucp_worker_h worker, + const ucp_rkey_h rkey, + int *index) +{ +#if HAVE_DECL_UCP_RKEY_COMPARE + ucp_rkey_compare_params_t params; + int i, result, m, end; + ucs_status_t status; + + for (i = 0, end = store->count; i < end;) { + m = (i + end) / 2; + + params.field_mask = 0; + status = ucp_rkey_compare(worker, store->array[m].rkey, + rkey, ¶ms, &result); + if (status != UCS_OK) { + return OSHMEM_ERROR; + } else if (result == 0) { + *index = m; + return OSHMEM_SUCCESS; + } else if (result > 0) { + end = m; + } else { + i = m + 1; + } + } + + *index = i; + return OSHMEM_ERR_NOT_FOUND; +#else + return OSHMEM_ERROR; +#endif +} + +static void mca_spml_ucx_rkey_store_insert(mca_spml_ucx_rkey_store_t *store, + int i, ucp_rkey_h rkey) +{ + int size; + mca_spml_ucx_rkey_t *tmp; + + if (store->count >= mca_spml_ucx.symmetric_rkey_max_count) { + return; + } + + if (store->count >= store->size) { + size = sshmem_ucx_min(sshmem_ucx_max(store->size, 8) * 2, + mca_spml_ucx.symmetric_rkey_max_count); + tmp = realloc(store->array, size * sizeof(*store->array)); + if (tmp == NULL) { + return; + } + + store->array = tmp; + store->size = size; + } + + memmove(&store->array[i + 1], &store->array[i], + (store->count - i) * sizeof(*store->array)); + store->array[i].rkey = rkey; + store->array[i].refcnt = 1; + store->count++; + return; +} + +/* Takes ownership of input ucp remote key */ +static ucp_rkey_h mca_spml_ucx_rkey_store_get(mca_spml_ucx_rkey_store_t *store, + ucp_worker_h worker, + ucp_rkey_h rkey) +{ + int ret, i; + + if (mca_spml_ucx.symmetric_rkey_max_count == 0) { + return rkey; + } + + ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i); + if (ret == OSHMEM_SUCCESS) { + ucp_rkey_destroy(rkey); + store->array[i].refcnt++; + return store->array[i].rkey; + } + + if (ret == OSHMEM_ERR_NOT_FOUND) { + mca_spml_ucx_rkey_store_insert(store, i, rkey); + } + + return rkey; +} + +static void mca_spml_ucx_rkey_store_put(mca_spml_ucx_rkey_store_t *store, + ucp_worker_h worker, + ucp_rkey_h rkey) +{ + mca_spml_ucx_rkey_t *entry; + int ret, i; + + ret = mca_spml_ucx_rkey_store_find(store, worker, rkey, &i); + if (ret != OSHMEM_SUCCESS) { + goto out; + } + + entry = &store->array[i]; + assert(entry->rkey == rkey); + if (--entry->refcnt > 0) { + return; + } + + memmove(&store->array[i], &store->array[i + 1], + (store->count - (i + 1)) * sizeof(*store->array)); + store->count--; + +out: + ucp_rkey_destroy(rkey); +} + int mca_spml_ucx_enable(bool enable) { SPML_UCX_VERBOSE(50, "*** ucx ENABLED ****"); @@ -212,6 +377,7 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn { int rc; ucs_status_t err; + ucp_rkey_h rkey; rc = mca_spml_ucx_ctx_mkey_new(ucx_ctx, pe, segno, ucx_mkey); if (OSHMEM_SUCCESS != rc) { @@ -220,11 +386,18 @@ int mca_spml_ucx_ctx_mkey_add(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn } if (mkey->u.data) { - err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &((*ucx_mkey)->rkey)); + err = ucp_ep_rkey_unpack(ucx_ctx->ucp_peers[pe].ucp_conn, mkey->u.data, &rkey); if (UCS_OK != err) { SPML_UCX_ERROR("failed to unpack rkey: %s", ucs_status_string(err)); return OSHMEM_ERROR; } + + if (!oshmem_proc_on_local_node(pe)) { + rkey = mca_spml_ucx_rkey_store_get(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], rkey); + } + + (*ucx_mkey)->rkey = rkey; + rc = mca_spml_ucx_ctx_mkey_cache(ucx_ctx, mkey, segno, pe); if (OSHMEM_SUCCESS != rc) { SPML_UCX_ERROR("mca_spml_ucx_ctx_mkey_cache failed"); @@ -239,7 +412,7 @@ int mca_spml_ucx_ctx_mkey_del(mca_spml_ucx_ctx_t *ucx_ctx, int pe, uint32_t segn ucp_peer_t *ucp_peer; int rc; ucp_peer = &(ucx_ctx->ucp_peers[pe]); - ucp_rkey_destroy(ucx_mkey->rkey); + mca_spml_ucx_rkey_store_put(&ucx_ctx->rkey_store, ucx_ctx->ucp_worker[0], ucx_mkey->rkey); ucx_mkey->rkey = NULL; rc = mca_spml_ucx_peer_mkey_cache_del(ucp_peer, segno); if(OSHMEM_SUCCESS != rc){ @@ -697,7 +870,8 @@ sshmem_mkey_t *mca_spml_ucx_register(void* addr, UCP_MEM_MAP_PARAM_FIELD_FLAGS; mem_map_params.address = addr; mem_map_params.length = size; - mem_map_params.flags = flags; + mem_map_params.flags = flags | + mca_spml_ucx_mem_map_flags_symmetric_rkey(&mca_spml_ucx); status = ucp_mem_map(mca_spml_ucx.ucp_context, &mem_map_params, &mem_h); if (UCS_OK != status) { @@ -887,6 +1061,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx } } + mca_spml_ucx_rkey_store_init(&ucx_ctx->rkey_store); + *ucx_ctx_p = ucx_ctx; return OSHMEM_SUCCESS; diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index 00d34b092eb..162ae855827 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -77,18 +77,31 @@ struct ucp_peer { size_t mkeys_cnt; }; typedef struct ucp_peer ucp_peer_t; - + +/* An rkey_store entry */ +typedef struct mca_spml_ucx_rkey { + ucp_rkey_h rkey; + int refcnt; +} mca_spml_ucx_rkey_t; + +typedef struct mca_spml_ucx_rkey_store { + mca_spml_ucx_rkey_t *array; + int size; + int count; +} mca_spml_ucx_rkey_store_t; + struct mca_spml_ucx_ctx { - ucp_worker_h *ucp_worker; - ucp_peer_t *ucp_peers; - long options; - opal_bitmap_t put_op_bitmap; - unsigned long nb_progress_cnt; - unsigned int ucp_workers; - int *put_proc_indexes; - unsigned put_proc_count; - bool synchronized_quiet; - int strong_sync; + ucp_worker_h *ucp_worker; + ucp_peer_t *ucp_peers; + long options; + opal_bitmap_t put_op_bitmap; + unsigned long nb_progress_cnt; + unsigned int ucp_workers; + int *put_proc_indexes; + unsigned put_proc_count; + bool synchronized_quiet; + int strong_sync; + mca_spml_ucx_rkey_store_t rkey_store; }; typedef struct mca_spml_ucx_ctx mca_spml_ucx_ctx_t; @@ -129,6 +142,7 @@ struct mca_spml_ucx { unsigned long nb_ucp_worker_progress; unsigned int ucp_workers; unsigned int ucp_worker_cnt; + int symmetric_rkey_max_count; }; typedef struct mca_spml_ucx mca_spml_ucx_t; @@ -217,6 +231,12 @@ int mca_spml_ucx_peer_mkey_cache_del(ucp_peer_t *ucp_peer, int segno); void mca_spml_ucx_peer_mkey_cache_release(ucp_peer_t *ucp_peer); void mca_spml_ucx_peer_mkey_cache_init(mca_spml_ucx_ctx_t *ucx_ctx, int pe); +extern unsigned +mca_spml_ucx_mem_map_flags_symmetric_rkey(struct mca_spml_ucx *spml_ucx); + +extern void mca_spml_ucx_rkey_store_init(mca_spml_ucx_rkey_store_t *store); +extern void mca_spml_ucx_rkey_store_cleanup(mca_spml_ucx_rkey_store_t *store); + static inline int mca_spml_ucx_peer_mkey_get(ucp_peer_t *ucp_peer, int index, spml_ucx_cached_mkey_t **out_rmkey) { diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index bdbcf25a5ad..89c04e2f9eb 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -153,6 +153,10 @@ static int mca_spml_ucx_component_register(void) "Enable asynchronous progress thread", &mca_spml_ucx.async_progress); + mca_spml_ucx_param_register_int("symmetric_rkey_max_count", 0, + "Size of the symmetric key store. Non-zero to enable, typical use 5000", + &mca_spml_ucx.symmetric_rkey_max_count); + mca_spml_ucx_param_register_int("async_tick_usec", 3000, "Asynchronous progress tick granularity (in usec)", &mca_spml_ucx.async_tick); @@ -332,6 +336,8 @@ static int spml_ucx_init(void) mca_spml_ucx_ctx_default.ucp_workers++; } + mca_spml_ucx_rkey_store_init(&mca_spml_ucx_ctx_default.rkey_store); + wrk_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; err = ucp_worker_query(mca_spml_ucx_ctx_default.ucp_worker[0], &wrk_attr); @@ -432,10 +438,25 @@ static void _ctx_cleanup(mca_spml_ucx_ctx_t *ctx) free(ctx->ucp_peers); } +static void mca_spml_ucx_ctx_fini(mca_spml_ucx_ctx_t *ctx) +{ + unsigned int i; + + mca_spml_ucx_rkey_store_cleanup(&ctx->rkey_store); + for (i = 0; i < ctx->ucp_workers; i++) { + ucp_worker_destroy(ctx->ucp_worker[i]); + } + free(ctx->ucp_worker); + if (ctx != &mca_spml_ucx_ctx_default) { + free(ctx); + } +} + static int mca_spml_ucx_component_fini(void) { int fenced = 0, i; int ret = OSHMEM_SUCCESS; + mca_spml_ucx_ctx_t *ctx; opal_progress_unregister(spml_ucx_default_progress); if (mca_spml_ucx.active_array.ctxs_count) { @@ -488,36 +509,26 @@ static int mca_spml_ucx_component_fini(void) } } - /* delete all workers */ for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]); - free(mca_spml_ucx.active_array.ctxs[i]->ucp_worker); - free(mca_spml_ucx.active_array.ctxs[i]); + mca_spml_ucx_ctx_fini(mca_spml_ucx.active_array.ctxs[i]); } for (i = 0; i < mca_spml_ucx.idle_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker[0]); - free(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker); - free(mca_spml_ucx.idle_array.ctxs[i]); + mca_spml_ucx_ctx_fini(mca_spml_ucx.idle_array.ctxs[i]); } if (mca_spml_ucx_ctx_default.ucp_worker) { - for (i = 0; i < (signed int)mca_spml_ucx.ucp_workers; i++) { - ucp_worker_destroy(mca_spml_ucx_ctx_default.ucp_worker[i]); - } - free(mca_spml_ucx_ctx_default.ucp_worker); + mca_spml_ucx_ctx_fini(&mca_spml_ucx_ctx_default); } if (mca_spml_ucx.aux_ctx != NULL) { - ucp_worker_destroy(mca_spml_ucx.aux_ctx->ucp_worker[0]); - free(mca_spml_ucx.aux_ctx->ucp_worker); + mca_spml_ucx_ctx_fini(mca_spml_ucx.aux_ctx); } mca_spml_ucx.enabled = false; /* not anymore */ free(mca_spml_ucx.active_array.ctxs); free(mca_spml_ucx.idle_array.ctxs); - free(mca_spml_ucx.aux_ctx); SHMEM_MUTEX_DESTROY(mca_spml_ucx.internal_mutex); pthread_mutex_destroy(&mca_spml_ucx.ctx_create_mutex); diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx.h b/oshmem/mca/sshmem/ucx/sshmem_ucx.h index 1b8e6e8dcce..06e4b20c737 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx.h +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx.h @@ -21,6 +21,8 @@ BEGIN_C_DECLS typedef struct sshmem_ucx_shadow_allocator sshmem_ucx_shadow_allocator_t; +#define sshmem_ucx_min(a, b) ((a) < (b) ? (a) : (b)) +#define sshmem_ucx_max(a, b) ((a) > (b) ? (a) : (b)) /** * globally exported variable to hold the ucx component. */ diff --git a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c index 4f48be0f66d..7ec5b0abba2 100644 --- a/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c +++ b/oshmem/mca/sshmem/ucx/sshmem_ucx_module.c @@ -117,7 +117,8 @@ segment_create_internal(map_segment_t *ds_buf, void *address, size_t size, mem_map_params.address = address; mem_map_params.length = size; - mem_map_params.flags = flags; + mem_map_params.flags = flags | + mca_spml_ucx_mem_map_flags_symmetric_rkey(spml); mem_map_params.memory_type = mem_type; status = ucp_mem_map(spml->ucp_context, &mem_map_params, &mem_h);