1 //===-- RNBSocketTest.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "gtest/gtest.h"
10 
11 #include <arpa/inet.h>
12 #include <sys/sysctl.h>
13 #include <unistd.h>
14 
15 #include "RNBDefs.h"
16 #include "RNBSocket.h"
17 #include "lldb/Host/Socket.h"
18 #include "lldb/Host/StringConvert.h"
19 #include "lldb/Host/common/TCPSocket.h"
20 #include "llvm/Testing/Support/Error.h"
21 
22 using namespace lldb_private;
23 
24 std::string hello = "Hello, world!";
25 std::string goodbye = "Goodbye!";
26 
27 static void ServerCallbackv4(const void *baton, in_port_t port) {
28   auto child_pid = fork();
29   if (child_pid == 0) {
30     char addr_buffer[256];
31     sprintf(addr_buffer, "%s:%d", baton, port);
32     llvm::Expected<std::unique_ptr<Socket>> socket_or_err =
33         Socket::TcpConnect(addr_buffer, false);
34     ASSERT_THAT_EXPECTED(socket_or_err, llvm::Succeeded());
35     Socket *client_socket = socket_or_err->get();
36 
37     char buffer[32];
38     size_t read_size = 32;
39     Status err = client_socket->Read((void *)&buffer[0], read_size);
40     if (err.Fail())
41       abort();
42     std::string Recv(&buffer[0], read_size);
43     if (Recv != hello)
44       abort();
45     size_t write_size = goodbye.length();
46     err = client_socket->Write(goodbye.c_str(), write_size);
47     if (err.Fail())
48       abort();
49     if (write_size != goodbye.length())
50       abort();
51     delete client_socket;
52     exit(0);
53   }
54 }
55 
56 void TestSocketListen(const char *addr) {
57   // Skip IPv6 tests if there isn't a valid interafce
58   auto addresses = lldb_private::SocketAddress::GetAddressInfo(
59       addr, NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
60   if (addresses.size() == 0)
61     return;
62 
63   char addr_wrap[256];
64   if (addresses.front().GetFamily() == AF_INET6)
65     sprintf(addr_wrap, "[%s]", addr);
66   else
67     sprintf(addr_wrap, "%s", addr);
68 
69   RNBSocket server_socket;
70   auto result =
71       server_socket.Listen(addr, 0, ServerCallbackv4, (const void *)addr_wrap);
72   ASSERT_TRUE(result == rnb_success);
73   result = server_socket.Write(hello.c_str(), hello.length());
74   ASSERT_TRUE(result == rnb_success);
75   std::string bye;
76   result = server_socket.Read(bye);
77   ASSERT_TRUE(result == rnb_success);
78   ASSERT_EQ(bye, goodbye);
79 
80   int exit_status;
81   wait(&exit_status);
82   ASSERT_EQ(exit_status, 0);
83 }
84 
85 TEST(RNBSocket, LoopBackListenIPv4) { TestSocketListen("127.0.0.1"); }
86 
87 TEST(RNBSocket, LoopBackListenIPv6) { TestSocketListen("::1"); }
88 
89 TEST(RNBSocket, AnyListen) { TestSocketListen("*"); }
90 
91 void TestSocketConnect(const char *addr) {
92   // Skip IPv6 tests if there isn't a valid interafce
93   auto addresses = lldb_private::SocketAddress::GetAddressInfo(
94       addr, NULL, AF_UNSPEC, SOCK_STREAM, IPPROTO_TCP);
95   if (addresses.size() == 0)
96     return;
97 
98   char addr_wrap[256];
99   if (addresses.front().GetFamily() == AF_INET6)
100     sprintf(addr_wrap, "[%s]:0", addr);
101   else
102     sprintf(addr_wrap, "%s:0", addr);
103 
104   Socket *server_socket;
105   Predicate<uint16_t> port_predicate;
106   port_predicate.SetValue(0, eBroadcastNever);
107   llvm::Expected<std::unique_ptr<Socket>> socket_or_err =
108       Socket::TcpListen(addr_wrap, false, &port_predicate);
109   ASSERT_THAT_EXPECTED(socket_or_err, llvm::Succeeded());
110   server_socket = socket_or_err->get();
111 
112   auto port = ((TCPSocket *)server_socket)->GetLocalPortNumber();
113   auto child_pid = fork();
114   if (child_pid != 0) {
115     RNBSocket client_socket;
116     auto result = client_socket.Connect(addr, port);
117     ASSERT_TRUE(result == rnb_success);
118     result = client_socket.Write(hello.c_str(), hello.length());
119     ASSERT_TRUE(result == rnb_success);
120     std::string bye;
121     result = client_socket.Read(bye);
122     ASSERT_TRUE(result == rnb_success);
123     ASSERT_EQ(bye, goodbye);
124   } else {
125     Socket *connected_socket;
126     Status err = server_socket->Accept(connected_socket);
127     if (err.Fail()) {
128       llvm::errs() << err.AsCString();
129       abort();
130     }
131     char buffer[32];
132     size_t read_size = 32;
133     err = connected_socket->Read((void *)&buffer[0], read_size);
134     if (err.Fail()) {
135       llvm::errs() << err.AsCString();
136       abort();
137     }
138     std::string Recv(&buffer[0], read_size);
139     if (Recv != hello) {
140       llvm::errs() << err.AsCString();
141       abort();
142     }
143     size_t write_size = goodbye.length();
144     err = connected_socket->Write(goodbye.c_str(), write_size);
145     if (err.Fail()) {
146       llvm::errs() << err.AsCString();
147       abort();
148     }
149     if (write_size != goodbye.length()) {
150       llvm::errs() << err.AsCString();
151       abort();
152     }
153     exit(0);
154   }
155   int exit_status;
156   wait(&exit_status);
157   ASSERT_EQ(exit_status, 0);
158 }
159 
160 TEST(RNBSocket, LoopBackConnectIPv4) { TestSocketConnect("127.0.0.1"); }
161 
162 TEST(RNBSocket, LoopBackConnectIPv6) { TestSocketConnect("::1"); }
163