Skip to content

Commit a774b93

Browse files
feat: add custom PA Operation for AclGraph. (#412)
1 parent e5fb3e4 commit a774b93

19 files changed

+1031
-217
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,20 @@ if(USE_NPU)
2828
if(DEVICE_TYPE STREQUAL "USE_A3")
2929
message("downloading a3 arm xllm kernels")
3030
file(DOWNLOAD
31-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.2-Linux.a3.arm.rpm"
31+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a3.arm.rpm"
3232
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
3333
)
3434
else()
3535
if(DEVICE_ARCH STREQUAL "ARM")
3636
message("downloading a2 arm xllm_kernels")
3737
file(DOWNLOAD
38-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.2-Linux.a2.arm.rpm"
38+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.arm.rpm"
3939
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4040
)
4141
else()
4242
message("downloading a2 x86 xllm_kernels")
4343
file(DOWNLOAD
44-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.2-Linux.a2.x86.rpm"
44+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.x86.rpm"
4545
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4646
)
4747
endif()

xllm/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ target_link_libraries(xllm PRIVATE glog::glog brpc leveldb::leveldb ZLIB::ZLIB p
3434
add_dependencies(xllm brpc-static)
3535

3636
if(USE_NPU)
37-
set(COMMON_LIBS Python::Python ascendcl hccl c_sec nnopbase ms_tools_ext)
37+
set(COMMON_LIBS Python::Python ascendcl atb_customize hccl c_sec nnopbase ms_tools_ext)
3838
elseif(USE_MLU)
3939
set(COMMON_LIBS Python::Python)
4040
endif()

xllm/core/common/global_flags.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,15 @@ DEFINE_bool(enable_acl_graph,
8686
"Whether to enable ACL graph execution for decode phase.");
8787

8888
DEFINE_int32(max_seq_len_for_graph_mode,
89-
20480,
90-
"Maximum number of tokens per sequence for ACL graph execution.");
89+
0,
90+
"Maximum number of tokens per sequence for ACL graph execution. "
91+
"If 0, use model max_position_embeddings.");
9192

93+
DEFINE_bool(enable_acl_graph_no_padding,
94+
false,
95+
"Whether to enable ACL graph execution for decode phase without "
96+
"padding. If true, graph will be caputured with every actual num "
97+
"tokens, as stride is 1.");
9298
// --- vlm config ---
9399

94100
DEFINE_int32(limit_image_per_prompt,

xllm/core/common/global_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ DECLARE_bool(enable_acl_graph);
8787

8888
DECLARE_int32(max_seq_len_for_graph_mode);
8989

90+
DECLARE_bool(enable_acl_graph_no_padding);
91+
9092
DECLARE_bool(enable_chunked_prefill);
9193

9294
DECLARE_string(master_node_addr);

xllm/core/distributed_runtime/spawn_worker_server/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ cc_binary(
1717
ascendcl
1818
nnopbase
1919
atb
20+
atb_customize
2021
c_sec
2122
spdlog::spdlog
2223
)

xllm/core/framework/model/model_input_params.h

100755100644
Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ struct ModelInputParams {
9797
params.kv_cache_start_offsets = safe_to(kv_cache_start_offsets, device);
9898

9999
// Copy graph_buffer to device
100-
params.graph_buffer = safe_to(graph_buffer, device, true);
100+
// params.graph_buffer = safe_to(graph_buffer, device, true);
101+
params.graph_buffer.attn_mask =
102+
safe_to(graph_buffer.attn_mask, device, true);
103+
params.graph_buffer.tiling_data =
104+
safe_to(graph_buffer.tiling_data, device, true);
101105

102106
return params;
103107
}
@@ -206,7 +210,12 @@ struct ModelInputParams {
206210
torch::Tensor kv_cache_start_offsets;
207211
// Graph execution buffer for temporary tensor storage
208212
// Used by ACL Graph Executor to avoid repeated memory allocation
209-
torch::Tensor graph_buffer;
213+
214+
struct GraphBuffer {
215+
torch::Tensor attn_mask;
216+
torch::Tensor tiling_data;
217+
};
218+
GraphBuffer graph_buffer;
210219
};
211220

212221
} // namespace xllm

xllm/core/framework/request/mm_data.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,22 @@ struct MMData {
8181
return true;
8282
}
8383

84+
template <typename T>
85+
bool update(uint32_t type, const MMKey& key, const T& value) {
86+
const auto& itor = data_.find(key);
87+
if (itor != data_.end()) {
88+
// Key exists, update it
89+
data_[key] = value;
90+
ty_ |= type;
91+
return true;
92+
} else {
93+
// Key doesn't exist, add it (same as add method)
94+
ty_ |= type;
95+
data_.insert({key, value});
96+
return true;
97+
}
98+
}
99+
84100
template <typename T>
85101
std::optional<T> get(const MMKey& key) const {
86102
if (!valid()) return std::nullopt;

xllm/core/kernels/npu/impl/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ cc_test(
4040
xllm_kernels
4141
c_sec
4242
atb
43+
opapi
4344
spdlog::spdlog
4445
)
4546

@@ -55,6 +56,7 @@ cc_test(
5556
xllm_kernels
5657
c_sec
5758
atb
59+
opapi
5860
spdlog::spdlog
5961
)
6062

@@ -70,6 +72,7 @@ cc_test(
7072
xllm_kernels
7173
c_sec
7274
atb
75+
opapi
7376
spdlog::spdlog
7477
)
7578

@@ -85,6 +88,7 @@ cc_test(
8588
xllm_kernels
8689
c_sec
8790
atb
91+
opapi
8892
spdlog::spdlog
8993
)
9094

@@ -100,5 +104,6 @@ cc_test(
100104
xllm_kernels
101105
c_sec
102106
atb
107+
opapi
103108
spdlog::spdlog
104109
)

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,12 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters(
382382

383383
param.mlpLinearTransposeType = {1, -1, 1, -1};
384384

385-
param.enableSplitFuse = (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
385+
param.enableSplitFuse =
386+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
387+
388+
// not support MTP model yet
389+
param.enableAclGraph =
390+
FLAGS_enable_acl_graph && !is_prefill && args.n_layers() > 1;
386391

387392
param.moeLinearTransposeType = (layer_id_ < args.first_k_dense_replace())
388393
? std::vector<int>{-1, -1, -1, -1}
@@ -406,7 +411,7 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters(
406411
param.enableSwiGLUQuantForSharedExperts = false; // TODO
407412

408413
param.useQKNorm = args.use_qk_norm();
409-
if(args.use_qk_norm()){
414+
if (args.use_qk_norm()) {
410415
WEIGHT_COUNT_PER_LAYER = 70;
411416
WEIGHT_MAPPING_W8A8["self_attn.q_norm.weight"] = Q_NORM_WEIGHT;
412417
WEIGHT_MAPPING_W8A8["self_attn.k_norm.weight"] = K_NORM_WEIGHT;
@@ -1086,8 +1091,9 @@ torch::Tensor Glm4MoeDecoderImpl::forward(
10861091
std::vector<std::atomic<bool>*> event_flag,
10871092
int node_id) {
10881093
atb::Status st;
1089-
if (input_params.decode_seq_range.second !=
1090-
input_params.q_seq_lens.size(0) - 1) {
1094+
bool is_prefill = input_params.decode_seq_range.second !=
1095+
input_params.q_seq_lens.size(0) - 1;
1096+
if (is_prefill) {
10911097
build_node_variant_pack(prefill_node_,
10921098
x,
10931099
cos_pos,
@@ -1200,6 +1206,13 @@ void Glm4MoeDecoderImpl::build_node_variant_pack(
12001206
node.variantPack.inTensors.at(input_idx++) =
12011207
atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_);
12021208

1209+
if (FLAGS_enable_acl_graph && !is_prefill &&
1210+
input_params.graph_buffer.tiling_data.defined()) {
1211+
node.variantPack.inTensors.at(input_idx++) =
1212+
atb_speed::Utils::AtTensor2Tensor(
1213+
input_params.graph_buffer.tiling_data);
1214+
}
1215+
12031216
for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
12041217
CHECK_THROW(node.inTensors.at(i) == nullptr,
12051218
model_name_ << " inTensor " << i << " is NULL");

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ class Glm4MoeDecoderImpl : public NpuBaseLayer {
170170
const ModelInputParams& input_params,
171171
torch::Tensor& expert_array,
172172
bool is_prefill);
173-
174173
std::string model_name_;
175174

176175
int32_t device_id_;

0 commit comments

Comments
 (0)