Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions ompi/mca/osc/ucx/osc_ucx.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@
#define OMPI_OSC_UCX_ATTACH_MAX 48
#define OMPI_OSC_UCX_MEM_ADDR_MAX_LEN 1024


typedef struct ompi_osc_ucx_component {
ompi_osc_base_component_t super;
opal_common_ucx_wpool_t *wpool;
bool enable_mpi_threads;
opal_free_list_t requests; /* request free list for the r* communication variants */
opal_free_list_t accumulate_requests; /* request free list for the r* communication variants */
bool env_initialized; /* UCX environment is initialized or not */
int num_incomplete_req_ops;
int comm_world_size;
ucp_ep_h *endpoints;
int num_modules;
bool no_locks; /* Default value of the no_locks info key for new windows */
bool acc_single_intrinsic;
Expand All @@ -44,6 +47,16 @@ typedef struct ompi_osc_ucx_component {

OMPI_DECLSPEC extern ompi_osc_ucx_component_t mca_osc_ucx_component;

#define OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(_module) \
do { \
opal_atomic_add_fetch_size_t(&_module->ctx->num_incomplete_req_ops, 1); \
} while(0);

#define OSC_UCX_DECREMENT_OUTSTANDING_NB_OPS(_module) \
do { \
opal_atomic_add_fetch_size_t(&_module->ctx->num_incomplete_req_ops, -1); \
} while(0);

typedef enum ompi_osc_ucx_epoch {
NONE_EPOCH,
FENCE_EPOCH,
Expand All @@ -69,7 +82,8 @@ typedef struct ompi_osc_ucx_epoch_type {
#define OSC_UCX_STATE_COMPLETE_COUNT_OFFSET (sizeof(uint64_t) * 3)
#define OSC_UCX_STATE_POST_INDEX_OFFSET (sizeof(uint64_t) * 4)
#define OSC_UCX_STATE_POST_STATE_OFFSET (sizeof(uint64_t) * 5)
#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (5 + OMPI_OSC_UCX_POST_PEER_MAX))
#define OSC_UCX_STATE_DYNAMIC_LOCK_OFFSET (sizeof(uint64_t) * 6)
#define OSC_UCX_STATE_DYNAMIC_WIN_CNT_OFFSET (sizeof(uint64_t) * (6 + OMPI_OSC_UCX_POST_PEER_MAX))

typedef struct ompi_osc_dynamic_win_info {
uint64_t base;
Expand Down Expand Up @@ -102,6 +116,7 @@ typedef struct ompi_osc_ucx_module {
size_t size;
uint64_t *addrs;
uint64_t *state_addrs;
uint64_t *comm_world_ranks;
int disp_unit; /* if disp_unit >= 0, then everyone has the same
* disp unit size; if disp_unit == -1, then we
* need to look at disp_units */
Expand All @@ -125,6 +140,7 @@ typedef struct ompi_osc_ucx_module {
opal_common_ucx_wpmem_t *mem;
opal_common_ucx_wpmem_t *state_mem;

bool skip_sync_check;
bool noncontig_shared_win;
size_t *sizes;
/* in shared windows, shmem_addrs can be used for direct load store to
Expand All @@ -147,9 +163,18 @@ typedef struct ompi_osc_ucx_lock {
bool is_nocheck;
} ompi_osc_ucx_lock_t;

#define OSC_UCX_GET_EP(comm_, rank_) (ompi_comm_peer_lookup(comm_, rank_)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_UCX])
#define OSC_UCX_GET_EP(_module, rank_) (mca_osc_ucx_component.endpoints[_module->comm_world_ranks[rank_]])
#define OSC_UCX_GET_DISP(module_, rank_) ((module_->disp_unit < 0) ? module_->disp_units[rank_] : module_->disp_unit)

#define OSC_UCX_GET_DEFAULT_EP(_ep_ptr, _module, _target) \
if (opal_common_ucx_thread_enabled) { \
_ep_ptr = NULL; \
} else { \
_ep_ptr = (ucp_ep_h *)&(OSC_UCX_GET_EP(_module, _target)); \
}

extern size_t ompi_osc_ucx_outstanding_ops_flush_threshold;

int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size,
int *disp_unit, void * baseptr);
int ompi_osc_ucx_win_attach(struct ompi_win_t *win, void *base, size_t len);
Expand All @@ -169,6 +194,11 @@ int ompi_osc_ucx_accumulate(const void *origin_addr, int origin_count,
int target, ptrdiff_t target_disp, int target_count,
struct ompi_datatype_t *target_dt,
struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_accumulate_nb(const void *origin_addr, int origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, int target_count,
struct ompi_datatype_t *target_dt,
struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_compare_and_swap(const void *origin_addr, const void *compare_addr,
void *result_addr, struct ompi_datatype_t *dt,
int target, ptrdiff_t target_disp,
Expand All @@ -184,6 +214,13 @@ int ompi_osc_ucx_get_accumulate(const void *origin_addr, int origin_count,
int target_rank, ptrdiff_t target_disp,
int target_count, struct ompi_datatype_t *target_datatype,
struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_get_accumulate_nb(const void *origin_addr, int origin_count,
struct ompi_datatype_t *origin_datatype,
void *result_addr, int result_count,
struct ompi_datatype_t *result_datatype,
int target_rank, ptrdiff_t target_disp,
int target_count, struct ompi_datatype_t *target_datatype,
struct ompi_op_t *op, struct ompi_win_t *win);
int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
struct ompi_datatype_t *origin_dt,
int target, ptrdiff_t target_disp, int target_count,
Expand Down Expand Up @@ -228,10 +265,7 @@ int ompi_osc_ucx_flush_local_all(struct ompi_win_t *win);
int ompi_osc_find_attached_region_position(ompi_osc_dynamic_win_info_t *dynamic_wins,
int min_index, int max_index,
uint64_t base, size_t len, int *insert);
extern inline bool ompi_osc_need_acc_lock(ompi_osc_ucx_module_t *module, int target);
extern inline int ompi_osc_state_lock(ompi_osc_ucx_module_t *module, int target,
bool *lock_acquired, bool force_lock);
extern inline int ompi_osc_state_unlock(ompi_osc_ucx_module_t *module, int target,
bool lock_acquired, void *free_ptr);
int ompi_osc_ucx_dynamic_lock(ompi_osc_ucx_module_t *module, int target);
int ompi_osc_ucx_dynamic_unlock(ompi_osc_ucx_module_t *module, int target);

#endif /* OMPI_OSC_UCX_H */
19 changes: 11 additions & 8 deletions ompi/mca/osc/ucx/osc_ucx_active_target.c
Original file line number Diff line number Diff line change
Expand Up @@ -165,31 +165,33 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) {
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
int i, size;
int ret = OMPI_SUCCESS;
ucp_ep_h *ep;

if (module->epoch_type.access != START_COMPLETE_EPOCH) {
return OMPI_ERR_RMA_SYNC;
}

module->epoch_type.access = NONE_EPOCH;

ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_WORKER, 0/*ignore*/);
if (ret != OMPI_SUCCESS) {
return ret;
}

module->epoch_type.access = NONE_EPOCH;

size = ompi_group_size(module->start_group);
for (i = 0; i < size; i++) {
uint64_t remote_addr = module->state_addrs[module->start_grp_ranks[i]] + OSC_UCX_STATE_COMPLETE_COUNT_OFFSET; // write to state.complete_count on remote side

OSC_UCX_GET_DEFAULT_EP(ep, module, module->start_grp_ranks[i]);

ret = opal_common_ucx_wpmem_post(module->state_mem, UCP_ATOMIC_POST_OP_ADD,
1, module->start_grp_ranks[i], sizeof(uint64_t),
remote_addr);
remote_addr, ep);
if (ret != OMPI_SUCCESS) {
OSC_UCX_VERBOSE(1, "opal_common_ucx_mem_post failed: %d", ret);
}

ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP,
module->start_grp_ranks[i]);
ret = opal_common_ucx_ctx_flush(module->ctx, OPAL_COMMON_UCX_SCOPE_EP, module->start_grp_ranks[i]);
if (ret != OMPI_SUCCESS) {
return ret;
}
Expand All @@ -204,6 +206,7 @@ int ompi_osc_ucx_complete(struct ompi_win_t *win) {

int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_win_t *win) {
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
ucp_ep_h *ep;
int ret = OMPI_SUCCESS;

if (module->epoch_type.exposure != NONE_EPOCH) {
Expand Down Expand Up @@ -243,12 +246,12 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_wi
uint64_t remote_addr = module->state_addrs[ranks_in_win_grp[i]] + OSC_UCX_STATE_POST_INDEX_OFFSET; // write to state.post_index on remote side
uint64_t curr_idx = 0, result = 0;


OSC_UCX_GET_DEFAULT_EP(ep, module, ranks_in_win_grp[i]);

/* do fop first to get an post index */
ret = opal_common_ucx_wpmem_fetch(module->state_mem, UCP_ATOMIC_FETCH_OP_FADD,
1, ranks_in_win_grp[i], &result,
sizeof(result), remote_addr);
sizeof(result), remote_addr, ep);

if (ret != OMPI_SUCCESS) {
ret = OMPI_ERROR;
Expand All @@ -265,7 +268,7 @@ int ompi_osc_ucx_post(struct ompi_group_t *group, int mpi_assert, struct ompi_wi
result = myrank + 1;
ret = opal_common_ucx_wpmem_cmpswp(module->state_mem, 0, result,
ranks_in_win_grp[i], &result, sizeof(result),
remote_addr);
remote_addr, ep);

if (ret != OMPI_SUCCESS) {
ret = OMPI_ERROR;
Expand Down
Loading