@@ -840,7 +840,8 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
840840 OBJ_CONSTRUCT (& module -> pending_posts , opal_list_t );
841841 module -> start_grp_ranks = NULL ;
842842 module -> lock_all_is_nocheck = false;
843- module -> acc_lock_refcnt = 0 ;
843+ module -> acc_lock_refcnt = calloc (comm_size , sizeof (uint64_t ));
844+ module -> pending_acc_finalize_refcnt = calloc (comm_size , sizeof (uint64_t ));
844845
845846 if (!module -> no_locks ) {
846847 OBJ_CONSTRUCT (& module -> outstanding_locks , opal_hash_table_t );
@@ -930,8 +931,13 @@ inline int ompi_osc_ucx_state_lock(
930931 OSC_UCX_GET_DEFAULT_EP (ep , module , target );
931932 int ret = OMPI_SUCCESS ;
932933
934+ assert (module -> pending_acc_finalize_refcnt [target ] >= 0 );
935+ while (module -> pending_acc_finalize_refcnt [target ] > 0 ) {
936+ opal_common_ucx_wpool_progress (mca_osc_ucx_component .wpool );
937+ }
938+
933939 if (force_lock || ompi_osc_need_acc_lock (module , target )) {
934- if (module -> acc_lock_refcnt == 0 ) {
940+ if (module -> acc_lock_refcnt [ target ] == 0 ) {
935941 for (;;) {
936942 ret = opal_common_ucx_wpmem_cmpswp (module -> state_mem ,
937943 TARGET_LOCK_UNLOCKED , TARGET_LOCK_EXCLUSIVE ,
@@ -948,13 +954,11 @@ inline int ompi_osc_ucx_state_lock(
948954 opal_common_ucx_wpool_progress (mca_osc_ucx_component .wpool );
949955 }
950956 }
951-
957+ module -> acc_lock_refcnt [ target ] ++ ;
952958 * lock_acquired = true;
953- module -> acc_lock_refcnt ++ ;
954959 } else {
955960 * lock_acquired = false;
956961 }
957-
958962 return OMPI_SUCCESS ;
959963}
960964
@@ -977,16 +981,16 @@ inline int ompi_osc_ucx_state_unlock(
977981 return OMPI_ERROR ;
978982 }
979983
980- if (module -> acc_lock_refcnt == 1 ) {
984+ if (module -> acc_lock_refcnt [ target ] == 1 ) {
981985 ret = opal_common_ucx_wpmem_fetch (module -> state_mem ,
982986 UCP_ATOMIC_FETCH_OP_SWAP , TARGET_LOCK_UNLOCKED ,
983987 target , & result_value , sizeof (result_value ),
984988 remote_addr , ep );
985989 assert (result_value == TARGET_LOCK_EXCLUSIVE );
986990 }
987- module -> acc_lock_refcnt -- ;
988- assert (module -> acc_lock_refcnt >= 0 );
989- } else if (NULL != free_ptr ){
991+ module -> acc_lock_refcnt [ target ] -- ;
992+ assert (module -> acc_lock_refcnt [ target ] >= 0 );
993+ } else if (NULL != free_ptr ) {
990994 /* flush before freeing the buffer */
991995 ret = opal_common_ucx_ctx_flush (module -> ctx , OPAL_COMMON_UCX_SCOPE_EP , target );
992996 }
@@ -1016,6 +1020,8 @@ inline int ompi_osc_ucx_nonblocking_ops_finalize(ompi_osc_ucx_module_t *module,
10161020 ucx_req -> phase = ACC_FINALIZE ;
10171021 ucx_req -> acc_type = ANY ;
10181022 ucx_req -> super .module = module ;
1023+ ucx_req -> lock_acquired = lock_acquired ;
1024+ ucx_req -> target = target ;
10191025
10201026 /* Fence any still active operations */
10211027 ret = opal_common_ucx_wpmem_fence (module -> mem );
@@ -1024,7 +1030,8 @@ inline int ompi_osc_ucx_nonblocking_ops_finalize(ompi_osc_ucx_module_t *module,
10241030 return OMPI_ERROR ;
10251031 }
10261032
1027- if (lock_acquired ) {
1033+ module -> pending_acc_finalize_refcnt [target ]++ ;
1034+ if (lock_acquired && module -> acc_lock_refcnt [target ] == 1 ) {
10281035 OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS (module );
10291036 ret = opal_common_ucx_wpmem_fetch_nb (module -> state_mem ,
10301037 UCP_ATOMIC_FETCH_OP_SWAP , TARGET_LOCK_UNLOCKED ,
@@ -1036,7 +1043,7 @@ inline int ompi_osc_ucx_nonblocking_ops_finalize(ompi_osc_ucx_module_t *module,
10361043 return ret ;
10371044 }
10381045 } else {
1039- /* Lock is not acquired, but still, we need to know when the
1046+ /* Lock is not acquired/ or ref count is not 1 , but still, we need to know when the
10401047 * acc is finalized so that we can free the temp buffers */
10411048 OSC_UCX_INCREMENT_OUTSTANDING_NB_OPS (module );
10421049 ret = opal_common_ucx_wpmem_flush_ep_nb (module -> mem , target , ompi_osc_ucx_req_completion , ucx_req , ep );
@@ -1054,6 +1061,11 @@ inline int ompi_osc_ucx_nonblocking_ops_finalize(ompi_osc_ucx_module_t *module,
10541061 }
10551062 }
10561063
1064+ if (lock_acquired ) {
1065+ module -> acc_lock_refcnt [target ]-- ;
1066+ assert (module -> acc_lock_refcnt [target ] >= 0 );
1067+ }
1068+
10571069 return ret ;
10581070}
10591071
0 commit comments