From f3d0c59f446f1a6c59fe2e208c90a79542839f72 Mon Sep 17 00:00:00 2001 From: Jessie Yang Date: Tue, 9 Jan 2024 09:48:55 -0800 Subject: [PATCH 1/2] communicator: Rename allreduce_fn to iallreduce_fn The allreduce_fn is non-blocking. Rename it to iallreduce_fn to make it clear. Signed-off-by: Jessie Yang --- ompi/communicator/comm_cid.c | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/ompi/communicator/comm_cid.c b/ompi/communicator/comm_cid.c index 4424bdf84e0..15ba6703367 100644 --- a/ompi/communicator/comm_cid.c +++ b/ompi/communicator/comm_cid.c @@ -65,9 +65,9 @@ opal_atomic_int64_t ompi_comm_next_base_cid = 1; struct ompi_comm_cid_context_t; -typedef int (*ompi_comm_allreduce_impl_fn_t) (int *inbuf, int *outbuf, int count, struct ompi_op_t *op, - struct ompi_comm_cid_context_t *cid_context, - ompi_request_t **req); +typedef int (*ompi_comm_iallreduce_impl_fn_t) (int *inbuf, int *outbuf, int count, struct ompi_op_t *op, + struct ompi_comm_cid_context_t *cid_context, + ompi_request_t **req); struct ompi_comm_cid_context_t { @@ -78,7 +78,7 @@ struct ompi_comm_cid_context_t { ompi_communicator_t *comm; ompi_communicator_t *bridgecomm; - ompi_comm_allreduce_impl_fn_t allreduce_fn; + ompi_comm_iallreduce_impl_fn_t iallreduce_fn; int nextcid; int nextlocal_cid; @@ -225,18 +225,18 @@ static ompi_comm_cid_context_t *mca_comm_cid_context_alloc (ompi_communicator_t * for the current mode. */ switch (mode) { case OMPI_COMM_CID_INTRA: - context->allreduce_fn = ompi_comm_allreduce_intra_nb; + context->iallreduce_fn = ompi_comm_allreduce_intra_nb; break; case OMPI_COMM_CID_INTER: - context->allreduce_fn = ompi_comm_allreduce_inter_nb; + context->iallreduce_fn = ompi_comm_allreduce_inter_nb; break; case OMPI_COMM_CID_GROUP: case OMPI_COMM_CID_GROUP_NEW: - context->allreduce_fn = ompi_comm_allreduce_group_nb; + context->iallreduce_fn = ompi_comm_allreduce_group_nb; context->pml_tag = ((int *) arg0)[0]; break; case OMPI_COMM_CID_INTRA_PMIX: - context->allreduce_fn = ompi_comm_allreduce_intra_pmix_nb; + context->iallreduce_fn = ompi_comm_allreduce_intra_pmix_nb; context->local_leader = ((int *) arg0)[0]; if (arg1) { context->port_string = strdup ((char *) arg1); @@ -244,19 +244,19 @@ static ompi_comm_cid_context_t *mca_comm_cid_context_alloc (ompi_communicator_t context->pmix_tag = strdup ((char *) pmix_tag); break; case OMPI_COMM_CID_INTRA_BRIDGE: - context->allreduce_fn = ompi_comm_allreduce_intra_bridge_nb; + context->iallreduce_fn = ompi_comm_allreduce_intra_bridge_nb; context->local_leader = ((int *) arg0)[0]; context->remote_leader = ((int *) arg1)[0]; break; #if OPAL_ENABLE_FT_MPI case OMPI_COMM_CID_INTRA_FT: - context->allreduce_fn = ompi_comm_ft_allreduce_intra_nb; + context->iallreduce_fn = ompi_comm_ft_allreduce_intra_nb; break; case OMPI_COMM_CID_INTER_FT: - context->allreduce_fn = ompi_comm_ft_allreduce_inter_nb; + context->iallreduce_fn = ompi_comm_ft_allreduce_inter_nb; break; case OMPI_COMM_CID_INTRA_PMIX_FT: - context->allreduce_fn = ompi_comm_ft_allreduce_intra_pmix_nb; + context->iallreduce_fn = ompi_comm_ft_allreduce_intra_pmix_nb; break; #endif /* OPAL_ENABLE_FT_MPI */ default: @@ -600,8 +600,8 @@ static int ompi_comm_allreduce_getnextcid (ompi_comm_request_t *request) #endif /* OPAL_ENABLE_FT_MPI */ } - ret = context->allreduce_fn (&context->nextlocal_cid, &context->nextcid, 1, MPI_MAX, - context, &subreq); + ret = context->iallreduce_fn (&context->nextlocal_cid, &context->nextcid, 1, MPI_MAX, + context, &subreq); /* there was a failure during non-blocking collective * all we can do is abort */ @@ -666,7 +666,7 @@ static int ompi_comm_checkcid (ompi_comm_request_t *request) ++context->iter; - ret = context->allreduce_fn (&context->flag, &context->rflag, 1, MPI_MIN, context, &subreq); + ret = context->iallreduce_fn (&context->flag, &context->rflag, 1, MPI_MIN, context, &subreq); if (OMPI_SUCCESS == ret) { ompi_comm_request_schedule_append (request, ompi_comm_nextcid_check_flag, &subreq, 1); } else { @@ -908,7 +908,7 @@ int ompi_comm_activate_nb (ompi_communicator_t **newcomm, ompi_communicator_t *c * 2. After the operation it is allowed to send messages over the new communicator. */ local_peers = context->max_local_peers; - ret = context->allreduce_fn (&local_peers, &context->max_local_peers, 1, MPI_MAX, context, + ret = context->iallreduce_fn (&local_peers, &context->max_local_peers, 1, MPI_MAX, context, &subreq); if (OMPI_SUCCESS != ret) { ompi_comm_request_return (request); From 23df181bd5c1b04def4e8ebbac5fad9476285cde Mon Sep 17 00:00:00 2001 From: Jessie Yang Date: Tue, 9 Jan 2024 10:30:00 -0800 Subject: [PATCH 2/2] communicator bugfix: disjoint function does not have the correct max_local_peers value local_peers is passed in the non-blocking function iallreduce_fn as a stack variable. Change it to be part of the context struct so the correct value is passed. Signed-off-by: Jessie Yang --- ompi/communicator/comm_cid.c | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/ompi/communicator/comm_cid.c b/ompi/communicator/comm_cid.c index 15ba6703367..6d08fd4d768 100644 --- a/ompi/communicator/comm_cid.c +++ b/ompi/communicator/comm_cid.c @@ -774,6 +774,11 @@ static int ompi_comm_activate_nb_complete (ompi_comm_request_t *request); /* Callback function to set communicator disjointness flags */ static inline void ompi_comm_set_disjointness_nb_complete(ompi_comm_cid_context_t *context) { + /* Only set the disjoint flags when it is intra-communicator */ + if (OMPI_COMM_IS_INTER(*context->newcommp)) { + return; + } + if (OMPI_COMM_IS_DISJOINT_SET(*context->newcommp)) { opal_show_help("help-comm.txt", "disjointness-set-again", true); return; @@ -870,7 +875,7 @@ int ompi_comm_activate_nb (ompi_communicator_t **newcomm, ompi_communicator_t *c ompi_comm_cid_context_t *context; ompi_comm_request_t *request; ompi_request_t *subreq; - int ret = 0, local_peers = -1; + int ret = 0; /* the caller should not pass NULL for comm (it may be the same as *newcomm) */ assert (NULL != comm); @@ -902,20 +907,19 @@ int ompi_comm_activate_nb (ompi_communicator_t **newcomm, ompi_communicator_t *c OMPI_COMM_SET_PML_ADDED(*newcomm); } - /** - * Dual-purpose barrier: - * 1. The communicator's disjointness is inferred from max_local_peers. - * 2. After the operation it is allowed to send messages over the new communicator. - */ - local_peers = context->max_local_peers; - ret = context->iallreduce_fn (&local_peers, &context->max_local_peers, 1, MPI_MAX, context, - &subreq); - if (OMPI_SUCCESS != ret) { - ompi_comm_request_return (request); - return ret; + if (OMPI_COMM_IS_INTRA(*newcomm)) { + /* The communicator's disjointness is inferred from max_local_peers. */ + ret = context->iallreduce_fn (MPI_IN_PLACE, &context->max_local_peers, 1, MPI_MAX, context, + &subreq); + if (OMPI_SUCCESS != ret) { + ompi_comm_request_return (request); + return ret; + } + ompi_comm_request_schedule_append (request, ompi_comm_activate_nb_complete, &subreq, 1); + } else { + ompi_comm_request_schedule_append (request, ompi_comm_activate_nb_complete, NULL, 0); } - - ompi_comm_request_schedule_append (request, ompi_comm_activate_nb_complete, &subreq, 1); + ompi_comm_request_start (request); *req = &request->super;