diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c index 8c5fa1d1a64..b3cddcd6d2b 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_cswap.c @@ -50,6 +50,6 @@ int mca_atomic_ucx_cswap(shmem_ctx_t ctx, mca_spml_ucx_remote_op_posted(ucx_ctx, pe); } - return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker, + return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], "ucp_atomic_fetch_nb"); } diff --git a/oshmem/mca/atomic/ucx/atomic_ucx_module.c b/oshmem/mca/atomic/ucx/atomic_ucx_module.c index 882b83f6520..34ed0b551b9 100644 --- a/oshmem/mca/atomic/ucx/atomic_ucx_module.c +++ b/oshmem/mca/atomic/ucx/atomic_ucx_module.c @@ -80,7 +80,7 @@ int mca_atomic_ucx_fop(shmem_ctx_t ctx, op, value, prev, size, rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); - return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker, + return opal_common_ucx_wait_request(status_ptr, ucx_ctx->ucp_worker[0], "ucp_atomic_fetch_nb"); } diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 6ade52446ae..60453e92438 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -105,8 +105,9 @@ int mca_spml_ucx_enable(bool enable) int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) { opal_common_ucx_del_proc_t *del_procs; - size_t i; + size_t i, w, n; int ret; + int ucp_workers = mca_spml_ucx.ucp_workers; oshmem_shmem_barrier(); @@ -129,10 +130,23 @@ int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) ret = opal_common_ucx_del_procs_nofence(del_procs, nprocs, oshmem_my_proc_id(), mca_spml_ucx.num_disconnect, - mca_spml_ucx_ctx_default.ucp_worker); + mca_spml_ucx_ctx_default.ucp_worker[0]); /* No need to barrier here - barrier is called in _shmem_finalize */ free(del_procs); - free(mca_spml_ucx.remote_addrs_tbl); + if (mca_spml_ucx.remote_addrs_tbl) { + for (w = 0; w < ucp_workers; w++) { + if (mca_spml_ucx.remote_addrs_tbl[w]) { + for (n = 0; n < nprocs; n++) { + if (mca_spml_ucx.remote_addrs_tbl[w][n]) { + free(mca_spml_ucx.remote_addrs_tbl[w][n]); + } + } + free(mca_spml_ucx.remote_addrs_tbl[w]); + } + } + free(mca_spml_ucx.remote_addrs_tbl); + } + free(mca_spml_ucx_ctx_default.ucp_peers); mca_spml_ucx_ctx_default.ucp_peers = NULL; @@ -142,48 +156,80 @@ int mca_spml_ucx_del_procs(ompi_proc_t** procs, size_t nprocs) /* TODO: move func into common place, use it with rkey exchng too */ static int oshmem_shmem_xchng( - void *local_data, int local_size, int nprocs, - void **rdata_p, int **roffsets_p, int **rsizes_p) + void **local_data, unsigned int *local_size, int nprocs, int ucp_workers, + void **rdata_p, unsigned int **roffsets_p, unsigned int **rsizes_p) { - int *rcv_sizes = NULL; - int *rcv_offsets = NULL; - void *rcv_buf = NULL; + unsigned int *rcv_sizes = NULL; + int *_rcv_sizes = NULL; + unsigned int *rcv_offsets = NULL; + int *_rcv_offsets = NULL; + void *rcv_buf = NULL; int rc; - int i; + int i,j,k; - /* do llgatherv */ - rcv_offsets = malloc(nprocs * sizeof(*rcv_offsets)); + /* do allgatherv */ + rcv_offsets = calloc(ucp_workers * nprocs, sizeof(*rcv_offsets)); if (NULL == rcv_offsets) { goto err; } /* todo: move into separate function. do allgatherv */ - rcv_sizes = malloc(nprocs * sizeof(*rcv_sizes)); + rcv_sizes = calloc(ucp_workers * nprocs, sizeof(*rcv_sizes)); if (NULL == rcv_sizes) { goto err; } - - rc = oshmem_shmem_allgather(&local_size, rcv_sizes, sizeof(int)); + + rc = oshmem_shmem_allgather(local_size, rcv_sizes, ucp_workers * sizeof(*rcv_sizes)); if (MPI_SUCCESS != rc) { goto err; } /* calculate displacements */ rcv_offsets[0] = 0; - for (i = 1; i < nprocs; i++) { + for (i = 1; i < ucp_workers * nprocs; i++) { rcv_offsets[i] = rcv_offsets[i - 1] + rcv_sizes[i - 1]; } - rcv_buf = malloc(rcv_offsets[nprocs - 1] + rcv_sizes[nprocs - 1]); + rcv_buf = calloc(1, rcv_offsets[(ucp_workers * nprocs) - 1] + + rcv_sizes[(ucp_workers * nprocs) - 1]); if (NULL == rcv_buf) { goto err; } + + int _local_size = 0; + for (i = 0; i < ucp_workers; i++) { + _local_size += local_size[i]; + } + _rcv_offsets = calloc(nprocs, sizeof(*rcv_offsets)); + _rcv_sizes = calloc(nprocs, sizeof(*rcv_sizes)); + + k = 0; + for (i = 0; i < nprocs; i++) { + for (j = 0; j < ucp_workers; j++, k++) { + _rcv_sizes[i] += rcv_sizes[k]; + } + } + + _rcv_offsets[0] = 0; + for (i = 1; i < nprocs; i++) { + _rcv_offsets[i] = _rcv_offsets[i - 1] + _rcv_sizes[i - 1]; + } + + char *_local_data = calloc(_local_size, 1); + int new_offset = 0; + for (i = 0; i < ucp_workers; i++) { + memcpy((char *) (_local_data+new_offset), (char *)local_data[i], local_size[i]); + new_offset += local_size[i]; + } - rc = oshmem_shmem_allgatherv(local_data, rcv_buf, local_size, rcv_sizes, rcv_offsets); + rc = oshmem_shmem_allgatherv(_local_data, rcv_buf, _local_size, _rcv_sizes, _rcv_offsets); if (MPI_SUCCESS != rc) { goto err; } + free (_local_data); + free (_rcv_sizes); + free (_rcv_offsets); *rdata_p = rcv_buf; *roffsets_p = rcv_offsets; *rsizes_p = rcv_sizes; @@ -199,19 +245,6 @@ static int oshmem_shmem_xchng( return OSHMEM_ERROR; } -static void dump_address(int pe, char *addr, size_t len) -{ -#ifdef SPML_UCX_DEBUG - int my_rank = oshmem_my_proc_id(); - unsigned i; - - printf("me=%d dest_pe=%d addr=%p len=%d\n", my_rank, pe, addr, len); - for (i = 0; i < len; i++) { - printf("%02X ", (unsigned)0xFF&addr[i]); - } - printf("\n"); -#endif -} static char spml_ucx_transport_ids[1] = { 0 }; @@ -251,17 +284,20 @@ int mca_spml_ucx_clear_put_op_mask(mca_spml_ucx_ctx_t *ctx) int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs) { - size_t i, j, n; + size_t i, j, k, w, n; int rc = OSHMEM_ERROR; int my_rank = oshmem_my_proc_id(); + int ucp_workers = mca_spml_ucx.ucp_workers; ucs_status_t err; - ucp_address_t *wk_local_addr; - size_t wk_addr_len; - int *wk_roffs = NULL; - int *wk_rsizes = NULL; + ucp_address_t **wk_local_addr; + unsigned int *wk_addr_len; + unsigned int *wk_roffs = NULL; + unsigned int *wk_rsizes = NULL; char *wk_raddrs = NULL; ucp_ep_params_t ep_params; + wk_local_addr = calloc(mca_spml_ucx.ucp_workers, sizeof(ucp_address_t *)); + wk_addr_len = calloc(mca_spml_ucx.ucp_workers, sizeof(size_t)); mca_spml_ucx_ctx_default.ucp_peers = (ucp_peer_t *) calloc(nprocs, sizeof(*(mca_spml_ucx_ctx_default.ucp_peers))); if (NULL == mca_spml_ucx_ctx_default.ucp_peers) { @@ -273,13 +309,16 @@ int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs) goto error; } - err = ucp_worker_get_address(mca_spml_ucx_ctx_default.ucp_worker, &wk_local_addr, &wk_addr_len); - if (err != UCS_OK) { - goto error; + for (i = 0; i < mca_spml_ucx.ucp_workers; i++) { + size_t tmp_len; + err = ucp_worker_get_address(mca_spml_ucx_ctx_default.ucp_worker[i], &wk_local_addr[i], &tmp_len); + wk_addr_len[i] = (unsigned int)tmp_len; + if (err != UCS_OK) { + goto error; + } } - dump_address(my_rank, (char *)wk_local_addr, wk_addr_len); - rc = oshmem_shmem_xchng(wk_local_addr, wk_addr_len, nprocs, + rc = oshmem_shmem_xchng((void **)wk_local_addr, wk_addr_len, nprocs, (int) mca_spml_ucx.ucp_workers, (void **)&wk_raddrs, &wk_roffs, &wk_rsizes); if (rc != OSHMEM_SUCCESS) { goto error; @@ -287,22 +326,34 @@ int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs) opal_progress_register(spml_ucx_default_progress); - mca_spml_ucx.remote_addrs_tbl = (char **)calloc(nprocs, sizeof(char *)); - memset(mca_spml_ucx.remote_addrs_tbl, 0, nprocs * sizeof(char *)); + mca_spml_ucx.remote_addrs_tbl = (char ***)calloc(mca_spml_ucx.ucp_workers, + sizeof(mca_spml_ucx.remote_addrs_tbl[0])); + for (w = 0; w < ucp_workers; w++) { + mca_spml_ucx.remote_addrs_tbl[w] = (char **)calloc(nprocs, sizeof(mca_spml_ucx.remote_addrs_tbl[w][0])); + } + + /* Store all remote addresses */ + int offset = 0; + for (i = 0, n = 0; n < nprocs; n++) { + for (w = 0; w < ucp_workers; w++, i++) { + mca_spml_ucx.remote_addrs_tbl[w][n] = (char *)malloc(wk_rsizes[i]); + memcpy(mca_spml_ucx.remote_addrs_tbl[w][n], (char *)(wk_raddrs + offset), wk_rsizes[i]); + offset+=wk_rsizes[i]; + } + } /* Get the EP connection requests for all the processes from modex */ for (n = 0; n < nprocs; ++n) { i = (my_rank + n) % nprocs; - dump_address(i, (char *)(wk_raddrs + wk_roffs[i]), wk_rsizes[i]); ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = (ucp_address_t *)(wk_raddrs + wk_roffs[i]); + ep_params.address = (ucp_address_t *)mca_spml_ucx.remote_addrs_tbl[0][i]; - err = ucp_ep_create(mca_spml_ucx_ctx_default.ucp_worker, &ep_params, - &mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn); + err = ucp_ep_create(mca_spml_ucx_ctx_default.ucp_worker[0], &ep_params, + &mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn); if (UCS_OK != err) { SPML_UCX_ERROR("ucp_ep_create(proc=%zu/%zu) failed: %s", n, nprocs, - ucs_status_string(err)); + ucs_status_string(err)); goto error2; } @@ -312,20 +363,22 @@ int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs) for (j = 0; j < MCA_MEMHEAP_MAX_SEGMENTS; j++) { mca_spml_ucx_ctx_default.ucp_peers[i].mkeys[j].key.rkey = NULL; } + } - mca_spml_ucx.remote_addrs_tbl[i] = (char *)malloc(wk_rsizes[i]); - memcpy(mca_spml_ucx.remote_addrs_tbl[i], (char *)(wk_raddrs + wk_roffs[i]), - wk_rsizes[i]); + for (i = 0; i < mca_spml_ucx.ucp_workers; i++) { + ucp_worker_release_address(mca_spml_ucx_ctx_default.ucp_worker[i], wk_local_addr[i]); } - ucp_worker_release_address(mca_spml_ucx_ctx_default.ucp_worker, wk_local_addr); free(wk_raddrs); free(wk_rsizes); free(wk_roffs); + free(wk_addr_len); + free(wk_local_addr); SPML_UCX_VERBOSE(50, "*** ADDED PROCS ***"); opal_common_ucx_mca_proc_added(); + return OSHMEM_SUCCESS; error2: @@ -333,20 +386,31 @@ int mca_spml_ucx_add_procs(ompi_proc_t** procs, size_t nprocs) if (mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn) { ucp_ep_destroy(mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn); } - if (mca_spml_ucx.remote_addrs_tbl[i]) { - free(mca_spml_ucx.remote_addrs_tbl[i]); - } + } + + if (mca_spml_ucx.remote_addrs_tbl) { + for (w = 0; w < ucp_workers; w++) { + if (mca_spml_ucx.remote_addrs_tbl[w]) { + for (n = 0; n < nprocs; n++) { + if (mca_spml_ucx.remote_addrs_tbl[w][n]) { + free(mca_spml_ucx.remote_addrs_tbl[w][n]); + } + } + free(mca_spml_ucx.remote_addrs_tbl[w]); + } + } + free(mca_spml_ucx.remote_addrs_tbl); } mca_spml_ucx_clear_put_op_mask(&mca_spml_ucx_ctx_default); if (mca_spml_ucx_ctx_default.ucp_peers) free(mca_spml_ucx_ctx_default.ucp_peers); - if (mca_spml_ucx.remote_addrs_tbl) - free(mca_spml_ucx.remote_addrs_tbl); free(wk_raddrs); free(wk_rsizes); free(wk_roffs); error: + free(wk_addr_len); + free(wk_local_addr); rc = OSHMEM_ERR_OUT_OF_RESOURCE; SPML_UCX_ERROR("add procs FAILED rc=%d", rc); return rc; @@ -596,6 +660,7 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx ucp_ep_params_t ep_params; size_t i, nprocs = oshmem_num_procs(); int j; + unsigned int cur_ucp_worker = mca_spml_ucx.ucp_worker_cnt++ % mca_spml_ucx.ucp_workers; ucs_status_t err; spml_ucx_mkey_t *ucx_mkey; sshmem_mkey_t *mkey; @@ -604,6 +669,8 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx ucx_ctx = malloc(sizeof(mca_spml_ucx_ctx_t)); ucx_ctx->options = options; + ucx_ctx->ucp_worker = calloc(1, sizeof(ucp_worker_h)); + ucx_ctx->ucp_workers = 1; params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE || options & SHMEM_CTX_PRIVATE || options & SHMEM_CTX_SERIALIZED) { @@ -613,7 +680,7 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx } err = ucp_worker_create(mca_spml_ucx.ucp_context, ¶ms, - &ucx_ctx->ucp_worker); + &ucx_ctx->ucp_worker[0]); if (UCS_OK != err) { free(ucx_ctx); return OSHMEM_ERROR; @@ -631,8 +698,9 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx for (i = 0; i < nprocs; i++) { ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = (ucp_address_t *)(mca_spml_ucx.remote_addrs_tbl[i]); - err = ucp_ep_create(ucx_ctx->ucp_worker, &ep_params, + ep_params.address = (ucp_address_t *)(mca_spml_ucx.remote_addrs_tbl[cur_ucp_worker][i]); + + err = ucp_ep_create(ucx_ctx->ucp_worker[0], &ep_params, &ucx_ctx->ucp_peers[i].ucp_conn); if (UCS_OK != err) { SPML_ERROR("ucp_ep_create(proc=%d/%d) failed: %s", i, nprocs, @@ -673,7 +741,9 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx free(ucx_ctx->ucp_peers); error: - ucp_worker_destroy(ucx_ctx->ucp_worker); + ucp_worker_destroy(ucx_ctx->ucp_worker[0]); + free(ucx_ctx->ucp_worker); + ucx_ctx->ucp_worker = NULL; free(ucx_ctx); rc = OSHMEM_ERR_OUT_OF_RESOURCE; SPML_ERROR("ctx create FAILED rc=%d", rc); @@ -755,7 +825,7 @@ int mca_spml_ucx_get(shmem_ctx_t ctx, void *src_addr, size_t size, void *dst_add #if HAVE_DECL_UCP_GET_NB request = ucp_get_nb(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); - return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker, "ucp_get_nb"); + return opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_get_nb"); #else status = ucp_get(ucx_ctx->ucp_peers[src].ucp_conn, dst_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -791,7 +861,7 @@ int mca_spml_ucx_get_nb_wprogress(shmem_ctx_t ctx, void *src_addr, size_t size, if (++ucx_ctx->nb_progress_cnt > mca_spml_ucx.nb_get_progress_thresh) { for (i = 0; i < mca_spml_ucx.nb_ucp_worker_progress; i++) { - if (!ucp_worker_progress(ucx_ctx->ucp_worker)) { + if (!ucp_worker_progress(ucx_ctx->ucp_worker[0])) { ucx_ctx->nb_progress_cnt = 0; break; } @@ -817,7 +887,7 @@ int mca_spml_ucx_put(shmem_ctx_t ctx, void* dst_addr, size_t size, void* src_add #if HAVE_DECL_UCP_PUT_NB request = ucp_put_nb(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey, opal_common_ucx_empty_complete_cb); - res = opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker, "ucp_put_nb"); + res = opal_common_ucx_wait_request(request, ucx_ctx->ucp_worker[0], "ucp_put_nb"); #else status = ucp_put(ucx_ctx->ucp_peers[dst].ucp_conn, src_addr, size, (uint64_t)rva, ucx_mkey->rkey); @@ -867,7 +937,7 @@ int mca_spml_ucx_put_nb_wprogress(shmem_ctx_t ctx, void* dst_addr, size_t size, if (++ucx_ctx->nb_progress_cnt > mca_spml_ucx.nb_put_progress_thresh) { for (i = 0; i < mca_spml_ucx.nb_ucp_worker_progress; i++) { - if (!ucp_worker_progress(ucx_ctx->ucp_worker)) { + if (!ucp_worker_progress(ucx_ctx->ucp_worker[0])) { ucx_ctx->nb_progress_cnt = 0; break; } @@ -880,15 +950,20 @@ int mca_spml_ucx_put_nb_wprogress(shmem_ctx_t ctx, void* dst_addr, size_t size, int mca_spml_ucx_fence(shmem_ctx_t ctx) { ucs_status_t err; + unsigned int i = 0; mca_spml_ucx_ctx_t *ucx_ctx = (mca_spml_ucx_ctx_t *)ctx; opal_atomic_wmb(); - err = ucp_worker_fence(ucx_ctx->ucp_worker); - if (UCS_OK != err) { - SPML_UCX_ERROR("fence failed: %s", ucs_status_string(err)); - oshmem_shmem_abort(-1); - return OSHMEM_ERROR; + for (i=0; i < ucx_ctx->ucp_workers; i++) { + if (ucx_ctx->ucp_worker[i] != NULL) { + err = ucp_worker_fence(ucx_ctx->ucp_worker[i]); + if (UCS_OK != err) { + SPML_UCX_ERROR("fence failed: %s", ucs_status_string(err)); + oshmem_shmem_abort(-1); + return OSHMEM_ERROR; + } + } } return OSHMEM_SUCCESS; } @@ -919,10 +994,14 @@ int mca_spml_ucx_quiet(shmem_ctx_t ctx) opal_atomic_wmb(); - ret = opal_common_ucx_worker_flush(ucx_ctx->ucp_worker); - if (OMPI_SUCCESS != ret) { - oshmem_shmem_abort(-1); - return ret; + for (i = 0; i < ucx_ctx->ucp_workers; i++) { + if (ucx_ctx->ucp_worker[i] != NULL) { + ret = opal_common_ucx_worker_flush(ucx_ctx->ucp_worker[i]); + if (OMPI_SUCCESS != ret) { + oshmem_shmem_abort(-1); + return ret; + } + } } /* If put_all_nb op/s is/are being executed asynchronously, need to wait its @@ -1060,7 +1139,7 @@ int mca_spml_ucx_put_all_nb(void *dest, const void *source, size_t size, long *c RUNTIME_CHECK_RC(rc); } - request = ucp_worker_flush_nb(((mca_spml_ucx_ctx_t*)ctx)->ucp_worker, 0, + request = ucp_worker_flush_nb(((mca_spml_ucx_ctx_t*)ctx)->ucp_worker[0], 0, mca_spml_ucx_put_all_complete_cb); if (!UCS_PTR_IS_PTR(request)) { mca_spml_ucx_put_all_complete_cb(NULL, UCS_PTR_STATUS(request)); diff --git a/oshmem/mca/spml/ucx/spml_ucx.h b/oshmem/mca/spml/ucx/spml_ucx.h index d390002f3ed..56db4278c4e 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.h +++ b/oshmem/mca/spml/ucx/spml_ucx.h @@ -68,11 +68,12 @@ struct ucp_peer { typedef struct ucp_peer ucp_peer_t; struct mca_spml_ucx_ctx { - ucp_worker_h ucp_worker; + ucp_worker_h *ucp_worker; ucp_peer_t *ucp_peers; long options; opal_bitmap_t put_op_bitmap; unsigned long nb_progress_cnt; + unsigned int ucp_workers; int *put_proc_indexes; unsigned put_proc_count; }; @@ -95,7 +96,7 @@ struct mca_spml_ucx { int heap_reg_nb; bool enabled; mca_spml_ucx_get_mkey_slow_fn_t get_mkey_slow; - char **remote_addrs_tbl; + char ***remote_addrs_tbl; mca_spml_ucx_ctx_array_t active_array; mca_spml_ucx_ctx_array_t idle_array; int priority; /* component priority */ @@ -114,6 +115,8 @@ struct mca_spml_ucx { unsigned long nb_put_progress_thresh; unsigned long nb_get_progress_thresh; unsigned long nb_ucp_worker_progress; + unsigned int ucp_workers; + unsigned int ucp_worker_cnt; }; typedef struct mca_spml_ucx mca_spml_ucx_t; diff --git a/oshmem/mca/spml/ucx/spml_ucx_component.c b/oshmem/mca/spml/ucx/spml_ucx_component.c index 192934649ac..5fd43bdbe5e 100644 --- a/oshmem/mca/spml/ucx/spml_ucx_component.c +++ b/oshmem/mca/spml/ucx/spml_ucx_component.c @@ -75,6 +75,21 @@ static inline void mca_spml_ucx_param_register_ulong(const char* param_name, storage); } +static inline void mca_spml_ucx_param_register_uint(const char* param_name, + unsigned int default_value, + const char *help_msg, + unsigned int *storage) +{ + *storage = default_value; + (void) mca_base_component_var_register(&mca_spml_ucx_component.spmlm_version, + param_name, + help_msg, + MCA_BASE_VAR_TYPE_UNSIGNED_INT, NULL, 0, 0, + OPAL_INFO_LVL_9, + MCA_BASE_VAR_SCOPE_READONLY, + storage); +} + static inline void mca_spml_ucx_param_register_int(const char* param_name, int default_value, const char *help_msg, @@ -161,6 +176,9 @@ static int mca_spml_ucx_component_register(void) mca_spml_ucx_param_register_ulong("nb_ucp_worker_progress", 32, "Maximum number of ucx worker progress calls if triggered during nb_put or nb_get", &mca_spml_ucx.nb_ucp_worker_progress); + mca_spml_ucx_param_register_uint("default_ctx_ucp_workers", 1, + "Number of ucp workers per default context", + &mca_spml_ucx.ucp_workers); opal_common_ucx_mca_var_register(&mca_spml_ucx_component.spmlm_version); @@ -171,14 +189,17 @@ int spml_ucx_ctx_progress(void) { int i; for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) { - ucp_worker_progress(mca_spml_ucx.active_array.ctxs[i]->ucp_worker); + ucp_worker_progress(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]); } return 1; } int spml_ucx_default_progress(void) { - ucp_worker_progress(mca_spml_ucx_ctx_default.ucp_worker); + unsigned int i=0; + for (i = 0; i < mca_spml_ucx.ucp_workers; i++) { + ucp_worker_progress(mca_spml_ucx_ctx_default.ucp_worker[i]); + } return 1; } @@ -194,7 +215,7 @@ int spml_ucx_progress_aux_ctx(void) return 0; } - count = ucp_worker_progress(mca_spml_ucx.aux_ctx->ucp_worker); + count = ucp_worker_progress(mca_spml_ucx.aux_ctx->ucp_worker[0]); pthread_spin_unlock(&mca_spml_ucx.async_lock); return count; @@ -209,7 +230,7 @@ void mca_spml_ucx_async_cb(int fd, short event, void *cbdata) } do { - count = ucp_worker_progress(mca_spml_ucx.aux_ctx->ucp_worker); + count = ucp_worker_progress(mca_spml_ucx.aux_ctx->ucp_worker[0]); } while (count); pthread_spin_unlock(&mca_spml_ucx.async_lock); @@ -227,12 +248,13 @@ static int mca_spml_ucx_component_close(void) static int spml_ucx_init(void) { + unsigned int i; ucs_status_t err; ucp_config_t *ucp_config; ucp_params_t params; ucp_context_attr_t attr; ucp_worker_params_t wkr_params; - ucp_worker_attr_t wkr_attr; + ucp_worker_attr_t wrk_attr; err = ucp_config_read("OSHMEM", NULL, &ucp_config); if (UCS_OK != err) { @@ -293,18 +315,22 @@ static int spml_ucx_init(void) } else { wkr_params.thread_mode = UCS_THREAD_MODE_SINGLE; } - - err = ucp_worker_create(mca_spml_ucx.ucp_context, &wkr_params, - &mca_spml_ucx_ctx_default.ucp_worker); - if (UCS_OK != err) { - return OSHMEM_ERROR; + + mca_spml_ucx_ctx_default.ucp_worker = calloc(mca_spml_ucx.ucp_workers, sizeof(ucp_worker_h)); + for (i = 0; i < mca_spml_ucx.ucp_workers; i++) { + err = ucp_worker_create(mca_spml_ucx.ucp_context, &wkr_params, + &mca_spml_ucx_ctx_default.ucp_worker[i]); + if (UCS_OK != err) { + return OSHMEM_ERROR; + } + mca_spml_ucx_ctx_default.ucp_workers++; } - wkr_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; - err = ucp_worker_query(mca_spml_ucx_ctx_default.ucp_worker, &wkr_attr); + wrk_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; + err = ucp_worker_query(mca_spml_ucx_ctx_default.ucp_worker[0], &wrk_attr); if (oshmem_mpi_thread_requested == SHMEM_THREAD_MULTIPLE && - wkr_attr.thread_mode != UCS_THREAD_MODE_MULTI) { + wrk_attr.thread_mode != UCS_THREAD_MODE_MULTI) { oshmem_mpi_thread_provided = SHMEM_THREAD_SINGLE; } @@ -377,7 +403,7 @@ static void _ctx_cleanup(mca_spml_ucx_ctx_t *ctx) opal_common_ucx_del_procs_nofence(del_procs, nprocs, oshmem_my_proc_id(), mca_spml_ucx.num_disconnect, - ctx->ucp_worker); + ctx->ucp_worker[0]); free(del_procs); mca_spml_ucx_clear_put_op_mask(ctx); free(ctx->ucp_peers); @@ -423,37 +449,45 @@ static int mca_spml_ucx_component_fini(void) while (!fenced) { for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) { - ucp_worker_progress(mca_spml_ucx.active_array.ctxs[i]->ucp_worker); + ucp_worker_progress(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]); } for (i = 0; i < mca_spml_ucx.idle_array.ctxs_count; i++) { - ucp_worker_progress(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker); + ucp_worker_progress(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker[0]); + } + + for (i = 0; i < (signed int)mca_spml_ucx.ucp_workers; i++) { + ucp_worker_progress(mca_spml_ucx_ctx_default.ucp_worker[i]); } - - ucp_worker_progress(mca_spml_ucx_ctx_default.ucp_worker); if (mca_spml_ucx.aux_ctx != NULL) { - ucp_worker_progress(mca_spml_ucx.aux_ctx->ucp_worker); + ucp_worker_progress(mca_spml_ucx.aux_ctx->ucp_worker[0]); } } /* delete all workers */ for (i = 0; i < mca_spml_ucx.active_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.active_array.ctxs[i]->ucp_worker); + ucp_worker_destroy(mca_spml_ucx.active_array.ctxs[i]->ucp_worker[0]); + free(mca_spml_ucx.active_array.ctxs[i]->ucp_worker); free(mca_spml_ucx.active_array.ctxs[i]); } for (i = 0; i < mca_spml_ucx.idle_array.ctxs_count; i++) { - ucp_worker_destroy(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker); + ucp_worker_destroy(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker[0]); + free(mca_spml_ucx.idle_array.ctxs[i]->ucp_worker); free(mca_spml_ucx.idle_array.ctxs[i]); } if (mca_spml_ucx_ctx_default.ucp_worker) { - ucp_worker_destroy(mca_spml_ucx_ctx_default.ucp_worker); + for (i = 0; i < (signed int)mca_spml_ucx.ucp_workers; i++) { + ucp_worker_destroy(mca_spml_ucx_ctx_default.ucp_worker[i]); + } + free(mca_spml_ucx_ctx_default.ucp_worker); } if (mca_spml_ucx.aux_ctx != NULL) { - ucp_worker_destroy(mca_spml_ucx.aux_ctx->ucp_worker); + ucp_worker_destroy(mca_spml_ucx.aux_ctx->ucp_worker[0]); + free(mca_spml_ucx.aux_ctx->ucp_worker); } mca_spml_ucx.enabled = false; /* not anymore */ @@ -472,4 +506,3 @@ static int mca_spml_ucx_component_fini(void) return OSHMEM_SUCCESS; } -