Skip to content

Commit 9f46f48

Browse files
author
Mamzi Bayatpour [email protected] ()
committed
create separate req obj for accumulate
1 parent ddf5d70 commit 9f46f48

File tree

5 files changed

+190
-135
lines changed

5 files changed

+190
-135
lines changed

ompi/mca/osc/ucx/osc_ucx.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ typedef struct ompi_osc_ucx_component {
3333
opal_common_ucx_wpool_t *wpool;
3434
bool enable_mpi_threads;
3535
opal_free_list_t requests; /* request free list for the r* communication variants */
36+
opal_free_list_t accumulate_requests; /* request free list for the r* communication variants */
3637
bool env_initialized; /* UCX environment is initialized or not */
3738
int comm_world_size;
3839
ucp_ep_h *endpoints;
@@ -173,7 +174,7 @@ extern bool thread_enabled;
173174
_ep_ptr = (ucp_ep_h *)&(OSC_UCX_GET_EP(_module, _target)); \
174175
}
175176

176-
extern int ompi_osc_ucx_outstanding_ops_flush_threshold;
177+
extern size_t ompi_osc_ucx_outstanding_ops_flush_threshold;
177178

178179
int ompi_osc_ucx_shared_query(struct ompi_win_t *win, int rank, size_t *size,
179180
int *disp_unit, void * baseptr);

ompi/mca/osc/ucx/osc_ucx_comm.c

Lines changed: 73 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ typedef struct ucx_iovec {
4040
size_t len;
4141
} ucx_iovec_t;
4242

43-
int ompi_osc_ucx_outstanding_ops_flush_threshold = 64;
43+
size_t ompi_osc_ucx_outstanding_ops_flush_threshold = 64;
4444

4545
static inline int check_sync_state(ompi_osc_ucx_module_t *module, int target,
4646
bool is_req_ops) {
@@ -423,7 +423,7 @@ static int do_atomic_op_intrinsic(
423423
struct ompi_datatype_t *dt,
424424
ptrdiff_t target_disp,
425425
void *result_addr,
426-
ompi_osc_ucx_request_t *ucx_req)
426+
ompi_osc_ucx_accumulate_request_t *ucx_req)
427427
{
428428
int ret = OMPI_SUCCESS;
429429
size_t origin_dt_bytes;
@@ -598,7 +598,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
598598
int target, ptrdiff_t target_disp, int target_count,
599599
struct ompi_datatype_t *target_dt,
600600
struct ompi_op_t *op, struct ompi_win_t *win,
601-
ompi_osc_ucx_request_t *ucx_req) {
601+
ompi_osc_ucx_accumulate_request_t *ucx_req) {
602602

603603
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
604604
int ret = OMPI_SUCCESS;
@@ -735,7 +735,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
735735

736736
if (NULL != ucx_req) {
737737
// nothing to wait for, mark request as completed
738-
ompi_request_complete(&ucx_req->super, true);
738+
ompi_request_complete(&ucx_req->super.super, true);
739739
}
740740

741741
return ompi_osc_ucx_state_unlock(module, target, lock_acquired, free_ptr);
@@ -927,7 +927,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
927927
int target, ptrdiff_t target_disp,
928928
int target_count, struct ompi_datatype_t *target_dt,
929929
struct ompi_op_t *op, struct ompi_win_t *win,
930-
ompi_osc_ucx_request_t *ucx_req) {
930+
ompi_osc_ucx_accumulate_request_t *ucx_req) {
931931
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
932932
int ret = OMPI_SUCCESS;
933933
uint64_t remote_addr = (module->addrs[target]) + target_disp *
@@ -1067,7 +1067,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
10671067

10681068
if (NULL != ucx_req) {
10691069
// nothing to wait for, mark request as completed
1070-
ompi_request_complete(&ucx_req->super, true);
1070+
ompi_request_complete(&ucx_req->super.super, true);
10711071
}
10721072

10731073

@@ -1110,7 +1110,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
11101110
OSC_UCX_GET_DEFAULT_EP(ep, module, target);
11111111
opal_common_ucx_wpmem_t *mem = module->mem;
11121112
uint64_t remote_addr = (module->state_addrs[target]) + OSC_UCX_STATE_REQ_FLAG_OFFSET;
1113-
ompi_osc_ucx_request_t *ucx_req = NULL;
1113+
ompi_osc_ucx_generic_request_t *ucx_req = NULL;
11141114
int ret = OMPI_SUCCESS;
11151115

11161116
ret = check_sync_state(module, target, true);
@@ -1126,8 +1126,8 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
11261126
return ret;
11271127
}
11281128

