@@ -25,11 +25,35 @@ limitations under the License.
2525#include < torch_mlu/csrc/framework/distributed/process_group_cncl.hpp>
2626
2727#include " mlu_process_group.h"
28+ #elif defined(USE_CUDA)
29+ #include " cuda_process_group.h"
2830#endif
2931#include " common/global_flags.h"
3032#include " parallel_args.h"
3133#include " util/net.h"
3234
35+ namespace {
36+ std::unique_ptr<xllm::ProcessGroup> create_process_group (
37+ int rank,
38+ int world_size,
39+ int rank_size,
40+ int port,
41+ const std::string& host,
42+ const std::string& group_name,
43+ const torch::Device& device) {
44+ #if defined(USE_MLU)
45+ return std::make_unique<xllm::ProcessGroupCncl>(
46+ rank, world_size, rank_size, port, host, group_name, device);
47+ #elif defined(USE_CUDA)
48+ return std::make_unique<xllm::ProcessGroupNccl>(
49+ rank, world_size, rank_size, port, host, group_name, device);
50+ #else
51+ LOG (FATAL) << " Unsupported device type" ;
52+ return nullptr ;
53+ #endif
54+ }
55+ } // namespace
56+
3357namespace xllm {
3458
3559CollectiveCommunicator::CollectiveCommunicator (int global_rank,
@@ -90,40 +114,41 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank,
90114 mapping,
91115 dispatchAndCombinecommDomain,
92116 dispatchAndCombineHcclComm);
93- #elif defined(USE_MLU)
117+ #else
94118 parallel_args_ = std::make_unique<ParallelArgs>(
95119 global_rank, world_size, dp_size, nullptr , ep_size);
96120#endif
97121}
98122
99- #if defined(USE_MLU)
100- void CollectiveCommunicator::create_process_groups_cncl (
123+ void CollectiveCommunicator::create_process_groups (
101124 const std::string& master_addr,
102125 const torch::Device& device) {
103126 std::string host;
104127 int port;
105128 net::parse_host_port_from_addr (master_addr, host, port);
106129
107- std::vector<std::unique_ptr<ProcessGroup>> process_groups;
108130 int global_rank = parallel_args_->rank ();
109131 int world_size = parallel_args_->world_size ();
110132 int dp_size = parallel_args_->dp_size ();
111- process_group_ = std::make_unique<ProcessGroupCncl>(
133+
134+ process_group_ = create_process_group (
112135 global_rank, world_size, world_size, ++port, host, " world_group" , device);
136+
113137 int tp_size = world_size / dp_size;
114138 CHECK_EQ (tp_size * dp_size, world_size);
115139 int port_offset = global_rank / tp_size + 1 ;
116- tp_group_ = std::make_unique<ProcessGroupCncl>(global_rank,
117- world_size,
118- tp_size,
119- port + port_offset,
120- host,
121- " tp_group" ,
122- device);
140+
141+ tp_group_ = create_process_group (global_rank,
142+ world_size,
143+ tp_size,
144+ port + port_offset,
145+ host,
146+ " tp_group" ,
147+ device);
148+
123149 parallel_args_->process_group_ = process_group_.get ();
124150 parallel_args_->tp_group_ = tp_group_.get ();
125151}
126- #endif
127152
128153const ParallelArgs* CollectiveCommunicator::parallel_args () {
129154 // TODO: init communicator
0 commit comments