1 //===-- TCPSocket.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #if defined(_MSC_VER)
10 #define _WINSOCK_DEPRECATED_NO_WARNINGS
11 #endif
12 
13 #include "lldb/Host/common/TCPSocket.h"
14 
15 #include "lldb/Host/Config.h"
16 #include "lldb/Host/MainLoop.h"
17 #include "lldb/Utility/Log.h"
18 
19 #include "llvm/Config/llvm-config.h"
20 #include "llvm/Support/Errno.h"
21 #include "llvm/Support/WindowsError.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 #if LLDB_ENABLE_POSIX
25 #include <arpa/inet.h>
26 #include <netinet/tcp.h>
27 #include <sys/socket.h>
28 #endif
29 
30 #if defined(_WIN32)
31 #include <winsock2.h>
32 #endif
33 
34 #ifdef _WIN32
35 #define CLOSE_SOCKET closesocket
36 typedef const char *set_socket_option_arg_type;
37 #else
38 #include <unistd.h>
39 #define CLOSE_SOCKET ::close
40 typedef const void *set_socket_option_arg_type;
41 #endif
42 
43 using namespace lldb;
44 using namespace lldb_private;
45 
46 static Status GetLastSocketError() {
47   std::error_code EC;
48 #ifdef _WIN32
49   EC = llvm::mapWindowsError(WSAGetLastError());
50 #else
51   EC = std::error_code(errno, std::generic_category());
52 #endif
53   return EC;
54 }
55 
56 namespace {
57 const int kType = SOCK_STREAM;
58 }
59 
60 TCPSocket::TCPSocket(bool should_close, bool child_processes_inherit)
61     : Socket(ProtocolTcp, should_close, child_processes_inherit) {}
62 
63 TCPSocket::TCPSocket(NativeSocket socket, const TCPSocket &listen_socket)
64     : Socket(ProtocolTcp, listen_socket.m_should_close_fd,
65              listen_socket.m_child_processes_inherit) {
66   m_socket = socket;
67 }
68 
69 TCPSocket::TCPSocket(NativeSocket socket, bool should_close,
70                      bool child_processes_inherit)
71     : Socket(ProtocolTcp, should_close, child_processes_inherit) {
72   m_socket = socket;
73 }
74 
75 TCPSocket::~TCPSocket() { CloseListenSockets(); }
76 
77 bool TCPSocket::IsValid() const {
78   return m_socket != kInvalidSocketValue || m_listen_sockets.size() != 0;
79 }
80 
81 // Return the port number that is being used by the socket.
82 uint16_t TCPSocket::GetLocalPortNumber() const {
83   if (m_socket != kInvalidSocketValue) {
84     SocketAddress sock_addr;
85     socklen_t sock_addr_len = sock_addr.GetMaxLength();
86     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
87       return sock_addr.GetPort();
88   } else if (!m_listen_sockets.empty()) {
89     SocketAddress sock_addr;
90     socklen_t sock_addr_len = sock_addr.GetMaxLength();
91     if (::getsockname(m_listen_sockets.begin()->first, sock_addr,
92                       &sock_addr_len) == 0)
93       return sock_addr.GetPort();
94   }
95   return 0;
96 }
97 
98 std::string TCPSocket::GetLocalIPAddress() const {
99   // We bound to port zero, so we need to figure out which port we actually
100   // bound to
101   if (m_socket != kInvalidSocketValue) {
102     SocketAddress sock_addr;
103     socklen_t sock_addr_len = sock_addr.GetMaxLength();
104     if (::getsockname(m_socket, sock_addr, &sock_addr_len) == 0)
105       return sock_addr.GetIPAddress();
106   }
107   return "";
108 }
109 
110 uint16_t TCPSocket::GetRemotePortNumber() const {
111   if (m_socket != kInvalidSocketValue) {
112     SocketAddress sock_addr;
113     socklen_t sock_addr_len = sock_addr.GetMaxLength();
114     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
115       return sock_addr.GetPort();
116   }
117   return 0;
118 }
119 
120 std::string TCPSocket::GetRemoteIPAddress() const {
121   // We bound to port zero, so we need to figure out which port we actually
122   // bound to
123   if (m_socket != kInvalidSocketValue) {
124     SocketAddress sock_addr;
125     socklen_t sock_addr_len = sock_addr.GetMaxLength();
126     if (::getpeername(m_socket, sock_addr, &sock_addr_len) == 0)
127       return sock_addr.GetIPAddress();
128   }
129   return "";
130 }
131 
132 std::string TCPSocket::GetRemoteConnectionURI() const {
133   if (m_socket != kInvalidSocketValue) {
134     return std::string(llvm::formatv(
135         "connect://[{0}]:{1}", GetRemoteIPAddress(), GetRemotePortNumber()));
136   }
137   return "";
138 }
139 
140 Status TCPSocket::CreateSocket(int domain) {
141   Status error;
142   if (IsValid())
143     error = Close();
144   if (error.Fail())
145     return error;
146   m_socket = Socket::CreateSocket(domain, kType, IPPROTO_TCP,
147                                   m_child_processes_inherit, error);
148   return error;
149 }
150 
151 Status TCPSocket::Connect(llvm::StringRef name) {
152 
153   Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_COMMUNICATION));
154   LLDB_LOGF(log, "TCPSocket::%s (host/port = %s)", __FUNCTION__, name.data());
155 
156   Status error;
157   std::string host_str;
158   std::string port_str;
159   uint16_t port;
160   if (llvm::Error decode_error =
161           DecodeHostAndPort(name, host_str, port_str, port))
162     return Status(std::move(decode_error));
163 
164   std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
165       host_str.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
166   for (SocketAddress &address : addresses) {
167     error = CreateSocket(address.GetFamily());
168     if (error.Fail())
169       continue;
170 
171     address.SetPort(port);
172 
173     if (-1 == llvm::sys::RetryAfterSignal(-1, ::connect,
174           GetNativeSocket(), &address.sockaddr(), address.GetLength())) {
175       CLOSE_SOCKET(GetNativeSocket());
176       continue;
177     }
178 
179     SetOptionNoDelay();
180 
181     error.Clear();
182     return error;
183   }
184 
185   error.SetErrorString("Failed to connect port");
186   return error;
187 }
188 
189 Status TCPSocket::Listen(llvm::StringRef name, int backlog) {
190   Log *log(lldb_private::GetLogIfAnyCategoriesSet(LIBLLDB_LOG_CONNECTION));
191   LLDB_LOGF(log, "TCPSocket::%s (%s)", __FUNCTION__, name.data());
192 
193   Status error;
194   std::string host_str;
195   std::string port_str;
196   uint16_t port;
197   if (llvm::Error decode_error =
198           DecodeHostAndPort(name, host_str, port_str, port))
199     return Status(std::move(decode_error));
200 
201   if (host_str == "*")
202     host_str = "0.0.0.0";
203   std::vector<SocketAddress> addresses = SocketAddress::GetAddressInfo(
204       host_str.c_str(), nullptr, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
205   for (SocketAddress &address : addresses) {
206     int fd = Socket::CreateSocket(address.GetFamily(), kType, IPPROTO_TCP,
207                                   m_child_processes_inherit, error);
208     if (error.Fail())
209       continue;
210 
211     // enable local address reuse
212     int option_value = 1;
213     set_socket_option_arg_type option_value_p =
214         reinterpret_cast<set_socket_option_arg_type>(&option_value);
215     ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, option_value_p,
216                  sizeof(option_value));
217 
218     SocketAddress listen_address = address;
219     if(!listen_address.IsLocalhost())
220       listen_address.SetToAnyAddress(address.GetFamily(), port);
221     else
222       listen_address.SetPort(port);
223 
224     int err =
225         ::bind(fd, &listen_address.sockaddr(), listen_address.GetLength());
226     if (-1 != err)
227       err = ::listen(fd, backlog);
228 
229     if (-1 == err) {
230       error = GetLastSocketError();
231       CLOSE_SOCKET(fd);
232       continue;
233     }
234 
235     if (port == 0) {
236       socklen_t sa_len = address.GetLength();
237       if (getsockname(fd, &address.sockaddr(), &sa_len) == 0)
238         port = address.GetPort();
239     }
240     m_listen_sockets[fd] = address;
241   }
242 
243   if (m_listen_sockets.empty()) {
244     assert(error.Fail());
245     return error;
246   }
247   return Status();
248 }
249 
250 void TCPSocket::CloseListenSockets() {
251   for (auto socket : m_listen_sockets)
252     CLOSE_SOCKET(socket.first);
253   m_listen_sockets.clear();
254 }
255 
256 Status TCPSocket::Accept(Socket *&conn_socket) {
257   Status error;
258   if (m_listen_sockets.size() == 0) {
259     error.SetErrorString("No open listening sockets!");
260     return error;
261   }
262 
263   int sock = -1;
264   int listen_sock = -1;
265   lldb_private::SocketAddress AcceptAddr;
266   MainLoop accept_loop;
267   std::vector<MainLoopBase::ReadHandleUP> handles;
268   for (auto socket : m_listen_sockets) {
269     auto fd = socket.first;
270     auto inherit = this->m_child_processes_inherit;
271     auto io_sp = IOObjectSP(new TCPSocket(socket.first, false, inherit));
272     handles.emplace_back(accept_loop.RegisterReadObject(
273         io_sp, [fd, inherit, &sock, &AcceptAddr, &error,
274                         &listen_sock](MainLoopBase &loop) {
275           socklen_t sa_len = AcceptAddr.GetMaxLength();
276           sock = AcceptSocket(fd, &AcceptAddr.sockaddr(), &sa_len, inherit,
277                               error);
278           listen_sock = fd;
279           loop.RequestTermination();
280         }, error));
281     if (error.Fail())
282       return error;
283   }
284 
285   bool accept_connection = false;
286   std::unique_ptr<TCPSocket> accepted_socket;
287   // Loop until we are happy with our connection
288   while (!accept_connection) {
289     accept_loop.Run();
290 
291     if (error.Fail())
292         return error;
293 
294     lldb_private::SocketAddress &AddrIn = m_listen_sockets[listen_sock];
295     if (!AddrIn.IsAnyAddr() && AcceptAddr != AddrIn) {
296       CLOSE_SOCKET(sock);
297       llvm::errs() << llvm::formatv(
298           "error: rejecting incoming connection from {0} (expecting {1})",
299           AcceptAddr.GetIPAddress(), AddrIn.GetIPAddress());
300       continue;
301     }
302     accept_connection = true;
303     accepted_socket.reset(new TCPSocket(sock, *this));
304   }
305 
306   if (!accepted_socket)
307     return error;
308 
309   // Keep our TCP packets coming without any delays.
310   accepted_socket->SetOptionNoDelay();
311   error.Clear();
312   conn_socket = accepted_socket.release();
313   return error;
314 }
315 
316 int TCPSocket::SetOptionNoDelay() {
317   return SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
318 }
319 
320 int TCPSocket::SetOptionReuseAddress() {
321   return SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
322 }
323