1129-
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
1130-
ucx_req->module = module;
1129+
OMPI_OSC_UCX_GENERIC_REQUEST_ALLOC(win, ucx_req, RPUT_REQ);
1130+
ucx_req->super.module = module;
11311131

11321132
OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(module);
11331133
ret = opal_common_ucx_wpmem_flush_ep_nb(mem, target, req_completion, ucx_req, ep);
@@ -1151,7 +1151,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
11511151
}
11521152
}
11531153

1154-
*request = &ucx_req->super;
1154+
*request = &ucx_req->super.super;
11551155

11561156
return ret;
11571157
}
@@ -1166,7 +1166,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
11661166
OSC_UCX_GET_DEFAULT_EP(ep, module, target);
11671167
opal_common_ucx_wpmem_t *mem = module->mem;
11681168
uint64_t remote_addr = (module->state_addrs[target]) + OSC_UCX_STATE_REQ_FLAG_OFFSET;
1169-
ompi_osc_ucx_request_t *ucx_req = NULL;
1169+
ompi_osc_ucx_generic_request_t *ucx_req = NULL;
11701170
int ret = OMPI_SUCCESS;
11711171

11721172
ret = check_sync_state(module, target, true);
@@ -1182,8 +1182,8 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
11821182
return ret;
11831183
}
11841184

1185-
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
1186-
ucx_req->module = module;
1185+
OMPI_OSC_UCX_GENERIC_REQUEST_ALLOC(win, ucx_req, RGET_REQ);
1186+
ucx_req->super.module = module;
11871187

11881188
OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS(module);
11891189
ret = opal_common_ucx_wpmem_flush_ep_nb(mem, target, req_completion, ucx_req, ep);
@@ -1207,7 +1207,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
12071207
}
12081208
}
12091209

1210-
*request = &ucx_req->super;
1210+
*request = &ucx_req->super.super;
12111211

12121212
return ret;
12131213
}
@@ -1218,16 +1218,16 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
12181218
struct ompi_datatype_t *target_dt, struct ompi_op_t *op,
12191219
struct ompi_win_t *win, struct ompi_request_t **request) {
12201220
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
1221-
ompi_osc_ucx_request_t *ucx_req = NULL;
1221+
ompi_osc_ucx_accumulate_request_t *ucx_req = NULL;
12221222
int ret = OMPI_SUCCESS;
12231223

12241224
ret = check_sync_state(module, target, true);
12251225
if (ret != OMPI_SUCCESS) {
12261226
return ret;
12271227
}
12281228

1229-
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
1230-
ucx_req->module = module;
1229+
OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req);
1230+
ucx_req->super.module = module;
12311231
assert(NULL != ucx_req);
12321232

12331233
ret = accumulate_req(origin_addr, origin_count, origin_dt, target, target_disp,
@@ -1237,7 +1237,7 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
12371237
return ret;
12381238
}
12391239

1240-
*request = &ucx_req->super;
1240+
*request = &ucx_req->super.super;
12411241

12421242
return ret;
12431243
}
@@ -1251,16 +1251,16 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
12511251
struct ompi_op_t *op, struct ompi_win_t *win,
12521252
struct ompi_request_t **request) {
12531253
ompi_osc_ucx_module_t *module = (ompi_osc_ucx_module_t*) win->w_osc_module;
1254-
ompi_osc_ucx_request_t *ucx_req = NULL;
1254+
ompi_osc_ucx_accumulate_request_t *ucx_req = NULL;
12551255
int ret = OMPI_SUCCESS;
12561256

12571257
ret = check_sync_state(module, target, true);
12581258
if (ret != OMPI_SUCCESS) {
12591259
return ret;
12601260
}
12611261

1262-
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
1263-
ucx_req->module = module;
1262+
OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req);
1263+
ucx_req->super.module = module;
12641264
assert(NULL != ucx_req);
12651265

