1 //===-- SocketTest.cpp ------------------------------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #if defined(_MSC_VER) && (_HAS_EXCEPTIONS == 0)
11 // Workaround for MSVC standard library bug, which fails to include <thread>
12 // when
13 // exceptions are disabled.
14 #include <eh.h>
15 #endif
16 
17 #include <cstdio>
18 #include <functional>
19 #include <thread>
20 
21 #include "gtest/gtest.h"
22 
23 #include "lldb/Host/Config.h"
24 #include "lldb/Host/Socket.h"
25 #include "lldb/Host/common/TCPSocket.h"
26 #include "lldb/Host/common/UDPSocket.h"
27 
28 #ifndef LLDB_DISABLE_POSIX
29 #include "lldb/Host/posix/DomainSocket.h"
30 #endif
31 
32 using namespace lldb_private;
33 
34 class SocketTest : public testing::Test {
35 public:
36   void SetUp() override {
37 #if defined(_MSC_VER)
38     WSADATA data;
39     ::WSAStartup(MAKEWORD(2, 2), &data);
40 #endif
41   }
42 
43   void TearDown() override {
44 #if defined(_MSC_VER)
45     ::WSACleanup();
46 #endif
47   }
48 
49 protected:
50   static void AcceptThread(Socket *listen_socket,
51                            const char *listen_remote_address,
52                            bool child_processes_inherit, Socket **accept_socket,
53                            Error *error) {
54     *error = listen_socket->Accept(listen_remote_address,
55                                    child_processes_inherit, *accept_socket);
56   }
57 
58   template <typename SocketType>
59   void CreateConnectedSockets(
60       const char *listen_remote_address,
61       const std::function<std::string(const SocketType &)> &get_connect_addr,
62       std::unique_ptr<SocketType> *a_up, std::unique_ptr<SocketType> *b_up) {
63     bool child_processes_inherit = false;
64     Error error;
65     std::unique_ptr<SocketType> listen_socket_up(
66         new SocketType(child_processes_inherit, error));
67     EXPECT_FALSE(error.Fail());
68     error = listen_socket_up->Listen(listen_remote_address, 5);
69     EXPECT_FALSE(error.Fail());
70     EXPECT_TRUE(listen_socket_up->IsValid());
71 
72     Error accept_error;
73     Socket *accept_socket;
74     std::thread accept_thread(AcceptThread, listen_socket_up.get(),
75                               listen_remote_address, child_processes_inherit,
76                               &accept_socket, &accept_error);
77 
78     std::string connect_remote_address = get_connect_addr(*listen_socket_up);
79     std::unique_ptr<SocketType> connect_socket_up(
80         new SocketType(child_processes_inherit, error));
81     EXPECT_FALSE(error.Fail());
82     error = connect_socket_up->Connect(connect_remote_address);
83     EXPECT_FALSE(error.Fail());
84     EXPECT_TRUE(connect_socket_up->IsValid());
85 
86     a_up->swap(connect_socket_up);
87     EXPECT_TRUE(error.Success());
88     EXPECT_NE(nullptr, a_up->get());
89     EXPECT_TRUE((*a_up)->IsValid());
90 
91     accept_thread.join();
92     b_up->reset(static_cast<SocketType *>(accept_socket));
93     EXPECT_TRUE(accept_error.Success());
94     EXPECT_NE(nullptr, b_up->get());
95     EXPECT_TRUE((*b_up)->IsValid());
96 
97     listen_socket_up.reset();
98   }
99 };
100 
101 TEST_F(SocketTest, DecodeHostAndPort) {
102   std::string host_str;
103   std::string port_str;
104   int32_t port;
105   Error error;
106   EXPECT_TRUE(Socket::DecodeHostAndPort("localhost:1138", host_str, port_str,
107                                         port, &error));
108   EXPECT_STREQ("localhost", host_str.c_str());
109   EXPECT_STREQ("1138", port_str.c_str());
110   EXPECT_EQ(1138, port);
111   EXPECT_TRUE(error.Success());
112 
113   EXPECT_FALSE(Socket::DecodeHostAndPort("google.com:65536", host_str, port_str,
114                                          port, &error));
115   EXPECT_TRUE(error.Fail());
116   EXPECT_STREQ("invalid host:port specification: 'google.com:65536'",
117                error.AsCString());
118 
119   EXPECT_FALSE(Socket::DecodeHostAndPort("google.com:-1138", host_str, port_str,
120                                          port, &error));
121   EXPECT_TRUE(error.Fail());
122   EXPECT_STREQ("invalid host:port specification: 'google.com:-1138'",
123                error.AsCString());
124 
125   EXPECT_FALSE(Socket::DecodeHostAndPort("google.com:65536", host_str, port_str,
126                                          port, &error));
127   EXPECT_TRUE(error.Fail());
128   EXPECT_STREQ("invalid host:port specification: 'google.com:65536'",
129                error.AsCString());
130 
131   EXPECT_TRUE(
132       Socket::DecodeHostAndPort("12345", host_str, port_str, port, &error));
133   EXPECT_STREQ("", host_str.c_str());
134   EXPECT_STREQ("12345", port_str.c_str());
135   EXPECT_EQ(12345, port);
136   EXPECT_TRUE(error.Success());
137 
138   EXPECT_TRUE(
139       Socket::DecodeHostAndPort("*:0", host_str, port_str, port, &error));
140   EXPECT_STREQ("*", host_str.c_str());
141   EXPECT_STREQ("0", port_str.c_str());
142   EXPECT_EQ(0, port);
143   EXPECT_TRUE(error.Success());
144 
145   EXPECT_TRUE(
146       Socket::DecodeHostAndPort("*:65535", host_str, port_str, port, &error));
147   EXPECT_STREQ("*", host_str.c_str());
148   EXPECT_STREQ("65535", port_str.c_str());
149   EXPECT_EQ(65535, port);
150   EXPECT_TRUE(error.Success());
151 }
152 
153 #ifndef LLDB_DISABLE_POSIX
154 TEST_F(SocketTest, DomainListenConnectAccept) {
155   char *file_name_str = tempnam(nullptr, nullptr);
156   EXPECT_NE(nullptr, file_name_str);
157   const std::string file_name(file_name_str);
158   free(file_name_str);
159 
160   std::unique_ptr<DomainSocket> socket_a_up;
161   std::unique_ptr<DomainSocket> socket_b_up;
162   CreateConnectedSockets<DomainSocket>(
163       file_name.c_str(), [=](const DomainSocket &) { return file_name; },
164       &socket_a_up, &socket_b_up);
165 }
166 #endif
167 
168 TEST_F(SocketTest, TCPListen0ConnectAccept) {
169   std::unique_ptr<TCPSocket> socket_a_up;
170   std::unique_ptr<TCPSocket> socket_b_up;
171   CreateConnectedSockets<TCPSocket>(
172       "127.0.0.1:0",
173       [=](const TCPSocket &s) {
174         char connect_remote_address[64];
175         snprintf(connect_remote_address, sizeof(connect_remote_address),
176                  "localhost:%u", s.GetLocalPortNumber());
177         return std::string(connect_remote_address);
178       },
179       &socket_a_up, &socket_b_up);
180 }
181 
182 TEST_F(SocketTest, TCPGetAddress) {
183   std::unique_ptr<TCPSocket> socket_a_up;
184   std::unique_ptr<TCPSocket> socket_b_up;
185   CreateConnectedSockets<TCPSocket>(
186       "127.0.0.1:0",
187       [=](const TCPSocket &s) {
188         char connect_remote_address[64];
189         snprintf(connect_remote_address, sizeof(connect_remote_address),
190                  "localhost:%u", s.GetLocalPortNumber());
191         return std::string(connect_remote_address);
192       },
193       &socket_a_up, &socket_b_up);
194 
195   EXPECT_EQ(socket_a_up->GetLocalPortNumber(),
196             socket_b_up->GetRemotePortNumber());
197   EXPECT_EQ(socket_b_up->GetLocalPortNumber(),
198             socket_a_up->GetRemotePortNumber());
199   EXPECT_NE(socket_a_up->GetLocalPortNumber(),
200             socket_b_up->GetLocalPortNumber());
201   EXPECT_STREQ("127.0.0.1", socket_a_up->GetRemoteIPAddress().c_str());
202   EXPECT_STREQ("127.0.0.1", socket_b_up->GetRemoteIPAddress().c_str());
203 }
204 
205 TEST_F(SocketTest, UDPConnect) {
206   Socket *socket_a;
207   Socket *socket_b;
208 
209   bool child_processes_inherit = false;
210   auto error = UDPSocket::Connect("127.0.0.1:0", child_processes_inherit,
211                                   socket_a, socket_b);
212 
213   std::unique_ptr<Socket> a_up(socket_a);
214   std::unique_ptr<Socket> b_up(socket_b);
215 
216   EXPECT_TRUE(error.Success());
217   EXPECT_TRUE(a_up->IsValid());
218   EXPECT_TRUE(b_up->IsValid());
219 }
220