Skip to content

Commit cb70e57

Browse files
committed
[lldb] Make MCP server instance global (llvm#145616)
Rather than having one MCP server per debugger, make the MCP server global and pass a debugger id along with tool invocations that require one. This PR also adds a second tool to list the available debuggers with their targets so the model can decide which debugger instance to use. (cherry picked from commit e8abdfc)
1 parent 2b06012 commit cb70e57

File tree

13 files changed

+180
-136
lines changed

13 files changed

+180
-136
lines changed

lldb/include/lldb/Core/Debugger.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -617,10 +617,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
617617
void FlushProcessOutput(Process &process, bool flush_stdout,
618618
bool flush_stderr);
619619

620-
void AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
621-
void RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp);
622-
lldb::ProtocolServerSP GetProtocolServer(llvm::StringRef protocol) const;
623-
624620
SourceManager::SourceFileCache &GetSourceFileCache() {
625621
return m_source_file_cache;
626622
}
@@ -793,8 +789,6 @@ class Debugger : public std::enable_shared_from_this<Debugger>,
793789
mutable std::mutex m_progress_reports_mutex;
794790
/// @}
795791

796-
llvm::SmallVector<lldb::ProtocolServerSP> m_protocol_servers;
797-
798792
std::mutex m_destroy_callback_mutex;
799793
lldb::callback_token_t m_destroy_callback_next_token = 0;
800794
struct DestroyCallbackInfo {

lldb/include/lldb/Core/ProtocolServer.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ class ProtocolServer : public PluginInterface {
2020
ProtocolServer() = default;
2121
virtual ~ProtocolServer() = default;
2222

23-
static lldb::ProtocolServerSP Create(llvm::StringRef name,
24-
Debugger &debugger);
23+
static ProtocolServer *GetOrCreate(llvm::StringRef name);
24+
25+
static std::vector<llvm::StringRef> GetSupportedProtocols();
2526

2627
struct Connection {
2728
Socket::SocketProtocol protocol;

lldb/include/lldb/lldb-forward.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ typedef std::shared_ptr<lldb_private::Platform> PlatformSP;
389389
typedef std::shared_ptr<lldb_private::Process> ProcessSP;
390390
typedef std::shared_ptr<lldb_private::ProcessAttachInfo> ProcessAttachInfoSP;
391391
typedef std::shared_ptr<lldb_private::ProcessLaunchInfo> ProcessLaunchInfoSP;
392-
typedef std::shared_ptr<lldb_private::ProtocolServer> ProtocolServerSP;
392+
typedef std::unique_ptr<lldb_private::ProtocolServer> ProtocolServerUP;
393393
typedef std::weak_ptr<lldb_private::Process> ProcessWP;
394394
typedef std::shared_ptr<lldb_private::RegisterCheckpoint> RegisterCheckpointSP;
395395
typedef std::shared_ptr<lldb_private::RegisterContext> RegisterContextSP;

lldb/include/lldb/lldb-private-interfaces.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ typedef lldb::PlatformSP (*PlatformCreateInstance)(bool force,
8282
typedef lldb::ProcessSP (*ProcessCreateInstance)(
8383
lldb::TargetSP target_sp, lldb::ListenerSP listener_sp,
8484
const FileSpec *crash_file_path, bool can_connect);
85-
typedef lldb::ProtocolServerSP (*ProtocolServerCreateInstance)(
86-
Debugger &debugger);
85+
typedef lldb::ProtocolServerUP (*ProtocolServerCreateInstance)();
8786
typedef lldb::RegisterTypeBuilderSP (*RegisterTypeBuilderCreateInstance)(
8887
Target &target);
8988
typedef lldb::ScriptInterpreterSP (*ScriptInterpreterCreateInstance)(

lldb/source/Commands/CommandObjectProtocolServer.cpp

Lines changed: 9 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,6 @@ using namespace lldb_private;
2424
#define LLDB_OPTIONS_mcp
2525
#include "CommandOptions.inc"
2626

27-
static std::vector<llvm::StringRef> GetSupportedProtocols() {
28-
std::vector<llvm::StringRef> supported_protocols;
29-
size_t i = 0;
30-
31-
for (llvm::StringRef protocol_name =
32-
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
33-
!protocol_name.empty();
34-
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
35-
supported_protocols.push_back(protocol_name);
36-
}
37-
38-
return supported_protocols;
39-
}
40-
4127
class CommandObjectProtocolServerStart : public CommandObjectParsed {
4228
public:
4329
CommandObjectProtocolServerStart(CommandInterpreter &interpreter)
@@ -58,12 +44,11 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
5844
}
5945

6046
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
61-
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
62-
if (llvm::find(supported_protocols, protocol) ==
63-
supported_protocols.end()) {
47+
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
48+
if (!server) {
6449
result.AppendErrorWithFormatv(
6550
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
66-
llvm::join(GetSupportedProtocols(), ", "));
51+
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
6752
return;
6853
}
6954

@@ -73,10 +58,6 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
7358
}
7459
llvm::StringRef connection_uri = args.GetArgumentAtIndex(1);
7560

76-
ProtocolServerSP server_sp = GetDebugger().GetProtocolServer(protocol);
77-
if (!server_sp)
78-
server_sp = ProtocolServer::Create(protocol, GetDebugger());
79-
8061
const char *connection_error =
8162
"unsupported connection specifier, expected 'accept:///path' or "
8263
"'listen://[host]:port', got '{0}'.";
@@ -99,14 +80,12 @@ class CommandObjectProtocolServerStart : public CommandObjectParsed {
9980
formatv("[{0}]:{1}", uri->hostname.empty() ? "0.0.0.0" : uri->hostname,
10081
uri->port.value_or(0));
10182

102-
if (llvm::Error error = server_sp->Start(connection)) {
83+
if (llvm::Error error = server->Start(connection)) {
10384
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
10485
return;
10586
}
10687

107-
GetDebugger().AddProtocolServer(server_sp);
108-
109-
if (Socket *socket = server_sp->GetSocket()) {
88+
if (Socket *socket = server->GetSocket()) {
11089
std::string address =
11190
llvm::join(socket->GetListeningConnectionURI(), ", ");
11291
result.AppendMessageWithFormatv(
@@ -135,30 +114,18 @@ class CommandObjectProtocolServerStop : public CommandObjectParsed {
135114
}
136115

137116
llvm::StringRef protocol = args.GetArgumentAtIndex(0);
138-
std::vector<llvm::StringRef> supported_protocols = GetSupportedProtocols();
139-
if (llvm::find(supported_protocols, protocol) ==
140-
supported_protocols.end()) {
117+
ProtocolServer *server = ProtocolServer::GetOrCreate(protocol);
118+
if (!server) {
141119
result.AppendErrorWithFormatv(
142120
"unsupported protocol: {0}. Supported protocols are: {1}", protocol,
143-
llvm::join(GetSupportedProtocols(), ", "));
121+
llvm::join(ProtocolServer::GetSupportedProtocols(), ", "));
144122
return;
145123
}
146124

147-
Debugger &debugger = GetDebugger();
148-
149-
ProtocolServerSP server_sp = debugger.GetProtocolServer(protocol);
150-
if (!server_sp) {
151-
result.AppendError(
152-
llvm::formatv("no {0} protocol server running", protocol).str());
153-
return;
154-
}
155-
156-
if (llvm::Error error = server_sp->Stop()) {
125+
if (llvm::Error error = server->Stop()) {
157126
result.AppendErrorWithFormatv("{0}", llvm::fmt_consume(std::move(error)));
158127
return;
159128
}
160-
161-
debugger.RemoveProtocolServer(server_sp);
162129
}
163130
};
164131

lldb/source/Core/Debugger.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,26 +2380,3 @@ llvm::ThreadPoolInterface &Debugger::GetThreadPool() {
23802380
"Debugger::GetThreadPool called before Debugger::Initialize");
23812381
return *g_thread_pool;
23822382
}
2383-
2384-
void Debugger::AddProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
2385-
assert(protocol_server_sp &&
2386-
GetProtocolServer(protocol_server_sp->GetPluginName()) == nullptr);
2387-
m_protocol_servers.push_back(protocol_server_sp);
2388-
}
2389-
2390-
void Debugger::RemoveProtocolServer(lldb::ProtocolServerSP protocol_server_sp) {
2391-
auto it = llvm::find(m_protocol_servers, protocol_server_sp);
2392-
if (it != m_protocol_servers.end())
2393-
m_protocol_servers.erase(it);
2394-
}
2395-
2396-
lldb::ProtocolServerSP
2397-
Debugger::GetProtocolServer(llvm::StringRef protocol) const {
2398-
for (ProtocolServerSP protocol_server_sp : m_protocol_servers) {
2399-
if (!protocol_server_sp)
2400-
continue;
2401-
if (protocol_server_sp->GetPluginName() == protocol)
2402-
return protocol_server_sp;
2403-
}
2404-
return nullptr;
2405-
}

lldb/source/Core/ProtocolServer.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,36 @@
1212
using namespace lldb_private;
1313
using namespace lldb;
1414

15-
ProtocolServerSP ProtocolServer::Create(llvm::StringRef name,
16-
Debugger &debugger) {
15+
ProtocolServer *ProtocolServer::GetOrCreate(llvm::StringRef name) {
16+
static std::mutex g_mutex;
17+
static llvm::StringMap<ProtocolServerUP> g_protocol_server_instances;
18+
19+
std::lock_guard<std::mutex> guard(g_mutex);
20+
21+
auto it = g_protocol_server_instances.find(name);
22+
if (it != g_protocol_server_instances.end())
23+
return it->second.get();
24+
1725
if (ProtocolServerCreateInstance create_callback =
18-
PluginManager::GetProtocolCreateCallbackForPluginName(name))
19-
return create_callback(debugger);
26+
PluginManager::GetProtocolCreateCallbackForPluginName(name)) {
27+
auto pair =
28+
g_protocol_server_instances.try_emplace(name, create_callback());
29+
return pair.first->second.get();
30+
}
31+
2032
return nullptr;
2133
}
34+
35+
std::vector<llvm::StringRef> ProtocolServer::GetSupportedProtocols() {
36+
std::vector<llvm::StringRef> supported_protocols;
37+
size_t i = 0;
38+
39+
for (llvm::StringRef protocol_name =
40+
PluginManager::GetProtocolServerPluginNameAtIndex(i++);
41+
!protocol_name.empty();
42+
protocol_name = PluginManager::GetProtocolServerPluginNameAtIndex(i++)) {
43+
supported_protocols.push_back(protocol_name);
44+
}
45+
46+
return supported_protocols;
47+
}

lldb/source/Plugins/Protocol/MCP/Protocol.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ using Message = std::variant<Request, Response, Notification, Error>;
123123
bool fromJSON(const llvm::json::Value &, Message &, llvm::json::Path);
124124
llvm::json::Value toJSON(const Message &);
125125

126+
using ToolArguments = std::variant<std::monostate, llvm::json::Value>;
127+
126128
} // namespace lldb_private::mcp::protocol
127129

128130
#endif

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ LLDB_PLUGIN_DEFINE(ProtocolServerMCP)
2424

2525
static constexpr size_t kChunkSize = 1024;
2626

27-
ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
28-
: ProtocolServer(), m_debugger(debugger) {
27+
ProtocolServerMCP::ProtocolServerMCP() : ProtocolServer() {
2928
AddRequestHandler("initialize",
3029
std::bind(&ProtocolServerMCP::InitializeHandler, this,
3130
std::placeholders::_1));
@@ -39,8 +38,10 @@ ProtocolServerMCP::ProtocolServerMCP(Debugger &debugger)
3938
"notifications/initialized", [](const protocol::Notification &) {
4039
LLDB_LOG(GetLog(LLDBLog::Host), "MCP initialization complete");
4140
});
42-
AddTool(std::make_unique<LLDBCommandTool>(
43-
"lldb_command", "Run an lldb command.", m_debugger));
41+
AddTool(
42+
std::make_unique<CommandTool>("lldb_command", "Run an lldb command."));
43+
AddTool(std::make_unique<DebuggerListTool>(
44+
"lldb_debugger_list", "List debugger instances with their debugger_id."));
4445
}
4546

4647
ProtocolServerMCP::~ProtocolServerMCP() { llvm::consumeError(Stop()); }
@@ -54,8 +55,8 @@ void ProtocolServerMCP::Terminate() {
5455
PluginManager::UnregisterPlugin(CreateInstance);
5556
}
5657

57-
lldb::ProtocolServerSP ProtocolServerMCP::CreateInstance(Debugger &debugger) {
58-
return std::make_shared<ProtocolServerMCP>(debugger);
58+
lldb::ProtocolServerUP ProtocolServerMCP::CreateInstance() {
59+
return std::make_unique<ProtocolServerMCP>();
5960
}
6061

6162
llvm::StringRef ProtocolServerMCP::GetPluginDescriptionStatic() {
@@ -145,7 +146,7 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
145146
std::lock_guard<std::mutex> guard(m_server_mutex);
146147

147148
if (m_running)
148-
return llvm::createStringError("server already running");
149+
return llvm::createStringError("the MCP server is already running");
149150

150151
Status status;
151152
m_listener = Socket::Create(connection.protocol, false, status);
@@ -164,10 +165,10 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
164165
if (llvm::Error error = handles.takeError())
165166
return error;
166167

168+
m_running = true;
167169
m_listen_handlers = std::move(*handles);
168170
m_loop_thread = std::thread([=] {
169-
llvm::set_thread_name(
170-
llvm::formatv("debugger-{0}.mcp.runloop", m_debugger.GetID()));
171+
llvm::set_thread_name("protocol-server.mcp");
171172
m_loop.Run();
172173
});
173174

@@ -177,6 +178,8 @@ llvm::Error ProtocolServerMCP::Start(ProtocolServer::Connection connection) {
177178
llvm::Error ProtocolServerMCP::Stop() {
178179
{
179180
std::lock_guard<std::mutex> guard(m_server_mutex);
181+
if (!m_running)
182+
return createStringError("the MCP sever is not running");
180183
m_running = false;
181184
}
182185

@@ -313,11 +316,12 @@ ProtocolServerMCP::ToolsCallHandler(const protocol::Request &request) {
313316
if (it == m_tools.end())
314317
return llvm::createStringError(llvm::formatv("no tool \"{0}\"", tool_name));
315318

316-
const json::Value *args = param_obj->get("arguments");
317-
if (!args)
318-
return llvm::createStringError("no tool arguments");
319+
protocol::ToolArguments tool_args;
320+
if (const json::Value *args = param_obj->get("arguments"))
321+
tool_args = *args;
319322

320-
llvm::Expected<protocol::TextResult> text_result = it->second->Call(*args);
323+
llvm::Expected<protocol::TextResult> text_result =
324+
it->second->Call(tool_args);
321325
if (!text_result)
322326
return text_result.takeError();
323327

lldb/source/Plugins/Protocol/MCP/ProtocolServerMCP.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace lldb_private::mcp {
2121

2222
class ProtocolServerMCP : public ProtocolServer {
2323
public:
24-
ProtocolServerMCP(Debugger &debugger);
24+
ProtocolServerMCP();
2525
virtual ~ProtocolServerMCP() override;
2626

2727
virtual llvm::Error Start(ProtocolServer::Connection connection) override;
@@ -33,7 +33,7 @@ class ProtocolServerMCP : public ProtocolServer {
3333
static llvm::StringRef GetPluginNameStatic() { return "MCP"; }
3434
static llvm::StringRef GetPluginDescriptionStatic();
3535

36-
static lldb::ProtocolServerSP CreateInstance(Debugger &debugger);
36+
static lldb::ProtocolServerUP CreateInstance();
3737

3838
llvm::StringRef GetPluginName() override { return GetPluginNameStatic(); }
3939

@@ -71,8 +71,6 @@ class ProtocolServerMCP : public ProtocolServer {
7171
llvm::StringLiteral kName = "lldb-mcp";
7272
llvm::StringLiteral kVersion = "0.1.0";
7373

74-
Debugger &m_debugger;
75-
7674
bool m_running = false;
7775

7876
MainLoop m_loop;

0 commit comments

Comments
 (0)