12661266
ret = get_accumulate_req(origin_addr, origin_count, origin_datatype,
@@ -1272,7 +1272,7 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
12721272
return ret;
12731273
}
12741274

1275-
*request = &ucx_req->super;
1275+
*request = &ucx_req->super.super;
12761276

12771277
return ret;
12781278
}
@@ -1288,47 +1288,47 @@ static inline int ompi_osc_ucx_acc_rputget(void *stage_addr, int stage_count,
12881288
OSC_UCX_GET_DEFAULT_EP(ep, module, target);
12891289
opal_common_ucx_wpmem_t *mem = module->mem;
12901290
uint64_t remote_addr = (module->state_addrs[target]) + OSC_UCX_STATE_REQ_FLAG_OFFSET;
1291-
ompi_osc_ucx_request_t *ucx_req = NULL;
1291+
ompi_osc_ucx_accumulate_request_t *ucx_req = NULL;
12921292
bool sync_check;
12931293
int ret = OMPI_SUCCESS;
12941294
CHECK_DYNAMIC_WIN(remote_addr, module, target, ret, true);
12951295

12961296
if (acc_type != NONE) {
1297-
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
1297+
OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req);
12981298
assert(NULL != ucx_req);
1299-
ucx_req->acc.op = op;
1300-
ucx_req->acc.acc_type = acc_type;
1301-
ucx_req->acc.phase = phase;
1302-
ucx_req->module = module;
1303-
ucx_req->acc.target = target;
1304-
ucx_req->acc.lock_acquired = lock_acquired;
1305-
ucx_req->acc.win = win;
1306-
ucx_req->acc.origin_addr = origin_addr;
1307-
ucx_req->acc.origin_count = origin_count;
1299+
ucx_req->op = op;
1300+
ucx_req->acc_type = acc_type;
1301+
ucx_req->phase = phase;
1302+
ucx_req->super.module = module;
1303+
ucx_req->target = target;
1304+
ucx_req->lock_acquired = lock_acquired;
1305+
ucx_req->win = win;
1306+
ucx_req->origin_addr = origin_addr;
1307+
ucx_req->origin_count = origin_count;
13081308
if (origin_dt != NULL) {
1309-
ucx_req->acc.origin_dt = origin_dt;
1309+
ucx_req->origin_dt = origin_dt;
13101310
if (!ompi_datatype_is_predefined(origin_dt)) {
1311-
OBJ_RETAIN(ucx_req->acc.origin_dt);
1311+
OBJ_RETAIN(ucx_req->origin_dt);
13121312
}
13131313
}
1314-
ucx_req->acc.stage_addr = stage_addr;
1315-
ucx_req->acc.stage_count = stage_count;
1314+
ucx_req->stage_addr = stage_addr;
1315+
ucx_req->stage_count = stage_count;
13161316
if (stage_dt != NULL) {
1317-
ucx_req->acc.stage_dt = stage_dt;
1317+
ucx_req->stage_dt = stage_dt;
13181318
if (!ompi_datatype_is_predefined(stage_dt)) {
1319-
OBJ_RETAIN(ucx_req->acc.stage_dt);
1319+
OBJ_RETAIN(ucx_req->stage_dt);
13201320
}
13211321
}
1322-
ucx_req->acc.target = target;
1322+
ucx_req->target = target;
13231323
if (target_dt != NULL) {
1324-
ucx_req->acc.target_dt = target_dt;
1324+
ucx_req->target_dt = target_dt;
13251325
if (!ompi_datatype_is_predefined(target_dt)) {
1326-
OBJ_RETAIN(ucx_req->acc.target_dt);
1326+
OBJ_RETAIN(ucx_req->target_dt);
13271327
}
13281328
}
1329-
ucx_req->acc.target_disp = target_disp;
1330-
ucx_req->acc.target_count = target_count;
1331-
ucx_req->acc.free_ptr = NULL;
1329+
ucx_req->target_disp = target_disp;
1330+
ucx_req->target_count = target_count;
1331+
ucx_req->free_ptr = NULL;
13321332
}
13331333
sync_check = module->skip_sync_check;
13341334
module->skip_sync_check = true; /* we already hold the acc lock, so no need for sync check*/
@@ -1478,39 +1478,41 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
14781478
}
14791479

