1 //===-- TCPSocket.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #if defined(_MSC_VER)
10 #define _WINSOCK_DEPRECATED_NO_WARNINGS
11 #endif
12
13 #include "lldb/Host/common/TCPSocket.h"
14
15 #include "lldb/Host/Config.h"
16 #include "lldb/Host/MainLoop.h"
17 #include "lldb/Utility/LLDBLog.h"
18 #include "lldb/Utility/Log.h"
19
20 #include "llvm/Config/llvm-config.h"
21 #include "llvm/Support/Errno.h"
22 #include "llvm/Support/WindowsError.h"
23 #include "llvm/Support/raw_ostream.h"
24
25 #if LLDB_ENABLE_POSIX
26 #include <arpa/inet.h>
27 #include <netinet/tcp.h>
28 #include <sys/socket.h>
29 #endif
30
31 #if defined(_WIN32)
32 #include <winsock2.h>
33 #endif
34
35 #ifdef _WIN32
36 #define CLOSE_SOCKET closesocket
37 typedef const char *set_socket_option_arg_type;
38 #else
39 #include <unistd.h>
40 #define CLOSE_SOCKET ::close
41 typedef const void *set_socket_option_arg_type;
42 #endif
43
44 using namespace lldb;
45 using namespace lldb_private;
46
GetLastSocketError()47 static Status GetLastSocketError() {
48 std::error_code EC;
49 #ifdef _WIN32
50 EC = llvm::mapWindowsError(WSAGetLastError());
51 #else
52 EC = std::error_code(errno, std::generic_category());
53 #endif
54 return EC;
55 }
56
57 static const int kType = SOCK_STREAM;
58
TCPSocket(bool should_close,bool child_processes_inherit)59 TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit)
60 : Socket(ProtocolTcp, should_close, child_processes_inherit) {}
61
TCPSocket(NativeSocket socket,const TCPSocket & listen_socket)62 TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
63 : Socket(ProtocolTcp, listen_socket.m_should_close_fd,
64 listen_socket.m_child_processes_inherit) {
65 m_socket = socket;
66 }
67
TCPSocket(NativeSocket socket,bool should_close,bool child_processes_inherit)68 TCPSocket::TCPSocket(NativeSocket socket, bool should_close,
69 bool child_processes_inherit)
70 : Socket(ProtocolTcp, should_close, child_processes_inherit) {
71 m_socket = socket;
72 }
73
~TCPSocket()74 TCPSocket::~TCPSocket() { CloseListenSockets(); }
75
IsValid() const76 bool TCPSocket::IsValid() const {
77 return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
78 }
79
80 // Return the port number that is being used by the socket.
GetLocalPortNumber() const81 uint16_t TCPSocket::GetLocalPortNumber() const {
82 if (m_socket != kInvalidSocketValue) {
83 SocketAddress sock_addr;
84 socklen_t sock_addr_len = sock_addr.GetMaxLength();
85 if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
86 return sock_addr.GetPort();
87 } else if (!m_listen_sockets.empty()) {
88 SocketAddress sock_addr;
89 socklen_t sock_addr_len = sock_addr.GetMaxLength();
90 if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
91 &sock_addr_len) == 0)
92 return sock_addr.GetPort();
93 }
94 return 0;
95 }
96
GetLocalIPAddress() const97 std::string TCPSocket::GetLocalIPAddress() const {
98 // We bound to port zero, so we need to figure out which port we actually
99 // bound to
100 if (m_socket != kInvalidSocketValue) {
101 SocketAddress sock_addr;
102 socklen_t sock_addr_len = sock_addr.GetMaxLength();
103 if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
104 return sock_addr.GetIPAddress();
105 }
106 return "";
107 }
108
GetRemotePortNumber() const109 uint16_t TCPSocket::GetRemotePortNumber() const {
110 if (m_socket != kInvalidSocketValue) {
111 SocketAddress sock_addr;
112 socklen_t sock_addr_len = sock_addr.GetMaxLength();
113 if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
114 return sock_addr.GetPort();
115 }
116 return 0;
117 }
118
GetRemoteIPAddress() const119 std::string TCPSocket::GetRemoteIPAddress() const {
120 // We bound to port zero, so we need to figure out which port we actually
121 // bound to
122 if (m_socket != kInvalidSocketValue) {
123 SocketAddress sock_addr;
124 socklen_t sock_addr_len = sock_addr.GetMaxLength();
125 if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
126 return sock_addr.GetIPAddress();
127 }
128 return "";
129 }
130
GetRemoteConnectionURI() const131 std::string TCPSocket::GetRemoteConnectionURI() const {
132 if (m_socket != kInvalidSocketValue) {
133 return std::string(llvm::formatv(
134 "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
135 }
136 return "";
137 }
138
CreateSocket(int domain)139 Status TCPSocket::CreateSocket(int domain) {
140 Status error;
141 if (IsValid())
142 error = Close();
143 if (error.Fail())
144 return error;
145 m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP,
146 m_child_processes_inherit, error);
147 return error;
148 }
149
Connect(llvm::StringRef name)150 Status TCPSocket::Connect(llvm::StringRef name) {
151
152 Log *log = GetLog(LLDBLog::Communication);
153 LLDB_LOGF(log, "TCPSocket::%s (host/port = %s)", __FUNCTION__, name.data());
154
155 Status error;
156 llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
157 if (!host_port)
158 return Status(host_port.takeError());
159
160 std::vector<SocketAddress> addresses =
161 SocketAddress::GetAddressInfo(host_port->hostname.c_str(), nullptr,
162 AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
163 for (SocketAddress &address : addresses) {
164 error = CreateSocket(address.GetFamily());
165 if (error.Fail())
166 continue;
167
168 address.SetPort(host_port->port);
169
170 if (-1 == llvm::sys::RetryAfterSignal(-1, ::connect, GetNativeSocket(),
171 &address.sockaddr(),
172 address.GetLength())) {
173 Close();
174 continue;
175 }
176
177 SetOptionNoDelay();
178
179 error.Clear();
180 return error;
181 }
182
183 error.SetErrorString("Failed to connect port");
184 return error;
185 }
186
Listen(llvm::StringRef name,int backlog)187 Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
188 Log *log = GetLog(LLDBLog::Connection);
189 LLDB_LOGF(log, "TCPSocket::%s (%s)", __FUNCTION__, name.data());
190
191 Status error;
192 llvm::Expected<HostAndPort> host_port = DecodeHostAndPort(name);
193 if (!host_port)
194 return Status(host_port.takeError());
195
196 if (host_port->hostname == "*")
197 host_port->hostname = "0.0.0.0";
198 std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
199 host_port->hostname.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
200 for (SocketAddress &address : addresses) {
201 int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP,
202 m_child_processes_inherit, error);
203 if (error.Fail())
204 continue;
205
206 // enable local address reuse
207 int option_value = 1;
208 set_socket_option_arg_type option_value_p =
209 reinterpret_cast<set_socket_option_arg_type>(&option_value);
210 ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p,
211 sizeof(option_value));
212
213 SocketAddress listen_address = address;
214 if(!listen_address.IsLocalhost())
215 listen_address.SetToAnyAddress(address.GetFamily(), host_port->port);
216 else
217 listen_address.SetPort(host_port->port);
218
219 int err =
220 ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
221 if (-1 != err)
222 err = ::listen(fd, backlog);
223
224 if (-1 == err) {
225 error = GetLastSocketError();
226 CLOSE_SOCKET(fd);
227 continue;
228 }
229
230 if (host_port->port == 0) {
231 socklen_t sa_len = address.GetLength();
232 if (getsockname(fd, &address.sockaddr(), &sa_len) == 0)
233 host_port->port = address.GetPort();
234 }
235 m_listen_sockets[fd] = address;
236 }
237
238 if (m_listen_sockets.empty()) {
239 assert(error.Fail());
240 return error;
241 }
242 return Status();
243 }
244
CloseListenSockets()245 void TCPSocket::CloseListenSockets() {
246 for (auto socket : m_listen_sockets)
247 CLOSE_SOCKET(socket.first);
248 m_listen_sockets.clear();
249 }
250
Accept(Socket * & conn_socket)251 Status TCPSocket::Accept(Socket *&conn_socket) {
252 Status error;
253 if (m_listen_sockets.size() == 0) {
254 error.SetErrorString("No open listening sockets!");
255 return error;
256 }
257
258 int sock = -1;
259 int listen_sock = -1;
260 lldb_private::SocketAddress AcceptAddr;
261 MainLoop accept_loop;
262 std::vector<MainLoopBase::ReadHandleUP> handles;
263 for (auto socket : m_listen_sockets) {
264 auto fd = socket.first;
265 auto inherit = this->m_child_processes_inherit;
266 auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit));
267 handles.emplace_back(accept_loop.RegisterReadObject(
268 io_sp, [fd, inherit, &sock, &AcceptAddr, &error,
269 &listen_sock](MainLoopBase &loop) {
270 socklen_t sa_len = AcceptAddr.GetMaxLength();
271 sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit,
272 error);
273 listen_sock = fd;
274 loop.RequestTermination();
275 }, error));
276 if (error.Fail())
277 return error;
278 }
279
280 bool accept_connection = false;
281 std::unique_ptr<TCPSocket> accepted_socket;
282 // Loop until we are happy with our connection
283 while (!accept_connection) {
284 accept_loop.Run();
285
286 if (error.Fail())
287 return error;
288
289 lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock];
290 if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
291 CLOSE_SOCKET(sock);
292 llvm::errs() << llvm::formatv(
293 "error: rejecting incoming connection from {0} (expecting {1})",
294 AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
295 continue;
296 }
297 accept_connection = true;
298 accepted_socket.reset(new TCPSocket(sock, *this));
299 }
300
301 if (!accepted_socket)
302 return error;
303
304 // Keep our TCP packets coming without any delays.
305 accepted_socket->SetOptionNoDelay();
306 error.Clear();
307 conn_socket = accepted_socket.release();
308 return error;
309 }
310
SetOptionNoDelay()311 int TCPSocket::SetOptionNoDelay() {
312 return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
313 }
314
SetOptionReuseAddress()315 int TCPSocket::SetOptionReuseAddress() {
316 return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
317 }
318