Skip to content

Commit 742b1bf

Browse files
authored
feat: add process group for cuda device. (#282)
1 parent fcec9c9 commit 742b1bf

File tree

12 files changed

+194
-54
lines changed

12 files changed

+194
-54
lines changed

xllm/core/distributed_runtime/worker_server.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ void WorkerServer::create_server(const runtime::Options& options,
9595

9696
CollectiveCommunicator comm(worker_global_rank, world_size, dp_size, ep_size);
9797
const ParallelArgs* parallel_args = comm.parallel_args();
98-
#if defined(USE_MLU)
99-
comm.create_process_groups_cncl(master_node_addr, device);
98+
#if defined(USE_MLU) || defined(USE_CUDA)
99+
comm.create_process_groups(master_node_addr, device);
100100
#endif
101101

102102
WorkerType worker_type =

xllm/core/framework/parallel_state/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ cc_library(
1111
process_group.h
1212
$<$<BOOL:${USE_NPU}>:npu_process_group.h>
1313
$<$<BOOL:${USE_MLU}>:mlu_process_group.h>
14+
$<$<BOOL:${USE_CUDA}>:cuda_process_group.h>
1415
collective_communicator.h
1516
SRCS
1617
mapping_npu.cpp
1718
parallel_state.cpp
1819
$<$<BOOL:${USE_NPU}>:npu_process_group.cpp>
1920
$<$<BOOL:${USE_MLU}>:mlu_process_group.cpp>
21+
$<$<BOOL:${USE_CUDA}>:cuda_process_group.cpp>
2022
collective_communicator.cpp
2123
DEPS
2224
:common

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,35 @@ limitations under the License.
2525
#include <torch_mlu/csrc/framework/distributed/process_group_cncl.hpp>
2626

2727
#include "mlu_process_group.h"
28+
#elif defined(USE_CUDA)
29+
#include "cuda_process_group.h"
2830
#endif
2931
#include "common/global_flags.h"
3032
#include "parallel_args.h"
3133
#include "util/net.h"
3234

35+
namespace {
36+
std::unique_ptr<xllm::ProcessGroup> create_process_group(
37+
int rank,
38+
int world_size,
39+
int rank_size,
40+
int port,
41+
const std::string& host,
42+
const std::string& group_name,
43+
const torch::Device& device) {
44+
#if defined(USE_MLU)
45+
return std::make_unique<xllm::ProcessGroupCncl>(
46+
rank, world_size, rank_size, port, host, group_name, device);
47+
#elif defined(USE_CUDA)
48+
return std::make_unique<xllm::ProcessGroupNccl>(
49+
rank, world_size, rank_size, port, host, group_name, device);
50+
#else
51+
LOG(FATAL) << "Unsupported device type";
52+
return nullptr;
53+
#endif
54+
}
55+
} // namespace
56+
3357
namespace xllm {
3458

3559
CollectiveCommunicator::CollectiveCommunicator(int global_rank,
@@ -90,40 +114,41 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank,
90114
mapping,
91115
dispatchAndCombinecommDomain,
92116
dispatchAndCombineHcclComm);
93-
#elif defined(USE_MLU)
117+
#else
94118
parallel_args_ = std::make_unique<ParallelArgs>(
95119
global_rank, world_size, dp_size, nullptr, ep_size);
96120
#endif
97121
}
98122

99-
#if defined(USE_MLU)
100-
void CollectiveCommunicator::create_process_groups_cncl(
123+
void CollectiveCommunicator::create_process_groups(
101124
const std::string& master_addr,
102125
const torch::Device& device) {
103126
std::string host;
104127
int port;
105128
net::parse_host_port_from_addr(master_addr, host, port);
106129

107-
std::vector<std::unique_ptr<ProcessGroup>> process_groups;
108130
int global_rank = parallel_args_->rank();
109131
int world_size = parallel_args_->world_size();
110132
int dp_size = parallel_args_->dp_size();
111-
process_group_ = std::make_unique<ProcessGroupCncl>(
133+
134+
process_group_ = create_process_group(
112135
global_rank, world_size, world_size, ++port, host, "world_group", device);
136+
113137
int tp_size = world_size / dp_size;
114138
CHECK_EQ(tp_size * dp_size, world_size);
115139
int port_offset = global_rank / tp_size + 1;
116-
tp_group_ = std::make_unique<ProcessGroupCncl>(global_rank,
117-
world_size,
118-
tp_size,
119-
port + port_offset,
120-
host,
121-
"tp_group",
122-
device);
140+
141+
tp_group_ = create_process_group(global_rank,
142+
world_size,
143+
tp_size,
144+
port + port_offset,
145+
host,
146+
"tp_group",
147+
device);
148+
123149
parallel_args_->process_group_ = process_group_.get();
124150
parallel_args_->tp_group_ = tp_group_.get();
125151
}
126-
#endif
127152

128153
const ParallelArgs* CollectiveCommunicator::parallel_args() {
129154
// TODO: init communicator

xllm/core/framework/parallel_state/collective_communicator.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ class CollectiveCommunicator {
3131
int ep_size);
3232
~CollectiveCommunicator() = default;
3333

34-
#if defined(USE_MLU)
35-
void create_process_groups_cncl(const std::string& master_addr,
36-
const torch::Device& device);
37-
#endif
34+
void create_process_groups(const std::string& master_addr,
35+
const torch::Device& device);
3836

3937
// init communicator and return parallel args.
4038
const ParallelArgs* parallel_args();
@@ -43,9 +41,7 @@ class CollectiveCommunicator {
4341
std::unique_ptr<ParallelArgs> parallel_args_;
4442
std::unique_ptr<ProcessGroup> process_group_;
4543
std::unique_ptr<ProcessGroup> dp_local_process_group_;
46-
#if defined(USE_MLU)
4744
std::unique_ptr<ProcessGroup> tp_group_;
48-
#endif
4945
};
5046

5147
} // namespace xllm
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "cuda_process_group.h"
17+
18+
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
19+
20+
#include "parallel_state.h"
21+
22+
namespace xllm {
23+
24+
ProcessGroupNccl::ProcessGroupNccl(int rank,
25+
int world_size,
26+
int rank_size,
27+
int port,
28+
const std::string& host,
29+
const std::string& group_name,
30+
const torch::Device& device)
31+
: ProcessGroup(rank, rank_size, device),
32+
world_size_(rank_size),
33+
rank_(rank) {
34+
c10::intrusive_ptr<c10d::ProcessGroupNCCL::Options> nccl_pg_options =
35+
c10d::ProcessGroupNCCL::Options::create();
36+
nccl_pg_options->is_high_priority_stream = false;
37+
38+
if (world_size != rank_size) {
39+
auto [local_rank, group_ranks] =
40+
parallel_state::get_group_rank(world_size, rank, rank_size);
41+
nccl_pg_options->global_ranks_in_group = group_ranks;
42+
rank_ = local_rank;
43+
}
44+
45+
c10d::TCPStoreOptions tcp_options;
46+
tcp_options.isServer = (rank_ == 0);
47+
tcp_options.port = port;
48+
49+
c10::intrusive_ptr<c10d::Store> store =
50+
c10::make_intrusive<c10d::TCPStore>(host, tcp_options);
51+
nccl_pg_ = std::make_unique<c10d::ProcessGroupNCCL>(
52+
store, rank_, rank_size, nccl_pg_options);
53+
}
54+
55+
ProcessGroupNccl::~ProcessGroupNccl() { nccl_pg_->shutdown(); }
56+
57+
void ProcessGroupNccl::allreduce(torch::Tensor& input) {
58+
std::vector<torch::Tensor> input_tensors = {input};
59+
nccl_pg_->allreduce(input_tensors)->wait();
60+
}
61+
62+
void ProcessGroupNccl::allgather(torch::Tensor input,
63+
std::vector<torch::Tensor>& outputs) {
64+
std::vector<torch::Tensor> input_tensors = {input};
65+
std::vector<std::vector<torch::Tensor>> output_tensors = {outputs};
66+
nccl_pg_->allgather(output_tensors, input_tensors)->wait();
67+
}
68+
69+
} // namespace xllm
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
19+
20+
#include "process_group.h"
21+
22+
namespace xllm {
23+
24+
class ProcessGroupNccl : public ProcessGroup {
25+
public:
26+
ProcessGroupNccl(int rank,
27+
int world_size,
28+
int rank_size,
29+
int port,
30+
const std::string& host,
31+
const std::string& group_name,
32+
const torch::Device& device);
33+
34+
~ProcessGroupNccl() override;
35+
36+
void allreduce(torch::Tensor& input) override;
37+
38+
void allgather(torch::Tensor input,
39+
std::vector<torch::Tensor>& outputs) override;
40+
41+
private:
42+
// rank of current process
43+
int rank_ = 0;
44+
45+
// number of processes
46+
int world_size_ = 0;
47+
48+
// nccl process group
49+
std::unique_ptr<c10d::ProcessGroupNCCL> nccl_pg_;
50+
};
51+
52+
} // namespace xllm

xllm/core/framework/parallel_state/mlu_process_group.cpp

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,7 @@ limitations under the License.
1717

1818
#include <torch/csrc/distributed/c10d/TCPStore.hpp>
1919

20-
namespace {
21-
22-
std::pair<int, std::vector<uint64_t>> get_group_rank(int world_size,
23-
int global_rank,
24-
int split_size) {
25-
int target_group_index = global_rank / split_size;
26-
uint64_t start_rank = target_group_index * split_size;
27-
uint64_t end_rank = start_rank + split_size;
28-
std::vector<uint64_t> group_rank;
29-
int index = global_rank - start_rank;
30-
for (uint64_t rank = start_rank; rank < end_rank; rank++) {
31-
group_rank.push_back(rank);
32-
}
33-
return {index, group_rank};
34-
}
35-
36-
} // namespace
20+
#include "parallel_state.h"
3721

3822
namespace xllm {
3923

@@ -52,7 +36,7 @@ ProcessGroupCncl::ProcessGroupCncl(int rank,
5236
cncl_pg_options->group_name = group_name;
5337
if (world_size != rank_size) {
5438
auto [local_rank, group_ranks] =
55-
get_group_rank(world_size, rank, rank_size);
39+
parallel_state::get_group_rank(world_size, rank, rank_size);
5640
cncl_pg_options->global_ranks_in_group = group_ranks;
5741
rank_ = local_rank;
5842
}
@@ -67,7 +51,6 @@ ProcessGroupCncl::ProcessGroupCncl(int rank,
6751
store, rank, world_size, cncl_pg_options);
6852
}
6953

70-
// Destructor.
7154
ProcessGroupCncl::~ProcessGroupCncl() { cncl_pg_->shutdown(); }
7255

7356
void ProcessGroupCncl::allreduce(torch::Tensor& input) {

xllm/core/framework/parallel_state/mlu_process_group.h

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ namespace xllm {
2323

2424
class ProcessGroupCncl : public ProcessGroup {
2525
public:
26-
// Constructor.
2726
ProcessGroupCncl(int rank,
2827
int world_size,
2928
int rank_size,
@@ -32,11 +31,6 @@ class ProcessGroupCncl : public ProcessGroup {
3231
const std::string& group_name,
3332
const torch::Device& device);
3433

35-
int rank() override { return rank_; }
36-
37-
int world_size() override { return world_size_; }
38-
39-
// Destructor.
4034
~ProcessGroupCncl() override;
4135

4236
void allreduce(torch::Tensor& input) override;
@@ -45,12 +39,14 @@ class ProcessGroupCncl : public ProcessGroup {
4539
std::vector<torch::Tensor>& outputs) override;
4640

4741
private:
48-
std::shared_ptr<torch_mlu::ProcessGroupCNCL> cncl_pg_ = nullptr;
49-
// rank of current process.
42+
// rank of current process
5043
int rank_ = 0;
5144

52-
// number of processes.
45+
// number of processes
5346
int world_size_ = 0;
47+
48+
// cncl process group
49+
std::unique_ptr<torch_mlu::ProcessGroupCNCL> cncl_pg_;
5450
};
5551

5652
} // namespace xllm

xllm/core/framework/parallel_state/parallel_args.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,10 @@ struct ParallelArgs {
111111

112112
// atb hccl dispatchAndCombineHcclComm
113113
PROPERTY(HcclComm, dispatchAndCombineHcclComm);
114-
#elif defined(USE_MLU)
114+
#endif
115115
ProcessGroup* tp_group_ = nullptr;
116116
ProcessGroup* moe_ep_group_ = nullptr;
117117
ProcessGroup* moe_tp_group_ = nullptr;
118-
#endif
119118
};
120119

121120
} // namespace xllm

xllm/core/framework/parallel_state/parallel_state.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,19 @@ std::vector<std::unique_ptr<ProcessGroup>> create_npu_process_groups(
125125
#endif
126126
}
127127

128+
std::pair<int, std::vector<uint64_t>> get_group_rank(int world_size,
129+
int global_rank,
130+
int split_size) {
131+
int target_group_index = global_rank / split_size;
132+
uint64_t start_rank = target_group_index * split_size;
133+
uint64_t end_rank = start_rank + split_size;
134+
std::vector<uint64_t> group_rank;
135+
int index = global_rank - start_rank;
136+
for (uint64_t rank = start_rank; rank < end_rank; rank++) {
137+
group_rank.push_back(rank);
138+
}
139+
return {index, group_rank};
140+
}
141+
128142
} // namespace parallel_state
129143
} // namespace xllm

0 commit comments

Comments
 (0)