xref: /freebsd-12.1/lib/libnv/msgio.c (revision 3e3ba5f4)
1 /*-
2  * SPDX-License-Identifier: BSD-2-Clause-FreeBSD
3  *
4  * Copyright (c) 2013 The FreeBSD Foundation
5  * Copyright (c) 2013 Mariusz Zaborski <[email protected]>
6  * All rights reserved.
7  *
8  * This software was developed by Pawel Jakub Dawidek under sponsorship from
9  * the FreeBSD Foundation.
10  *
11  * Redistribution and use in source and binary forms, with or without
12  * modification, are permitted provided that the following conditions
13  * are met:
14  * 1. Redistributions of source code must retain the above copyright
15  *    notice, this list of conditions and the following disclaimer.
16  * 2. Redistributions in binary form must reproduce the above copyright
17  *    notice, this list of conditions and the following disclaimer in the
18  *    documentation and/or other materials provided with the distribution.
19  *
20  * THIS SOFTWARE IS PROVIDED BY THE AUTHORS AND CONTRIBUTORS ``AS IS'' AND
21  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
23  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS OR CONTRIBUTORS BE LIABLE
24  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
26  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
27  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
28  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
29  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
30  * SUCH DAMAGE.
31  */
32 
33 #include <sys/cdefs.h>
34 __FBSDID("$FreeBSD$");
35 
36 #include <sys/param.h>
37 #include <sys/socket.h>
38 
39 #include <errno.h>
40 #include <fcntl.h>
41 #include <stdbool.h>
42 #include <stdint.h>
43 #include <stdlib.h>
44 #include <string.h>
45 #include <unistd.h>
46 
47 #ifdef HAVE_PJDLOG
48 #include <pjdlog.h>
49 #endif
50 
51 #include "common_impl.h"
52 #include "msgio.h"
53 
54 #ifndef	HAVE_PJDLOG
55 #include <assert.h>
56 #define	PJDLOG_ASSERT(...)		assert(__VA_ARGS__)
57 #define	PJDLOG_RASSERT(expr, ...)	assert(expr)
58 #define	PJDLOG_ABORT(...)		abort()
59 #endif
60 
61 /*
62  * To work around limitations in 32-bit emulation on 64-bit kernels, use a
63  * machine-independent limit on the number of FDs per message.  Each control
64  * message contains 1 FD and requires 12 bytes for the header, 4 pad bytes,
65  * 4 bytes for the descriptor, and another 4 pad bytes.
66  */
67 #define	PKG_MAX_SIZE	(MCLBYTES / 24)
68 
69 static int
msghdr_add_fd(struct cmsghdr * cmsg,int fd)70 msghdr_add_fd(struct cmsghdr *cmsg, int fd)
71 {
72 
73 	PJDLOG_ASSERT(fd >= 0);
74 
75 	cmsg->cmsg_level = SOL_SOCKET;
76 	cmsg->cmsg_type = SCM_RIGHTS;
77 	cmsg->cmsg_len = CMSG_LEN(sizeof(fd));
78 	bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd));
79 
80 	return (0);
81 }
82 
83 static int
msghdr_get_fd(struct cmsghdr * cmsg)84 msghdr_get_fd(struct cmsghdr *cmsg)
85 {
86 	int fd;
87 
88 	if (cmsg == NULL || cmsg->cmsg_level != SOL_SOCKET ||
89 	    cmsg->cmsg_type != SCM_RIGHTS ||
90 	    cmsg->cmsg_len != CMSG_LEN(sizeof(fd))) {
91 		errno = EINVAL;
92 		return (-1);
93 	}
94 
95 	bcopy(CMSG_DATA(cmsg), &fd, sizeof(fd));
96 #ifndef MSG_CMSG_CLOEXEC
97 	/*
98 	 * If the MSG_CMSG_CLOEXEC flag is not available we cannot set the
99 	 * close-on-exec flag atomically, but we still want to set it for
100 	 * consistency.
101 	 */
102 	(void) fcntl(fd, F_SETFD, FD_CLOEXEC);
103 #endif
104 
105 	return (fd);
106 }
107 
108 static void
fd_wait(int fd,bool doread)109 fd_wait(int fd, bool doread)
110 {
111 	fd_set fds;
112 
113 	PJDLOG_ASSERT(fd >= 0);
114 
115 	FD_ZERO(&fds);
116 	FD_SET(fd, &fds);
117 	(void)select(fd + 1, doread ? &fds : NULL, doread ? NULL : &fds,
118 	    NULL, NULL);
119 }
120 
121 static int
msg_recv(int sock,struct msghdr * msg)122 msg_recv(int sock, struct msghdr *msg)
123 {
124 	int flags;
125 
126 	PJDLOG_ASSERT(sock >= 0);
127 
128 #ifdef MSG_CMSG_CLOEXEC
129 	flags = MSG_CMSG_CLOEXEC;
130 #else
131 	flags = 0;
132 #endif
133 
134 	for (;;) {
135 		fd_wait(sock, true);
136 		if (recvmsg(sock, msg, flags) == -1) {
137 			if (errno == EINTR)
138 				continue;
139 			return (-1);
140 		}
141 		break;
142 	}
143 
144 	return (0);
145 }
146 
147 static int
msg_send(int sock,const struct msghdr * msg)148 msg_send(int sock, const struct msghdr *msg)
149 {
150 
151 	PJDLOG_ASSERT(sock >= 0);
152 
153 	for (;;) {
154 		fd_wait(sock, false);
155 		if (sendmsg(sock, msg, 0) == -1) {
156 			if (errno == EINTR)
157 				continue;
158 			return (-1);
159 		}
160 		break;
161 	}
162 
163 	return (0);
164 }
165 
166 int
cred_send(int sock)167 cred_send(int sock)
168 {
169 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
170 	struct msghdr msg;
171 	struct cmsghdr *cmsg;
172 	struct iovec iov;
173 	uint8_t dummy;
174 
175 	bzero(credbuf, sizeof(credbuf));
176 	bzero(&msg, sizeof(msg));
177 	bzero(&iov, sizeof(iov));
178 
179 	/*
180 	 * XXX: We send one byte along with the control message, because
181 	 *      setting msg_iov to NULL only works if this is the first
182 	 *      packet send over the socket. Once we send some data we
183 	 *      won't be able to send credentials anymore. This is most
184 	 *      likely a kernel bug.
185 	 */
186 	dummy = 0;
187 	iov.iov_base = &dummy;
188 	iov.iov_len = sizeof(dummy);
189 
190 	msg.msg_iov = &iov;
191 	msg.msg_iovlen = 1;
192 	msg.msg_control = credbuf;
193 	msg.msg_controllen = sizeof(credbuf);
194 
195 	cmsg = CMSG_FIRSTHDR(&msg);
196 	cmsg->cmsg_len = CMSG_LEN(sizeof(struct cmsgcred));
197 	cmsg->cmsg_level = SOL_SOCKET;
198 	cmsg->cmsg_type = SCM_CREDS;
199 
200 	if (msg_send(sock, &msg) == -1)
201 		return (-1);
202 
203 	return (0);
204 }
205 
206 int
cred_recv(int sock,struct cmsgcred * cred)207 cred_recv(int sock, struct cmsgcred *cred)
208 {
209 	unsigned char credbuf[CMSG_SPACE(sizeof(struct cmsgcred))];
210 	struct msghdr msg;
211 	struct cmsghdr *cmsg;
212 	struct iovec iov;
213 	uint8_t dummy;
214 
215 	bzero(credbuf, sizeof(credbuf));
216 	bzero(&msg, sizeof(msg));
217 	bzero(&iov, sizeof(iov));
218 
219 	iov.iov_base = &dummy;
220 	iov.iov_len = sizeof(dummy);
221 
222 	msg.msg_iov = &iov;
223 	msg.msg_iovlen = 1;
224 	msg.msg_control = credbuf;
225 	msg.msg_controllen = sizeof(credbuf);
226 
227 	if (msg_recv(sock, &msg) == -1)
228 		return (-1);
229 
230 	cmsg = CMSG_FIRSTHDR(&msg);
231 	if (cmsg == NULL ||
232 	    cmsg->cmsg_len != CMSG_LEN(sizeof(struct cmsgcred)) ||
233 	    cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_CREDS) {
234 		errno = EINVAL;
235 		return (-1);
236 	}
237 	bcopy(CMSG_DATA(cmsg), cred, sizeof(*cred));
238 
239 	return (0);
240 }
241 
242 static int
fd_package_send(int sock,const int * fds,size_t nfds)243 fd_package_send(int sock, const int *fds, size_t nfds)
244 {
245 	struct msghdr msg;
246 	struct cmsghdr *cmsg;
247 	struct iovec iov;
248 	unsigned int i;
249 	int serrno, ret;
250 	uint8_t dummy;
251 
252 	PJDLOG_ASSERT(sock >= 0);
253 	PJDLOG_ASSERT(fds != NULL);
254 	PJDLOG_ASSERT(nfds > 0);
255 
256 	bzero(&msg, sizeof(msg));
257 
258 	/*
259 	 * XXX: Look into cred_send function for more details.
260 	 */
261 	dummy = 0;
262 	iov.iov_base = &dummy;
263 	iov.iov_len = sizeof(dummy);
264 
265 	msg.msg_iov = &iov;
266 	msg.msg_iovlen = 1;
267 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
268 	msg.msg_control = calloc(1, msg.msg_controllen);
269 	if (msg.msg_control == NULL)
270 		return (-1);
271 
272 	ret = -1;
273 
274 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
275 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
276 		if (msghdr_add_fd(cmsg, fds[i]) == -1)
277 			goto end;
278 	}
279 
280 	if (msg_send(sock, &msg) == -1)
281 		goto end;
282 
283 	ret = 0;
284 end:
285 	serrno = errno;
286 	free(msg.msg_control);
287 	errno = serrno;
288 	return (ret);
289 }
290 
291 static int
fd_package_recv(int sock,int * fds,size_t nfds)292 fd_package_recv(int sock, int *fds, size_t nfds)
293 {
294 	struct msghdr msg;
295 	struct cmsghdr *cmsg;
296 	unsigned int i;
297 	int serrno, ret;
298 	struct iovec iov;
299 	uint8_t dummy;
300 
301 	PJDLOG_ASSERT(sock >= 0);
302 	PJDLOG_ASSERT(nfds > 0);
303 	PJDLOG_ASSERT(fds != NULL);
304 
305 	bzero(&msg, sizeof(msg));
306 	bzero(&iov, sizeof(iov));
307 
308 	/*
309 	 * XXX: Look into cred_send function for more details.
310 	 */
311 	iov.iov_base = &dummy;
312 	iov.iov_len = sizeof(dummy);
313 
314 	msg.msg_iov = &iov;
315 	msg.msg_iovlen = 1;
316 	msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int));
317 	msg.msg_control = calloc(1, msg.msg_controllen);
318 	if (msg.msg_control == NULL)
319 		return (-1);
320 
321 	ret = -1;
322 
323 	if (msg_recv(sock, &msg) == -1)
324 		goto end;
325 
326 	for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL;
327 	    i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) {
328 		fds[i] = msghdr_get_fd(cmsg);
329 		if (fds[i] < 0)
330 			break;
331 	}
332 
333 	if (cmsg != NULL || i < nfds) {
334 		int fd;
335 
336 		/*
337 		 * We need to close all received descriptors, even if we have
338 		 * different control message (eg. SCM_CREDS) in between.
339 		 */
340 		for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL;
341 		    cmsg = CMSG_NXTHDR(&msg, cmsg)) {
342 			fd = msghdr_get_fd(cmsg);
343 			if (fd >= 0)
344 				close(fd);
345 		}
346 		errno = EINVAL;
347 		goto end;
348 	}
349 
350 	ret = 0;
351 end:
352 	serrno = errno;
353 	free(msg.msg_control);
354 	errno = serrno;
355 	return (ret);
356 }
357 
358 int
fd_recv(int sock,int * fds,size_t nfds)359 fd_recv(int sock, int *fds, size_t nfds)
360 {
361 	unsigned int i, step, j;
362 	int ret, serrno;
363 
364 	if (nfds == 0 || fds == NULL) {
365 		errno = EINVAL;
366 		return (-1);
367 	}
368 
369 	ret = i = step = 0;
370 	while (i < nfds) {
371 		if (PKG_MAX_SIZE < nfds - i)
372 			step = PKG_MAX_SIZE;
373 		else
374 			step = nfds - i;
375 		ret = fd_package_recv(sock, fds + i, step);
376 		if (ret != 0) {
377 			/* Close all received descriptors. */
378 			serrno = errno;
379 			for (j = 0; j < i; j++)
380 				close(fds[j]);
381 			errno = serrno;
382 			break;
383 		}
384 		i += step;
385 	}
386 
387 	return (ret);
388 }
389 
390 int
fd_send(int sock,const int * fds,size_t nfds)391 fd_send(int sock, const int *fds, size_t nfds)
392 {
393 	unsigned int i, step;
394 	int ret;
395 
396 	if (nfds == 0 || fds == NULL) {
397 		errno = EINVAL;
398 		return (-1);
399 	}
400 
401 	ret = i = step = 0;
402 	while (i < nfds) {
403 		if (PKG_MAX_SIZE < nfds - i)
404 			step = PKG_MAX_SIZE;
405 		else
406 			step = nfds - i;
407 		ret = fd_package_send(sock, fds + i, step);
408 		if (ret != 0)
409 			break;
410 		i += step;
411 	}
412 
413 	return (ret);
414 }
415 
416 int
buf_send(int sock,void * buf,size_t size)417 buf_send(int sock, void *buf, size_t size)
418 {
419 	ssize_t done;
420 	unsigned char *ptr;
421 
422 	PJDLOG_ASSERT(sock >= 0);
423 	PJDLOG_ASSERT(size > 0);
424 	PJDLOG_ASSERT(buf != NULL);
425 
426 	ptr = buf;
427 	do {
428 		fd_wait(sock, false);
429 		done = send(sock, ptr, size, 0);
430 		if (done == -1) {
431 			if (errno == EINTR)
432 				continue;
433 			return (-1);
434 		} else if (done == 0) {
435 			errno = ENOTCONN;
436 			return (-1);
437 		}
438 		size -= done;
439 		ptr += done;
440 	} while (size > 0);
441 
442 	return (0);
443 }
444 
445 int
buf_recv(int sock,void * buf,size_t size)446 buf_recv(int sock, void *buf, size_t size)
447 {
448 	ssize_t done;
449 	unsigned char *ptr;
450 
451 	PJDLOG_ASSERT(sock >= 0);
452 	PJDLOG_ASSERT(buf != NULL);
453 
454 	ptr = buf;
455 	while (size > 0) {
456 		fd_wait(sock, true);
457 		done = recv(sock, ptr, size, 0);
458 		if (done == -1) {
459 			if (errno == EINTR)
460 				continue;
461 			return (-1);
462 		} else if (done == 0) {
463 			errno = ENOTCONN;
464 			return (-1);
465 		}
466 		size -= done;
467 		ptr += done;
468 	}
469 
470 	return (0);
471 }
472