1 //===-- Socket.cpp ----------------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "lldb/Host/Socket.h"
11 
12 #include "lldb/Core/Log.h"
13 #include "lldb/Core/RegularExpression.h"
14 #include "lldb/Host/Config.h"
15 #include "lldb/Host/Host.h"
16 #include "lldb/Host/SocketAddress.h"
17 #include "lldb/Host/TimeValue.h"
18 #include "lldb/Interpreter/Args.h"
19 
20 #ifndef LLDB_DISABLE_POSIX
21 #include <arpa/inet.h>
22 #include <netdb.h>
23 #include <netinet/in.h>
24 #include <netinet/tcp.h>
25 #include <sys/socket.h>
26 #include <sys/un.h>
27 #endif
28 
29 using namespace lldb;
30 using namespace lldb_private;
31 
32 #if defined(_WIN32)
33 typedef const char * set_socket_option_arg_type;
34 typedef char * get_socket_option_arg_type;
35 const NativeSocket Socket::kInvalidSocketValue = INVALID_SOCKET;
36 #else // #if defined(_WIN32)
37 typedef const void * set_socket_option_arg_type;
38 typedef void * get_socket_option_arg_type;
39 const NativeSocket Socket::kInvalidSocketValue = -1;
40 #endif // #if defined(_WIN32)
41 
42 Socket::Socket(NativeSocket socket, SocketProtocol protocol, bool should_close)
43     : IOObject(eFDTypeSocket, should_close)
44     , m_protocol(protocol)
45     , m_socket(socket)
46 {
47 
48 }
49 
50 Socket::~Socket()
51 {
52     Close();
53 }
54 
55 Error Socket::TcpConnect(llvm::StringRef host_and_port, Socket *&socket)
56 {
57     // Store the result in a unique_ptr in case we error out, the memory will get correctly freed.
58     std::unique_ptr<Socket> final_socket;
59     NativeSocket sock = kInvalidSocketValue;
60     Error error;
61 
62     Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_HOST));
63     if (log)
64         log->Printf ("Socket::TcpConnect (host/port = %s)", host_and_port.data());
65 
66     std::string host_str;
67     std::string port_str;
68     int32_t port = INT32_MIN;
69     if (!DecodeHostAndPort (host_and_port, host_str, port_str, port, &error))
70         return error;
71 
72     // Create the socket
73     sock = ::socket (AF_INET, SOCK_STREAM, IPPROTO_TCP);
74     if (sock == kInvalidSocketValue)
75     {
76         // TODO: On Windows, use WSAGetLastError().
77         error.SetErrorToErrno();
78         return error;
79     }
80 
81     // Since they both refer to the same socket descriptor, arbitrarily choose the send socket to
82     // be the owner.
83     final_socket.reset(new Socket(sock, ProtocolTcp, true));
84 
85     // Enable local address reuse
86     final_socket->SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
87 
88     struct sockaddr_in sa;
89     ::memset (&sa, 0, sizeof (sa));
90     sa.sin_family = AF_INET;
91     sa.sin_port = htons (port);
92 
93     int inet_pton_result = ::inet_pton (AF_INET, host_str.c_str(), &sa.sin_addr);
94 
95     if (inet_pton_result <= 0)
96     {
97         struct hostent *host_entry = gethostbyname (host_str.c_str());
98         if (host_entry)
99             host_str = ::inet_ntoa (*(struct in_addr *)*host_entry->h_addr_list);
100         inet_pton_result = ::inet_pton (AF_INET, host_str.c_str(), &sa.sin_addr);
101         if (inet_pton_result <= 0)
102         {
103             // TODO: On Windows, use WSAGetLastError()
104             if (inet_pton_result == -1)
105                 error.SetErrorToErrno();
106             else
107                 error.SetErrorStringWithFormat("invalid host string: '%s'", host_str.c_str());
108 
109             return error;
110         }
111     }
112 
113     if (-1 == ::connect (sock, (const struct sockaddr *)&sa, sizeof(sa)))
114     {
115         // TODO: On Windows, use WSAGetLastError()
116         error.SetErrorToErrno();
117         return error;
118     }
119 
120     // Keep our TCP packets coming without any delays.
121     final_socket->SetOption(IPPROTO_TCP, TCP_NODELAY, 1);
122     error.Clear();
123     socket = final_socket.release();
124     return error;
125 }
126 
127 Error Socket::TcpListen(llvm::StringRef host_and_port, Socket *&socket, Predicate<uint16_t>* predicate)
128 {
129     std::unique_ptr<Socket> listen_socket;
130     NativeSocket listen_sock = kInvalidSocketValue;
131     Error error;
132 
133     const sa_family_t family = AF_INET;
134     const int socktype = SOCK_STREAM;
135     const int protocol = IPPROTO_TCP;
136     listen_sock = ::socket (family, socktype, protocol);
137     if (listen_sock == kInvalidSocketValue)
138     {
139         error.SetErrorToErrno();
140         return error;
141     }
142 
143     listen_socket.reset(new Socket(listen_sock, ProtocolTcp, true));
144 
145     // enable local address reuse
146     listen_socket->SetOption(SOL_SOCKET, SO_REUSEADDR, 1);
147 
148     Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_CONNECTION));
149     if (log)
150         log->Printf ("ConnectionFileDescriptor::SocketListen (%s)", host_and_port.data());
151 
152     std::string host_str;
153     std::string port_str;
154     int32_t port = INT32_MIN;
155     if (!DecodeHostAndPort (host_and_port, host_str, port_str, port, &error))
156         return error;
157 
158     SocketAddress anyaddr;
159     if (anyaddr.SetToAnyAddress (family, port))
160     {
161         int err = ::bind (listen_sock, anyaddr, anyaddr.GetLength());
162         if (err == -1)
163         {
164             // TODO: On Windows, use WSAGetLastError()
165             error.SetErrorToErrno();
166             return error;
167         }
168 
169         err = ::listen (listen_sock, 1);
170         if (err == -1)
171         {
172             // TODO: On Windows, use WSAGetLastError()
173             error.SetErrorToErrno();
174             return error;
175         }
176 
177         // We were asked to listen on port zero which means we
178         // must now read the actual port that was given to us
179         // as port zero is a special code for "find an open port
180         // for me".
181         if (port == 0)
182             port = listen_socket->GetPortNumber();
183 
184         // Set the port predicate since when doing a listen://<host>:<port>
185         // it often needs to accept the incoming connection which is a blocking
186         // system call. Allowing access to the bound port using a predicate allows
187         // us to wait for the port predicate to be set to a non-zero value from
188         // another thread in an efficient manor.
189         if (predicate)
190             predicate->SetValue(port, eBroadcastAlways);
191 
192         socket = listen_socket.release();
193     }
194 
195     return error;
196 }
197 
198 Error Socket::BlockingAccept(llvm::StringRef host_and_port, Socket *&socket)
199 {
200     Error error;
201     std::string host_str;
202     std::string port_str;
203     int32_t port;
204     if (!DecodeHostAndPort(host_and_port, host_str, port_str, port, &error))
205         return error;
206 
207     const sa_family_t family = AF_INET;
208     const int socktype = SOCK_STREAM;
209     const int protocol = IPPROTO_TCP;
210     SocketAddress listen_addr;
211     if (host_str.empty())
212         listen_addr.SetToLocalhost(family, port);
213     else if (host_str.compare("*") == 0)
214         listen_addr.SetToAnyAddress(family, port);
215     else
216     {
217         if (!listen_addr.getaddrinfo(host_str.c_str(), port_str.c_str(), family, socktype, protocol))
218         {
219             error.SetErrorStringWithFormat("unable to resolve hostname '%s'", host_str.c_str());
220             return error;
221         }
222     }
223 
224     bool accept_connection = false;
225     std::unique_ptr<Socket> accepted_socket;
226 
227     // Loop until we are happy with our connection
228     while (!accept_connection)
229     {
230         struct sockaddr_in accept_addr;
231         ::memset (&accept_addr, 0, sizeof accept_addr);
232 #if !(defined (__linux__) || defined(_WIN32))
233         accept_addr.sin_len = sizeof accept_addr;
234 #endif
235         socklen_t accept_addr_len = sizeof accept_addr;
236 
237         int sock = ::accept (this->GetNativeSocket(), (struct sockaddr *)&accept_addr, &accept_addr_len);
238 
239         if (sock == kInvalidSocketValue)
240         {
241             // TODO: On Windows, use WSAGetLastError()
242             error.SetErrorToErrno();
243             break;
244         }
245 
246         bool is_same_addr = true;
247 #if !(defined(__linux__) || (defined(_WIN32)))
248         is_same_addr = (accept_addr_len == listen_addr.sockaddr_in().sin_len);
249 #endif
250         if (is_same_addr)
251             is_same_addr = (accept_addr.sin_addr.s_addr == listen_addr.sockaddr_in().sin_addr.s_addr);
252 
253         if (is_same_addr || (listen_addr.sockaddr_in().sin_addr.s_addr == INADDR_ANY))
254         {
255             accept_connection = true;
256             // Since both sockets have the same descriptor, arbitrarily choose the send
257             // socket to be the owner.
258             accepted_socket.reset(new Socket(sock, ProtocolTcp, true));
259         }
260         else
261         {
262             const uint8_t *accept_ip = (const uint8_t *)&accept_addr.sin_addr.s_addr;
263             const uint8_t *listen_ip = (const uint8_t *)&listen_addr.sockaddr_in().sin_addr.s_addr;
264             ::fprintf (stderr, "error: rejecting incoming connection from %u.%u.%u.%u (expecting %u.%u.%u.%u)\n",
265                         accept_ip[0], accept_ip[1], accept_ip[2], accept_ip[3],
266                         listen_ip[0], listen_ip[1], listen_ip[2], listen_ip[3]);
267             accepted_socket.reset();
268         }
269     }
270 
271     if (!accepted_socket)
272         return error;
273 
274     // Keep our TCP packets coming without any delays.
275     accepted_socket->SetOption (IPPROTO_TCP, TCP_NODELAY, 1);
276     error.Clear();
277     socket = accepted_socket.release();
278     return error;
279 
280 }
281 
282 Error Socket::UdpConnect(llvm::StringRef host_and_port, Socket *&send_socket, Socket *&recv_socket)
283 {
284     std::unique_ptr<Socket> final_send_socket;
285     std::unique_ptr<Socket> final_recv_socket;
286     NativeSocket final_send_fd = kInvalidSocketValue;
287     NativeSocket final_recv_fd = kInvalidSocketValue;
288 
289     Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_CONNECTION));
290     if (log)
291         log->Printf ("Socket::UdpConnect (host/port = %s)", host_and_port.data());
292 
293     Error error;
294     std::string host_str;
295     std::string port_str;
296     int32_t port = INT32_MIN;
297     if (!DecodeHostAndPort (host_and_port, host_str, port_str, port, &error))
298         return error;
299 
300     // Setup the receiving end of the UDP connection on this localhost
301     // on port zero. After we bind to port zero we can read the port.
302     final_recv_fd = ::socket (AF_INET, SOCK_DGRAM, 0);
303     if (final_recv_fd == kInvalidSocketValue)
304     {
305         // Socket creation failed...
306         // TODO: On Windows, use WSAGetLastError().
307         error.SetErrorToErrno();
308     }
309     else
310     {
311         final_recv_socket.reset(new Socket(final_recv_fd, ProtocolUdp, true));
312 
313         // Socket was created, now lets bind to the requested port
314         SocketAddress addr;
315         addr.SetToAnyAddress (AF_INET, 0);
316 
317         if (::bind (final_recv_fd, addr, addr.GetLength()) == -1)
318         {
319             // Bind failed...
320             // TODO: On Windows use WSAGetLastError()
321             error.SetErrorToErrno();
322         }
323     }
324 
325     assert(error.Fail() == !(final_recv_socket && final_recv_socket->IsValid()));
326     if (error.Fail())
327         return error;
328 
329     // At this point we have setup the receive port, now we need to
330     // setup the UDP send socket
331 
332     struct addrinfo hints;
333     struct addrinfo *service_info_list = NULL;
334 
335     ::memset (&hints, 0, sizeof(hints));
336     hints.ai_family = AF_INET;
337     hints.ai_socktype = SOCK_DGRAM;
338     int err = ::getaddrinfo (host_str.c_str(), port_str.c_str(), &hints, &service_info_list);
339     if (err != 0)
340     {
341         error.SetErrorStringWithFormat("getaddrinfo(%s, %s, &hints, &info) returned error %i (%s)",
342                                        host_str.c_str(),
343                                        port_str.c_str(),
344                                        err,
345                                        gai_strerror(err));
346         return error;
347     }
348 
349     for (struct addrinfo *service_info_ptr = service_info_list;
350          service_info_ptr != NULL;
351          service_info_ptr = service_info_ptr->ai_next)
352     {
353         final_send_fd = ::socket (service_info_ptr->ai_family,
354                                   service_info_ptr->ai_socktype,
355                                   service_info_ptr->ai_protocol);
356 
357         if (final_send_fd != kInvalidSocketValue)
358         {
359             final_send_socket.reset(new Socket(final_send_fd, ProtocolUdp, true));
360             final_send_socket->m_udp_send_sockaddr = service_info_ptr;
361             break;
362         }
363         else
364             continue;
365     }
366 
367     :: freeaddrinfo (service_info_list);
368 
369     if (final_send_fd == kInvalidSocketValue)
370     {
371         // TODO: On Windows, use WSAGetLastError().
372         error.SetErrorToErrno();
373         return error;
374     }
375 
376     send_socket = final_send_socket.release();
377     recv_socket = final_recv_socket.release();
378     error.Clear();
379     return error;
380 }
381 
382 Error Socket::UnixDomainConnect(llvm::StringRef name, Socket *&socket)
383 {
384     Error error;
385 #ifndef LLDB_DISABLE_POSIX
386     std::unique_ptr<Socket> final_socket;
387 
388     // Open the socket that was passed in as an option
389     struct sockaddr_un saddr_un;
390     int fd = ::socket (AF_UNIX, SOCK_STREAM, 0);
391     if (fd == kInvalidSocketValue)
392     {
393         error.SetErrorToErrno();
394         return error;
395     }
396 
397     final_socket.reset(new Socket(fd, ProtocolUnixDomain, true));
398 
399     saddr_un.sun_family = AF_UNIX;
400     ::strncpy(saddr_un.sun_path, name.data(), sizeof(saddr_un.sun_path) - 1);
401     saddr_un.sun_path[sizeof(saddr_un.sun_path) - 1] = '\0';
402 #if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)
403     saddr_un.sun_len = SUN_LEN (&saddr_un);
404 #endif
405 
406     if (::connect (fd, (struct sockaddr *)&saddr_un, SUN_LEN (&saddr_un)) < 0)
407     {
408         error.SetErrorToErrno();
409         return error;
410     }
411 
412     socket = final_socket.release();
413 #else
414     error.SetErrorString("Unix domain sockets are not supported on this platform.");
415 #endif
416     return error;
417 }
418 
419 Error Socket::UnixDomainAccept(llvm::StringRef name, Socket *&socket)
420 {
421     Error error;
422 #ifndef LLDB_DISABLE_POSIX
423     struct sockaddr_un saddr_un;
424     std::unique_ptr<Socket> listen_socket;
425     std::unique_ptr<Socket> final_socket;
426     NativeSocket listen_fd = kInvalidSocketValue;
427     NativeSocket socket_fd = kInvalidSocketValue;
428 
429     listen_fd = ::socket (AF_UNIX, SOCK_STREAM, 0);
430     if (listen_fd == kInvalidSocketValue)
431     {
432         error.SetErrorToErrno();
433         return error;
434     }
435 
436     listen_socket.reset(new Socket(listen_fd, ProtocolUnixDomain, true));
437 
438     saddr_un.sun_family = AF_UNIX;
439     ::strncpy(saddr_un.sun_path, name.data(), sizeof(saddr_un.sun_path) - 1);
440     saddr_un.sun_path[sizeof(saddr_un.sun_path) - 1] = '\0';
441 #if defined(__APPLE__) || defined(__FreeBSD__) || defined(__NetBSD__)
442     saddr_un.sun_len = SUN_LEN (&saddr_un);
443 #endif
444 
445     Host::Unlink (name.data());
446     bool success = false;
447     if (::bind (listen_fd, (struct sockaddr *)&saddr_un, SUN_LEN (&saddr_un)) == 0)
448     {
449         if (::listen (listen_fd, 5) == 0)
450         {
451             socket_fd = ::accept (listen_fd, NULL, 0);
452             if (socket_fd > 0)
453             {
454                 final_socket.reset(new Socket(socket_fd, ProtocolUnixDomain, true));
455                 success = true;
456             }
457         }
458     }
459 
460     if (!success)
461     {
462         error.SetErrorToErrno();
463         return error;
464     }
465     // We are done with the listen port
466     listen_socket.reset();
467 
468     socket = final_socket.release();
469 #else
470     error.SetErrorString("Unix domain sockets are not supported on this platform.");
471 #endif
472     return error;
473 }
474 
475 bool
476 Socket::DecodeHostAndPort(llvm::StringRef host_and_port,
477                           std::string &host_str,
478                           std::string &port_str,
479                           int32_t& port,
480                           Error *error_ptr)
481 {
482     static RegularExpression g_regex ("([^:]+):([0-9]+)");
483     RegularExpression::Match regex_match(2);
484     if (g_regex.Execute (host_and_port.data(), &regex_match))
485     {
486         if (regex_match.GetMatchAtIndex (host_and_port.data(), 1, host_str) &&
487             regex_match.GetMatchAtIndex (host_and_port.data(), 2, port_str))
488         {
489             port = Args::StringToSInt32 (port_str.c_str(), INT32_MIN);
490             if (port != INT32_MIN)
491             {
492                 if (error_ptr)
493                     error_ptr->Clear();
494                 return true;
495             }
496         }
497     }
498 
499     // If this was unsuccessful, then check if it's simply a signed 32-bit integer, representing
500     // a port with an empty host.
501     host_str.clear();
502     port_str.clear();
503     port = Args::StringToSInt32(host_and_port.data(), INT32_MIN);
504     if (port != INT32_MIN)
505     {
506         port_str = host_and_port;
507         return true;
508     }
509 
510     if (error_ptr)
511         error_ptr->SetErrorStringWithFormat("invalid host:port specification: '%s'", host_and_port.data());
512     return false;
513 }
514 
515 IOObject::WaitableHandle Socket::GetWaitableHandle()
516 {
517     // TODO: On Windows, use WSAEventSelect
518     return m_socket;
519 }
520 
521 Error Socket::Read (void *buf, size_t &num_bytes)
522 {
523     Error error;
524     int bytes_received = 0;
525     do
526     {
527         bytes_received = ::recv (m_socket, static_cast<char *>(buf), num_bytes, 0);
528         // TODO: Use WSAGetLastError on windows.
529     } while (bytes_received < 0 && errno == EINTR);
530 
531     if (bytes_received < 0)
532     {
533         error.SetErrorToErrno();
534         num_bytes = 0;
535     }
536     else
537         num_bytes = bytes_received;
538 
539     Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_HOST | LIBLLDB_LOG_COMMUNICATION));
540     if (log)
541     {
542         log->Printf ("%p Socket::Read() (socket = %" PRIu64 ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64 " (error = %s)",
543                      static_cast<void*>(this),
544                      static_cast<uint64_t>(m_socket),
545                      buf,
546                      static_cast<uint64_t>(num_bytes),
547                      static_cast<int64_t>(bytes_received),
548                      error.AsCString());
549     }
550 
551     return error;
552 }
553 
554 Error Socket::Write (const void *buf, size_t &num_bytes)
555 {
556     Error error;
557     int bytes_sent = 0;
558     do
559     {
560         if (m_protocol == ProtocolUdp)
561         {
562             bytes_sent = ::sendto (m_socket,
563                                     static_cast<const char*>(buf),
564                                     num_bytes,
565                                     0,
566                                     m_udp_send_sockaddr,
567                                     m_udp_send_sockaddr.GetLength());
568         }
569         else
570             bytes_sent = ::send (m_socket, static_cast<const char *>(buf), num_bytes, 0);
571         // TODO: Use WSAGetLastError on windows.
572     } while (bytes_sent < 0 && errno == EINTR);
573 
574     if (bytes_sent < 0)
575     {
576         // TODO: On Windows, use WSAGEtLastError.
577         error.SetErrorToErrno();
578         num_bytes = 0;
579     }
580     else
581         num_bytes = bytes_sent;
582 
583     Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_HOST));
584     if (log)
585     {
586         log->Printf ("%p Socket::Write() (socket = %" PRIu64 ", src = %p, src_len = %" PRIu64 ", flags = 0) => %" PRIi64 " (error = %s)",
587                         static_cast<void*>(this),
588                         static_cast<uint64_t>(m_socket),
589                         buf,
590                         static_cast<uint64_t>(num_bytes),
591                         static_cast<int64_t>(bytes_sent),
592                         error.AsCString());
593     }
594 
595     return error;
596 }
597 
598 Error Socket::PreDisconnect()
599 {
600     Error error;
601     return error;
602 }
603 
604 Error Socket::Close()
605 {
606     Error error;
607     if (!IsValid() || !m_should_close_fd)
608         return error;
609 
610     Log *log(lldb_private::GetLogIfAnyCategoriesSet (LIBLLDB_LOG_CONNECTION));
611     if (log)
612         log->Printf ("%p Socket::Close (fd = %i)", static_cast<void*>(this), m_socket);
613 
614 #if defined(_WIN32)
615     bool success = !!closesocket(m_socket);
616 #else
617     bool success = !!::close (m_socket);
618 #endif
619     // A reference to a FD was passed in, set it to an invalid value
620     m_socket = kInvalidSocketValue;
621     if (!success)
622     {
623         // TODO: On Windows, use WSAGetLastError().
624         error.SetErrorToErrno();
625     }
626 
627     return error;
628 }
629 
630 
631 int Socket::GetOption(int level, int option_name, int &option_value)
632 {
633     get_socket_option_arg_type option_value_p = reinterpret_cast<get_socket_option_arg_type>(&option_value);
634     socklen_t option_value_size = sizeof(int);
635 	return ::getsockopt(m_socket, level, option_name, option_value_p, &option_value_size);
636 }
637 
638 int Socket::SetOption(int level, int option_name, int option_value)
639 {
640     set_socket_option_arg_type option_value_p = reinterpret_cast<get_socket_option_arg_type>(&option_value);
641 	return ::setsockopt(m_socket, level, option_name, option_value_p, sizeof(option_value));
642 }
643 
644 uint16_t Socket::GetPortNumber(const NativeSocket& socket)
645 {
646     // We bound to port zero, so we need to figure out which port we actually bound to
647     if (socket >= 0)
648     {
649         SocketAddress sock_addr;
650         socklen_t sock_addr_len = sock_addr.GetMaxLength ();
651         if (::getsockname (socket, sock_addr, &sock_addr_len) == 0)
652             return sock_addr.GetPort ();
653     }
654     return 0;
655 }
656 
657 // Return the port number that is being used by the socket.
658 uint16_t Socket::GetPortNumber() const
659 {
660     return GetPortNumber(m_socket);
661 }
662