Skip to content

Commit f2b0088

Browse files
author
Xianzhe Dong
committed
feat: support mm embedding service and vlm embedding model factory.
1 parent 9bbd770 commit f2b0088

File tree

14 files changed

+290
-10
lines changed

14 files changed

+290
-10
lines changed

xllm/api_service/api_service.cpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ APIService::APIService(Master* master,
6464
auto vlm_master = dynamic_cast<VLMMaster*>(master);
6565
mm_chat_service_impl_ =
6666
std::make_unique<MMChatServiceImpl>(vlm_master, model_names);
67+
mm_embedding_service_impl_ =
68+
std::make_unique<MMEmbeddingServiceImpl>(vlm_master, model_names);
6769
} else if (FLAGS_backend == "dit") {
6870
image_generation_service_impl_ =
6971
std::make_unique<ImageGenerationServiceImpl>(
@@ -190,10 +192,13 @@ void APIService::Embeddings(::google::protobuf::RpcController* controller,
190192
// TODO with xllm-service
191193
}
192194

193-
void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
194-
const proto::HttpRequest* request,
195-
proto::HttpResponse* response,
196-
::google::protobuf::Closure* done) {
195+
namespace {
196+
template <typename EmbeddingCall, typename Service>
197+
void EmbeddingsImpl(std::unique_ptr<Service>& embedding_service_impl_,
198+
::google::protobuf::RpcController* controller,
199+
const proto::HttpRequest* request,
200+
proto::HttpResponse* response,
201+
::google::protobuf::Closure* done) {
197202
xllm::ClosureGuard done_guard(
198203
done,
199204
std::bind(request_in_metric, nullptr),
@@ -202,12 +207,13 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
202207
LOG(ERROR) << "brpc request | respose | controller is null";
203208
return;
204209
}
205-
206210
auto arena = response->GetArena();
207211
auto req_pb =
208-
google::protobuf::Arena::CreateMessage<proto::EmbeddingRequest>(arena);
212+
google::protobuf::Arena::CreateMessage<typename EmbeddingCall::ReqType>(
213+
arena);
209214
auto resp_pb =
210-
google::protobuf::Arena::CreateMessage<proto::EmbeddingResponse>(arena);
215+
google::protobuf::Arena::CreateMessage<typename EmbeddingCall::ResType>(
216+
arena);
211217

212218
auto ctrl = reinterpret_cast<brpc::Controller*>(controller);
213219
std::string error;
@@ -230,6 +236,22 @@ void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
230236
ctrl, done_guard.release(), req_pb, resp_pb);
231237
embedding_service_impl_->process_async(call);
232238
}
239+
} // namespace
240+
241+
void APIService::EmbeddingsHttp(::google::protobuf::RpcController* controller,
242+
const proto::HttpRequest* request,
243+
proto::HttpResponse* response,
244+
::google::protobuf::Closure* done) {
245+
if (FLAGS_backend == "llm") {
246+
CHECK(embedding_service_impl_) << " embedding service is invalid.";
247+
EmbeddingsImpl<EmbeddingCall, EmbeddingServiceImpl>(
248+
embedding_service_impl_, controller, request, response, done);
249+
} else if (FLAGS_backend == "vlm") {
250+
CHECK(mm_chat_service_impl_) << " mm embedding service is invalid.";
251+
EmbeddingsImpl<MMEmbeddingCall, MMEmbeddingServiceImpl>(
252+
mm_embedding_service_impl_, controller, request, response, done);
253+
}
254+
}
233255

234256
void APIService::ImageGeneration(::google::protobuf::RpcController* controller,
235257
const proto::ImageGenerationRequest* request,

xllm/api_service/api_service.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ class APIService : public proto::XllmAPIService {
120120
std::unique_ptr<ChatServiceImpl> chat_service_impl_;
121121
std::unique_ptr<MMChatServiceImpl> mm_chat_service_impl_;
122122
std::unique_ptr<EmbeddingServiceImpl> embedding_service_impl_;
123+
std::unique_ptr<MMEmbeddingServiceImpl> mm_embedding_service_impl_;
123124
std::unique_ptr<ModelsServiceImpl> models_service_impl_;
124125
std::unique_ptr<ImageGenerationServiceImpl> image_generation_service_impl_;
125126
std::unique_ptr<RerankServiceImpl> rerank_service_impl_;

xllm/api_service/embedding_service_impl.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ limitations under the License.
2828
namespace xllm {
2929
namespace {
3030

31+
template <typename EmbeddingCall>
3132
bool send_result_to_client_brpc(std::shared_ptr<EmbeddingCall> call,
3233
const std::string& request_id,
3334
int64_t created_time,
@@ -113,9 +114,64 @@ void EmbeddingServiceImpl::process_async_impl(
113114
}
114115
}
115116

116-
return send_result_to_client_brpc(
117+
return send_result_to_client_brpc<EmbeddingCall>(
117118
call, request_id, created_time, model, req_output);
118119
});
119120
}
120121

122+
MMEmbeddingServiceImpl::MMEmbeddingServiceImpl(
123+
VLMMaster* master,
124+
const std::vector<std::string>& models)
125+
: APIServiceImpl(models), master_(master) {
126+
CHECK(master_ != nullptr);
127+
}
128+
129+
void MMEmbeddingServiceImpl::process_async_impl(
130+
std::shared_ptr<MMEmbeddingCall> call) {
131+
const auto& rpc_request = call->request();
132+
// check if model is supported
133+
const auto& model = rpc_request.model();
134+
if (!models_.contains(model)) {
135+
call->finish_with_error(StatusCode::UNKNOWN, "Model not supported");
136+
return;
137+
}
138+
139+
// create RequestParams for embeddings request
140+
// set is_embeddings and max_tokens = 1 to control engine step once.
141+
RequestParams request_params(
142+
rpc_request, call->get_x_request_id(), call->get_x_request_time());
143+
144+
auto& req_messages = rpc_request.messages();
145+
146+
std::vector<Message> messages;
147+
MMInput mm_inputs;
148+
149+
static MMInputHelper helper;
150+
if (!helper.trans(req_messages, messages, mm_inputs.items_)) {
151+
call->finish_with_error(StatusCode::INVALID_ARGUMENT,
152+
"inputs argument is invalid.");
153+
return;
154+
}
155+
156+
// schedule the request
157+
master_->handle_request(
158+
std::move(messages),
159+
std::move(mm_inputs),
160+
std::move(request_params),
161+
[call,
162+
model,
163+
request_id = request_params.request_id,
164+
created_time = absl::ToUnixSeconds(absl::Now())](
165+
const RequestOutput& req_output) -> bool {
166+
if (req_output.status.has_value()) {
167+
const auto& status = req_output.status.value();
168+
if (!status.ok()) {
169+
return call->finish_with_error(status.code(), status.message());
170+
}
171+
}
172+
173+
return send_result_to_client_brpc<MMEmbeddingCall>(
174+
call, request_id, created_time, model, req_output);
175+
});
176+
}
121177
} // namespace xllm

xllm/api_service/embedding_service_impl.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include "api_service/api_service_impl.h"
2020
#include "api_service/call.h"
2121
#include "api_service/non_stream_call.h"
22+
#include "core/runtime/vlm_master.h"
2223
#include "embedding.pb.h"
2324

2425
namespace xllm {
@@ -40,4 +41,18 @@ class EmbeddingServiceImpl final : public APIServiceImpl<EmbeddingCall> {
4041
LLMMaster* master_ = nullptr;
4142
};
4243

44+
using MMEmbeddingCall =
45+
NonStreamCall<proto::MMEmbeddingRequest, proto::EmbeddingResponse>;
46+
class MMEmbeddingServiceImpl : public APIServiceImpl<MMEmbeddingCall> {
47+
public:
48+
MMEmbeddingServiceImpl(VLMMaster* master,
49+
const std::vector<std::string>& models);
50+
// brpc call_data needs to use shared_ptr
51+
void process_async_impl(std::shared_ptr<MMEmbeddingCall> call);
52+
53+
private:
54+
DISALLOW_COPY_AND_ASSIGN(MMEmbeddingServiceImpl);
55+
VLMMaster* master_ = nullptr;
56+
};
57+
4358
} // namespace xllm

xllm/api_service/non_stream_call.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ namespace xllm {
3333
template <typename Request, typename Response>
3434
class NonStreamCall : public Call {
3535
public:
36+
using ReqType = Request;
37+
using ResType = Response;
3638
NonStreamCall(brpc::Controller* controller,
3739
::google::protobuf::Closure* done,
3840
Request* request,

xllm/core/framework/model/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ cc_library(
3232
causal_vlm.h
3333
dit_model.h
3434
embedding_lm.h
35+
embedding_vlm.h
3536
model_args.h
3637
npu_dp_ep_padding.h
3738
model_input_params.h
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#pragma once
17+
18+
#include <c10/core/Device.h>
19+
#include <torch/torch.h>
20+
21+
#include <vector>
22+
23+
#include "causal_vlm.h"
24+
#include "core/framework/kv_cache/kv_cache.h"
25+
#include "core/framework/quant_args.h"
26+
#include "core/framework/state_dict/state_dict.h"
27+
#include "model_args.h"
28+
#include "model_input_params.h"
29+
30+
namespace xllm {
31+
32+
class EmbeddingVLM : public CausalVLM {
33+
public:
34+
~EmbeddingVLM() override = default;
35+
36+
// hidden_states: [num_tokens, hidden_size]
37+
// seleted_idxes: [num_tokens]
38+
// returns: [num_seqs, hidden_size]
39+
virtual torch::Tensor pooler(const torch::Tensor& hidden_states,
40+
const torch::Tensor& seleted_idxes) = 0;
41+
};
42+
43+
template <typename Model>
44+
class EmbeddingVLMImpl : public EmbeddingVLM {
45+
public:
46+
EmbeddingVLMImpl(Model model, const torch::TensorOptions& options)
47+
: model_(std::move(model)), options_(options) {}
48+
49+
torch::Tensor logits(const torch::Tensor& hidden_states,
50+
const torch::Tensor& seleted_idxes) override {
51+
return model_->logits(hidden_states, seleted_idxes);
52+
}
53+
54+
torch::Tensor pooler(const torch::Tensor& hidden_states,
55+
const torch::Tensor& seleted_idxes) override {
56+
return model_->pooler(hidden_states, seleted_idxes);
57+
}
58+
59+
void load_model(std::unique_ptr<ModelLoader> loader) override {
60+
model_->load_model(std::move(loader));
61+
}
62+
63+
torch::Device device() const override { return options_.device(); }
64+
65+
const torch::TensorOptions& options() const override { return options_; }
66+
67+
private:
68+
Model model_;
69+
70+
torch::TensorOptions options_;
71+
};
72+
73+
} // namespace xllm

xllm/core/framework/request/request_params.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,18 @@ RequestParams::RequestParams(const proto::EmbeddingRequest& request,
337337
max_tokens = 1;
338338
streaming = false;
339339
}
340+
RequestParams::RequestParams(const proto::MMEmbeddingRequest& request,
341+
const std::string& x_rid,
342+
const std::string& x_rtime) {
343+
if (request.has_service_request_id()) {
344+
service_request_id = request.service_request_id();
345+
}
346+
x_request_id = x_rid;
347+
x_request_time = x_rtime;
348+
is_embeddings = true;
349+
max_tokens = 1;
350+
streaming = false;
351+
}
340352

341353
RequestParams::RequestParams(const proto::RerankRequest& request,
342354
const std::string& x_rid,

xllm/core/framework/request/request_params.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ struct RequestParams {
4949
RequestParams(const proto::EmbeddingRequest& request,
5050
const std::string& x_rid,
5151
const std::string& x_rtime);
52+
RequestParams(const proto::MMEmbeddingRequest& request,
53+
const std::string& x_rid,
54+
const std::string& x_rtime);
5255
RequestParams(const proto::RerankRequest& request,
5356
const std::string& x_rid,
5457
const std::string& x_rtime);

xllm/core/runtime/embed_vlm_worker_impl.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ bool EmbedVLMWorkerImpl::init_model(ModelContext& context) {
4444
CHECK(model_ == nullptr) << "Model is already initialized.";
4545

4646
context.set_image_embedding_mode(true);
47-
model_ = create_vlm_model(context);
47+
model_ = create_embeddingvlm_model(context);
4848
CHECK(model_ != nullptr) << "Failed to create model.";
4949
model_executor_ = std::make_unique<Executor>(
5050
model_.get(), context.get_model_args(), device_, options_);
@@ -80,7 +80,18 @@ std::optional<ForwardOutput> EmbedVLMWorkerImpl::step(
8080

8181
// driver prepare model output
8282
ForwardOutput output;
83-
output.embedding = hidden_states;
83+
SampleOutput sample_output;
84+
85+
if (sampling_params.selected_token_idxes.defined() &&
86+
inputs.micro_inputs[0].sampling_params.is_embeddings) {
87+
EmbeddingVLM* em_model = dynamic_cast<EmbeddingVLM*>(model_.get());
88+
auto embeddings =
89+
em_model->pooler(hidden_states, sampling_params.selected_token_idxes);
90+
sample_output.embeddings = embeddings;
91+
output.sample_output = sample_output;
92+
output.embedding = embeddings;
93+
}
94+
8495
return output;
8596
}
8697

0 commit comments

Comments
 (0)