diff --git a/lldb/include/lldb/Host/Socket.h b/lldb/include/lldb/Host/Socket.h index e98797b36c8a5..4585eac12efb9 100644 --- a/lldb/include/lldb/Host/Socket.h +++ b/lldb/include/lldb/Host/Socket.h @@ -11,6 +11,7 @@ #include #include +#include #include "lldb/Host/MainLoopBase.h" #include "lldb/Utility/Timeout.h" @@ -151,6 +152,11 @@ class Socket : public IOObject { // If this Socket is connected then return the URI used to connect. virtual std::string GetRemoteConnectionURI() const { return ""; }; + // If the Socket is listening then return the URI for clients to connect. + virtual std::vector GetListeningConnectionURI() const { + return {}; + } + protected: Socket(SocketProtocol protocol, bool should_close); diff --git a/lldb/include/lldb/Host/common/TCPSocket.h b/lldb/include/lldb/Host/common/TCPSocket.h index ca36622691fe9..cb950c0015ea6 100644 --- a/lldb/include/lldb/Host/common/TCPSocket.h +++ b/lldb/include/lldb/Host/common/TCPSocket.h @@ -13,6 +13,8 @@ #include "lldb/Host/Socket.h" #include "lldb/Host/SocketAddress.h" #include +#include +#include namespace lldb_private { class TCPSocket : public Socket { @@ -52,6 +54,8 @@ class TCPSocket : public Socket { std::string GetRemoteConnectionURI() const override; + std::vector GetListeningConnectionURI() const override; + private: TCPSocket(NativeSocket socket, const TCPSocket &listen_socket); diff --git a/lldb/include/lldb/Host/posix/DomainSocket.h b/lldb/include/lldb/Host/posix/DomainSocket.h index d4e0d43ee169c..3dbe6206da2c5 100644 --- a/lldb/include/lldb/Host/posix/DomainSocket.h +++ b/lldb/include/lldb/Host/posix/DomainSocket.h @@ -10,6 +10,8 @@ #define LLDB_HOST_POSIX_DOMAINSOCKET_H #include "lldb/Host/Socket.h" +#include +#include namespace lldb_private { class DomainSocket : public Socket { @@ -27,6 +29,8 @@ class DomainSocket : public Socket { std::string GetRemoteConnectionURI() const override; + std::vector GetListeningConnectionURI() const override; + protected: DomainSocket(SocketProtocol protocol); diff --git a/lldb/source/Host/common/TCPSocket.cpp b/lldb/source/Host/common/TCPSocket.cpp index 5d863954ee886..d0055c3b6c44f 100644 --- a/lldb/source/Host/common/TCPSocket.cpp +++ b/lldb/source/Host/common/TCPSocket.cpp @@ -115,6 +115,14 @@ std::string TCPSocket::GetRemoteConnectionURI() const { return ""; } +std::vector TCPSocket::GetListeningConnectionURI() const { + std::vector URIs; + for (const auto &[fd, addr] : m_listen_sockets) + URIs.emplace_back(llvm::formatv("connection://[{0}]:{1}", + addr.GetIPAddress(), addr.GetPort())); + return URIs; +} + Status TCPSocket::CreateSocket(int domain) { Status error; if (IsValid()) diff --git a/lldb/source/Host/posix/DomainSocket.cpp b/lldb/source/Host/posix/DomainSocket.cpp index 0451834630d33..9a0b385d998bf 100644 --- a/lldb/source/Host/posix/DomainSocket.cpp +++ b/lldb/source/Host/posix/DomainSocket.cpp @@ -175,3 +175,17 @@ std::string DomainSocket::GetRemoteConnectionURI() const { "{0}://{1}", GetNameOffset() == 0 ? "unix-connect" : "unix-abstract-connect", name); } + +std::vector DomainSocket::GetListeningConnectionURI() const { + if (m_socket == kInvalidSocketValue) + return {}; + + struct sockaddr_un addr; + bzero(&addr, sizeof(struct sockaddr_un)); + addr.sun_family = AF_UNIX; + socklen_t addr_len = sizeof(struct sockaddr_un); + if (::getsockname(m_socket, (struct sockaddr *)&addr, &addr_len) != 0) + return {}; + + return {llvm::formatv("unix-connect://{0}", addr.sun_path)}; +} diff --git a/lldb/unittests/Host/SocketTest.cpp b/lldb/unittests/Host/SocketTest.cpp index b20cfe5464028..a74352c19725d 100644 --- a/lldb/unittests/Host/SocketTest.cpp +++ b/lldb/unittests/Host/SocketTest.cpp @@ -88,6 +88,28 @@ TEST_P(SocketTest, DomainListenConnectAccept) { CreateDomainConnectedSockets(Path, &socket_a_up, &socket_b_up); } +TEST_P(SocketTest, DomainListenGetListeningConnectionURI) { + llvm::SmallString<64> Path; + std::error_code EC = + llvm::sys::fs::createUniqueDirectory("DomainListenConnectAccept", Path); + ASSERT_FALSE(EC); + llvm::sys::path::append(Path, "test"); + + // Skip the test if the $TMPDIR is too long to hold a domain socket. + if (Path.size() > 107u) + return; + + auto listen_socket_up = std::make_unique( + /*should_close=*/true); + Status error = listen_socket_up->Listen(Path, 5); + ASSERT_THAT_ERROR(error.ToError(), llvm::Succeeded()); + ASSERT_TRUE(listen_socket_up->IsValid()); + + ASSERT_THAT( + listen_socket_up->GetListeningConnectionURI(), + testing::ElementsAre(llvm::formatv("unix-connect://{0}", Path).str())); +} + TEST_P(SocketTest, DomainMainLoopAccept) { llvm::SmallString<64> Path; std::error_code EC = @@ -225,12 +247,29 @@ TEST_P(SocketTest, TCPListen0GetPort) { if (!HostSupportsIPv4()) return; llvm::Expected> sock = - Socket::TcpListen("10.10.12.3:0", false); + Socket::TcpListen("10.10.12.3:0", 5); ASSERT_THAT_EXPECTED(sock, llvm::Succeeded()); ASSERT_TRUE(sock.get()->IsValid()); EXPECT_NE(sock.get()->GetLocalPortNumber(), 0); } +TEST_P(SocketTest, TCPListen0GetListeningConnectionURI) { + if (!HostSupportsProtocol()) + return; + + std::string addr = llvm::formatv("[{0}]:0", GetParam().localhost_ip).str(); + llvm::Expected> sock = Socket::TcpListen(addr); + ASSERT_THAT_EXPECTED(sock, llvm::Succeeded()); + ASSERT_TRUE(sock.get()->IsValid()); + + EXPECT_THAT( + sock.get()->GetListeningConnectionURI(), + testing::ElementsAre(llvm::formatv("connection://[{0}]:{1}", + GetParam().localhost_ip, + sock->get()->GetLocalPortNumber()) + .str())); +} + TEST_P(SocketTest, TCPGetConnectURI) { std::unique_ptr socket_a_up; std::unique_ptr socket_b_up;