From 1ea6fb9a672766352bb858601d0fd363a700ad16 Mon Sep 17 00:00:00 2001 From: "Mamzi Bayatpour mbayatpour@nvidia.com ()" Date: Wed, 3 Aug 2022 17:45:17 -0700 Subject: [PATCH] OSC/UCX: Adding the following optimzations: 1) Reuse the same worker/eps in single threaded applications, this is helpful if an application creates many windows, therefore, we avoid the unnecessary overheads and 2) adding the truely nonblocking MPI_Accumulate/Get_Accumulate. Signed-off-by: Mamzi Bayatpour Co-authored-by: Tomislav Janjusic Co-authored-by: Joseph Schuchart > --- ompi/mca/osc/ucx/osc_ucx.h | 50 +- ompi/mca/osc/ucx/osc_ucx_active_target.c | 19 +- ompi/mca/osc/ucx/osc_ucx_comm.c | 740 +++++++++++++++++++--- ompi/mca/osc/ucx/osc_ucx_component.c | 194 +++--- ompi/mca/osc/ucx/osc_ucx_passive_target.c | 24 +- ompi/mca/osc/ucx/osc_ucx_request.c | 25 +- ompi/mca/osc/ucx/osc_ucx_request.h | 140 +++- opal/mca/common/ucx/common_ucx_wpool.c | 107 +++- opal/mca/common/ucx/common_ucx_wpool.h | 44 +- 9 files changed, 1073 insertions(+), 270 deletions(-) diff --git a/ompi/mca/osc/ucx/osc_ucx.h b/ompi/mca/osc/ucx/osc_ucx.h index f4e16a35808..8a1c6153cd5 100644 --- a/ompi/mca/osc/ucx/osc_ucx.h +++ b/ompi/mca/osc/ucx/osc_ucx.h @@ -27,13 +27,16 @@ #define OMPI_OSC_UCX_ATTACH_MAX 48 #define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN 1024 + typedef struct ompi_osc_ucx_component { ompi_osc_base_component_t super; opal_common_ucx_wpool_t *wpool; bool enable_mpi_threads; opal_free_list_t requests; /* request free list for the r* communication variants */ + opal_free_list_t accumulate_requests; /* request free list for the r* communication variants */ bool env_initialized; /* UCX environment is initialized or not */ - int num_incomplete_req_ops; + int comm_world_size; + ucp_ep_h *endpoints; int num_modules; bool no_locks; /* Default value of the no_locks info key for new windows */ bool acc_single_intrinsic; @@ -44,6 +47,16 @@ typedef struct ompi_osc_ucx_component { OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component; +#define OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(_module) \ + do { \ + opal_atomic_add_fetch_size_t(&_module->ctx->num_incomplete_req_ops, 1); \ + } while(0); + +#define OSC_UCX_DECREMENT_OUTSTANDING_NB_OPS(_module) \ + do { \ + opal_atomic_add_fetch_size_t(&_module->ctx->num_incomplete_req_ops, -1); \ + } while(0); + typedef enum ompi_osc_ucx_epoch { NONE_EPOCH, FENCE_EPOCH, @@ -69,7 +82,8 @@ typedef struct ompi_osc_ucx_epoch_type { #define OSC_UCX_STATE_COMPLETE_COUNT_OFFSET (sizeof(uint64_t) * 3) #define OSC_UCX_STATE_POST_INDEX_OFFSET (sizeof(uint64_t) * 4) #define OSC_UCX_STATE_POST_STATE_OFFSET (sizeof(uint64_t) * 5) -#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (5 + OMPI_OSC_UCX_POST_PEER_MAX)) +#define OSC_UCX_STATE_DYNAMIC_LOCK_OFFSET (sizeof(uint64_t) * 6) +#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (6 + OMPI_OSC_UCX_POST_PEER_MAX)) typedef struct ompi_osc_dynamic_win_info { uint64_t base; @@ -102,6 +116,7 @@ typedef struct ompi_osc_ucx_module { size_t size; uint64_t *addrs; uint64_t *state_addrs; + uint64_t *comm_world_ranks; int disp_unit; /* if disp_unit >= 0, then everyone has the same * disp unit size; if disp_unit == -1, then we * need to look at disp_units */ @@ -125,6 +140,7 @@ typedef struct ompi_osc_ucx_module { opal_common_ucx_wpmem_t *mem; opal_common_ucx_wpmem_t *state_mem; + bool skip_sync_check; bool noncontig_shared_win; size_t *sizes; /* in shared windows, shmem_addrs can be used for direct load store to @@ -147,9 +163,18 @@ typedef struct ompi_osc_ucx_lock { bool is_nocheck; } ompi_osc_ucx_lock_t; -#define OSC_UCX_GET_EP(comm_, rank_) (ompi_comm_peer_lookup(comm_, rank_)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_UCX]) +#define OSC_UCX_GET_EP(_module, rank_) (mca_osc_ucx_component.endpoints[_module->comm_world_ranks[rank_]]) #define OSC_UCX_GET_DISP(module_, rank_) ((module_->disp_unit < 0) ? module_->disp_units[rank_] : module_->disp_unit) +#define OSC_UCX_GET_DEFAULT_EP(_ep_ptr, _module, _target) \ + if (opal_common_ucx_thread_enabled) { \ + _ep_ptr = NULL; \ + } else { \ + _ep_ptr = (ucp_ep_h *)&(OSC_UCX_GET_EP(_module, _target)); \ + } + +extern size_t ompi_osc_ucx_outstanding_ops_flush_threshold; + int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size, int *disp_unit, void * baseptr); int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len); @@ -169,6 +194,11 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t *op, struct ompi_win_t *win); +int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, int target_count, + struct ompi_datatype_t *target_dt, + struct ompi_op_t *op, struct ompi_win_t *win); int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr, void *result_addr, struct ompi_datatype_t *dt, int target, ptrdiff_t target_disp, @@ -184,6 +214,13 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, int target_rank, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_datatype, struct ompi_op_t *op, struct ompi_win_t *win); +int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count, + struct ompi_datatype_t *origin_datatype, + void *result_addr, int result_count, + struct ompi_datatype_t *result_datatype, + int target_rank, ptrdiff_t target_disp, + int target_count, struct ompi_datatype_t *target_datatype, + struct ompi_op_t *op, struct ompi_win_t *win); int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt, int target, ptrdiff_t target_disp, int target_count, @@ -228,10 +265,7 @@ int ompi_osc_ucx_flush_local_all(struct ompi_win_t *win); int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_wins, int min_index, int max_index, uint64_t base, size_t len, int *insert); -extern inline bool ompi_osc_need_acc_lock(ompi_osc_ucx_module_t *module, int target); -extern inline int ompi_osc_state_lock(ompi_osc_ucx_module_t *module, int target, - bool *lock_acquired, bool force_lock); -extern inline int ompi_osc_state_unlock(ompi_osc_ucx_module_t *module, int target, - bool lock_acquired, void *free_ptr); +int ompi_osc_ucx_dynamic_lock(ompi_osc_ucx_module_t *module, int target); +int ompi_osc_ucx_dynamic_unlock(ompi_osc_ucx_module_t *module, int target); #endif /* OMPI_OSC_UCX_H */ diff --git a/ompi/mca/osc/ucx/osc_ucx_active_target.c b/ompi/mca/osc/ucx/osc_ucx_active_target.c index 90c63591665..57084e6ee7c 100644 --- a/ompi/mca/osc/ucx/osc_ucx_active_target.c +++ b/ompi/mca/osc/ucx/osc_ucx_active_target.c @@ -165,31 +165,33 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; int i, size; int ret = OMPI_SUCCESS; + ucp_ep_h *ep; if (module->epoch_type.access != START_COMPLETE_EPOCH) { return OMPI_ERR_RMA_SYNC; } - module->epoch_type.access = NONE_EPOCH; - ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_WORKER, 0/*ignore*/); if (ret != OMPI_SUCCESS) { return ret; } + module->epoch_type.access = NONE_EPOCH; + size = ompi_group_size(module->start_group); for (i = 0; i < size; i++) { uint64_t remote_addr = module->state_addrs[module->start_grp_ranks[i]] + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET; // write to state.complete_count on remote side + OSC_UCX_GET_DEFAULT_EP(ep, module, module->start_grp_ranks[i]); + ret = opal_common_ucx_wpmem_post(module->state_mem, UCP_ATOMIC_POST_OP_ADD, 1, module->start_grp_ranks[i], sizeof(uint64_t), - remote_addr); + remote_addr, ep); if (ret != OMPI_SUCCESS) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_post failed: %d", ret); } - ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, - module->start_grp_ranks[i]); + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, module->start_grp_ranks[i]); if (ret != OMPI_SUCCESS) { return ret; } @@ -204,6 +206,7 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) { int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; + ucp_ep_h *ep; int ret = OMPI_SUCCESS; if (module->epoch_type.exposure != NONE_EPOCH) { @@ -243,12 +246,12 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_wi uint64_t remote_addr = module->state_addrs[ranks_in_win_grp[i]] + OSC_UCX_STATE_POST_INDEX_OFFSET; // write to state.post_index on remote side uint64_t curr_idx = 0, result = 0; - + OSC_UCX_GET_DEFAULT_EP(ep, module, ranks_in_win_grp[i]); /* do fop first to get an post index */ ret = opal_common_ucx_wpmem_fetch(module->state_mem, UCP_ATOMIC_FETCH_OP_FADD, 1, ranks_in_win_grp[i], &result, - sizeof(result), remote_addr); + sizeof(result), remote_addr, ep); if (ret != OMPI_SUCCESS) { ret = OMPI_ERROR; @@ -265,7 +268,7 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_wi result = myrank + 1; ret = opal_common_ucx_wpmem_cmpswp(module->state_mem, 0, result, ranks_in_win_grp[i], &result, sizeof(result), - remote_addr); + remote_addr, ep); if (ret != OMPI_SUCCESS) { ret = OMPI_ERROR; diff --git a/ompi/mca/osc/ucx/osc_ucx_comm.c b/ompi/mca/osc/ucx/osc_ucx_comm.c index 7e794dcbc79..5bb240b3a46 100644 --- a/ompi/mca/osc/ucx/osc_ucx_comm.c +++ b/ompi/mca/osc/ucx/osc_ucx_comm.c @@ -27,9 +27,9 @@ return OMPI_ERROR; \ } -#define CHECK_DYNAMIC_WIN(_remote_addr, _module, _target, _ret, _lock_required) \ +#define CHECK_DYNAMIC_WIN(_remote_addr, _module, _target, _ret) \ if (_module->flavor == MPI_WIN_FLAVOR_DYNAMIC) { \ - _ret = get_dynamic_win_info(_remote_addr, _module, _target, _lock_required); \ + _ret = get_dynamic_win_info(_remote_addr, _module, _target); \ if (_ret != OMPI_SUCCESS) { \ return _ret; \ } \ @@ -40,8 +40,13 @@ typedef struct ucx_iovec { size_t len; } ucx_iovec_t; +size_t ompi_osc_ucx_outstanding_ops_flush_threshold = 64; + static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target, bool is_req_ops) { + + if (module->skip_sync_check) return OMPI_SUCCESS; + if (is_req_ops == false) { if (module->epoch_type.access == NONE_EPOCH) { return OMPI_ERR_RMA_SYNC; @@ -134,6 +139,8 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, int target, uint64_t remote_addr, int target_count, struct ompi_datatype_t *target_dt, bool is_target_contig, ptrdiff_t target_lb, bool is_get) { + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); ucx_iovec_t *origin_ucx_iov = NULL, *target_ucx_iov = NULL; uint32_t origin_ucx_iov_count = 0, target_ucx_iov_count = 0; uint32_t origin_ucx_iov_idx = 0, target_ucx_iov_idx = 0; @@ -169,7 +176,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, } status = opal_common_ucx_wpmem_putget(module->mem, op, target, origin_ucx_iov[origin_ucx_iov_idx].addr, curr_len, - remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr)); + remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), ep); if (OPAL_SUCCESS != status) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status); ret = OMPI_ERROR; @@ -204,7 +211,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, status = opal_common_ucx_wpmem_putget(module->mem, op, target, origin_ucx_iov[origin_ucx_iov_idx].addr, origin_ucx_iov[origin_ucx_iov_idx].len, - remote_addr + target_lb + prev_len); + remote_addr + target_lb + prev_len, ep); if (OPAL_SUCCESS != status) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status); ret = OMPI_ERROR; @@ -227,7 +234,7 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, status = opal_common_ucx_wpmem_putget(module->mem, op, target, (void *)((intptr_t)origin_addr + origin_lb + prev_len), target_ucx_iov[target_ucx_iov_idx].len, - remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr)); + remote_addr + (uint64_t)(target_ucx_iov[target_ucx_iov_idx].addr), ep); if (OPAL_SUCCESS != status) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", status); ret = OMPI_ERROR; @@ -240,7 +247,6 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, } cleanup: - if (origin_ucx_iov != NULL) { free(origin_ucx_iov); } @@ -251,8 +257,10 @@ static inline int ddt_put_get(ompi_osc_ucx_module_t *module, return ret; } -static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module_t *module, - int target, bool lock_required) { +static inline int get_dynamic_win_info(uint64_t remote_addr, + ompi_osc_ucx_module_t *module, int target) { + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); uint64_t remote_state_addr = (module->state_addrs)[target] + OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET; size_t remote_state_len = sizeof(uint64_t) + sizeof(ompi_osc_dynamic_win_info_t) * OMPI_OSC_UCX_ATTACH_MAX; char *temp_buf; @@ -262,21 +270,19 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module int insert = -1; int ret; - bool lock_acquired = false; - if (lock_required) { - /* We need to lock acc-lock even if the process has an exclusive lock. - * Therefore, force lock is needed. Remote process protects its window - * attach/detach operations with an acc-lock */ - ret = ompi_osc_state_lock(module, target, &lock_acquired, true); - if (OMPI_SUCCESS != ret) { - return ret; - } + /* We need to lock dyn-lock even if the process has an exclusive lock. + * Remote process protects its window attach/detach operations with a + * dynamic lock */ + ret = ompi_osc_ucx_dynamic_lock(module, target); + if (ret != OPAL_SUCCESS) { + ret = OMPI_ERROR; + goto cleanup; } temp_buf = calloc(remote_state_len, 1); ret = opal_common_ucx_wpmem_putget(module->state_mem, OPAL_COMMON_UCX_GET, target, (void *)((intptr_t)temp_buf), - remote_state_len, remote_state_addr); + remote_state_len, remote_state_addr, ep); if (OPAL_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); ret = OMPI_ERROR; @@ -299,7 +305,7 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module contain = ompi_osc_find_attached_region_position(temp_dynamic_wins, 0, win_count - 1, remote_addr, 1, &insert); if (contain < 0 || contain >= (int)win_count) { - OSC_UCX_ERROR("Dynamic window index not found contain: %d win_count: %d\n", + OSC_UCX_ERROR("Dynamic window index not found contain: %d win_count: %" PRIu64 "\n", contain, win_count); ret = MPI_ERR_RMA_RANGE; goto cleanup; @@ -314,7 +320,8 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module } if (mem_rec == NULL) { - ret = opal_common_ucx_tlocal_fetch_spath(module->mem, target); + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + ret = opal_common_ucx_tlocal_fetch_spath(module->mem, target, ep); if (OPAL_SUCCESS != ret) { goto cleanup; } @@ -350,9 +357,8 @@ static inline int get_dynamic_win_info(uint64_t remote_addr, ompi_osc_ucx_module cleanup: free(temp_buf); - ompi_osc_state_unlock(module, target, lock_acquired, NULL); - - return ret; + /* unlock the dynamic lock */ + return ompi_osc_ucx_dynamic_unlock(module, target); } static inline @@ -415,16 +421,18 @@ static int do_atomic_op_intrinsic( struct ompi_datatype_t *dt, ptrdiff_t target_disp, void *result_addr, - ompi_osc_ucx_request_t *ucx_req) + ompi_osc_ucx_accumulate_request_t *ucx_req) { int ret = OMPI_SUCCESS; size_t origin_dt_bytes; opal_common_ucx_wpmem_t *mem = module->mem; ompi_datatype_type_size(dt, &origin_dt_bytes); + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, true); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); ucp_atomic_fetch_op_t opcode; bool is_no_op = false; @@ -447,7 +455,7 @@ static int do_atomic_op_intrinsic( uint64_t value = 0; if ((count - 1) == i && NULL != ucx_req) { // the last item is used to feed the request, if needed - user_req_cb = &req_completion; + user_req_cb = &ompi_osc_ucx_req_completion; user_req_ptr = ucx_req; // issue a fence if this is the last but not the only element if (0 < i) { @@ -465,7 +473,7 @@ static int do_atomic_op_intrinsic( } ret = opal_common_ucx_wpmem_fetch_nb(mem, opcode, value, target, output_addr, origin_dt_bytes, remote_addr, - user_req_cb, user_req_ptr); + user_req_cb, user_req_ptr, ep); // advance origin and remote address origin_addr = (void*)((intptr_t)origin_addr + origin_dt_bytes); @@ -483,6 +491,8 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; opal_common_ucx_wpmem_t *mem = module->mem; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); bool is_origin_contig = false, is_target_contig = false; ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; @@ -493,7 +503,7 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data return ret; } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, true); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); if (!target_count) { return OMPI_SUCCESS; @@ -514,7 +524,7 @@ int ompi_osc_ucx_put(const void *origin_addr, int origin_count, struct ompi_data ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_PUT, target, (void *)((intptr_t)origin_addr + origin_lb), - origin_len, remote_addr + target_lb); + origin_len, remote_addr + target_lb, ep); if (OPAL_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); return OMPI_ERROR; @@ -533,6 +543,8 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; opal_common_ucx_wpmem_t *mem = module->mem; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); ptrdiff_t origin_lb, origin_extent, target_lb, target_extent; bool is_origin_contig = false, is_target_contig = false; @@ -543,7 +555,7 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, return ret; } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, true); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); if (!target_count) { return OMPI_SUCCESS; @@ -564,7 +576,7 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_GET, target, (void *)((intptr_t)origin_addr + origin_lb), - origin_len, remote_addr + target_lb); + origin_len, remote_addr + target_lb, ep); if (OPAL_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); return OMPI_ERROR; @@ -578,19 +590,97 @@ int ompi_osc_ucx_get(void *origin_addr, int origin_count, } } +static inline bool ompi_osc_need_acc_lock(ompi_osc_ucx_module_t *module, int target) +{ + ompi_osc_ucx_lock_t *lock = NULL; + opal_hash_table_get_value_uint32(&module->outstanding_locks, + (uint32_t) target, (void **) &lock); + + /* if there is an exclusive lock there is no need to acqurie the accumulate lock */ + return !(NULL != lock && LOCK_EXCLUSIVE == lock->type); +} + +static inline int ompi_osc_ucx_acc_lock(ompi_osc_ucx_module_t *module, int target, bool *lock_acquired) { + uint64_t result_value = -1; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + int ret = OMPI_SUCCESS; + + if (ompi_osc_need_acc_lock(module, target)) { + for (;;) { + ret = opal_common_ucx_wpmem_cmpswp(module->state_mem, + TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, + target, &result_value, sizeof(result_value), + remote_addr, ep); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_cmpswp failed: %d", ret); + return OMPI_ERROR; + } + if (result_value == TARGET_LOCK_UNLOCKED) { + break; + } + + opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool); + } + + *lock_acquired = true; + } else { + *lock_acquired = false; + } + + return OMPI_SUCCESS; +} + +static inline int ompi_osc_ucx_acc_unlock(ompi_osc_ucx_module_t *module, int target, + bool lock_acquired, void *free_ptr) { + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + int ret = OMPI_SUCCESS; + + if (lock_acquired) { + uint64_t result_value = 0; + /* fence any still active operations */ + ret = opal_common_ucx_wpmem_fence(module->mem); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); + return OMPI_ERROR; + } + + ret = opal_common_ucx_wpmem_fetch(module->state_mem, + UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, + target, &result_value, sizeof(result_value), + remote_addr, ep); + assert(result_value == TARGET_LOCK_EXCLUSIVE); + } else if (NULL != free_ptr){ + /* flush before freeing the buffer */ + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); + } + /* TODO: encapsulate in a request and make the release non-blocking */ + if (NULL != free_ptr) { + free(free_ptr); + } + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fetch failed: %d", ret); + return OMPI_ERROR; + } + + return ret; +} + static int accumulate_req(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t *op, struct ompi_win_t *win, - ompi_osc_ucx_request_t *ucx_req) { + ompi_osc_ucx_accumulate_request_t *ucx_req) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; int ret = OMPI_SUCCESS; uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); - opal_common_ucx_wpmem_t *mem = module->mem; void *free_ptr = NULL; bool lock_acquired = false; @@ -611,12 +701,12 @@ int accumulate_req(const void *origin_addr, int origin_count, } /* Start atomicity by acquiring acc lock */ - ret = ompi_osc_state_lock(module, target, &lock_acquired, false); + ret = ompi_osc_ucx_acc_lock(module, target, &lock_acquired); if (ret != OMPI_SUCCESS) { return ret; } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, !lock_acquired); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); if (op == &ompi_mpi_op_replace.op) { ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target, @@ -714,12 +804,17 @@ int accumulate_req(const void *origin_addr, int origin_count, } + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); + if (ret != OPAL_SUCCESS) { + return ret; + } + if (NULL != ucx_req) { // nothing to wait for, mark request as completed - ompi_request_complete(&ucx_req->super, true); + ompi_request_complete(&ucx_req->super.super, true); } - return ompi_osc_state_unlock(module, target, lock_acquired, free_ptr); + return ompi_osc_ucx_acc_unlock(module, target, lock_acquired, free_ptr); } int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, @@ -727,8 +822,218 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t *op, struct ompi_win_t *win) { + return accumulate_req(origin_addr, origin_count, origin_dt, target, - target_disp, target_count, target_dt, op, win, NULL); + target_disp, target_count, target_dt, op, win, NULL); +} + +static inline int ompi_osc_ucx_acc_rputget(void *stage_addr, int stage_count, + struct ompi_datatype_t *stage_dt, int target, ptrdiff_t target_disp, + int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t + *op, struct ompi_win_t *win, bool lock_acquired, const void + *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt, bool is_put, + int phase, int acc_type) { + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + opal_common_ucx_wpmem_t *mem = module->mem; + uint64_t remote_addr = (module->state_addrs[target]) + OSC_UCX_STATE_REQ_FLAG_OFFSET; + ompi_osc_ucx_accumulate_request_t *ucx_req = NULL; + bool sync_check; + int ret = OMPI_SUCCESS; + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); + + if (acc_type != NONE) { + OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req); + assert(NULL != ucx_req); + ucx_req->op = op; + ucx_req->acc_type = acc_type; + ucx_req->phase = phase; + ucx_req->super.module = module; + ucx_req->target = target; + ucx_req->lock_acquired = lock_acquired; + ucx_req->win = win; + ucx_req->origin_addr = origin_addr; + ucx_req->origin_count = origin_count; + if (origin_dt != NULL) { + ucx_req->origin_dt = origin_dt; + if (!ompi_datatype_is_predefined(origin_dt)) { + OBJ_RETAIN(ucx_req->origin_dt); + } + } + ucx_req->stage_addr = stage_addr; + ucx_req->stage_count = stage_count; + if (stage_dt != NULL) { + ucx_req->stage_dt = stage_dt; + if (!ompi_datatype_is_predefined(stage_dt)) { + OBJ_RETAIN(ucx_req->stage_dt); + } + } + ucx_req->target = target; + if (target_dt != NULL) { + ucx_req->target_dt = target_dt; + if (!ompi_datatype_is_predefined(target_dt)) { + OBJ_RETAIN(ucx_req->target_dt); + } + } + ucx_req->target_disp = target_disp; + ucx_req->target_count = target_count; + ucx_req->free_ptr = NULL; + } + sync_check = module->skip_sync_check; + module->skip_sync_check = true; /* we already hold the acc lock, so no need for sync check*/ + + if (is_put) { + ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target, target_disp, + target_count, target_dt, win); + } else { + ret = ompi_osc_ucx_get(stage_addr, stage_count, stage_dt, target, target_disp, + target_count, target_dt, win); + } + if (ret != OMPI_SUCCESS) { + return ret; + } + + module->skip_sync_check = sync_check; + if (acc_type != NONE) { + OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(module); + ret = opal_common_ucx_wpmem_flush_ep_nb(mem, target, ompi_osc_ucx_req_completion, ucx_req, ep); + + if (ret != OMPI_SUCCESS) { + /* fallback to using an atomic op to acquire a request handle */ + ret = opal_common_ucx_wpmem_fence(mem); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); + return OMPI_ERROR; + } + + ret = opal_common_ucx_wpmem_fetch_nb(mem, UCP_ATOMIC_FETCH_OP_FADD, + 0, target, &(module->req_result), + sizeof(uint64_t), remote_addr & (~0x7), + ompi_osc_ucx_req_completion, ucx_req, ep); + if (ret != OMPI_SUCCESS) { + OMPI_OSC_UCX_REQUEST_RETURN(ucx_req); + return ret; + } + } + } + + return ret; +} + +/* Nonblocking variant of accumulate. reduce+put happens inside completion call back + * of rget */ +static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int origin_count, + struct ompi_datatype_t *origin_dt, void *result_addr, int result_count, + struct ompi_datatype_t *result_dt, int target, ptrdiff_t target_disp, + int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t + *op, struct ompi_win_t *win) { + + ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; + int ret = OMPI_SUCCESS; + uint64_t remote_addr = (module->addrs[target]) + target_disp * + OSC_UCX_GET_DISP(module, target); + void *free_ptr = NULL; + bool lock_acquired = false; + + ret = check_sync_state(module, target, false); + if (ret != OMPI_SUCCESS) { + return ret; + } + + if (result_addr == NULL && op == &ompi_mpi_op_no_op.op) { + /* This is an accumulate (not get-accumulate) operation, so return */ + return ret; + } + + /* rely on UCX network atomics if the user told us that it safe */ + if (use_atomic_op(module, op, target_disp, origin_dt, target_dt, origin_count, target_count)) { + return do_atomic_op_intrinsic(module, op, target, + origin_addr, origin_count, origin_dt, + target_disp, result_addr, NULL); + } + + /* Start atomicity by acquiring acc lock */ + ret = ompi_osc_ucx_acc_lock(module, target, &lock_acquired); + if (ret != OMPI_SUCCESS) { + return ret; + } + + if (module->ctx->num_incomplete_req_ops > ompi_osc_ucx_outstanding_ops_flush_threshold) { + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_WORKER, 0); + if (ret != OPAL_SUCCESS) { + ret = OMPI_ERROR; + return ret; + } + } + + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); + + if (result_addr != NULL) { + /* This is a get-accumulate operation, so read the target data into result addr */ + ret = ompi_osc_ucx_acc_rputget(result_addr, (int)result_count, result_dt, target, + target_disp, target_count, target_dt, op, win, lock_acquired, + origin_addr, origin_count, origin_dt, false, ACC_GET_RESULTS_DATA, GET_ACCUMULATE); + if (ret != OMPI_SUCCESS) { + return ret; + } else if (op == &ompi_mpi_op_no_op.op || op == &ompi_mpi_op_replace.op) { + /* Nothing else to do, so return */ + return MPI_SUCCESS; + } + } + + if (op == &ompi_mpi_op_replace.op) { + assert(result_addr == NULL); + /* No need for get, just use put and realize when to release the lock */ + ret = ompi_osc_ucx_acc_rputget(NULL, 0, NULL, target, target_disp, + target_count, target_dt, op, win, lock_acquired, origin_addr, + origin_count, origin_dt, true, ACC_PUT_TARGET_DATA, ACCUMULATE); + if (ret != OMPI_SUCCESS) { + return ret; + } + } else { + void *temp_addr = NULL; + uint32_t temp_count; + ompi_datatype_t *temp_dt; + ptrdiff_t temp_lb, temp_extent; + + if (ompi_datatype_is_predefined(target_dt)) { + temp_dt = target_dt; + temp_count = target_count; + } else { + ret = ompi_osc_base_get_primitive_type_info(target_dt, &temp_dt, &temp_count); + if (ret != OMPI_SUCCESS) { + return ret; + } + temp_count *= target_count; + } + ompi_datatype_get_true_extent(temp_dt, &temp_lb, &temp_extent); + temp_addr = free_ptr = malloc(temp_extent * temp_count); + if (temp_addr == NULL) { + return OMPI_ERR_TEMP_OUT_OF_RESOURCE; + } + + ret = ompi_osc_ucx_acc_rputget(temp_addr, (int)temp_count, temp_dt, target, + target_disp, target_count, target_dt, op, win, lock_acquired, + origin_addr, origin_count, origin_dt, false, ACC_GET_STAGE_DATA, + (result_addr == NULL) ? ACCUMULATE : GET_ACCUMULATE); + if (ret != OMPI_SUCCESS) { + return ret; + } + } + + return ret; +} + +int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count, + struct ompi_datatype_t *origin_dt, + int target, ptrdiff_t target_disp, int target_count, + struct ompi_datatype_t *target_dt, + struct ompi_op_t *op, struct ompi_win_t *win) { + + return ompi_osc_ucx_get_accumulate_nonblocking(origin_addr, origin_count, + origin_dt, (void *)NULL, 0, NULL, target, target_disp, + target_count, target_dt, op, win); } static int @@ -741,28 +1046,30 @@ do_atomic_compare_and_swap(const void *origin_addr, const void *compare_addr, bool lock_acquired = false; size_t dt_bytes; opal_common_ucx_wpmem_t *mem = module->mem; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); if (!module->acc_single_intrinsic) { /* Start atomicity by acquiring acc lock */ - ret = ompi_osc_state_lock(module, target, &lock_acquired, false); + ret = ompi_osc_ucx_acc_lock(module, target, &lock_acquired); if (ret != OMPI_SUCCESS) { return ret; } } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, !lock_acquired); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); ompi_datatype_type_size(dt, &dt_bytes); uint64_t compare_val = opal_common_ucx_load_uint64(compare_addr, dt_bytes); uint64_t value = opal_common_ucx_load_uint64(origin_addr, dt_bytes); ret = opal_common_ucx_wpmem_cmpswp_nb(mem, compare_val, value, target, result_addr, dt_bytes, remote_addr, - NULL, NULL); + NULL, NULL, ep); if (module->acc_single_intrinsic) { return ret; } - return ompi_osc_state_unlock(module, target, lock_acquired, NULL); + return ompi_osc_ucx_acc_unlock(module, target, lock_acquired, NULL); } int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr, @@ -771,6 +1078,8 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a struct ompi_win_t *win) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t *)win->w_osc_module; opal_common_ucx_wpmem_t *mem = module->mem; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); size_t dt_bytes; int ret = OMPI_SUCCESS; @@ -792,15 +1101,15 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a /* fall back to get-compare-put */ /* Start atomicity by acquiring acc lock */ - ret = ompi_osc_state_lock(module, target, &lock_acquired, false); + ret = ompi_osc_ucx_acc_lock(module, target, &lock_acquired); if (ret != OMPI_SUCCESS) { return ret; } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, !lock_acquired); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_GET, target, - result_addr, dt_bytes, remote_addr); + result_addr, dt_bytes, remote_addr, ep); if (OPAL_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); return OMPI_ERROR; @@ -814,14 +1123,14 @@ int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_a if (0 == memcmp(result_addr, compare_addr, dt_bytes)) { // write the new value ret = opal_common_ucx_wpmem_putget(mem, OPAL_COMMON_UCX_PUT, target, - (void*)origin_addr, dt_bytes, remote_addr); + (void*)origin_addr, dt_bytes, remote_addr, ep); if (OPAL_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_putget failed: %d", ret); return OMPI_ERROR; } } - return ompi_osc_state_unlock(module, target, lock_acquired, NULL); + return ompi_osc_ucx_acc_unlock(module, target, lock_acquired, NULL); } int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, @@ -831,6 +1140,8 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, size_t dt_bytes; ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; opal_common_ucx_wpmem_t *mem = module->mem; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, false); @@ -848,13 +1159,13 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, if (!module->acc_single_intrinsic) { /* Start atomicity by acquiring acc lock */ - ret = ompi_osc_state_lock(module, target, &lock_acquired, false); + ret = ompi_osc_ucx_acc_lock(module, target, &lock_acquired); if (ret != OMPI_SUCCESS) { return ret; } } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, !lock_acquired); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); value = origin_addr ? opal_common_ucx_load_uint64(origin_addr, dt_bytes) : 0; @@ -869,13 +1180,13 @@ int ompi_osc_ucx_fetch_and_op(const void *origin_addr, void *result_addr, ret = opal_common_ucx_wpmem_fetch_nb(mem, opcode, value, target, (void *)result_addr, dt_bytes, - remote_addr, NULL, NULL); + remote_addr, NULL, NULL, ep); if (module->acc_single_intrinsic) { return ret; } - return ompi_osc_state_unlock(module, target, lock_acquired, NULL); + return ompi_osc_ucx_acc_unlock(module, target, lock_acquired, NULL); } else { return ompi_osc_ucx_get_accumulate(origin_addr, 1, dt, result_addr, 1, dt, target, target_disp, 1, dt, op, win); @@ -890,12 +1201,11 @@ int get_accumulate_req(const void *origin_addr, int origin_count, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_op_t *op, struct ompi_win_t *win, - ompi_osc_ucx_request_t *ucx_req) { + ompi_osc_ucx_accumulate_request_t *ucx_req) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; int ret = OMPI_SUCCESS; uint64_t remote_addr = (module->addrs[target]) + target_disp * OSC_UCX_GET_DISP(module, target); - opal_common_ucx_wpmem_t *mem = module->mem; void *free_addr = NULL; bool lock_acquired = false; @@ -912,12 +1222,12 @@ int get_accumulate_req(const void *origin_addr, int origin_count, } /* Start atomicity by acquiring acc lock */ - ret = ompi_osc_state_lock(module, target, &lock_acquired, false); + ret = ompi_osc_ucx_acc_lock(module, target, &lock_acquired); if (ret != OMPI_SUCCESS) { return ret; } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, !lock_acquired); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); ret = ompi_osc_ucx_get(result_addr, result_count, result_dt, target, target_disp, target_count, target_dt, win); @@ -925,13 +1235,13 @@ int get_accumulate_req(const void *origin_addr, int origin_count, return ret; } - ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); - if (ret != OMPI_SUCCESS) { - return ret; - } - if (op != &ompi_mpi_op_no_op.op) { if (op == &ompi_mpi_op_replace.op) { + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); + if (ret != OMPI_SUCCESS) { + return ret; + } + ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target, target_disp, target_count, target_dt, win); @@ -1023,13 +1333,18 @@ int get_accumulate_req(const void *origin_addr, int origin_count, } } + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); + if (ret != OPAL_SUCCESS) { + return ret; + } + if (NULL != ucx_req) { // nothing to wait for, mark request as completed - ompi_request_complete(&ucx_req->super, true); + ompi_request_complete(&ucx_req->super.super, true); } - return ompi_osc_state_unlock(module, target, lock_acquired, free_addr); + return ompi_osc_ucx_acc_unlock(module, target, lock_acquired, free_addr); } int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, @@ -1045,15 +1360,30 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count, target_count, target_dt, op, win, NULL); } +int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count, + struct ompi_datatype_t *origin_dt, + void *result_addr, int result_count, + struct ompi_datatype_t *result_dt, + int target, ptrdiff_t target_disp, + int target_count, struct ompi_datatype_t *target_dt, + struct ompi_op_t *op, struct ompi_win_t *win) { + + return ompi_osc_ucx_get_accumulate_nonblocking(origin_addr, origin_count, origin_dt, + result_addr, result_count, result_dt, target, target_disp, + target_count, target_dt, op, win); +} + int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, struct ompi_datatype_t *origin_dt, int target, ptrdiff_t target_disp, int target_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win, struct ompi_request_t **request) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); opal_common_ucx_wpmem_t *mem = module->mem; uint64_t remote_addr = (module->state_addrs[target]) + OSC_UCX_STATE_REQ_FLAG_OFFSET; - ompi_osc_ucx_request_t *ucx_req = NULL; + ompi_osc_ucx_generic_request_t *ucx_req = NULL; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, true); @@ -1061,7 +1391,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, return ret; } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, true); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); ret = ompi_osc_ucx_put(origin_addr, origin_count, origin_dt, target, target_disp, target_count, target_dt, win); @@ -1069,10 +1399,11 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, return ret; } - OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); + OMPI_OSC_UCX_GENERIC_REQUEST_ALLOC(win, ucx_req, RPUT_REQ); + ucx_req->super.module = module; - mca_osc_ucx_component.num_incomplete_req_ops++; - ret = opal_common_ucx_wpmem_flush_ep_nb(mem, target, req_completion, ucx_req); + OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(module); + ret = opal_common_ucx_wpmem_flush_ep_nb(mem, target, ompi_osc_ucx_req_completion, ucx_req, ep); if (ret != OMPI_SUCCESS) { /* fallback to using an atomic op to acquire a request handle */ @@ -1086,14 +1417,14 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count, ret = opal_common_ucx_wpmem_fetch_nb(mem, UCP_ATOMIC_FETCH_OP_FADD, 0, target, &(module->req_result), sizeof(uint64_t), remote_addr & (~0x7), - req_completion, ucx_req); + ompi_osc_ucx_req_completion, ucx_req, ep); if (ret != OMPI_SUCCESS) { OMPI_OSC_UCX_REQUEST_RETURN(ucx_req); return ret; } } - *request = &ucx_req->super; + *request = &ucx_req->super.super; return ret; } @@ -1104,9 +1435,11 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, struct ompi_datatype_t *target_dt, struct ompi_win_t *win, struct ompi_request_t **request) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); opal_common_ucx_wpmem_t *mem = module->mem; uint64_t remote_addr = (module->state_addrs[target]) + OSC_UCX_STATE_REQ_FLAG_OFFSET; - ompi_osc_ucx_request_t *ucx_req = NULL; + ompi_osc_ucx_generic_request_t *ucx_req = NULL; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, true); @@ -1114,7 +1447,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, return ret; } - CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, true); + CHECK_DYNAMIC_WIN(remote_addr, module, target, ret); ret = ompi_osc_ucx_get(origin_addr, origin_count, origin_dt, target, target_disp, target_count, target_dt, win); @@ -1122,10 +1455,11 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, return ret; } - OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); + OMPI_OSC_UCX_GENERIC_REQUEST_ALLOC(win, ucx_req, RGET_REQ); + ucx_req->super.module = module; - mca_osc_ucx_component.num_incomplete_req_ops++; - ret = opal_common_ucx_wpmem_flush_ep_nb(mem, target, req_completion, ucx_req); + OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(module); + ret = opal_common_ucx_wpmem_flush_ep_nb(mem, target, ompi_osc_ucx_req_completion, ucx_req, ep); if (ret != OMPI_SUCCESS) { /* fallback to using an atomic op to acquire a request handle */ @@ -1139,14 +1473,14 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count, ret = opal_common_ucx_wpmem_fetch_nb(mem, UCP_ATOMIC_FETCH_OP_FADD, 0, target, &(module->req_result), sizeof(uint64_t), remote_addr & (~0x7), - req_completion, ucx_req); + ompi_osc_ucx_req_completion, ucx_req, ep); if (ret != OMPI_SUCCESS) { OMPI_OSC_UCX_REQUEST_RETURN(ucx_req); return ret; } } - *request = &ucx_req->super; + *request = &ucx_req->super.super; return ret; } @@ -1157,7 +1491,7 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count, struct ompi_datatype_t *target_dt, struct ompi_op_t *op, struct ompi_win_t *win, struct ompi_request_t **request) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ompi_osc_ucx_request_t *ucx_req = NULL; + ompi_osc_ucx_accumulate_request_t *ucx_req = NULL; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, true); @@ -1165,7 +1499,8 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count, return ret; } - OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); + OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req); + ucx_req->super.module = module; assert(NULL != ucx_req); ret = accumulate_req(origin_addr, origin_count, origin_dt, target, target_disp, @@ -1175,7 +1510,7 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count, return ret; } - *request = &ucx_req->super; + *request = &ucx_req->super.super; return ret; } @@ -1189,7 +1524,7 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count, struct ompi_op_t *op, struct ompi_win_t *win, struct ompi_request_t **request) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; - ompi_osc_ucx_request_t *ucx_req = NULL; + ompi_osc_ucx_accumulate_request_t *ucx_req = NULL; int ret = OMPI_SUCCESS; ret = check_sync_state(module, target, true); @@ -1197,7 +1532,8 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count, return ret; } - OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req); + OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req); + ucx_req->super.module = module; assert(NULL != ucx_req); ret = get_accumulate_req(origin_addr, origin_count, origin_datatype, @@ -1209,7 +1545,247 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count, return ret; } - *request = &ucx_req->super; + *request = &ucx_req->super.super; return ret; } + +static inline int ompi_osc_ucx_nonblocking_ops_finalize(ompi_osc_ucx_module_t *module, int + target, bool lock_acquired, struct ompi_win_t *win, void *free_ptr) { + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); + int ret = OMPI_SUCCESS; + ompi_osc_ucx_accumulate_request_t *ucx_req = NULL; + + OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req); + assert(NULL != ucx_req); + ucx_req->free_ptr = free_ptr; + ucx_req->phase = ACC_FINALIZE; + ucx_req->acc_type = ANY; + ucx_req->super.module = module; + + /* Fence any still active operations */ + ret = opal_common_ucx_wpmem_fence(module->mem); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); + return OMPI_ERROR; + } + + if (lock_acquired) { + OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(module); + ret = opal_common_ucx_wpmem_fetch_nb(module->state_mem, + UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, + target, &(module->req_result), sizeof(module->req_result), + remote_addr, ompi_osc_ucx_req_completion, ucx_req, ep); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_wpmem_fetch_nb failed: %d", ret); + OMPI_OSC_UCX_REQUEST_RETURN(ucx_req); + return ret; + } + } else { + /* Lock is not acquired, but still, we need to know when the + * acc is finalized so that we can free the temp buffers */ + OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(module); + ret = opal_common_ucx_wpmem_flush_ep_nb(module->mem, target, ompi_osc_ucx_req_completion, ucx_req, ep); + + if (ret != OMPI_SUCCESS) { + /* fallback to using an atomic op to acquire a request handle */ + ret = opal_common_ucx_wpmem_fetch_nb(module->mem, UCP_ATOMIC_FETCH_OP_FADD, + 0, target, &(module->req_result), + sizeof(uint64_t), remote_addr & (~0x7), + ompi_osc_ucx_req_completion, ucx_req, ep); + if (ret != OMPI_SUCCESS) { + OMPI_OSC_UCX_REQUEST_RETURN(ucx_req); + return ret; + } + } + } + + return ret; +} + +void ompi_osc_ucx_req_completion(void *request) { + ompi_osc_ucx_generic_request_t *ucx_req = (ompi_osc_ucx_generic_request_t *)request; + int ret = OMPI_SUCCESS; + ompi_osc_ucx_module_t *module = ucx_req->super.module; + if (ucx_req->super.request_type == ACCUMULATE_REQ) { + /* This is an accumulate request */ + ompi_osc_ucx_accumulate_request_t *req = (ompi_osc_ucx_accumulate_request_t *)request; + assert(req->phase != ACC_INIT); + void *free_addr = NULL; + bool release_lock = false; + ptrdiff_t temp_extent; + const void *origin_addr = req->origin_addr; + int origin_count = req->origin_count; + struct ompi_datatype_t *origin_dt = req->origin_dt; + void *temp_addr = req->stage_addr; + int temp_count = req->stage_count; + struct ompi_datatype_t *temp_dt = req->stage_dt; + int target = req->target; + int target_count = req->target_count; + int target_disp = req->target_disp; + struct ompi_datatype_t *target_dt = req->target_dt; + struct ompi_win_t *win = req->win; + struct ompi_op_t *op = req->op; + + if (req->phase != ACC_FINALIZE) { + /* Avoid calling flush while we are already in progress */ + module->mem->skip_periodic_flush = true; + module->state_mem->skip_periodic_flush = true; + } + + switch (req->phase) { + case ACC_FINALIZE: + { + if (req->free_ptr != NULL) { + free(req->free_ptr); + req->free_ptr = NULL; + } + if (origin_dt != NULL && !ompi_datatype_is_predefined(origin_dt)) { + OBJ_RELEASE(origin_dt); + } + if (target_dt != NULL && !ompi_datatype_is_predefined(target_dt)) { + OBJ_RELEASE(target_dt); + } + if (temp_dt != NULL && !ompi_datatype_is_predefined(temp_dt)) { + OBJ_RELEASE(temp_dt); + } + break; + } + case ACC_GET_RESULTS_DATA: + { + /* This is a get-accumulate operation */ + if (op == &ompi_mpi_op_no_op.op) { + /* Done with reading the target data, so release the + * acc lock and return */ + release_lock = true; + } else if (op == &ompi_mpi_op_replace.op) { + assert(target_dt != NULL && origin_dt != NULL); + /* Now that we have the results data, replace the target + * buffer with origin buffer and then release the lock */ + ret = ompi_osc_ucx_acc_rputget(NULL, 0, NULL, target, target_disp, + target_count, target_dt, op, win, 0, origin_addr, origin_count, + origin_dt, true, -1, NONE); + if (ret != OMPI_SUCCESS) { + OSC_UCX_ERROR("ompi_osc_ucx_acc_rputget failed ret= %d\n", ret); + free(temp_addr); + abort(); + } + release_lock = true; + } + break; + } + + case ACC_PUT_TARGET_DATA: + { + /* This is an accumulate (not get-accumulate) operation */ + assert(op == &ompi_mpi_op_replace.op); + release_lock = true; + break; + } + + case ACC_GET_STAGE_DATA: + { + assert(op != &ompi_mpi_op_replace.op && op != &ompi_mpi_op_no_op.op); + assert(origin_dt != NULL && temp_dt != NULL); + + bool is_origin_contig = + ompi_datatype_is_contiguous_memory_layout(origin_dt, origin_count); + + if (ompi_datatype_is_predefined(origin_dt) || is_origin_contig) { + ompi_op_reduce(op, (void *)origin_addr, temp_addr, (int)temp_count, temp_dt); + } else { + ucx_iovec_t *origin_ucx_iov = NULL; + uint32_t origin_ucx_iov_count = 0; + uint32_t origin_ucx_iov_idx = 0; + + ret = create_iov_list(origin_addr, origin_count, origin_dt, + &origin_ucx_iov, &origin_ucx_iov_count); + if (ret != OMPI_SUCCESS) { + OSC_UCX_ERROR("create_iov_list failed ret= %d\n", ret); + free(temp_addr); + abort(); + } + + if ((op != &ompi_mpi_op_maxloc.op && op != &ompi_mpi_op_minloc.op) || + ompi_datatype_is_contiguous_memory_layout(temp_dt, temp_count)) { + size_t temp_size; + char *curr_temp_addr = (char *)temp_addr; + ompi_datatype_type_size(temp_dt, &temp_size); + while (origin_ucx_iov_idx < origin_ucx_iov_count) { + int curr_count = origin_ucx_iov[origin_ucx_iov_idx].len / temp_size; + ompi_op_reduce(op, origin_ucx_iov[origin_ucx_iov_idx].addr, + curr_temp_addr, curr_count, temp_dt); + curr_temp_addr += curr_count * temp_size; + origin_ucx_iov_idx++; + } + } else { + int i; + void *curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr; + for (i = 0; i < (int)temp_count; i++) { + ompi_op_reduce(op, curr_origin_addr, + (void *)((char *)temp_addr + i * temp_extent), + 1, temp_dt); + curr_origin_addr = (void *)((char *)curr_origin_addr + temp_extent); + origin_ucx_iov_idx++; + if (curr_origin_addr >= (void *)((char + *)origin_ucx_iov[origin_ucx_iov_idx].addr + + + origin_ucx_iov[origin_ucx_iov_idx].len)) + { + origin_ucx_iov_idx++; + curr_origin_addr = origin_ucx_iov[origin_ucx_iov_idx].addr; + } + } + } + + free(origin_ucx_iov); + } + + if (req->acc_type == GET_ACCUMULATE) { + /* Do fence to make sure target results are received before + * writing into target */ + ret = opal_common_ucx_wpmem_fence(module->mem); + if (ret != OMPI_SUCCESS) { + OSC_UCX_ERROR("opal_common_ucx_mem_fence failed: %d", ret); + abort(); + } + } + + ret = ompi_osc_ucx_acc_rputget(NULL, 0, NULL, target, target_disp, + target_count, target_dt, op, win, 0, temp_addr, temp_count, + temp_dt, true, -1, NONE); + if (ret != OMPI_SUCCESS) { + OSC_UCX_ERROR("ompi_osc_ucx_acc_rputget failed ret= %d\n", ret); + free(temp_addr); + abort(); + } + release_lock = true; + free_addr = temp_addr; + break; + } + + default: + { + OSC_UCX_ERROR("accumulate progress failed\n"); + abort(); + } + } + + if (release_lock) { + /* Ordering between previous put/get operations and unlock will be realized + * through the ucp fence inside the finalize function */ + ompi_osc_ucx_nonblocking_ops_finalize(module, target, + req->lock_acquired, win, free_addr); + } + + if (req->phase != ACC_FINALIZE) { + module->mem->skip_periodic_flush = false; + module->state_mem->skip_periodic_flush = false; + } + } + OSC_UCX_DECREMENT_OUTSTANDING_NB_OPS(module); + ompi_request_complete(&(ucx_req->super.super), true); + assert(module->ctx->num_incomplete_req_ops >= 0); +} diff --git a/ompi/mca/osc/ucx/osc_ucx_component.c b/ompi/mca/osc/ucx/osc_ucx_component.c index 87d020e0c50..326b730d388 100644 --- a/ompi/mca/osc/ucx/osc_ucx_component.c +++ b/ompi/mca/osc/ucx/osc_ucx_component.c @@ -43,6 +43,7 @@ static void _osc_ucx_init_unlock(void) } } +static bool enable_nonblocking_accumulate = false; static int component_open(void); static int component_close(void); @@ -78,9 +79,10 @@ ompi_osc_ucx_component_t mca_osc_ucx_component = { }, .wpool = NULL, .env_initialized = false, - .num_incomplete_req_ops = 0, .num_modules = 0, - .acc_single_intrinsic = false + .acc_single_intrinsic = false, + .comm_world_size = 0, + .endpoints = NULL }; ompi_osc_ucx_module_t ompi_osc_ucx_module_template = { @@ -184,13 +186,34 @@ static int component_register(void) { MCA_BASE_VAR_SCOPE_GROUP, &mca_osc_ucx_component.no_locks); free(description_str); + opal_common_ucx_thread_enabled = opal_using_threads(); mca_osc_ucx_component.acc_single_intrinsic = false; + opal_asprintf(&description_str, "Enable optimizations for MPI_Fetch_and_op, MPI_Accumulate, etc for codes " "that will not use anything more than a single predefined datatype (default: %s)", mca_osc_ucx_component.acc_single_intrinsic ? "true" : "false"); (void) mca_base_component_var_register(&mca_osc_ucx_component.super.osc_version, "acc_single_intrinsic", description_str, MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_5, MCA_BASE_VAR_SCOPE_GROUP, &mca_osc_ucx_component.acc_single_intrinsic); + + opal_asprintf(&description_str, "Enable nonblocking MPI_Accumulate and MPI_Get_accumulate (default: %s)", + enable_nonblocking_accumulate ? "true" : "false"); + (void) mca_base_component_var_register(&mca_osc_ucx_component.super.osc_version, "enable_nonblocking_accumulate", + description_str, MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_GROUP, &enable_nonblocking_accumulate); + + opal_asprintf(&description_str, "Enable optimizations for multi-threaded applications by allocating a separate worker " + "for each thread and a separate endpoint for each window (default: %s)", + opal_common_ucx_thread_enabled ? "true" : "false"); + (void) mca_base_component_var_register(&mca_osc_ucx_component.super.osc_version, "enable_wpool_thread_multiple", + description_str, MCA_BASE_VAR_TYPE_BOOL, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_GROUP, &opal_common_ucx_thread_enabled); + + opal_asprintf(&description_str, "Threshold on number of nonblocking accumulate calls on which there is a periodical " + "flush (default: %ld)", ompi_osc_ucx_outstanding_ops_flush_threshold); + (void) mca_base_component_var_register(&mca_osc_ucx_component.super.osc_version, "outstanding_ops_flush_threshold", + description_str, MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_5, + MCA_BASE_VAR_SCOPE_GROUP, &ompi_osc_ucx_outstanding_ops_flush_threshold); free(description_str); opal_common_ucx_mca_var_register(&mca_osc_ucx_component.super.osc_version); @@ -296,7 +319,17 @@ static int component_init(bool enable_progress_threads, bool enable_mpi_threads) } static int component_finalize(void) { - + if (!opal_common_ucx_thread_enabled) { + int i; + for (i = 0; i < mca_osc_ucx_component.comm_world_size; i++) { + ucp_ep_h ep = mca_osc_ucx_component.endpoints[i]; + if (ep != NULL) { + ucp_ep_destroy(ep); + } + } + free(mca_osc_ucx_component.endpoints); + } + opal_common_ucx_mca_deregister(); if (mca_osc_ucx_component.env_initialized) { opal_common_ucx_wpool_finalize(mca_osc_ucx_component.wpool); } @@ -451,7 +484,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in opal_common_ucx_mem_type_t mem_type; char *my_mem_addr; int my_mem_addr_size; - uint64_t my_info[2] = {0}; + uint64_t my_info[3] = {0}; char *recv_buf = NULL; void *dynamic_base = NULL; unsigned long total, *rbuf; @@ -462,6 +495,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in /* May be called concurrently - protect */ _osc_ucx_init_lock(); + if (mca_osc_ucx_component.env_initialized == false) { /* Lazy initialization of the global state. * As not all of the MPI applications are using One-Sided functionality @@ -470,9 +504,20 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in OBJ_CONSTRUCT(&mca_osc_ucx_component.requests, opal_free_list_t); ret = opal_free_list_init (&mca_osc_ucx_component.requests, - sizeof(ompi_osc_ucx_request_t), + sizeof(ompi_osc_ucx_generic_request_t), opal_cache_line_size, - OBJ_CLASS(ompi_osc_ucx_request_t), + OBJ_CLASS(ompi_osc_ucx_generic_request_t), + 0, 0, 8, 0, 8, NULL, 0, NULL, NULL, NULL); + if (OMPI_SUCCESS != ret) { + OSC_UCX_VERBOSE(1, "opal_free_list_init failed: %d", ret); + goto select_unlock; + } + + OBJ_CONSTRUCT(&mca_osc_ucx_component.accumulate_requests, opal_free_list_t); + ret = opal_free_list_init (&mca_osc_ucx_component.accumulate_requests, + sizeof(ompi_osc_ucx_accumulate_request_t), + opal_cache_line_size, + OBJ_CLASS(ompi_osc_ucx_accumulate_request_t), 0, 0, 8, 0, 8, NULL, 0, NULL, NULL, NULL); if (OMPI_SUCCESS != ret) { OSC_UCX_VERBOSE(1, "opal_free_list_init failed: %d", ret); @@ -484,7 +529,10 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in OSC_UCX_VERBOSE(1, "opal_common_ucx_wpool_init failed: %d", ret); goto select_unlock; } - + if (!opal_common_ucx_thread_enabled) { + mca_osc_ucx_component.comm_world_size = ompi_proc_world_size(); + mca_osc_ucx_component.endpoints = calloc(mca_osc_ucx_component.comm_world_size, sizeof(ucp_ep_h)); + } /* Make sure that all memory updates performed above are globally * observable before (mca_osc_ucx_component.env_initialized = true) */ @@ -523,6 +571,12 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in /* fill in the function pointer part */ memcpy(module, &ompi_osc_ucx_module_template, sizeof(ompi_osc_base_module_t)); + /* TODO Provide support for nonblocking operations with dynamic windows */ + if (enable_nonblocking_accumulate && flavor != MPI_WIN_FLAVOR_DYNAMIC) { + module->super.osc_accumulate = ompi_osc_ucx_accumulate_nb; + module->super.osc_get_accumulate = ompi_osc_ucx_get_accumulate_nb; + } + ret = ompi_comm_dup(comm, &module->comm); if (ret != OMPI_SUCCESS) { goto error; @@ -537,6 +591,7 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in module->size = size; module->no_locks = check_config_value_bool ("no_locks", info); module->acc_single_intrinsic = check_config_value_bool ("acc_single_intrinsic", info); + module->skip_sync_check = false; /* share everyone's displacement units. Only do an allgather if strictly necessary, since it requires O(p) state. */ @@ -755,10 +810,11 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in my_info[0] = (uint64_t)dynamic_base; } my_info[1] = (uint64_t)state_base; + my_info[2] = ompi_comm_rank(&ompi_mpi_comm_world.comm); - recv_buf = (char *)calloc(comm_size, 2 * sizeof(uint64_t)); - ret = comm->c_coll->coll_allgather((void *)my_info, 2 * sizeof(uint64_t), - MPI_BYTE, recv_buf, 2 * sizeof(uint64_t), + recv_buf = (char *)calloc(comm_size, sizeof(my_info)); + ret = comm->c_coll->coll_allgather((void *)my_info, sizeof(my_info), + MPI_BYTE, recv_buf, sizeof(my_info), MPI_BYTE, comm, comm->c_coll->coll_allgather_module); if (ret != OMPI_SUCCESS) { goto error; @@ -766,9 +822,11 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in module->addrs = calloc(comm_size, sizeof(uint64_t)); module->state_addrs = calloc(comm_size, sizeof(uint64_t)); + module->comm_world_ranks = calloc(comm_size, sizeof(uint64_t)); for (i = 0; i < comm_size; i++) { - memcpy(&(module->addrs[i]), recv_buf + i * 2 * sizeof(uint64_t), sizeof(uint64_t)); - memcpy(&(module->state_addrs[i]), recv_buf + i * 2 * sizeof(uint64_t) + sizeof(uint64_t), sizeof(uint64_t)); + memcpy(&(module->addrs[i]), recv_buf + i * 3 * sizeof(uint64_t), sizeof(uint64_t)); + memcpy(&(module->state_addrs[i]), recv_buf + i * 3 * sizeof(uint64_t) + sizeof(uint64_t), sizeof(uint64_t)); + memcpy(&(module->comm_world_ranks[i]), recv_buf + i * 3 * sizeof(uint64_t) + 2 * sizeof(uint64_t), sizeof(uint64_t)); } free(recv_buf); @@ -860,85 +918,51 @@ int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_ base, len, insert); } } -inline bool ompi_osc_need_acc_lock(ompi_osc_ucx_module_t *module, int target) -{ - ompi_osc_ucx_lock_t *lock = NULL; - opal_hash_table_get_value_uint32(&module->outstanding_locks, - (uint32_t) target, (void **) &lock); - /* if there is an exclusive lock there is no need to acqurie the accumulate lock */ - return !(NULL != lock && LOCK_EXCLUSIVE == lock->type); -} - -inline int ompi_osc_state_lock( - ompi_osc_ucx_module_t *module, - int target, - bool *lock_acquired, - bool force_lock) { +int ompi_osc_ucx_dynamic_lock(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = -1; - uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_DYNAMIC_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); int ret = OMPI_SUCCESS; - if (force_lock || ompi_osc_need_acc_lock(module, target)) { - for (;;) { - ret = opal_common_ucx_wpmem_cmpswp(module->state_mem, - TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, - target, &result_value, sizeof(result_value), - remote_addr); - if (ret != OMPI_SUCCESS) { - OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_cmpswp failed: %d", ret); - return OMPI_ERROR; - } - if (result_value == TARGET_LOCK_UNLOCKED) { - break; - } - - opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool); + for (;;) { + ret = opal_common_ucx_wpmem_cmpswp(module->state_mem, + TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, + target, &result_value, sizeof(result_value), + remote_addr, ep); + if (ret != OMPI_SUCCESS) { + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_cmpswp failed: %d", ret); + return OMPI_ERROR; + } + if (result_value == TARGET_LOCK_UNLOCKED) { + break; } - *lock_acquired = true; - } else { - *lock_acquired = false; + opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool); } - return OMPI_SUCCESS; + return ret; } -inline int ompi_osc_state_unlock( - ompi_osc_ucx_module_t *module, - int target, - bool lock_acquired, - void *free_ptr) { - uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_ACC_LOCK_OFFSET; +int ompi_osc_ucx_dynamic_unlock(ompi_osc_ucx_module_t *module, int target) { + uint64_t result_value = -1; + uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_DYNAMIC_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); int ret = OMPI_SUCCESS; - if (lock_acquired) { - uint64_t result_value = 0; - /* fence any still active operations */ - ret = opal_common_ucx_wpmem_fence(module->mem); - if (ret != OMPI_SUCCESS) { - OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); - return OMPI_ERROR; - } - - ret = opal_common_ucx_wpmem_fetch(module->state_mem, - UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, - target, &result_value, sizeof(result_value), - remote_addr); - assert(result_value == TARGET_LOCK_EXCLUSIVE); - } else if (NULL != free_ptr){ - /* flush before freeing the buffer */ - ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); - } - /* TODO: encapsulate in a request and make the release non-blocking */ - if (NULL != free_ptr) { - free(free_ptr); - } + ret = opal_common_ucx_wpmem_fence(module->mem); if (ret != OMPI_SUCCESS) { - OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fetch failed: %d", ret); + OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_fence failed: %d", ret); return OMPI_ERROR; } + ret = opal_common_ucx_wpmem_fetch(module->state_mem, + UCP_ATOMIC_FETCH_OP_SWAP, TARGET_LOCK_UNLOCKED, + target, &result_value, sizeof(result_value), + remote_addr, ep); + assert(result_value == TARGET_LOCK_EXCLUSIVE); return ret; } @@ -948,15 +972,14 @@ int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) { int ret = OMPI_SUCCESS; if (module->state.dynamic_win_count >= OMPI_OSC_UCX_ATTACH_MAX) { - OSC_UCX_ERROR("Dynamic window attach failed: Cannot satisfy %d attached windows. " + OSC_UCX_ERROR("Dynamic window attach failed: Cannot satisfy %" PRIu64 "attached windows. " "Max attached windows is %d \n", module->state.dynamic_win_count+1, OMPI_OSC_UCX_ATTACH_MAX); return OMPI_ERR_TEMP_OUT_OF_RESOURCE; } - bool lock_acquired = false; - ret = ompi_osc_state_lock(module, ompi_comm_rank(module->comm), &lock_acquired, true); + ret = ompi_osc_ucx_dynamic_lock(module, ompi_comm_rank(module->comm)); if (ret != OMPI_SUCCESS) { return ret; } @@ -967,7 +990,7 @@ int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) { (uint64_t)base, len, &insert_index); if (contain_index >= 0) { module->local_dynamic_win_info[contain_index].refcnt++; - ompi_osc_state_unlock(module, ompi_comm_rank(module->comm), lock_acquired, NULL); + ret = ompi_osc_ucx_dynamic_unlock(module, ompi_comm_rank(module->comm)); return ret; } @@ -991,7 +1014,7 @@ int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) { &(module->local_dynamic_win_info[insert_index].my_mem_addr_size), &(module->local_dynamic_win_info[insert_index].mem)); if (ret != OMPI_SUCCESS) { - ompi_osc_state_unlock(module, ompi_comm_rank(module->comm), lock_acquired, NULL); + ompi_osc_ucx_dynamic_unlock(module, ompi_comm_rank(module->comm)); return ret; } @@ -1005,15 +1028,15 @@ int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len) { module->local_dynamic_win_info[insert_index].refcnt++; module->state.dynamic_win_count++; - return ompi_osc_state_unlock(module, ompi_comm_rank(module->comm), lock_acquired, NULL); + return ompi_osc_ucx_dynamic_unlock(module, ompi_comm_rank(module->comm)); } int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base) { ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module; int insert, contain; + int ret = OMPI_SUCCESS; - bool lock_acquired = false; - int ret = ompi_osc_state_lock(module, ompi_comm_rank(module->comm), &lock_acquired, true); + ret = ompi_osc_ucx_dynamic_lock(module, ompi_comm_rank(module->comm)); if (ret != OMPI_SUCCESS) { return ret; } @@ -1027,7 +1050,7 @@ int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base) { /* if we can't find region - just exit */ if (contain < 0) { - return ompi_osc_state_unlock(module, ompi_comm_rank(module->comm), lock_acquired, NULL); + return ompi_osc_ucx_dynamic_unlock(module, ompi_comm_rank(module->comm)); } module->local_dynamic_win_info[contain].refcnt--; @@ -1043,7 +1066,7 @@ int ompi_osc_ucx_win_detach(struct ompi_win_t *win, const void *base) { module->state.dynamic_win_count--; } - return ompi_osc_state_unlock(module, ompi_comm_rank(module->comm), lock_acquired, NULL); + return ompi_osc_ucx_dynamic_unlock(module, ompi_comm_rank(module->comm)); } @@ -1091,6 +1114,7 @@ int ompi_osc_ucx_free(struct ompi_win_t *win) { free(module->addrs); free(module->state_addrs); + free(module->comm_world_ranks); opal_common_ucx_wpmem_free(module->state_mem); if (NULL != module->mem) { diff --git a/ompi/mca/osc/ucx/osc_ucx_passive_target.c b/ompi/mca/osc/ucx/osc_ucx_passive_target.c index 7dafa15620e..455f7ae3302 100644 --- a/ompi/mca/osc/ucx/osc_ucx_passive_target.c +++ b/ompi/mca/osc/ucx/osc_ucx_passive_target.c @@ -21,12 +21,14 @@ OBJ_CLASS_INSTANCE(ompi_osc_ucx_lock_t, opal_object_t, NULL, NULL); static inline int start_shared(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = -1; uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); int ret = OMPI_SUCCESS; while (true) { ret = opal_common_ucx_wpmem_fetch(module->state_mem, UCP_ATOMIC_FETCH_OP_FADD, 1, target, &result_value, sizeof(result_value), - remote_addr); + remote_addr, ep); if (OMPI_SUCCESS != ret) { return ret; } @@ -35,7 +37,7 @@ static inline int start_shared(ompi_osc_ucx_module_t *module, int target) { if (result_value >= TARGET_LOCK_EXCLUSIVE) { ret = opal_common_ucx_wpmem_post(module->state_mem, UCP_ATOMIC_POST_OP_ADD, (-1), target, - sizeof(uint64_t), remote_addr); + sizeof(uint64_t), remote_addr, ep); if (OMPI_SUCCESS != ret) { return ret; } @@ -50,20 +52,24 @@ static inline int start_shared(ompi_osc_ucx_module_t *module, int target) { static inline int end_shared(ompi_osc_ucx_module_t *module, int target) { uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); return opal_common_ucx_wpmem_post(module->state_mem, UCP_ATOMIC_POST_OP_ADD, - (-1), target, sizeof(uint64_t), remote_addr); + (-1), target, sizeof(uint64_t), remote_addr, ep); } static inline int start_exclusive(ompi_osc_ucx_module_t *module, int target) { uint64_t result_value = -1; uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); int ret = OMPI_SUCCESS; for (;;) { ret = opal_common_ucx_wpmem_cmpswp(module->state_mem, TARGET_LOCK_UNLOCKED, TARGET_LOCK_EXCLUSIVE, target, &result_value, sizeof(result_value), - remote_addr); + remote_addr, ep); if (OMPI_SUCCESS != ret) { return ret; } @@ -76,9 +82,11 @@ static inline int start_exclusive(ompi_osc_ucx_module_t *module, int target) { static inline int end_exclusive(ompi_osc_ucx_module_t *module, int target) { uint64_t remote_addr = (module->state_addrs)[target] + OSC_UCX_STATE_LOCK_OFFSET; + ucp_ep_h *ep; + OSC_UCX_GET_DEFAULT_EP(ep, module, target); return opal_common_ucx_wpmem_post(module->state_mem, UCP_ATOMIC_POST_OP_ADD, -((int64_t)TARGET_LOCK_EXCLUSIVE), target, - sizeof(uint64_t), remote_addr); + sizeof(uint64_t), remote_addr, ep); } int ompi_osc_ucx_lock(int lock_type, int target, int mpi_assert, struct ompi_win_t *win) { @@ -155,12 +163,12 @@ int ompi_osc_ucx_unlock(int target, struct ompi_win_t *win) { return OMPI_ERR_RMA_SYNC; } - opal_hash_table_remove_value_uint32(&module->outstanding_locks, - (uint32_t)target); - ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, target); + ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_WORKER, 0); if (ret != OMPI_SUCCESS) { return ret; } + opal_hash_table_remove_value_uint32(&module->outstanding_locks, + (uint32_t)target); if (lock->is_nocheck == false) { if (lock->type == LOCK_EXCLUSIVE) { diff --git a/ompi/mca/osc/ucx/osc_ucx_request.c b/ompi/mca/osc/ucx/osc_ucx_request.c index d2a3fd494fa..73b5efc733b 100644 --- a/ompi/mca/osc/ucx/osc_ucx_request.c +++ b/ompi/mca/osc/ucx/osc_ucx_request.c @@ -24,9 +24,9 @@ static int request_cancel(struct ompi_request_t *request, int complete) static int request_free(struct ompi_request_t **ompi_req) { - ompi_osc_ucx_request_t *request = (ompi_osc_ucx_request_t*) *ompi_req; + ompi_osc_ucx_generic_request_t *request = (ompi_osc_ucx_generic_request_t*) *ompi_req; - if (true != (bool)(request->super.req_complete)) { + if (true != (bool)(request->super.super.req_complete)) { return MPI_ERR_REQUEST; } @@ -37,20 +37,15 @@ static int request_free(struct ompi_request_t **ompi_req) return OMPI_SUCCESS; } -static void request_construct(ompi_osc_ucx_request_t *request) +static void request_construct(ompi_osc_ucx_generic_request_t *request) { - request->super.req_type = OMPI_REQUEST_WIN; - request->super.req_status._cancelled = 0; - request->super.req_free = request_free; - request->super.req_cancel = request_cancel; + request->super.super.req_type = OMPI_REQUEST_WIN; + request->super.super.req_status._cancelled = 0; + request->super.super.req_free = request_free; + request->super.super.req_cancel = request_cancel; } -void req_completion(void *request) { - ompi_osc_ucx_request_t *req = (ompi_osc_ucx_request_t *)request; - ompi_request_complete(&(req->super), true); - mca_osc_ucx_component.num_incomplete_req_ops--; - assert(mca_osc_ucx_component.num_incomplete_req_ops >= 0); -} - -OBJ_CLASS_INSTANCE(ompi_osc_ucx_request_t, ompi_request_t, +OBJ_CLASS_INSTANCE(ompi_osc_ucx_generic_request_t, ompi_request_t, + request_construct, NULL); +OBJ_CLASS_INSTANCE(ompi_osc_ucx_accumulate_request_t, ompi_request_t, request_construct, NULL); diff --git a/ompi/mca/osc/ucx/osc_ucx_request.h b/ompi/mca/osc/ucx/osc_ucx_request.h index b471e671bad..dfa718bc699 100644 --- a/ompi/mca/osc/ucx/osc_ucx_request.h +++ b/ompi/mca/osc/ucx/osc_ucx_request.h @@ -16,38 +16,132 @@ #include "ompi/request/request.h" + +enum req_type { + ACCUMULATE_REQ, + RPUT_REQ, + RGET_REQ +}; + +enum acc_rma_type { + NONE, + ACCUMULATE, + GET_ACCUMULATE, + ANY +}; + +enum acc_phases { + ACC_INIT, + ACC_GET_RESULTS_DATA, + ACC_GET_STAGE_DATA, + ACC_PUT_TARGET_DATA, + ACC_FINALIZE +}; + typedef struct ompi_osc_ucx_request { ompi_request_t super; + int request_type; + ompi_osc_ucx_module_t *module; } ompi_osc_ucx_request_t; +typedef struct ompi_osc_ucx_generic_request { + ompi_osc_ucx_request_t super; +} ompi_osc_ucx_generic_request_t; + +typedef struct ompi_osc_ucx_accumulate_request { + ompi_osc_ucx_request_t super; + struct ompi_op_t *op; + int phase; + int acc_type; + bool lock_acquired; + int target; + struct ompi_win_t *win; + const void *origin_addr; + int origin_count; + struct ompi_datatype_t *origin_dt; + void *stage_addr; + int stage_count; + struct ompi_datatype_t *stage_dt; + struct ompi_datatype_t *target_dt; + int target_disp; + int target_count; + void *free_ptr; +} ompi_osc_ucx_accumulate_request_t; + OBJ_CLASS_DECLARATION(ompi_osc_ucx_request_t); +OBJ_CLASS_DECLARATION(ompi_osc_ucx_generic_request_t); +OBJ_CLASS_DECLARATION(ompi_osc_ucx_accumulate_request_t); + +#define OMPI_OSC_UCX_GENERIC_REQUEST_ALLOC(win, req, _req_type) \ + do { \ + opal_free_list_item_t *item; \ + do { \ + item = opal_free_list_get(&mca_osc_ucx_component.requests); \ + if (item == NULL) { \ + if (module->ctx->num_incomplete_req_ops > 0) { \ + opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool); \ + } \ + } \ + } while (item == NULL); \ + req = (ompi_osc_ucx_generic_request_t*) item; \ + OMPI_REQUEST_INIT(&req->super.super, false); \ + req->super.super.req_mpi_object.win = win; \ + req->super.super.req_complete = false; \ + req->super.super.req_state = OMPI_REQUEST_ACTIVE; \ + req->super.super.req_status.MPI_ERROR = MPI_SUCCESS; \ + req->super.module = NULL; \ + req->super.request_type = _req_type; \ + } while (0) -#define OMPI_OSC_UCX_REQUEST_ALLOC(win, req) \ - do { \ - opal_free_list_item_t *item; \ - do { \ - item = opal_free_list_get(&mca_osc_ucx_component.requests); \ - if (item == NULL) { \ - if (mca_osc_ucx_component.num_incomplete_req_ops > 0) { \ - opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool); \ - } \ - } \ - } while (item == NULL); \ - req = (ompi_osc_ucx_request_t*) item; \ - OMPI_REQUEST_INIT(&req->super, false); \ - req->super.req_mpi_object.win = win; \ - req->super.req_complete = false; \ - req->super.req_state = OMPI_REQUEST_ACTIVE; \ - req->super.req_status.MPI_ERROR = MPI_SUCCESS; \ +#define OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, req) \ + do { \ + opal_free_list_item_t *item; \ + do { \ + item = opal_free_list_get(&mca_osc_ucx_component.accumulate_requests); \ + if (item == NULL) { \ + if (module->ctx->num_incomplete_req_ops > 0) { \ + opal_common_ucx_wpool_progress(mca_osc_ucx_component.wpool); \ + } \ + } \ + } while (item == NULL); \ + req = (ompi_osc_ucx_accumulate_request_t*) item; \ + OMPI_REQUEST_INIT(&req->super.super, false); \ + req->super.super.req_mpi_object.win = win; \ + req->super.super.req_complete = false; \ + req->super.super.req_state = OMPI_REQUEST_ACTIVE; \ + req->super.super.req_status.MPI_ERROR = MPI_SUCCESS; \ + req->super.module = NULL; \ + req->super.request_type = ACCUMULATE_REQ; \ + req->acc_type = NONE; \ + req->op = MPI_NO_OP; \ + req->phase = ACC_INIT; \ + req->target = -1; \ + req->lock_acquired = false; \ + req->win = NULL; \ + req->origin_addr = NULL; \ + req->origin_count = 0; \ + req->origin_dt = NULL; \ + req->stage_addr = NULL; \ + req->stage_count = 0; \ + req->stage_dt = NULL; \ + req->target_dt = NULL; \ + req->target_count = 0; \ + req->target_disp = 0; \ + req->free_ptr = NULL; \ } while (0) -#define OMPI_OSC_UCX_REQUEST_RETURN(req) \ - do { \ - OMPI_REQUEST_FINI(&req->super); \ - opal_free_list_return (&mca_osc_ucx_component.requests, \ - (opal_free_list_item_t*) req); \ +#define OMPI_OSC_UCX_REQUEST_RETURN(req) \ + do { \ + OMPI_REQUEST_FINI(&req->super.super); \ + if (req->super.request_type == ACCUMULATE_REQ) { \ + opal_free_list_return (&mca_osc_ucx_component.accumulate_requests, \ + (opal_free_list_item_t*) req); \ + } else { \ + opal_free_list_return (&mca_osc_ucx_component.requests, \ + (opal_free_list_item_t*) req); \ + } \ } while (0) -void req_completion(void *request); +void ompi_osc_ucx_req_completion(void *request); #endif /* OMPI_OSC_UCX_REQUEST_H */ diff --git a/opal/mca/common/ucx/common_ucx_wpool.c b/opal/mca/common/ucx/common_ucx_wpool.c index 8bdf5bd3dba..d31d7d4187e 100644 --- a/opal/mca/common/ucx/common_ucx_wpool.c +++ b/opal/mca/common/ucx/common_ucx_wpool.c @@ -9,7 +9,7 @@ #include "opal/memoryhooks/memory.h" #include "opal/util/proc.h" #include "opal/util/sys_limits.h" - +#include "opal/util/sys_limits.h" #include /******************************************************************************* @@ -31,6 +31,8 @@ __thread FILE *tls_pf = NULL; __thread int initialized = 0; #endif +bool opal_common_ucx_thread_enabled = false; + static _ctx_record_t *_tlocal_add_ctx_rec(opal_common_ucx_ctx_t *ctx); static inline _ctx_record_t *_tlocal_get_ctx_rec(opal_tsd_tracked_key_t tls_key); static void _tlocal_ctx_rec_cleanup(_ctx_record_t *ctx_rec); @@ -48,13 +50,18 @@ static opal_common_ucx_winfo_t *_winfo_create(opal_common_ucx_wpool_t *wpool) ucs_status_t status; opal_common_ucx_winfo_t *winfo = NULL; - memset(&worker_params, 0, sizeof(worker_params)); - worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; - status = ucp_worker_create(wpool->ucp_ctx, &worker_params, &worker); - if (UCS_OK != status) { - MCA_COMMON_UCX_ERROR("ucp_worker_create failed: %d", status); - goto exit; + if (opal_common_ucx_thread_enabled || wpool->dflt_winfo == NULL) { + memset(&worker_params, 0, sizeof(worker_params)); + worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; + worker_params.thread_mode = UCS_THREAD_MODE_SINGLE; + status = ucp_worker_create(wpool->ucp_ctx, &worker_params, &worker); + if (UCS_OK != status) { + MCA_COMMON_UCX_ERROR("ucp_worker_create failed: %d", status); + goto exit; + } + } else { + /* Single threaded application can reuse the default worker */ + worker = wpool->dflt_winfo->worker; } winfo = OBJ_NEW(opal_common_ucx_winfo_t); @@ -70,6 +77,7 @@ static opal_common_ucx_winfo_t *_winfo_create(opal_common_ucx_wpool_t *wpool) winfo->inflight_ops = NULL; winfo->global_inflight_ops = 0; winfo->inflight_req = UCS_OK; + winfo->is_dflt_winfo = false; return winfo; @@ -90,11 +98,13 @@ static void _winfo_destructor(opal_common_ucx_winfo_t *winfo) if (winfo->comm_size != 0) { size_t i; - for (i = 0; i < winfo->comm_size; i++) { - if (NULL != winfo->endpoints[i]) { - ucp_ep_destroy(winfo->endpoints[i]); + if (opal_common_ucx_thread_enabled) { + for (i = 0; i < winfo->comm_size; i++) { + if (NULL != winfo->endpoints[i]) { + ucp_ep_destroy(winfo->endpoints[i]); + } + assert(winfo->inflight_ops[i] == 0); } - assert(winfo->inflight_ops[i] == 0); } free(winfo->endpoints); free(winfo->inflight_ops); @@ -103,7 +113,10 @@ static void _winfo_destructor(opal_common_ucx_winfo_t *winfo) winfo->comm_size = 0; OBJ_DESTRUCT(&winfo->mutex); - ucp_worker_destroy(winfo->worker); + if (opal_common_ucx_thread_enabled || winfo->is_dflt_winfo) { + ucp_worker_destroy(winfo->worker); + } + } /* ----------------------------------------------------------------------------- @@ -145,12 +158,15 @@ OPAL_DECLSPEC int opal_common_ucx_wpool_init(opal_common_ucx_wpool_t *wpool) OBJ_CONSTRUCT(&wpool->idle_workers, opal_list_t); OBJ_CONSTRUCT(&wpool->active_workers, opal_list_t); + wpool->dflt_winfo = NULL; + winfo = _winfo_create(wpool); if (NULL == winfo) { MCA_COMMON_UCX_ERROR("Failed to create receive worker"); rc = OPAL_ERROR; goto err_worker_create; } + winfo->is_dflt_winfo = true; wpool->dflt_winfo = winfo; OBJ_RETAIN(wpool->dflt_winfo); @@ -332,7 +348,7 @@ OPAL_DECLSPEC int opal_common_ucx_wpctx_create(opal_common_ucx_wpool_t *wpool, i opal_common_ucx_ctx_t *ctx = calloc(1, sizeof(*ctx)); int ret = OPAL_SUCCESS; - OBJ_CONSTRUCT(&ctx->mutex, opal_mutex_t); + OBJ_CONSTRUCT(&ctx->mutex, opal_recursive_mutex_t); OBJ_CONSTRUCT(&ctx->ctx_records, opal_list_t); ctx->wpool = wpool; @@ -340,6 +356,7 @@ OPAL_DECLSPEC int opal_common_ucx_wpctx_create(opal_common_ucx_wpool_t *wpool, i ctx->recv_worker_addrs = NULL; ctx->recv_worker_displs = NULL; + ctx->num_incomplete_req_ops = 0; ret = exchange_func(wpool->recv_waddr, wpool->recv_waddr_len, &ctx->recv_worker_addrs, &ctx->recv_worker_displs, exchange_metadata); if (ret != OPAL_SUCCESS) { @@ -404,6 +421,7 @@ int opal_common_ucx_wpmem_create(opal_common_ucx_ctx_t *ctx, void **mem_base, si mem->ctx = ctx; mem->mem_addrs = NULL; mem->mem_displs = NULL; + mem->skip_periodic_flush = false; OBJ_CONSTRUCT(&mem->mutex, opal_mutex_t); @@ -693,7 +711,7 @@ static int _tlocal_mem_create_rkey(_mem_record_t *mem_rec, ucp_ep_h ep, int targ } /* Get the TLS in case of slow path (not everything has been yet initialized */ -OPAL_DECLSPEC int opal_common_ucx_tlocal_fetch_spath(opal_common_ucx_wpmem_t *mem, int target) +OPAL_DECLSPEC int opal_common_ucx_tlocal_fetch_spath(opal_common_ucx_wpmem_t *mem, int target, ucp_ep_h *dflt_ep) { _ctx_record_t *ctx_rec = NULL; _mem_record_t *mem_rec = NULL; @@ -712,9 +730,20 @@ OPAL_DECLSPEC int opal_common_ucx_tlocal_fetch_spath(opal_common_ucx_wpmem_t *me /* Obtain the endpoint */ if (OPAL_UNLIKELY(NULL == winfo->endpoints[target])) { - rc = _tlocal_ctx_connect(ctx_rec, target); - if (rc != OPAL_SUCCESS) { - return rc; + if (opal_common_ucx_thread_enabled || (dflt_ep == NULL) || + (*dflt_ep == NULL)) { + rc = _tlocal_ctx_connect(ctx_rec, target); + if (rc != OPAL_SUCCESS) { + return rc; + } + if (!opal_common_ucx_thread_enabled && (dflt_ep != NULL) && + (*dflt_ep == NULL)) { + /* set the proc ep */ + *dflt_ep = winfo->endpoints[target]; + } + } else { + /* reuse the previously created ep */ + winfo->endpoints[target] = *dflt_ep; } } ep = winfo->endpoints[target]; @@ -784,8 +813,8 @@ OPAL_DECLSPEC int opal_common_ucx_winfo_flush(opal_common_ucx_winfo_t *winfo, in return rc; } -OPAL_DECLSPEC int opal_common_ucx_ctx_flush(opal_common_ucx_ctx_t *ctx, - opal_common_ucx_flush_scope_t scope, int target) +static inline int ctx_flush(opal_common_ucx_ctx_t *ctx, + opal_common_ucx_flush_scope_t scope, int target) { _ctx_record_t *ctx_rec; int rc = OPAL_SUCCESS; @@ -821,16 +850,48 @@ OPAL_DECLSPEC int opal_common_ucx_ctx_flush(opal_common_ucx_ctx_t *ctx, break; } } + opal_mutex_unlock(&ctx->mutex); return rc; } +OPAL_DECLSPEC int opal_common_ucx_ctx_flush(opal_common_ucx_ctx_t *ctx, + opal_common_ucx_flush_scope_t scope, int target) +{ + int rc = OPAL_SUCCESS; + int spin = 0; + + if (NULL == ctx) { + return OPAL_SUCCESS; + } + + rc = ctx_flush(ctx, scope, target); + if (rc != OPAL_SUCCESS) { + return rc; + } + + /* progress the nonblocking operations */ + while (ctx->num_incomplete_req_ops != 0) { + spin++; + rc = ctx_flush(ctx, OPAL_COMMON_UCX_SCOPE_WORKER, 0); + if (rc != OPAL_SUCCESS) { + return rc; + } + if (spin == opal_common_ucx.progress_iterations) { + opal_progress(); + spin = 0; + } + } + + return rc; +} + OPAL_DECLSPEC int opal_common_ucx_wpmem_flush_ep_nb(opal_common_ucx_wpmem_t *mem, int target, opal_common_ucx_user_req_handler_t user_req_cb, - void *user_req_ptr) + void *user_req_ptr, ucp_ep_h *dflt_ep) { #if HAVE_DECL_UCP_EP_FLUSH_NB int rc = OPAL_SUCCESS; @@ -842,7 +903,7 @@ OPAL_DECLSPEC int opal_common_ucx_wpmem_flush_ep_nb(opal_common_ucx_wpmem_t *mem return OPAL_SUCCESS; } - rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo, dflt_ep); if (OPAL_UNLIKELY(OPAL_SUCCESS != rc)) { MCA_COMMON_UCX_ERROR("tlocal_fetch failed: %d", rc); return rc; @@ -868,7 +929,7 @@ OPAL_DECLSPEC int opal_common_ucx_wpmem_flush_ep_nb(opal_common_ucx_wpmem_t *mem } - +/* TODO Replace the input with opal_common_ucx_ctx_t */ OPAL_DECLSPEC int opal_common_ucx_wpmem_fence(opal_common_ucx_wpmem_t *mem) { ucs_status_t status = UCS_OK; diff --git a/opal/mca/common/ucx/common_ucx_wpool.h b/opal/mca/common/ucx/common_ucx_wpool.h index 88665337903..48ee551685a 100644 --- a/opal/mca/common/ucx/common_ucx_wpool.h +++ b/opal/mca/common/ucx/common_ucx_wpool.h @@ -58,6 +58,8 @@ typedef struct { opal_list_t active_workers; } opal_common_ucx_wpool_t; +extern bool opal_common_ucx_thread_enabled; + /* Worker Pool Context (wpctx) is an object that is comprised of a set of UCP * workers that are considered as one logical communication entity. * One UCP worker per "active" thread is used. @@ -86,6 +88,7 @@ typedef struct { char *recv_worker_addrs; int *recv_worker_displs; size_t comm_size; + opal_atomic_size_t num_incomplete_req_ops; } opal_common_ucx_ctx_t; /* Worker Pool memory (wpmem) is an object that represents a remotely accessible @@ -105,6 +108,7 @@ typedef struct { ucp_mem_h memh; char *mem_addrs; int *mem_displs; + bool skip_periodic_flush; /* TLS item that allows each thread to * store endpoints and rkey arrays @@ -127,6 +131,7 @@ struct opal_common_ucx_winfo { short *inflight_ops; short global_inflight_ops; ucs_status_ptr_t inflight_req; + bool is_dflt_winfo; }; OBJ_CLASS_DECLARATION(opal_common_ucx_winfo_t); @@ -178,7 +183,6 @@ OBJ_CLASS_DECLARATION(_mem_record_t); typedef int (*opal_common_ucx_exchange_func_t)(void *my_info, size_t my_info_len, char **recv_info, int **disps, void *metadata); - /* Manage Worker Pool (wpool) */ OPAL_DECLSPEC opal_common_ucx_wpool_t *opal_common_ucx_wpool_allocate(void); OPAL_DECLSPEC void opal_common_ucx_wpool_free(opal_common_ucx_wpool_t *wpool); @@ -198,10 +202,11 @@ OPAL_DECLSPEC void opal_common_ucx_req_init(void *request); OPAL_DECLSPEC void opal_common_ucx_req_completion(void *request, ucs_status_t status); /* Managing thread local storage */ -OPAL_DECLSPEC int opal_common_ucx_tlocal_fetch_spath(opal_common_ucx_wpmem_t *mem, int target); +OPAL_DECLSPEC int opal_common_ucx_tlocal_fetch_spath(opal_common_ucx_wpmem_t *mem, int target, ucp_ep_h *_dflt_ep); static inline int opal_common_ucx_tlocal_fetch(opal_common_ucx_wpmem_t *mem, int target, ucp_ep_h *_ep, ucp_rkey_h *_rkey, - opal_common_ucx_winfo_t **_winfo) + opal_common_ucx_winfo_t **_winfo, + ucp_ep_h *_dflt_ep) { _mem_record_t *mem_rec = NULL; int is_ready; @@ -215,7 +220,7 @@ static inline int opal_common_ucx_tlocal_fetch(opal_common_ucx_wpmem_t *mem, int is_ready = mem_rec && (mem_rec->winfo->endpoints[target]) && (NULL != mem_rec->rkeys[target]); MCA_COMMON_UCX_ASSERT((NULL == mem_rec) || (NULL != mem_rec->winfo)); if (OPAL_UNLIKELY(!is_ready)) { - rc = opal_common_ucx_tlocal_fetch_spath(mem, target); + rc = opal_common_ucx_tlocal_fetch_spath(mem, target, _dflt_ep); if (OPAL_SUCCESS != rc) { return rc; } @@ -246,11 +251,12 @@ OPAL_DECLSPEC int opal_common_ucx_wpmem_create(opal_common_ucx_ctx_t *ctx, void OPAL_DECLSPEC void opal_common_ucx_wpmem_free(opal_common_ucx_wpmem_t *mem); OPAL_DECLSPEC int opal_common_ucx_ctx_flush(opal_common_ucx_ctx_t *ctx, - opal_common_ucx_flush_scope_t scope, int target); + opal_common_ucx_flush_scope_t scope, + int target); OPAL_DECLSPEC int opal_common_ucx_wpmem_flush_ep_nb(opal_common_ucx_wpmem_t *mem, int target, opal_common_ucx_user_req_handler_t user_req_cb, - void *user_req_ptr); + void *user_req_ptr, ucp_ep_h *_dflt_ep); OPAL_DECLSPEC int opal_common_ucx_wpmem_fence(opal_common_ucx_wpmem_t *mem); OPAL_DECLSPEC int opal_common_ucx_winfo_flush(opal_common_ucx_winfo_t *winfo, int target, @@ -309,6 +315,8 @@ static inline int _periodical_flush_nb(opal_common_ucx_wpmem_t *mem, opal_common { int rc = OPAL_SUCCESS; + if (mem->skip_periodic_flush) return OPAL_SUCCESS; + if (OPAL_UNLIKELY(winfo->inflight_ops[target] >= MCA_COMMON_UCX_PER_TARGET_OPS_THRESHOLD) || OPAL_UNLIKELY(winfo->global_inflight_ops >= MCA_COMMON_UCX_GLOBAL_OPS_THRESHOLD)) { opal_common_ucx_flush_scope_t scope; @@ -349,7 +357,7 @@ static inline int _periodical_flush_nb(opal_common_ucx_wpmem_t *mem, opal_common static inline int opal_common_ucx_wpmem_putget(opal_common_ucx_wpmem_t *mem, opal_common_ucx_op_t op, int target, void *buffer, - size_t len, uint64_t rem_addr) + size_t len, uint64_t rem_addr, ucp_ep_h *dflt_ep) { ucp_ep_h ep; ucp_rkey_h rkey; @@ -358,7 +366,7 @@ static inline int opal_common_ucx_wpmem_putget(opal_common_ucx_wpmem_t *mem, int rc = OPAL_SUCCESS; char *called_func = ""; - rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo, dflt_ep); if (OPAL_UNLIKELY(OPAL_SUCCESS != rc)) { MCA_COMMON_UCX_VERBOSE(1, "tlocal_fetch failed: %d", rc); return rc; @@ -401,7 +409,7 @@ static inline int opal_common_ucx_wpmem_putget(opal_common_ucx_wpmem_t *mem, static inline int opal_common_ucx_wpmem_cmpswp(opal_common_ucx_wpmem_t *mem, uint64_t compare, uint64_t value, int target, void *buffer, size_t len, - uint64_t rem_addr) + uint64_t rem_addr, ucp_ep_h *dflt_ep) { ucp_ep_h ep; ucp_rkey_h rkey; @@ -409,7 +417,7 @@ static inline int opal_common_ucx_wpmem_cmpswp(opal_common_ucx_wpmem_t *mem, uin ucs_status_t status; int rc = OPAL_SUCCESS; - rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo, dflt_ep); if (OPAL_UNLIKELY(OPAL_SUCCESS != rc)) { MCA_COMMON_UCX_ERROR("opal_common_ucx_tlocal_fetch failed: %d", rc); return rc; @@ -440,7 +448,7 @@ static inline int opal_common_ucx_wpmem_cmpswp_nb(opal_common_ucx_wpmem_t *mem, uint64_t value, int target, void *buffer, size_t len, uint64_t rem_addr, opal_common_ucx_user_req_handler_t user_req_cb, - void *user_req_ptr) + void *user_req_ptr, ucp_ep_h *dflt_ep) { ucp_ep_h ep; ucp_rkey_h rkey; @@ -448,7 +456,7 @@ static inline int opal_common_ucx_wpmem_cmpswp_nb(opal_common_ucx_wpmem_t *mem, opal_common_ucx_request_t *req; int rc = OPAL_SUCCESS; - rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo, dflt_ep); if (OPAL_UNLIKELY(OPAL_SUCCESS != rc)) { MCA_COMMON_UCX_ERROR("opal_common_ucx_tlocal_fetch failed: %d", rc); return rc; @@ -482,7 +490,7 @@ static inline int opal_common_ucx_wpmem_cmpswp_nb(opal_common_ucx_wpmem_t *mem, static inline int opal_common_ucx_wpmem_post(opal_common_ucx_wpmem_t *mem, ucp_atomic_post_op_t opcode, uint64_t value, - int target, size_t len, uint64_t rem_addr) + int target, size_t len, uint64_t rem_addr, ucp_ep_h *dflt_ep) { ucp_ep_h ep; ucp_rkey_h rkey; @@ -490,7 +498,7 @@ static inline int opal_common_ucx_wpmem_post(opal_common_ucx_wpmem_t *mem, ucs_status_t status; int rc = OPAL_SUCCESS; - rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo, dflt_ep); if (OPAL_UNLIKELY(OPAL_SUCCESS != rc)) { MCA_COMMON_UCX_ERROR("tlocal_fetch failed: %d", rc); return rc; @@ -518,7 +526,7 @@ static inline int opal_common_ucx_wpmem_post(opal_common_ucx_wpmem_t *mem, static inline int opal_common_ucx_wpmem_fetch(opal_common_ucx_wpmem_t *mem, ucp_atomic_fetch_op_t opcode, uint64_t value, int target, void *buffer, size_t len, - uint64_t rem_addr) + uint64_t rem_addr, ucp_ep_h *dflt_ep) { ucp_ep_h ep = NULL; ucp_rkey_h rkey = NULL; @@ -526,7 +534,7 @@ static inline int opal_common_ucx_wpmem_fetch(opal_common_ucx_wpmem_t *mem, ucs_status_t status; int rc = OPAL_SUCCESS; - rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo, dflt_ep); if (OPAL_UNLIKELY(OPAL_SUCCESS != rc)) { MCA_COMMON_UCX_ERROR("tlocal_fetch failed: %d", rc); return rc; @@ -558,7 +566,7 @@ static inline int opal_common_ucx_wpmem_fetch_nb(opal_common_ucx_wpmem_t *mem, int target, void *buffer, size_t len, uint64_t rem_addr, opal_common_ucx_user_req_handler_t user_req_cb, - void *user_req_ptr) + void *user_req_ptr, ucp_ep_h *dflt_ep) { ucp_ep_h ep = NULL; ucp_rkey_h rkey = NULL; @@ -566,7 +574,7 @@ static inline int opal_common_ucx_wpmem_fetch_nb(opal_common_ucx_wpmem_t *mem, int rc = OPAL_SUCCESS; opal_common_ucx_request_t *req; - rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo); + rc = opal_common_ucx_tlocal_fetch(mem, target, &ep, &rkey, &winfo, dflt_ep); if (OPAL_UNLIKELY(OPAL_SUCCESS != rc)) { MCA_COMMON_UCX_ERROR("tlocal_fetch failed: %d", rc); return rc;