@@ -40,7 +40,7 @@ typedef struct ucx_iovec {
4040 size_t len ;
4141} ucx_iovec_t ;
4242
43- int ompi_osc_ucx_outstanding_ops_flush_threshold = 64 ;
43+ size_t ompi_osc_ucx_outstanding_ops_flush_threshold = 64 ;
4444
4545static inline int check_sync_state (ompi_osc_ucx_module_t * module , int target ,
4646 bool is_req_ops ) {
@@ -423,7 +423,7 @@ static int do_atomic_op_intrinsic(
423423 struct ompi_datatype_t * dt ,
424424 ptrdiff_t target_disp ,
425425 void * result_addr ,
426- ompi_osc_ucx_request_t * ucx_req )
426+ ompi_osc_ucx_accumulate_request_t * ucx_req )
427427{
428428 int ret = OMPI_SUCCESS ;
429429 size_t origin_dt_bytes ;
@@ -598,7 +598,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
598598 int target , ptrdiff_t target_disp , int target_count ,
599599 struct ompi_datatype_t * target_dt ,
600600 struct ompi_op_t * op , struct ompi_win_t * win ,
601- ompi_osc_ucx_request_t * ucx_req ) {
601+ ompi_osc_ucx_accumulate_request_t * ucx_req ) {
602602
603603 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
604604 int ret = OMPI_SUCCESS ;
@@ -735,7 +735,7 @@ int accumulate_req(const void *origin_addr, int origin_count,
735735
736736 if (NULL != ucx_req ) {
737737 // nothing to wait for, mark request as completed
738- ompi_request_complete (& ucx_req -> super , true);
738+ ompi_request_complete (& ucx_req -> super . super , true);
739739 }
740740
741741 return ompi_osc_ucx_state_unlock (module , target , lock_acquired , free_ptr );
@@ -927,7 +927,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
927927 int target , ptrdiff_t target_disp ,
928928 int target_count , struct ompi_datatype_t * target_dt ,
929929 struct ompi_op_t * op , struct ompi_win_t * win ,
930- ompi_osc_ucx_request_t * ucx_req ) {
930+ ompi_osc_ucx_accumulate_request_t * ucx_req ) {
931931 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
932932 int ret = OMPI_SUCCESS ;
933933 uint64_t remote_addr = (module -> addrs [target ]) + target_disp *
@@ -1067,7 +1067,7 @@ int get_accumulate_req(const void *origin_addr, int origin_count,
10671067
10681068 if (NULL != ucx_req ) {
10691069 // nothing to wait for, mark request as completed
1070- ompi_request_complete (& ucx_req -> super , true);
1070+ ompi_request_complete (& ucx_req -> super . super , true);
10711071 }
10721072
10731073
@@ -1110,7 +1110,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
11101110 OSC_UCX_GET_DEFAULT_EP (ep , module , target );
11111111 opal_common_ucx_wpmem_t * mem = module -> mem ;
11121112 uint64_t remote_addr = (module -> state_addrs [target ]) + OSC_UCX_STATE_REQ_FLAG_OFFSET ;
1113- ompi_osc_ucx_request_t * ucx_req = NULL ;
1113+ ompi_osc_ucx_generic_request_t * ucx_req = NULL ;
11141114 int ret = OMPI_SUCCESS ;
11151115
11161116 ret = check_sync_state (module , target , true);
@@ -1126,8 +1126,8 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
11261126 return ret ;
11271127 }
11281128
1129- OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
1130- ucx_req -> module = module ;
1129+ OMPI_OSC_UCX_GENERIC_REQUEST_ALLOC (win , ucx_req , RPUT_REQ );
1130+ ucx_req -> super . module = module ;
11311131
11321132 OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS (module );
11331133 ret = opal_common_ucx_wpmem_flush_ep_nb (mem , target , req_completion , ucx_req , ep );
@@ -1151,7 +1151,7 @@ int ompi_osc_ucx_rput(const void *origin_addr, int origin_count,
11511151 }
11521152 }
11531153
1154- * request = & ucx_req -> super ;
1154+ * request = & ucx_req -> super . super ;
11551155
11561156 return ret ;
11571157}
@@ -1166,7 +1166,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
11661166 OSC_UCX_GET_DEFAULT_EP (ep , module , target );
11671167 opal_common_ucx_wpmem_t * mem = module -> mem ;
11681168 uint64_t remote_addr = (module -> state_addrs [target ]) + OSC_UCX_STATE_REQ_FLAG_OFFSET ;
1169- ompi_osc_ucx_request_t * ucx_req = NULL ;
1169+ ompi_osc_ucx_generic_request_t * ucx_req = NULL ;
11701170 int ret = OMPI_SUCCESS ;
11711171
11721172 ret = check_sync_state (module , target , true);
@@ -1182,8 +1182,8 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
11821182 return ret ;
11831183 }
11841184
1185- OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
1186- ucx_req -> module = module ;
1185+ OMPI_OSC_UCX_GENERIC_REQUEST_ALLOC (win , ucx_req , RGET_REQ );
1186+ ucx_req -> super . module = module ;
11871187
11881188 OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS (module );
11891189 ret = opal_common_ucx_wpmem_flush_ep_nb (mem , target , req_completion , ucx_req , ep );
@@ -1207,7 +1207,7 @@ int ompi_osc_ucx_rget(void *origin_addr, int origin_count,
12071207 }
12081208 }
12091209
1210- * request = & ucx_req -> super ;
1210+ * request = & ucx_req -> super . super ;
12111211
12121212 return ret ;
12131213}
@@ -1218,16 +1218,16 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
12181218 struct ompi_datatype_t * target_dt , struct ompi_op_t * op ,
12191219 struct ompi_win_t * win , struct ompi_request_t * * request ) {
12201220 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
1221- ompi_osc_ucx_request_t * ucx_req = NULL ;
1221+ ompi_osc_ucx_accumulate_request_t * ucx_req = NULL ;
12221222 int ret = OMPI_SUCCESS ;
12231223
12241224 ret = check_sync_state (module , target , true);
12251225 if (ret != OMPI_SUCCESS ) {
12261226 return ret ;
12271227 }
12281228
1229- OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
1230- ucx_req -> module = module ;
1229+ OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC (win , ucx_req );
1230+ ucx_req -> super . module = module ;
12311231 assert (NULL != ucx_req );
12321232
12331233 ret = accumulate_req (origin_addr , origin_count , origin_dt , target , target_disp ,
@@ -1237,7 +1237,7 @@ int ompi_osc_ucx_raccumulate(const void *origin_addr, int origin_count,
12371237 return ret ;
12381238 }
12391239
1240- * request = & ucx_req -> super ;
1240+ * request = & ucx_req -> super . super ;
12411241
12421242 return ret ;
12431243}
@@ -1251,16 +1251,16 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
12511251 struct ompi_op_t * op , struct ompi_win_t * win ,
12521252 struct ompi_request_t * * request ) {
12531253 ompi_osc_ucx_module_t * module = (ompi_osc_ucx_module_t * ) win -> w_osc_module ;
1254- ompi_osc_ucx_request_t * ucx_req = NULL ;
1254+ ompi_osc_ucx_accumulate_request_t * ucx_req = NULL ;
12551255 int ret = OMPI_SUCCESS ;
12561256
12571257 ret = check_sync_state (module , target , true);
12581258 if (ret != OMPI_SUCCESS ) {
12591259 return ret ;
12601260 }
12611261
1262- OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
1263- ucx_req -> module = module ;
1262+ OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC (win , ucx_req );
1263+ ucx_req -> super . module = module ;
12641264 assert (NULL != ucx_req );
12651265
12661266 ret = get_accumulate_req (origin_addr , origin_count , origin_datatype ,
@@ -1272,7 +1272,7 @@ int ompi_osc_ucx_rget_accumulate(const void *origin_addr, int origin_count,
12721272 return ret ;
12731273 }
12741274
1275- * request = & ucx_req -> super ;
1275+ * request = & ucx_req -> super . super ;
12761276
12771277 return ret ;
12781278}
@@ -1288,47 +1288,47 @@ static inline int ompi_osc_ucx_acc_rputget(void *stage_addr, int stage_count,
12881288 OSC_UCX_GET_DEFAULT_EP (ep , module , target );
12891289 opal_common_ucx_wpmem_t * mem = module -> mem ;
12901290 uint64_t remote_addr = (module -> state_addrs [target ]) + OSC_UCX_STATE_REQ_FLAG_OFFSET ;
1291- ompi_osc_ucx_request_t * ucx_req = NULL ;
1291+ ompi_osc_ucx_accumulate_request_t * ucx_req = NULL ;
12921292 bool sync_check ;
12931293 int ret = OMPI_SUCCESS ;
12941294 CHECK_DYNAMIC_WIN (remote_addr , module , target , ret , true);
12951295
12961296 if (acc_type != NONE ) {
1297- OMPI_OSC_UCX_REQUEST_ALLOC (win , ucx_req );
1297+ OMPI_OSC_UCX_ACCUMULATE_REQUEST_ALLOC (win , ucx_req );
12981298 assert (NULL != ucx_req );
1299- ucx_req -> acc . op = op ;
1300- ucx_req -> acc . acc_type = acc_type ;
1301- ucx_req -> acc . phase = phase ;
1302- ucx_req -> module = module ;
1303- ucx_req -> acc . target = target ;
1304- ucx_req -> acc . lock_acquired = lock_acquired ;
1305- ucx_req -> acc . win = win ;
1306- ucx_req -> acc . origin_addr = origin_addr ;
1307- ucx_req -> acc . origin_count = origin_count ;
1299+ ucx_req -> op = op ;
1300+ ucx_req -> acc_type = acc_type ;
1301+ ucx_req -> phase = phase ;
1302+ ucx_req -> super . module = module ;
1303+ ucx_req -> target = target ;
1304+ ucx_req -> lock_acquired = lock_acquired ;
1305+ ucx_req -> win = win ;
1306+ ucx_req -> origin_addr = origin_addr ;
1307+ ucx_req -> origin_count = origin_count ;
13081308 if (origin_dt != NULL ) {
1309- ucx_req -> acc . origin_dt = origin_dt ;
1309+ ucx_req -> origin_dt = origin_dt ;
13101310 if (!ompi_datatype_is_predefined (origin_dt )) {
1311- OBJ_RETAIN (ucx_req -> acc . origin_dt );
1311+ OBJ_RETAIN (ucx_req -> origin_dt );
13121312 }
13131313 }
1314- ucx_req -> acc . stage_addr = stage_addr ;
1315- ucx_req -> acc . stage_count = stage_count ;
1314+ ucx_req -> stage_addr = stage_addr ;
1315+ ucx_req -> stage_count = stage_count ;
13161316 if (stage_dt != NULL ) {
1317- ucx_req -> acc . stage_dt = stage_dt ;
1317+ ucx_req -> stage_dt = stage_dt ;
13181318 if (!ompi_datatype_is_predefined (stage_dt )) {
1319- OBJ_RETAIN (ucx_req -> acc . stage_dt );
1319+ OBJ_RETAIN (ucx_req -> stage_dt );
13201320 }
13211321 }
1322- ucx_req -> acc . target = target ;
1322+ ucx_req -> target = target ;
13231323 if (target_dt != NULL ) {
1324- ucx_req -> acc . target_dt = target_dt ;
1324+ ucx_req -> target_dt = target_dt ;
13251325 if (!ompi_datatype_is_predefined (target_dt )) {
1326- OBJ_RETAIN (ucx_req -> acc . target_dt );
1326+ OBJ_RETAIN (ucx_req -> target_dt );
13271327 }
13281328 }
1329- ucx_req -> acc . target_disp = target_disp ;
1330- ucx_req -> acc . target_count = target_count ;
1331- ucx_req -> acc . free_ptr = NULL ;
1329+ ucx_req -> target_disp = target_disp ;
1330+ ucx_req -> target_count = target_count ;
1331+ ucx_req -> free_ptr = NULL ;
13321332 }
13331333 sync_check = module -> skip_sync_check ;
13341334 module -> skip_sync_check = true; /* we already hold the acc lock, so no need for sync check*/
@@ -1478,39 +1478,41 @@ static int ompi_osc_ucx_get_accumulate_nonblocking(const void *origin_addr, int
14781478}
14791479
14801480void req_completion (void * request ) {
1481- ompi_osc_ucx_request_t * req = (ompi_osc_ucx_request_t * )request ;
1481+ ompi_osc_ucx_generic_request_t * ucx_req = (ompi_osc_ucx_generic_request_t * )request ;
14821482 int ret = OMPI_SUCCESS ;
1483- ompi_osc_ucx_module_t * module = req -> module ;
1484- if (req -> acc .acc_type != NONE ) {
1485- assert (req -> acc .phase != ACC_INIT );
1483+ ompi_osc_ucx_module_t * module = ucx_req -> super .module ;
1484+ if (ucx_req -> super .request_type == ACCUMULATE_REQ ) {
1485+ /* This is an accumulate request */
1486+ ompi_osc_ucx_accumulate_request_t * req = (ompi_osc_ucx_accumulate_request_t * )request ;
1487+ assert (req -> phase != ACC_INIT );
14861488 void * free_addr = NULL ;
14871489 bool release_lock = false;
14881490 ptrdiff_t temp_lb , temp_extent ;
1489- const void * origin_addr = req -> acc . origin_addr ;
1490- int origin_count = req -> acc . origin_count ;
1491- struct ompi_datatype_t * origin_dt = req -> acc . origin_dt ;
1492- void * temp_addr = req -> acc . stage_addr ;
1493- int temp_count = req -> acc . stage_count ;
1494- struct ompi_datatype_t * temp_dt = req -> acc . stage_dt ;
1495- int target = req -> acc . target ;
1496- int target_count = req -> acc . target_count ;
1497- int target_disp = req -> acc . target_disp ;
1498- struct ompi_datatype_t * target_dt = req -> acc . target_dt ;
1499- struct ompi_win_t * win = req -> acc . win ;
1500- struct ompi_op_t * op = req -> acc . op ;
1501-
1502- if (req -> acc . phase != ACC_FINALIZE ) {
1491+ const void * origin_addr = req -> origin_addr ;
1492+ int origin_count = req -> origin_count ;
1493+ struct ompi_datatype_t * origin_dt = req -> origin_dt ;
1494+ void * temp_addr = req -> stage_addr ;
1495+ int temp_count = req -> stage_count ;
1496+ struct ompi_datatype_t * temp_dt = req -> stage_dt ;
1497+ int target = req -> target ;
1498+ int target_count = req -> target_count ;
1499+ int target_disp = req -> target_disp ;
1500+ struct ompi_datatype_t * target_dt = req -> target_dt ;
1501+ struct ompi_win_t * win = req -> win ;
1502+ struct ompi_op_t * op = req -> op ;
1503+
1504+ if (req -> phase != ACC_FINALIZE ) {
15031505 /* Avoid calling flush while we are already in progress */
15041506 module -> mem -> skip_periodic_flush = true;
15051507 module -> state_mem -> skip_periodic_flush = true;
15061508 }
15071509
1508- switch (req -> acc . phase ) {
1510+ switch (req -> phase ) {
15091511 case ACC_FINALIZE :
15101512 {
1511- if (req -> acc . free_ptr != NULL ) {
1512- free (req -> acc . free_ptr );
1513- req -> acc . free_ptr = NULL ;
1513+ if (req -> free_ptr != NULL ) {
1514+ free (req -> free_ptr );
1515+ req -> free_ptr = NULL ;
15141516 }
15151517 if (origin_dt != NULL && !ompi_datatype_is_predefined (origin_dt )) {
15161518 OBJ_RELEASE (origin_dt );
@@ -1613,7 +1615,7 @@ void req_completion(void *request) {
16131615 free (origin_ucx_iov );
16141616 }
16151617
1616- if (req -> acc . acc_type == GET_ACCUMULATE ) {
1618+ if (req -> acc_type == GET_ACCUMULATE ) {
16171619 /* Do fence to make sure target results are received before
16181620 * writing into target */
16191621 ret = opal_common_ucx_wpmem_fence (module -> mem );
@@ -1647,16 +1649,15 @@ void req_completion(void *request) {
16471649 /* Ordering between previous put/get operations and unlock will be realized
16481650 * through the ucp fence inside the finalize function */
16491651 ompi_osc_ucx_nonblocking_ops_finalize (module , target ,
1650- req -> acc . lock_acquired , win , free_addr );
1652+ req -> lock_acquired , win , free_addr );
16511653 }
16521654
1653- if (req -> acc . phase != ACC_FINALIZE ) {
1655+ if (req -> phase != ACC_FINALIZE ) {
16541656 module -> mem -> skip_periodic_flush = false;
16551657 module -> state_mem -> skip_periodic_flush = false;
16561658 }
16571659 }
1658-
16591660 OSC_UCX_DECREMENT_OUTSTANDING_NB_OPS (module );
1660- ompi_request_complete (& (req -> super ), true);
1661+ ompi_request_complete (& (ucx_req -> super . super ), true);
16611662 assert (module -> ctx -> num_incomplete_req_ops >= 0 );
16621663}
0 commit comments