Skip to content

Commit 26af848

Browse files
committed
feat: revert the original code before refactoring the multi-stream[1/2].
Signed-off-by: Tao Peng <[email protected]>
1 parent 9bbd770 commit 26af848

33 files changed

+709
-1157
lines changed

xllm/core/framework/model/causal_lm.h

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,10 @@ class CausalLM : public torch::nn::Module {
4343
// tokens: [num_tokens]
4444
// positions: [num_tokens]
4545
// returns: [num_tokens, hidden_size]
46-
virtual torch::Tensor forward(
47-
const std::vector<torch::Tensor>& tokens,
48-
const std::vector<torch::Tensor>& positions,
49-
std::vector<KVCache>& kv_caches,
50-
const std::vector<ModelInputParams>& parameters) = 0;
46+
virtual torch::Tensor forward(const torch::Tensor& tokens,
47+
const torch::Tensor& positions,
48+
std::vector<KVCache>& kv_caches,
49+
const ModelInputParams& parameters) = 0;
5150

5251
// hidden_states: [num_tokens, hidden_size]
5352
// seleted_idxes: [num_tokens]
@@ -68,9 +67,8 @@ class CausalLM : public torch::nn::Module {
6867

6968
virtual layer::LmHead get_lm_head() = 0;
7069
virtual void set_lm_head(layer::LmHead& head) = 0;
71-
virtual std::vector<layer::WordEmbedding> get_word_embedding() = 0;
72-
virtual void set_word_embedding(
73-
std::vector<layer::WordEmbedding>& embedding) = 0;
70+
virtual layer::WordEmbedding get_word_embedding() = 0;
71+
virtual void set_word_embedding(layer::WordEmbedding& embedding) = 0;
7472
};
7573

7674
template <typename Model>
@@ -79,11 +77,10 @@ class CausalLMImpl : public CausalLM {
7977
CausalLMImpl(Model model, const torch::TensorOptions& options)
8078
: model_(std::move(model)), options_(options) {}
8179

82-
torch::Tensor forward(
83-
const std::vector<torch::Tensor>& tokens,
84-
const std::vector<torch::Tensor>& positions,
85-
std::vector<KVCache>& kv_caches,
86-
const std::vector<ModelInputParams>& parameters) override {
80+
torch::Tensor forward(const torch::Tensor& tokens,
81+
const torch::Tensor& positions,
82+
std::vector<KVCache>& kv_caches,
83+
const ModelInputParams& parameters) override {
8784
return model_->forward(tokens, positions, kv_caches, parameters);
8885
}
8986

@@ -109,12 +106,11 @@ class CausalLMImpl : public CausalLM {
109106

110107
void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };
111108

112-
std::vector<layer::WordEmbedding> get_word_embedding() override {
109+
layer::WordEmbedding get_word_embedding() override {
113110
return model_->get_word_embedding();
114111
};
115112

116-
void set_word_embedding(
117-
std::vector<layer::WordEmbedding>& embedding) override {
113+
void set_word_embedding(layer::WordEmbedding& embedding) override {
118114
model_->set_word_embedding(embedding);
119115
};
120116

xllm/core/framework/model/causal_vlm.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,10 @@ class CausalVLMImpl : public CausalVLM {
4040
CausalVLMImpl(Model model, const torch::TensorOptions& options)
4141
: model_(std::move(model)), options_(options) {}
4242

43-
torch::Tensor forward(
44-
const std::vector<torch::Tensor>& tokens,
45-
const std::vector<torch::Tensor>& positions,
46-
std::vector<KVCache>& kv_caches,
47-
const std::vector<ModelInputParams>& parameters) override {
43+
torch::Tensor forward(const torch::Tensor& tokens,
44+
const torch::Tensor& positions,
45+
std::vector<KVCache>& kv_caches,
46+
const ModelInputParams& parameters) override {
4847
return model_->forward(tokens, positions, kv_caches, parameters);
4948
}
5049

@@ -68,12 +67,11 @@ class CausalVLMImpl : public CausalVLM {
6867

6968
void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };
7069

71-
std::vector<layer::WordEmbedding> get_word_embedding() override {
70+
layer::WordEmbedding get_word_embedding() override {
7271
return model_->get_word_embedding();
7372
};
7473

75-
void set_word_embedding(
76-
std::vector<layer::WordEmbedding>& embedding) override {
74+
void set_word_embedding(layer::WordEmbedding& embedding) override {
7775
model_->set_word_embedding(embedding);
7876
};
7977

xllm/core/framework/model_context.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,8 @@ ModelContext::ModelContext(const ParallelArgs& input_parallel_args,
4040
int32_t device_id = tensor_options.device().index();
4141
aclError ret = aclrtSetDevice(device_id);
4242
atb::CreateContext(&context_);
43-
std::vector<aclrtStream> streams;
44-
streams.push_back(c10_npu::getCurrentNPUStream(device_id).stream());
45-
for (int i = 0; i < 1; ++i) {
46-
aclrtStream sub_stream;
47-
aclError ret = aclrtCreateStream(&sub_stream);
48-
if (ret != ACL_ERROR_NONE) {
49-
ATB_SPEED_LOG_ERROR("Failed to create aclrtStream: " << ret);
50-
}
51-
streams.push_back(sub_stream);
52-
}
53-
context_->SetExecuteStreams(streams);
43+
void* stream = c10_npu::getCurrentNPUStream(device_id).stream();
44+
context_->SetExecuteStream(stream);
5445
context_->SetAsyncTilingCopyStatus(true);
5546
#endif
5647
}

xllm/core/layers/npu/npu_base_layer.cpp

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,10 @@ NpuBaseLayer::NpuBaseLayer(const ModelContext& context) : BaseLayer(context) {
3232
work_space_ = AtbWorkspace(device_);
3333
}
3434

35-
atb::Status NpuBaseLayer::execute_node(
36-
atb_speed::Model::Node& node,
37-
int node_id,
38-
std::vector<aclrtEvent*> event,
39-
std::vector<std::atomic<bool>*> event_flag) {
35+
atb::Status NpuBaseLayer::execute_node(atb_speed::Model::Node& node,
36+
int node_id,
37+
aclrtEvent* event,
38+
std::atomic<bool>* event_flag) {
4039
// TODO(by [email protected]): Stream management needs to be refactored
4140
// for better separation of concerns Current issues:
4241
// 1. ACLGraph capture requires execution on a non-default stream, so we
@@ -93,28 +92,25 @@ atb::Status NpuBaseLayer::execute_node(
9392
return st;
9493
}
9594

96-
atb::Status NpuBaseLayer::execute_plan(
97-
const atb_speed::Model::Node& node,
98-
const std::string& op_name,
99-
std::vector<aclrtEvent*> event,
100-
std::vector<std::atomic<bool>*> event_flag) {
95+
atb::Status NpuBaseLayer::execute_plan(const atb_speed::Model::Node& node,
96+
const std::string& op_name,
97+
aclrtEvent* event,
98+
std::atomic<bool>* event_flag) {
10199
atb::Status st = node.operation->Execute(
102100
node.variantPack, (uint8_t*)node.workspace, node.workspaceSize, context_);
103101
LOG_IF(ERROR, st != 0) << name_ << " execute plan fail, error code: " << st;
104-
for (auto i = 0; i < event.size(); ++i) {
105-
if (st == 0 && event[i] != nullptr) {
106-
aclrtStream stream = context_->GetExecuteStream();
102+
if (st == 0 && event != nullptr) {
103+
aclrtStream stream = context_->GetExecuteStream();
107104

108-
aclrtEvent* aclrt_event = reinterpret_cast<aclrtEvent*>(event[i]);
105+
aclrtEvent* aclrt_event = reinterpret_cast<aclrtEvent*>(event);
109106

110-
auto ret = aclrtRecordEvent(*aclrt_event, stream);
111-
if (ret != ACL_SUCCESS) {
112-
LOG(ERROR) << "Record event failed.";
113-
return st;
114-
}
115-
116-
event_flag[i]->store(true, std::memory_order_release);
107+
auto ret = aclrtRecordEvent(*aclrt_event, stream);
108+
if (ret != ACL_SUCCESS) {
109+
LOG(ERROR) << "Record event failed.";
110+
return st;
117111
}
112+
113+
event_flag->store(true, std::memory_order_release);
118114
}
119115

120116
return st;

xllm/core/layers/npu/npu_base_layer.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,13 @@ class NpuBaseLayer : public BaseLayer {
6161

6262
atb::Status execute_node(atb_speed::Model::Node& node,
6363
int nodeId = 0,
64-
std::vector<aclrtEvent*> event = {nullptr, nullptr},
65-
std::vector<std::atomic<bool>*> event_flag = {
66-
nullptr,
67-
nullptr});
64+
aclrtEvent* event = nullptr,
65+
std::atomic<bool>* event_flag = nullptr);
6866

6967
atb::Status execute_plan(const atb_speed::Model::Node& node,
7068
const std::string& op_name,
71-
std::vector<aclrtEvent*> event,
72-
std::vector<std::atomic<bool>*> event_flag);
69+
aclrtEvent* event,
70+
std::atomic<bool>* event_flag);
7371

7472
virtual void run_task(std::string taskName,
7573
std::function<int()> task) const override;

0 commit comments

Comments
 (0)