Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ompi/mca/osc/ucx/osc_ucx_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,12 +319,14 @@ static int component_init(bool enable_progress_threads, bool enable_mpi_threads)
}

static int component_finalize(void) {

if (!opal_common_ucx_thread_enabled) {
int i;
for (i = 0; i < mca_osc_ucx_component.comm_world_size; i++) {
ucp_ep_h ep = mca_osc_ucx_component.endpoints[i];
if (ep != NULL) {
ucp_ep_destroy(ep);
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_ep_counts, -1);
}
}
free(mca_osc_ucx_component.endpoints);
Expand All @@ -334,6 +336,9 @@ static int component_finalize(void) {
opal_common_ucx_wpool_finalize(mca_osc_ucx_component.wpool);
}
opal_common_ucx_wpool_free(mca_osc_ucx_component.wpool);

assert(opal_common_ucx_ep_counts == 0);
assert(opal_common_ucx_unpacked_rkey_counts == 0);
return OMPI_SUCCESS;
}

Expand Down Expand Up @@ -790,6 +795,11 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
goto error;
}

if (my_mem_addr_size != 0) {
/* rkey object is already distributed among comm processes */
ucp_rkey_buffer_release(my_mem_addr);
}

state_base = (void *)&(module->state);
ret = opal_common_ucx_wpmem_create(module->ctx, &state_base,
sizeof(ompi_osc_ucx_state_t),
Expand All @@ -803,6 +813,11 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
goto error;
}

if (my_mem_addr_size != 0) {
/* rkey object is already distributed among comm processes */
ucp_rkey_buffer_release(my_mem_addr);
}

/* exchange window addrs */
if (flavor == MPI_WIN_FLAVOR_ALLOCATE || flavor == MPI_WIN_FLAVOR_CREATE ||
flavor == MPI_WIN_FLAVOR_SHARED) {
Expand Down
25 changes: 25 additions & 0 deletions opal/mca/common/ucx/common_ucx_wpool.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ __thread int initialized = 0;
#endif

bool opal_common_ucx_thread_enabled = false;
opal_atomic_int64_t opal_common_ucx_ep_counts = 0;
opal_atomic_int64_t opal_common_ucx_unpacked_rkey_counts = 0;

static _ctx_record_t *_tlocal_add_ctx_rec(opal_common_ucx_ctx_t *ctx);
static inline _ctx_record_t *_tlocal_get_ctx_rec(opal_tsd_tracked_key_t tls_key);
Expand Down Expand Up @@ -102,6 +104,7 @@ static void _winfo_destructor(opal_common_ucx_winfo_t *winfo)
for (i = 0; i < winfo->comm_size; i++) {
if (NULL != winfo->endpoints[i]) {
ucp_ep_destroy(winfo->endpoints[i]);
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_ep_counts, -1);
}
assert(winfo->inflight_ops[i] == 0);
}
Expand Down Expand Up @@ -326,9 +329,26 @@ static opal_common_ucx_winfo_t *_wpool_get_winfo(opal_common_ucx_wpool_t *wpool,
return winfo;
}

/* Remove the winfo from active workers and add it to idle workers */
static void _wpool_put_winfo(opal_common_ucx_wpool_t *wpool, opal_common_ucx_winfo_t *winfo)
{
opal_mutex_lock(&wpool->mutex);
if (winfo->comm_size != 0) {
size_t i;
if (opal_common_ucx_thread_enabled) {
for (i = 0; i < winfo->comm_size; i++) {
if (NULL != winfo->endpoints[i]) {
ucp_ep_destroy(winfo->endpoints[i]);
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_ep_counts, -1);
}
assert(winfo->inflight_ops[i] == 0);
}
}
free(winfo->endpoints);
free(winfo->inflight_ops);
}
winfo->endpoints = NULL;
winfo->comm_size = 0;
opal_list_remove_item(&wpool->active_workers, &winfo->super);
opal_list_prepend(&wpool->idle_workers, &winfo->super);
opal_mutex_unlock(&wpool->mutex);
Expand Down Expand Up @@ -632,6 +652,7 @@ static int _tlocal_ctx_connect(_ctx_record_t *ctx_rec, int target)
memset(&ep_params, 0, sizeof(ucp_ep_params_t));
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;

assert(winfo->endpoints[target] == NULL);
opal_mutex_lock(&winfo->mutex);
displ = gctx->recv_worker_displs[target];
ep_params.address = (ucp_address_t *) &(gctx->recv_worker_addrs[displ]);
Expand All @@ -641,7 +662,9 @@ static int _tlocal_ctx_connect(_ctx_record_t *ctx_rec, int target)
opal_mutex_unlock(&winfo->mutex);
return OPAL_ERROR;
}
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_ep_counts, 1);
opal_mutex_unlock(&winfo->mutex);
assert(winfo->endpoints[target] != NULL);
return OPAL_SUCCESS;
}

Expand All @@ -662,6 +685,7 @@ static void _tlocal_mem_rec_cleanup(_mem_record_t *mem_rec)
for (i = 0; i < mem_rec->gmem->ctx->comm_size; i++) {
if (mem_rec->rkeys[i]) {
ucp_rkey_destroy(mem_rec->rkeys[i]);
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_unpacked_rkey_counts, -1);
}
}
opal_mutex_unlock(&mem_rec->winfo->mutex);
Expand Down Expand Up @@ -701,6 +725,7 @@ static int _tlocal_mem_create_rkey(_mem_record_t *mem_rec, ucp_ep_h ep, int targ

opal_mutex_lock(&mem_rec->winfo->mutex);
status = ucp_ep_rkey_unpack(ep, &gmem->mem_addrs[displ], &mem_rec->rkeys[target]);
OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(opal_common_ucx_unpacked_rkey_counts, 1);
opal_mutex_unlock(&mem_rec->winfo->mutex);
if (status != UCS_OK) {
MCA_COMMON_UCX_VERBOSE(1, "ucp_ep_rkey_unpack failed: %d", status);
Expand Down
11 changes: 11 additions & 0 deletions opal/mca/common/ucx/common_ucx_wpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ typedef struct {
} opal_common_ucx_wpool_t;

extern bool opal_common_ucx_thread_enabled;
extern opal_atomic_int64_t opal_common_ucx_ep_counts;
extern opal_atomic_int64_t opal_common_ucx_unpacked_rkey_counts;

#if OPAL_ENABLE_DEBUG
#define OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(_var, _val) \
do { \
opal_atomic_add_fetch_64(&(_var), (_val)); \
} while(0);
#else
#define OPAL_COMMON_UCX_DEBUG_ATOMIC_ADD(&(_var), (_val));
#endif

/* Worker Pool Context (wpctx) is an object that is comprised of a set of UCP
* workers that are considered as one logical communication entity.
Expand Down