@@ -382,7 +382,12 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters(
382382
383383 param.mlpLinearTransposeType = {1 , -1 , 1 , -1 };
384384
385- param.enableSplitFuse = (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
385+ param.enableSplitFuse =
386+ (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
387+
388+ // not support MTP model yet
389+ param.enableAclGraph =
390+ FLAGS_enable_acl_graph && !is_prefill && args.n_layers () > 1 ;
386391
387392 param.moeLinearTransposeType = (layer_id_ < args.first_k_dense_replace ())
388393 ? std::vector<int >{-1 , -1 , -1 , -1 }
@@ -406,7 +411,7 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters(
406411 param.enableSwiGLUQuantForSharedExperts = false ; // TODO
407412
408413 param.useQKNorm = args.use_qk_norm ();
409- if (args.use_qk_norm ()){
414+ if (args.use_qk_norm ()) {
410415 WEIGHT_COUNT_PER_LAYER = 70 ;
411416 WEIGHT_MAPPING_W8A8[" self_attn.q_norm.weight" ] = Q_NORM_WEIGHT;
412417 WEIGHT_MAPPING_W8A8[" self_attn.k_norm.weight" ] = K_NORM_WEIGHT;
@@ -1086,8 +1091,9 @@ torch::Tensor Glm4MoeDecoderImpl::forward(
10861091 std::vector<std::atomic<bool >*> event_flag,
10871092 int node_id) {
10881093 atb::Status st;
1089- if (input_params.decode_seq_range .second !=
1090- input_params.q_seq_lens .size (0 ) - 1 ) {
1094+ bool is_prefill = input_params.decode_seq_range .second !=
1095+ input_params.q_seq_lens .size (0 ) - 1 ;
1096+ if (is_prefill) {
10911097 build_node_variant_pack (prefill_node_,
10921098 x,
10931099 cos_pos,
@@ -1200,6 +1206,13 @@ void Glm4MoeDecoderImpl::build_node_variant_pack(
12001206 node.variantPack .inTensors .at (input_idx++) =
12011207 atb_speed::Utils::AtTensor2Tensor (tensor_placeholder_);
12021208
1209+ if (FLAGS_enable_acl_graph && !is_prefill &&
1210+ input_params.graph_buffer .tiling_data .defined ()) {
1211+ node.variantPack .inTensors .at (input_idx++) =
1212+ atb_speed::Utils::AtTensor2Tensor (
1213+ input_params.graph_buffer .tiling_data );
1214+ }
1215+
12031216 for (size_t i = 0 ; i < WEIGHT_COUNT_PER_LAYER; ++i) {
12041217 CHECK_THROW (node.inTensors .at (i) == nullptr ,
12051218 model_name_ << " inTensor " << i << " is NULL" );
0 commit comments