diff --git a/ompi/mca/pml/ucx/pml_ucx.c b/ompi/mca/pml/ucx/pml_ucx.c index 13119041c4..d9073a59ea 100644 --- a/ompi/mca/pml/ucx/pml_ucx.c +++ b/ompi/mca/pml/ucx/pml_ucx.c @@ -207,6 +207,44 @@ int mca_pml_ucx_cleanup(void) return OMPI_SUCCESS; } +ucp_ep_h mca_pml_ucx_add_proc(ompi_communicator_t *comm, int dst) +{ + ucp_address_t *address; + ucs_status_t status; + size_t addrlen; + ucp_ep_h ep; + int ret; + + ompi_proc_t *proc0 = ompi_comm_peer_lookup(comm, 0); + ompi_proc_t *proc_peer = ompi_comm_peer_lookup(comm, dst); + + /* Note, mca_pml_base_pml_check_selected, doesn't use 3rd argument */ + if (OMPI_SUCCESS != (ret = mca_pml_base_pml_check_selected("ucx", + &proc0, + dst))) { + return NULL; + } + + ret = mca_pml_ucx_recv_worker_address(proc_peer, &address, &addrlen); + if (ret < 0) { + PML_UCX_ERROR("Failed to receive worker address from proc: %d", proc_peer->super.proc_name.vpid); + return NULL; + } + + PML_UCX_VERBOSE(2, "connecting to proc. %d", proc_peer->super.proc_name.vpid); + status = ucp_ep_create(ompi_pml_ucx.ucp_worker, address, &ep); + free(address); + if (UCS_OK != status) { + PML_UCX_ERROR("Failed to connect to proc: %d, %s", proc_peer->super.proc_name.vpid, + ucs_status_string(status)); + return NULL; + } + + proc_peer->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML] = ep; + + return ep; +} + int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs) { ucp_address_t *address; @@ -225,6 +263,7 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs) for (i = 0; i < nprocs; ++i) { ret = mca_pml_ucx_recv_worker_address(procs[i], &address, &addrlen); if (ret < 0) { + PML_UCX_ERROR("Failed to receive worker address from proc: %d", procs[i]->super.proc_name.vpid); return ret; } @@ -238,7 +277,8 @@ int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs) free(address); if (UCS_OK != status) { - PML_UCX_ERROR("Failed to connect"); + PML_UCX_ERROR("Failed to connect to proc: %d, %s", procs[i]->super.proc_name.vpid, + ucs_status_string(status)); return OMPI_ERROR; } @@ -426,7 +466,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat struct ompi_request_t **request) { mca_pml_ucx_persistent_request_t *req; - + ucp_ep_h ep; req = (mca_pml_ucx_persistent_request_t *)PML_UCX_FREELIST_GET(&ompi_pml_ucx.persistent_reqs); if (req == NULL) { @@ -436,6 +476,12 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat PML_UCX_TRACE_SEND("isend_init request *%p=%p", buf, count, datatype, dst, tag, mode, comm, (void*)request, (void*)req) + ep = mca_pml_ucx_get_ep(comm, dst); + if (OPAL_UNLIKELY(NULL == ep)) { + PML_UCX_ERROR("Failed to get ep for rank %d", dst); + return OMPI_ERROR; + } + req->ompi.req_state = OMPI_REQUEST_INACTIVE; req->flags = MCA_PML_UCX_REQUEST_FLAG_SEND; req->buffer = (void *)buf; @@ -443,7 +489,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat req->datatype = mca_pml_ucx_get_datatype(datatype); req->tag = PML_UCX_MAKE_SEND_TAG(tag, comm); req->send.mode = mode; - req->send.ep = mca_pml_ucx_get_ep(comm, dst); + req->send.ep = ep; *request = &req->ompi; return OMPI_SUCCESS; @@ -455,13 +501,20 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype, struct ompi_request_t **request) { ompi_request_t *req; + ucp_ep_h ep; PML_UCX_TRACE_SEND("isend request *%p", buf, count, datatype, dst, tag, mode, comm, (void*)request) /* TODO special care to sync/buffered send */ - req = (ompi_request_t*)ucp_tag_send_nb(mca_pml_ucx_get_ep(comm, dst), buf, count, + ep = mca_pml_ucx_get_ep(comm, dst); + if (OPAL_UNLIKELY(NULL == ep)) { + PML_UCX_ERROR("Failed to get ep for rank %d", dst); + return OMPI_ERROR; + } + + req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count, mca_pml_ucx_get_datatype(datatype), PML_UCX_MAKE_SEND_TAG(tag, comm), mca_pml_ucx_send_completion); @@ -484,12 +537,19 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i struct ompi_communicator_t* comm) { ompi_request_t *req; + ucp_ep_h ep; PML_UCX_TRACE_SEND("%s", buf, count, datatype, dst, tag, mode, comm, "send"); /* TODO special care to sync/buffered send */ - req = (ompi_request_t*)ucp_tag_send_nb(mca_pml_ucx_get_ep(comm, dst), buf, count, + ep = mca_pml_ucx_get_ep(comm, dst); + if (OPAL_UNLIKELY(NULL == ep)) { + PML_UCX_ERROR("Failed to get ep for rank %d", dst); + return OMPI_ERROR; + } + + req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count, mca_pml_ucx_get_datatype(datatype), PML_UCX_MAKE_SEND_TAG(tag, comm), mca_pml_ucx_send_completion); diff --git a/ompi/mca/pml/ucx/pml_ucx.h b/ompi/mca/pml/ucx/pml_ucx.h index d684ecb462..2f50cb2777 100644 --- a/ompi/mca/pml/ucx/pml_ucx.h +++ b/ompi/mca/pml/ucx/pml_ucx.h @@ -85,6 +85,7 @@ int mca_pml_ucx_close(void); int mca_pml_ucx_init(void); int mca_pml_ucx_cleanup(void); +ucp_ep_h mca_pml_ucx_add_proc(ompi_communicator_t *comm, int dst); int mca_pml_ucx_add_procs(struct ompi_proc_t **procs, size_t nprocs); int mca_pml_ucx_del_procs(struct ompi_proc_t **procs, size_t nprocs); @@ -146,4 +147,5 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests); int mca_pml_ucx_dump(struct ompi_communicator_t* comm, int verbose); + #endif /* PML_UCX_H_ */ diff --git a/ompi/mca/pml/ucx/pml_ucx_request.h b/ompi/mca/pml/ucx/pml_ucx_request.h index cf2cff68ae..bfa3019021 100644 --- a/ompi/mca/pml/ucx/pml_ucx_request.h +++ b/ompi/mca/pml/ucx/pml_ucx_request.h @@ -127,7 +127,12 @@ void mca_pml_ucx_request_cleanup(void *request); static inline ucp_ep_h mca_pml_ucx_get_ep(ompi_communicator_t *comm, int dst) { - return ompi_comm_peer_lookup(comm, dst)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]; + ucp_ep_h ep = ompi_comm_peer_lookup(comm,dst)->proc_endpoints[OMPI_PROC_ENDPOINT_TAG_PML]; + if (OPAL_UNLIKELY(NULL == ep)) { + ep = mca_pml_ucx_add_proc(comm, dst); + } + + return ep; } static inline void mca_pml_ucx_request_reset(ompi_request_t *req)