14801480
void req_completion(void *request) {
1481-
ompi_osc_ucx_request_t *req = (ompi_osc_ucx_request_t *)request;
1481+
ompi_osc_ucx_generic_request_t *ucx_req = (ompi_osc_ucx_generic_request_t *)request;
14821482
int ret = OMPI_SUCCESS;
1483-
ompi_osc_ucx_module_t *module = req->module;
1484-
if (req->acc.acc_type != NONE) {
1485-
assert(req->acc.phase != ACC_INIT);
1483+
ompi_osc_ucx_module_t *module = ucx_req->super.module;
1484+
if (ucx_req->super.request_type == ACCUMULATE_REQ) {
1485+
/* This is an accumulate request */
1486+
ompi_osc_ucx_accumulate_request_t *req = (ompi_osc_ucx_accumulate_request_t *)request;
1487+
assert(req->phase != ACC_INIT);
14861488
void *free_addr = NULL;
14871489
bool release_lock = false;
14881490
ptrdiff_t temp_lb, temp_extent;
1489-
const void *origin_addr = req->acc.origin_addr;
1490-
int origin_count = req->acc.origin_count;
1491-
struct ompi_datatype_t *origin_dt = req->acc.origin_dt;
1492-
void *temp_addr = req->acc.stage_addr;
1493-
int temp_count = req->acc.stage_count;
1494-
struct ompi_datatype_t *temp_dt = req->acc.stage_dt;
1495-
int target = req->acc.target;
1496-
int target_count = req->acc.target_count;
1497-
int target_disp = req->acc.target_disp;
1498-
struct ompi_datatype_t *target_dt = req->acc.target_dt;
1499-
struct ompi_win_t *win = req->acc.win;
1500-
struct ompi_op_t *op = req->acc.op;
1501-
1502-
if (req->acc.phase != ACC_FINALIZE) {
1491+
const void *origin_addr = req->origin_addr;
1492+
int origin_count = req->origin_count;
1493+
struct ompi_datatype_t *origin_dt = req->origin_dt;
1494+
void *temp_addr = req->stage_addr;
1495+
int temp_count = req->stage_count;
1496+
struct ompi_datatype_t *temp_dt = req->stage_dt;
1497+
int target = req->target;
1498+
int target_count = req->target_count;
1499+
int target_disp = req->target_disp;
1500+
struct ompi_datatype_t *target_dt = req->target_dt;
1501+
struct ompi_win_t *win = req->win;
1502+
struct ompi_op_t *op = req->op;
1503+
1504+
if (req->phase != ACC_FINALIZE) {
15031505
/* Avoid calling flush while we are already in progress */
15041506
module->mem->skip_periodic_flush = true;
15051507
module->state_mem->skip_periodic_flush = true;
15061508
}
15071509

1508-
switch (req->acc.phase) {
1510+
switch (req->phase) {
15091511
case ACC_FINALIZE:
15101512
{
1511-
if (req->acc.free_ptr != NULL) {
1512-
free(req->acc.free_ptr);
1513-
req->acc.free_ptr = NULL;
1513+
if (req->free_ptr != NULL) {
1514+
free(req->free_ptr);
1515+
req->free_ptr = NULL;
15141516
}
15151517
if (origin_dt != NULL && !ompi_datatype_is_predefined(origin_dt)) {
15161518
OBJ_RELEASE(origin_dt);
@@ -1613,7 +1615,7 @@ void req_completion(void *request) {
16131615
free(origin_ucx_iov);
16141616
}
16151617

1616-
if (req->acc.acc_type == GET_ACCUMULATE) {
1618+
if (req->acc_type == GET_ACCUMULATE) {
16171619
/* Do fence to make sure target results are received before
16181620
* writing into target */
16191621
ret = opal_common_ucx_wpmem_fence(module->mem);
@@ -1647,16 +1649,15 @@ void req_completion(void *request) {
16471649
/* Ordering between previous put/get operations and unlock will be realized
16481650
* through the ucp fence inside the finalize function */
16491651
ompi_osc_ucx_nonblocking_ops_finalize(module, target,
1650-
req->acc.lock_acquired, win, free_addr);
1652+
req->lock_acquired, win, free_addr);
16511653
}
16521654

1653-
if (req->acc.phase != ACC_FINALIZE) {
1655+
if (req->phase != ACC_FINALIZE) {
16541656
module->mem->skip_periodic_flush = false;
16551657
module->state_mem->skip_periodic_flush = false;
16561658
}
16571659
}
1658-
16591660
OSC_UCX_DECREMENT_OUTSTANDING_NB_OPS(module);
1660-
ompi_request_complete(&(req->super), true);
1661+
ompi_request_complete(&(ucx_req->super.super), true);
16611662
assert(module->ctx->num_incomplete_req_ops >= 0);
16621663
}

ompi/mca/osc/ucx/osc_ucx_component.c

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -483,9 +483,20 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
483483

484484
OBJ_CONSTRUCT(&mca_osc_ucx_component.requests, opal_free_list_t);
485485
ret = opal_free_list_init (&mca_osc_ucx_component.requests,
486-
sizeof(ompi_osc_ucx_request_t),
486+
sizeof(ompi_osc_ucx_generic_request_t),
487487
opal_cache_line_size,
488-
OBJ_CLASS(ompi_osc_ucx_request_t),
488+
OBJ_CLASS(ompi_osc_ucx_generic_request_t),
489+
0, 0, 8, 0, 8, NULL, 0, NULL, NULL, NULL);
490+
if (OMPI_SUCCESS != ret) {
491+
OSC_UCX_VERBOSE(1, "opal_free_list_init failed: %d", ret);
492+
goto select_unlock;
493+
}
494+
495+
OBJ_CONSTRUCT(&mca_osc_ucx_component.accumulate_requests, opal_free_list_t);
496+
ret = opal_free_list_init (&mca_osc_ucx_component.accumulate_requests,
497+
sizeof(ompi_osc_ucx_accumulate_request_t),
498+
opal_cache_line_size,
499+
OBJ_CLASS(ompi_osc_ucx_accumulate_request_t),
489500
0, 0, 8, 0, 8, NULL, 0, NULL, NULL, NULL);
490501
if (OMPI_SUCCESS != ret) {
491502
OSC_UCX_VERBOSE(1, "opal_free_list_init failed: %d", ret);
@@ -977,14 +988,14 @@ inline int ompi_osc_ucx_nonblocking_ops_finalize(ompi_osc_ucx_module_t *module,
977988
ucp_ep_h *ep;
978989
OSC_UCX_GET_DEFAULT_EP(ep, module, target);
979990
int ret = OMPI_SUCCESS;
980-
ompi_osc_ucx_request_t *ucx_req = NULL;
991+
ompi_osc_ucx_accumulate_request_t *ucx_req = NULL;
981992

982-
OMPI_OSC_UCX_REQUEST_ALLOC(win, ucx_req);
993+
OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC(win, ucx_req);
983994
assert(NULL != ucx_req);
984-
ucx_req->acc.free_ptr = free_ptr;
985-
ucx_req->acc.phase = ACC_FINALIZE;
986-
ucx_req->acc.acc_type = ANY;
987-
ucx_req->module = module;
995+
ucx_req->free_ptr = free_ptr;
996+
ucx_req->phase = ACC_FINALIZE;
997+
ucx_req->acc_type = ANY;
998+
ucx_req->super.module = module;
988999

9891000
/* Fence any still active operations */
9901001
ret = opal_common_ucx_wpmem_fence(module->mem);

0 commit comments

Comments
 (0)