1 /*-
2 * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3 *
4 * Copyright (c) 2013 The FreeBSD Foundation
5 * Copyright (c) 2013 Mariusz Zaborski <[email protected]>
6 * All rights reserved.
7 *
8 * This software was developed by Pawel Jakub Dawidek under sponsorship from
9 * the FreeBSD Foundation.
10 *
11 * Redistribution and use in source and binary forms, with or without
12 * modification, are permitted provided that the following conditions
13 * are met:
14 * 1. Redistributions of source code must retain the above copyright
15 * notice, this list of conditions and the following disclaimer.
16 * 2. Redistributions in binary form must reproduce the above copyright
17 * notice, this list of conditions and the following disclaimer in the
18 * documentation and/or other materials provided with the distribution.
19 *
20 * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30 * SUCH DAMAGE.
31 */
32
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35
36 #include <sys/param.h>
37 #include <sys/socket.h>
38
39 #include <errno.h>
40 #include <fcntl.h>
41 #include <stdbool.h>
42 #include <stdint.h>
43 #include <stdlib.h>
44 #include <string.h>
45 #include <unistd.h>
46
47 #ifdef HAVE_PJDLOG
48 #include <pjdlog.h>
49 #endif
50
51 #include "common_impl.h"
52 #include "msgio.h"
53
54 #ifndef HAVE_PJDLOG
55 #include <assert.h>
56 #define PJDLOG_ASSERT(...) assert(__VA_ARGS__)
57 #define PJDLOG_RASSERT(expr, ...) assert(expr)
58 #define PJDLOG_ABORT(...) abort()
59 #endif
60
61 /*
62 * To work around limitations in 32-bit emulation on 64-bit kernels, use a
63 * machine-independent limit on the number of FDs per message. Each control
64 * message contains 1 FD and requires 12 bytes for the header, 4 pad bytes,
65 * 4 bytes for the descriptor, and another 4 pad bytes.
66 */
67 #define PKG_MAX_SIZE (MCLBYTES / 24)
68
69 static int
msghdr_add_fd(struct cmsghdr * cmsg,int fd)70 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
71 {
72
73 PJDLOG_ASSERT(fd >= 0);
74
75 cmsg->cmsg_level = SOL_SOCKET;
76 cmsg->cmsg_type = SCM_RIGHTS;
77 cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
78 bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
79
80 return (0);
81 }
82
83 static int
msghdr_get_fd(struct cmsghdr * cmsg)84 msghdr_get_fd(struct cmsghdr *cmsg)
85 {
86 int fd;
87
88 if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
89 cmsg->cmsg_type != SCM_RIGHTS ||
90 cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
91 errno = EINVAL;
92 return (-1);
93 }
94
95 bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
96 #ifndef MSG_CMSG_CLOEXEC
97 /*
98 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
99 * close-on-exec flag atomically, but we still want to set it for
100 * consistency.
101 */
102 (void) fcntl(fd, F_SETFD, FD_CLOEXEC);
103 #endif
104
105 return (fd);
106 }
107
108 static void
fd_wait(int fd,bool doread)109 fd_wait(int fd, bool doread)
110 {
111 fd_set fds;
112
113 PJDLOG_ASSERT(fd >= 0);
114
115 FD_ZERO(&fds);
116 FD_SET(fd, &fds);
117 (void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
118 NULL, NULL);
119 }
120
121 static int
msg_recv(int sock,struct msghdr * msg)122 msg_recv(int sock, struct msghdr *msg)
123 {
124 int flags;
125
126 PJDLOG_ASSERT(sock >= 0);
127
128 #ifdef MSG_CMSG_CLOEXEC
129 flags = MSG_CMSG_CLOEXEC;
130 #else
131 flags = 0;
132 #endif
133
134 for (;;) {
135 fd_wait(sock, true);
136 if (recvmsg(sock, msg, flags) == -1) {
137 if (errno == EINTR)
138 continue;
139 return (-1);
140 }
141 break;
142 }
143
144 return (0);
145 }
146
147 static int
msg_send(int sock,const struct msghdr * msg)148 msg_send(int sock, const struct msghdr *msg)
149 {
150
151 PJDLOG_ASSERT(sock >= 0);
152
153 for (;;) {
154 fd_wait(sock, false);
155 if (sendmsg(sock, msg, 0) == -1) {
156 if (errno == EINTR)
157 continue;
158 return (-1);
159 }
160 break;
161 }
162
163 return (0);
164 }
165
166 int
cred_send(int sock)167 cred_send(int sock)
168 {
169 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
170 struct msghdr msg;
171 struct cmsghdr *cmsg;
172 struct iovec iov;
173 uint8_t dummy;
174
175 bzero(credbuf, sizeof(credbuf));
176 bzero(&msg, sizeof(msg));
177 bzero(&iov, sizeof(iov));
178
179 /*
180 * XXX: We send one byte along with the control message, because
181 * setting msg_iov to NULL only works if this is the first
182 * packet send over the socket. Once we send some data we
183 * won't be able to send credentials anymore. This is most
184 * likely a kernel bug.
185 */
186 dummy = 0;
187 iov.iov_base = &dummy;
188 iov.iov_len = sizeof(dummy);
189
190 msg.msg_iov = &iov;
191 msg.msg_iovlen = 1;
192 msg.msg_control = credbuf;
193 msg.msg_controllen = sizeof(credbuf);
194
195 cmsg = CMSG_FIRSTHDR(&msg);
196 cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
197 cmsg->cmsg_level = SOL_SOCKET;
198 cmsg->cmsg_type = SCM_CREDS;
199
200 if (msg_send(sock, &msg) == -1)
201 return (-1);
202
203 return (0);
204 }
205
206 int
cred_recv(int sock,struct cmsgcred * cred)207 cred_recv(int sock, struct cmsgcred *cred)
208 {
209 unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
210 struct msghdr msg;
211 struct cmsghdr *cmsg;
212 struct iovec iov;
213 uint8_t dummy;
214
215 bzero(credbuf, sizeof(credbuf));
216 bzero(&msg, sizeof(msg));
217 bzero(&iov, sizeof(iov));
218
219 iov.iov_base = &dummy;
220 iov.iov_len = sizeof(dummy);
221
222 msg.msg_iov = &iov;
223 msg.msg_iovlen = 1;
224 msg.msg_control = credbuf;
225 msg.msg_controllen = sizeof(credbuf);
226
227 if (msg_recv(sock, &msg) == -1)
228 return (-1);
229
230 cmsg = CMSG_FIRSTHDR(&msg);
231 if (cmsg == NULL ||
232 cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
233 cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
234 errno = EINVAL;
235 return (-1);
236 }
237 bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
238
239 return (0);
240 }
241
242 static int
fd_package_send(int sock,const int * fds,size_t nfds)243 fd_package_send(int sock, const int *fds, size_t nfds)
244 {
245 struct msghdr msg;
246 struct cmsghdr *cmsg;
247 struct iovec iov;
248 unsigned int i;
249 int serrno, ret;
250 uint8_t dummy;
251
252 PJDLOG_ASSERT(sock >= 0);
253 PJDLOG_ASSERT(fds != NULL);
254 PJDLOG_ASSERT(nfds > 0);
255
256 bzero(&msg, sizeof(msg));
257
258 /*
259 * XXX: Look into cred_send function for more details.
260 */
261 dummy = 0;
262 iov.iov_base = &dummy;
263 iov.iov_len = sizeof(dummy);
264
265 msg.msg_iov = &iov;
266 msg.msg_iovlen = 1;
267 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
268 msg.msg_control = calloc(1, msg.msg_controllen);
269 if (msg.msg_control == NULL)
270 return (-1);
271
272 ret = -1;
273
274 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
275 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
276 if (msghdr_add_fd(cmsg, fds[i]) == -1)
277 goto end;
278 }
279
280 if (msg_send(sock, &msg) == -1)
281 goto end;
282
283 ret = 0;
284 end:
285 serrno = errno;
286 free(msg.msg_control);
287 errno = serrno;
288 return (ret);
289 }
290
291 static int
fd_package_recv(int sock,int * fds,size_t nfds)292 fd_package_recv(int sock, int *fds, size_t nfds)
293 {
294 struct msghdr msg;
295 struct cmsghdr *cmsg;
296 unsigned int i;
297 int serrno, ret;
298 struct iovec iov;
299 uint8_t dummy;
300
301 PJDLOG_ASSERT(sock >= 0);
302 PJDLOG_ASSERT(nfds > 0);
303 PJDLOG_ASSERT(fds != NULL);
304
305 bzero(&msg, sizeof(msg));
306 bzero(&iov, sizeof(iov));
307
308 /*
309 * XXX: Look into cred_send function for more details.
310 */
311 iov.iov_base = &dummy;
312 iov.iov_len = sizeof(dummy);
313
314 msg.msg_iov = &iov;
315 msg.msg_iovlen = 1;
316 msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
317 msg.msg_control = calloc(1, msg.msg_controllen);
318 if (msg.msg_control == NULL)
319 return (-1);
320
321 ret = -1;
322
323 if (msg_recv(sock, &msg) == -1)
324 goto end;
325
326 for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
327 i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
328 fds[i] = msghdr_get_fd(cmsg);
329 if (fds[i] < 0)
330 break;
331 }
332
333 if (cmsg != NULL || i < nfds) {
334 int fd;
335
336 /*
337 * We need to close all received descriptors, even if we have
338 * different control message (eg. SCM_CREDS) in between.
339 */
340 for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
341 cmsg = CMSG_NXTHDR(&msg, cmsg)) {
342 fd = msghdr_get_fd(cmsg);
343 if (fd >= 0)
344 close(fd);
345 }
346 errno = EINVAL;
347 goto end;
348 }
349
350 ret = 0;
351 end:
352 serrno = errno;
353 free(msg.msg_control);
354 errno = serrno;
355 return (ret);
356 }
357
358 int
fd_recv(int sock,int * fds,size_t nfds)359 fd_recv(int sock, int *fds, size_t nfds)
360 {
361 unsigned int i, step, j;
362 int ret, serrno;
363
364 if (nfds == 0 || fds == NULL) {
365 errno = EINVAL;
366 return (-1);
367 }
368
369 ret = i = step = 0;
370 while (i < nfds) {
371 if (PKG_MAX_SIZE < nfds - i)
372 step = PKG_MAX_SIZE;
373 else
374 step = nfds - i;
375 ret = fd_package_recv(sock, fds + i, step);
376 if (ret != 0) {
377 /* Close all received descriptors. */
378 serrno = errno;
379 for (j = 0; j < i; j++)
380 close(fds[j]);
381 errno = serrno;
382 break;
383 }
384 i += step;
385 }
386
387 return (ret);
388 }
389
390 int
fd_send(int sock,const int * fds,size_t nfds)391 fd_send(int sock, const int *fds, size_t nfds)
392 {
393 unsigned int i, step;
394 int ret;
395
396 if (nfds == 0 || fds == NULL) {
397 errno = EINVAL;
398 return (-1);
399 }
400
401 ret = i = step = 0;
402 while (i < nfds) {
403 if (PKG_MAX_SIZE < nfds - i)
404 step = PKG_MAX_SIZE;
405 else
406 step = nfds - i;
407 ret = fd_package_send(sock, fds + i, step);
408 if (ret != 0)
409 break;
410 i += step;
411 }
412
413 return (ret);
414 }
415
416 int
buf_send(int sock,void * buf,size_t size)417 buf_send(int sock, void *buf, size_t size)
418 {
419 ssize_t done;
420 unsigned char *ptr;
421
422 PJDLOG_ASSERT(sock >= 0);
423 PJDLOG_ASSERT(size > 0);
424 PJDLOG_ASSERT(buf != NULL);
425
426 ptr = buf;
427 do {
428 fd_wait(sock, false);
429 done = send(sock, ptr, size, 0);
430 if (done == -1) {
431 if (errno == EINTR)
432 continue;
433 return (-1);
434 } else if (done == 0) {
435 errno = ENOTCONN;
436 return (-1);
437 }
438 size -= done;
439 ptr += done;
440 } while (size > 0);
441
442 return (0);
443 }
444
445 int
buf_recv(int sock,void * buf,size_t size)446 buf_recv(int sock, void *buf, size_t size)
447 {
448 ssize_t done;
449 unsigned char *ptr;
450
451 PJDLOG_ASSERT(sock >= 0);
452 PJDLOG_ASSERT(buf != NULL);
453
454 ptr = buf;
455 while (size > 0) {
456 fd_wait(sock, true);
457 done = recv(sock, ptr, size, 0);
458 if (done == -1) {
459 if (errno == EINTR)
460 continue;
461 return (-1);
462 } else if (done == 0) {
463 errno = ENOTCONN;
464 return (-1);
465 }
466 size -= done;
467 ptr += done;
468 }
469
470 return (0);
471 }
472