xref: /freebsd-14.2/sys/dev/wg/if_wg.c (revision a9a57c84)
1 /* SPDX-License-Identifier: ISC
2  *
3  * Copyright (C) 2015-2021 Jason A. Donenfeld <[email protected]>. All Rights Reserved.
4  * Copyright (C) 2019-2021 Matt Dunwoodie <[email protected]>
5  * Copyright (c) 2019-2020 Rubicon Communications, LLC (Netgate)
6  * Copyright (c) 2021 Kyle Evans <[email protected]>
7  * Copyright (c) 2022 The FreeBSD Foundation
8  */
9 
10 #include "opt_inet.h"
11 #include "opt_inet6.h"
12 
13 #include <sys/param.h>
14 #include <sys/systm.h>
15 #include <sys/counter.h>
16 #include <sys/gtaskqueue.h>
17 #include <sys/jail.h>
18 #include <sys/kernel.h>
19 #include <sys/lock.h>
20 #include <sys/mbuf.h>
21 #include <sys/module.h>
22 #include <sys/nv.h>
23 #include <sys/priv.h>
24 #include <sys/protosw.h>
25 #include <sys/rmlock.h>
26 #include <sys/rwlock.h>
27 #include <sys/smp.h>
28 #include <sys/socket.h>
29 #include <sys/socketvar.h>
30 #include <sys/sockio.h>
31 #include <sys/sysctl.h>
32 #include <sys/sx.h>
33 #include <machine/_inttypes.h>
34 #include <net/bpf.h>
35 #include <net/ethernet.h>
36 #include <net/if.h>
37 #include <net/if_clone.h>
38 #include <net/if_types.h>
39 #include <net/if_var.h>
40 #include <net/netisr.h>
41 #include <net/radix.h>
42 #include <netinet/in.h>
43 #include <netinet6/in6_var.h>
44 #include <netinet/ip.h>
45 #include <netinet/ip6.h>
46 #include <netinet/ip_icmp.h>
47 #include <netinet/icmp6.h>
48 #include <netinet/udp_var.h>
49 #include <netinet6/nd6.h>
50 
51 #include "wg_noise.h"
52 #include "wg_cookie.h"
53 #include "version.h"
54 #include "if_wg.h"
55 
56 #define DEFAULT_MTU		(ETHERMTU - 80)
57 #define MAX_MTU			(IF_MAXMTU - 80)
58 
59 #define MAX_STAGED_PKT		128
60 #define MAX_QUEUED_PKT		1024
61 #define MAX_QUEUED_PKT_MASK	(MAX_QUEUED_PKT - 1)
62 
63 #define MAX_QUEUED_HANDSHAKES	4096
64 
65 #define REKEY_TIMEOUT_JITTER	334 /* 1/3 sec, round for arc4random_uniform */
66 #define MAX_TIMER_HANDSHAKES	(90 / REKEY_TIMEOUT)
67 #define NEW_HANDSHAKE_TIMEOUT	(REKEY_TIMEOUT + KEEPALIVE_TIMEOUT)
68 #define UNDERLOAD_TIMEOUT	1
69 
70 #define DPRINTF(sc, ...) if (if_getflags(sc->sc_ifp) & IFF_DEBUG) if_printf(sc->sc_ifp, ##__VA_ARGS__)
71 
72 /* First byte indicating packet type on the wire */
73 #define WG_PKT_INITIATION htole32(1)
74 #define WG_PKT_RESPONSE htole32(2)
75 #define WG_PKT_COOKIE htole32(3)
76 #define WG_PKT_DATA htole32(4)
77 
78 #define WG_PKT_PADDING		16
79 #define WG_KEY_SIZE		32
80 
81 struct wg_pkt_initiation {
82 	uint32_t		t;
83 	uint32_t		s_idx;
84 	uint8_t			ue[NOISE_PUBLIC_KEY_LEN];
85 	uint8_t			es[NOISE_PUBLIC_KEY_LEN + NOISE_AUTHTAG_LEN];
86 	uint8_t			ets[NOISE_TIMESTAMP_LEN + NOISE_AUTHTAG_LEN];
87 	struct cookie_macs	m;
88 };
89 
90 struct wg_pkt_response {
91 	uint32_t		t;
92 	uint32_t		s_idx;
93 	uint32_t		r_idx;
94 	uint8_t			ue[NOISE_PUBLIC_KEY_LEN];
95 	uint8_t			en[0 + NOISE_AUTHTAG_LEN];
96 	struct cookie_macs	m;
97 };
98 
99 struct wg_pkt_cookie {
100 	uint32_t		t;
101 	uint32_t		r_idx;
102 	uint8_t			nonce[COOKIE_NONCE_SIZE];
103 	uint8_t			ec[COOKIE_ENCRYPTED_SIZE];
104 };
105 
106 struct wg_pkt_data {
107 	uint32_t		t;
108 	uint32_t		r_idx;
109 	uint64_t		nonce;
110 	uint8_t			buf[];
111 };
112 
113 struct wg_endpoint {
114 	union {
115 		struct sockaddr		r_sa;
116 		struct sockaddr_in	r_sin;
117 #ifdef INET6
118 		struct sockaddr_in6	r_sin6;
119 #endif
120 	} e_remote;
121 	union {
122 		struct in_addr		l_in;
123 #ifdef INET6
124 		struct in6_pktinfo	l_pktinfo6;
125 #define l_in6 l_pktinfo6.ipi6_addr
126 #endif
127 	} e_local;
128 };
129 
130 struct aip_addr {
131 	uint8_t		length;
132 	union {
133 		uint8_t		bytes[16];
134 		uint32_t	ip;
135 		uint32_t	ip6[4];
136 		struct in_addr	in;
137 		struct in6_addr	in6;
138 	};
139 };
140 
141 struct wg_aip {
142 	struct radix_node	 a_nodes[2];
143 	LIST_ENTRY(wg_aip)	 a_entry;
144 	struct aip_addr		 a_addr;
145 	struct aip_addr		 a_mask;
146 	struct wg_peer		*a_peer;
147 	sa_family_t		 a_af;
148 };
149 
150 struct wg_packet {
151 	STAILQ_ENTRY(wg_packet)	 p_serial;
152 	STAILQ_ENTRY(wg_packet)	 p_parallel;
153 	struct wg_endpoint	 p_endpoint;
154 	struct noise_keypair	*p_keypair;
155 	uint64_t		 p_nonce;
156 	struct mbuf		*p_mbuf;
157 	int			 p_mtu;
158 	sa_family_t		 p_af;
159 	enum wg_ring_state {
160 		WG_PACKET_UNCRYPTED,
161 		WG_PACKET_CRYPTED,
162 		WG_PACKET_DEAD,
163 	}			 p_state;
164 };
165 
166 STAILQ_HEAD(wg_packet_list, wg_packet);
167 
168 struct wg_queue {
169 	struct mtx		 q_mtx;
170 	struct wg_packet_list	 q_queue;
171 	size_t			 q_len;
172 };
173 
174 struct wg_peer {
175 	TAILQ_ENTRY(wg_peer)		 p_entry;
176 	uint64_t			 p_id;
177 	struct wg_softc			*p_sc;
178 
179 	struct noise_remote		*p_remote;
180 	struct cookie_maker		 p_cookie;
181 
182 	struct rwlock			 p_endpoint_lock;
183 	struct wg_endpoint		 p_endpoint;
184 
185 	struct wg_queue	 		 p_stage_queue;
186 	struct wg_queue	 		 p_encrypt_serial;
187 	struct wg_queue	 		 p_decrypt_serial;
188 
189 	bool				 p_enabled;
190 	bool				 p_need_another_keepalive;
191 	uint16_t			 p_persistent_keepalive_interval;
192 	struct callout			 p_new_handshake;
193 	struct callout			 p_send_keepalive;
194 	struct callout			 p_retry_handshake;
195 	struct callout			 p_zero_key_material;
196 	struct callout			 p_persistent_keepalive;
197 
198 	struct mtx			 p_handshake_mtx;
199 	struct timespec			 p_handshake_complete;	/* nanotime */
200 	int				 p_handshake_retries;
201 
202 	struct grouptask		 p_send;
203 	struct grouptask		 p_recv;
204 
205 	counter_u64_t			 p_tx_bytes;
206 	counter_u64_t			 p_rx_bytes;
207 
208 	LIST_HEAD(, wg_aip)		 p_aips;
209 	size_t				 p_aips_num;
210 };
211 
212 struct wg_socket {
213 	struct socket	*so_so4;
214 	struct socket	*so_so6;
215 	uint32_t	 so_user_cookie;
216 	int		 so_fibnum;
217 	in_port_t	 so_port;
218 };
219 
220 struct wg_softc {
221 	LIST_ENTRY(wg_softc)	 sc_entry;
222 	if_t			 sc_ifp;
223 	int			 sc_flags;
224 
225 	struct ucred		*sc_ucred;
226 	struct wg_socket	 sc_socket;
227 
228 	TAILQ_HEAD(,wg_peer)	 sc_peers;
229 	size_t			 sc_peers_num;
230 
231 	struct noise_local	*sc_local;
232 	struct cookie_checker	 sc_cookie;
233 
234 	struct radix_node_head	*sc_aip4;
235 	struct radix_node_head	*sc_aip6;
236 
237 	struct grouptask	 sc_handshake;
238 	struct wg_queue		 sc_handshake_queue;
239 
240 	struct grouptask	*sc_encrypt;
241 	struct grouptask	*sc_decrypt;
242 	struct wg_queue		 sc_encrypt_parallel;
243 	struct wg_queue		 sc_decrypt_parallel;
244 	u_int			 sc_encrypt_last_cpu;
245 	u_int			 sc_decrypt_last_cpu;
246 
247 	struct sx		 sc_lock;
248 };
249 
250 #define	WGF_DYING	0x0001
251 
252 #define MAX_LOOPS	8
253 #define MTAG_WGLOOP	0x77676c70 /* wglp */
254 
255 #define	GROUPTASK_DRAIN(gtask)			\
256 	gtaskqueue_drain((gtask)->gt_taskqueue, &(gtask)->gt_task)
257 
258 #define BPF_MTAP2_AF(ifp, m, af) do { \
259 		uint32_t __bpf_tap_af = (af); \
260 		BPF_MTAP2(ifp, &__bpf_tap_af, sizeof(__bpf_tap_af), m); \
261 	} while (0)
262 
263 static int clone_count;
264 static uma_zone_t wg_packet_zone;
265 static volatile unsigned long peer_counter = 0;
266 static const char wgname[] = "wg";
267 static unsigned wg_osd_jail_slot;
268 
269 static struct sx wg_sx;
270 SX_SYSINIT(wg_sx, &wg_sx, "wg_sx");
271 
272 static LIST_HEAD(, wg_softc) wg_list = LIST_HEAD_INITIALIZER(wg_list);
273 
274 static TASKQGROUP_DEFINE(wg_tqg, mp_ncpus, 1);
275 
276 MALLOC_DEFINE(M_WG, "WG", "wireguard");
277 
278 VNET_DEFINE_STATIC(struct if_clone *, wg_cloner);
279 
280 #define	V_wg_cloner	VNET(wg_cloner)
281 #define	WG_CAPS		IFCAP_LINKSTATE
282 
283 struct wg_timespec64 {
284 	uint64_t	tv_sec;
285 	uint64_t	tv_nsec;
286 };
287 
288 static int wg_socket_init(struct wg_softc *, in_port_t);
289 static int wg_socket_bind(struct socket **, struct socket **, in_port_t *);
290 static void wg_socket_set(struct wg_softc *, struct socket *, struct socket *);
291 static void wg_socket_uninit(struct wg_softc *);
292 static int wg_socket_set_sockopt(struct socket *, struct socket *, int, void *, size_t);
293 static int wg_socket_set_cookie(struct wg_softc *, uint32_t);
294 static int wg_socket_set_fibnum(struct wg_softc *, int);
295 static int wg_send(struct wg_softc *, struct wg_endpoint *, struct mbuf *);
296 static void wg_timers_enable(struct wg_peer *);
297 static void wg_timers_disable(struct wg_peer *);
298 static void wg_timers_set_persistent_keepalive(struct wg_peer *, uint16_t);
299 static void wg_timers_get_last_handshake(struct wg_peer *, struct wg_timespec64 *);
300 static void wg_timers_event_data_sent(struct wg_peer *);
301 static void wg_timers_event_data_received(struct wg_peer *);
302 static void wg_timers_event_any_authenticated_packet_sent(struct wg_peer *);
303 static void wg_timers_event_any_authenticated_packet_received(struct wg_peer *);
304 static void wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *);
305 static void wg_timers_event_handshake_initiated(struct wg_peer *);
306 static void wg_timers_event_handshake_complete(struct wg_peer *);
307 static void wg_timers_event_session_derived(struct wg_peer *);
308 static void wg_timers_event_want_initiation(struct wg_peer *);
309 static void wg_timers_run_send_initiation(struct wg_peer *, bool);
310 static void wg_timers_run_retry_handshake(void *);
311 static void wg_timers_run_send_keepalive(void *);
312 static void wg_timers_run_new_handshake(void *);
313 static void wg_timers_run_zero_key_material(void *);
314 static void wg_timers_run_persistent_keepalive(void *);
315 static int wg_aip_add(struct wg_softc *, struct wg_peer *, sa_family_t, const void *, uint8_t);
316 static struct wg_peer *wg_aip_lookup(struct wg_softc *, sa_family_t, void *);
317 static void wg_aip_remove_all(struct wg_softc *, struct wg_peer *);
318 static struct wg_peer *wg_peer_alloc(struct wg_softc *, const uint8_t [WG_KEY_SIZE]);
319 static void wg_peer_free_deferred(struct noise_remote *);
320 static void wg_peer_destroy(struct wg_peer *);
321 static void wg_peer_destroy_all(struct wg_softc *);
322 static void wg_peer_send_buf(struct wg_peer *, uint8_t *, size_t);
323 static void wg_send_initiation(struct wg_peer *);
324 static void wg_send_response(struct wg_peer *);
325 static void wg_send_cookie(struct wg_softc *, struct cookie_macs *, uint32_t, struct wg_endpoint *);
326 static void wg_peer_set_endpoint(struct wg_peer *, struct wg_endpoint *);
327 static void wg_peer_clear_src(struct wg_peer *);
328 static void wg_peer_get_endpoint(struct wg_peer *, struct wg_endpoint *);
329 static void wg_send_buf(struct wg_softc *, struct wg_endpoint *, uint8_t *, size_t);
330 static void wg_send_keepalive(struct wg_peer *);
331 static void wg_handshake(struct wg_softc *, struct wg_packet *);
332 static void wg_encrypt(struct wg_softc *, struct wg_packet *);
333 static void wg_decrypt(struct wg_softc *, struct wg_packet *);
334 static void wg_softc_handshake_receive(struct wg_softc *);
335 static void wg_softc_decrypt(struct wg_softc *);
336 static void wg_softc_encrypt(struct wg_softc *);
337 static void wg_encrypt_dispatch(struct wg_softc *);
338 static void wg_decrypt_dispatch(struct wg_softc *);
339 static void wg_deliver_out(struct wg_peer *);
340 static void wg_deliver_in(struct wg_peer *);
341 static struct wg_packet *wg_packet_alloc(struct mbuf *);
342 static void wg_packet_free(struct wg_packet *);
343 static void wg_queue_init(struct wg_queue *, const char *);
344 static void wg_queue_deinit(struct wg_queue *);
345 static size_t wg_queue_len(struct wg_queue *);
346 static int wg_queue_enqueue_handshake(struct wg_queue *, struct wg_packet *);
347 static struct wg_packet *wg_queue_dequeue_handshake(struct wg_queue *);
348 static void wg_queue_push_staged(struct wg_queue *, struct wg_packet *);
349 static void wg_queue_enlist_staged(struct wg_queue *, struct wg_packet_list *);
350 static void wg_queue_delist_staged(struct wg_queue *, struct wg_packet_list *);
351 static void wg_queue_purge(struct wg_queue *);
352 static int wg_queue_both(struct wg_queue *, struct wg_queue *, struct wg_packet *);
353 static struct wg_packet *wg_queue_dequeue_serial(struct wg_queue *);
354 static struct wg_packet *wg_queue_dequeue_parallel(struct wg_queue *);
355 static bool wg_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *);
356 static void wg_peer_send_staged(struct wg_peer *);
357 static int wg_clone_create(struct if_clone *ifc, char *name, size_t len,
358 	struct ifc_data *ifd, if_t *ifpp);
359 static void wg_qflush(if_t);
360 static inline int determine_af_and_pullup(struct mbuf **m, sa_family_t *af);
361 static int wg_xmit(if_t, struct mbuf *, sa_family_t, uint32_t);
362 static int wg_transmit(if_t, struct mbuf *);
363 static int wg_output(if_t, struct mbuf *, const struct sockaddr *, struct route *);
364 static int wg_clone_destroy(struct if_clone *ifc, if_t ifp,
365 	uint32_t flags);
366 static bool wgc_privileged(struct wg_softc *);
367 static int wgc_get(struct wg_softc *, struct wg_data_io *);
368 static int wgc_set(struct wg_softc *, struct wg_data_io *);
369 static int wg_up(struct wg_softc *);
370 static void wg_down(struct wg_softc *);
371 static void wg_reassign(if_t, struct vnet *, char *unused);
372 static void wg_init(void *);
373 static int wg_ioctl(if_t, u_long, caddr_t);
374 static void vnet_wg_init(const void *);
375 static void vnet_wg_uninit(const void *);
376 static int wg_module_init(void);
377 static void wg_module_deinit(void);
378 
379 /* TODO Peer */
380 static struct wg_peer *
wg_peer_alloc(struct wg_softc * sc,const uint8_t pub_key[WG_KEY_SIZE])381 wg_peer_alloc(struct wg_softc *sc, const uint8_t pub_key[WG_KEY_SIZE])
382 {
383 	struct wg_peer *peer;
384 
385 	sx_assert(&sc->sc_lock, SX_XLOCKED);
386 
387 	peer = malloc(sizeof(*peer), M_WG, M_WAITOK | M_ZERO);
388 	peer->p_remote = noise_remote_alloc(sc->sc_local, peer, pub_key);
389 	peer->p_tx_bytes = counter_u64_alloc(M_WAITOK);
390 	peer->p_rx_bytes = counter_u64_alloc(M_WAITOK);
391 	peer->p_id = peer_counter++;
392 	peer->p_sc = sc;
393 
394 	cookie_maker_init(&peer->p_cookie, pub_key);
395 
396 	rw_init(&peer->p_endpoint_lock, "wg_peer_endpoint");
397 
398 	wg_queue_init(&peer->p_stage_queue, "stageq");
399 	wg_queue_init(&peer->p_encrypt_serial, "txq");
400 	wg_queue_init(&peer->p_decrypt_serial, "rxq");
401 
402 	peer->p_enabled = false;
403 	peer->p_need_another_keepalive = false;
404 	peer->p_persistent_keepalive_interval = 0;
405 	callout_init(&peer->p_new_handshake, true);
406 	callout_init(&peer->p_send_keepalive, true);
407 	callout_init(&peer->p_retry_handshake, true);
408 	callout_init(&peer->p_persistent_keepalive, true);
409 	callout_init(&peer->p_zero_key_material, true);
410 
411 	mtx_init(&peer->p_handshake_mtx, "peer handshake", NULL, MTX_DEF);
412 	bzero(&peer->p_handshake_complete, sizeof(peer->p_handshake_complete));
413 	peer->p_handshake_retries = 0;
414 
415 	GROUPTASK_INIT(&peer->p_send, 0, (gtask_fn_t *)wg_deliver_out, peer);
416 	taskqgroup_attach(qgroup_wg_tqg, &peer->p_send, peer, NULL, NULL, "wg send");
417 	GROUPTASK_INIT(&peer->p_recv, 0, (gtask_fn_t *)wg_deliver_in, peer);
418 	taskqgroup_attach(qgroup_wg_tqg, &peer->p_recv, peer, NULL, NULL, "wg recv");
419 
420 	LIST_INIT(&peer->p_aips);
421 	peer->p_aips_num = 0;
422 
423 	return (peer);
424 }
425 
426 static void
wg_peer_free_deferred(struct noise_remote * r)427 wg_peer_free_deferred(struct noise_remote *r)
428 {
429 	struct wg_peer *peer = noise_remote_arg(r);
430 
431 	/* While there are no references remaining, we may still have
432 	 * p_{send,recv} executing (think empty queue, but wg_deliver_{in,out}
433 	 * needs to check the queue. We should wait for them and then free. */
434 	GROUPTASK_DRAIN(&peer->p_recv);
435 	GROUPTASK_DRAIN(&peer->p_send);
436 	taskqgroup_detach(qgroup_wg_tqg, &peer->p_recv);
437 	taskqgroup_detach(qgroup_wg_tqg, &peer->p_send);
438 
439 	wg_queue_deinit(&peer->p_decrypt_serial);
440 	wg_queue_deinit(&peer->p_encrypt_serial);
441 	wg_queue_deinit(&peer->p_stage_queue);
442 
443 	counter_u64_free(peer->p_tx_bytes);
444 	counter_u64_free(peer->p_rx_bytes);
445 	rw_destroy(&peer->p_endpoint_lock);
446 	mtx_destroy(&peer->p_handshake_mtx);
447 
448 	cookie_maker_free(&peer->p_cookie);
449 
450 	free(peer, M_WG);
451 }
452 
453 static void
wg_peer_destroy(struct wg_peer * peer)454 wg_peer_destroy(struct wg_peer *peer)
455 {
456 	struct wg_softc *sc = peer->p_sc;
457 	sx_assert(&sc->sc_lock, SX_XLOCKED);
458 
459 	/* Disable remote and timers. This will prevent any new handshakes
460 	 * occuring. */
461 	noise_remote_disable(peer->p_remote);
462 	wg_timers_disable(peer);
463 
464 	/* Now we can remove all allowed IPs so no more packets will be routed
465 	 * to the peer. */
466 	wg_aip_remove_all(sc, peer);
467 
468 	/* Remove peer from the interface, then free. Some references may still
469 	 * exist to p_remote, so noise_remote_free will wait until they're all
470 	 * put to call wg_peer_free_deferred. */
471 	sc->sc_peers_num--;
472 	TAILQ_REMOVE(&sc->sc_peers, peer, p_entry);
473 	DPRINTF(sc, "Peer %" PRIu64 " destroyed\n", peer->p_id);
474 	noise_remote_free(peer->p_remote, wg_peer_free_deferred);
475 }
476 
477 static void
wg_peer_destroy_all(struct wg_softc * sc)478 wg_peer_destroy_all(struct wg_softc *sc)
479 {
480 	struct wg_peer *peer, *tpeer;
481 	TAILQ_FOREACH_SAFE(peer, &sc->sc_peers, p_entry, tpeer)
482 		wg_peer_destroy(peer);
483 }
484 
485 static void
wg_peer_set_endpoint(struct wg_peer * peer,struct wg_endpoint * e)486 wg_peer_set_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
487 {
488 	MPASS(e->e_remote.r_sa.sa_family != 0);
489 	if (memcmp(e, &peer->p_endpoint, sizeof(*e)) == 0)
490 		return;
491 
492 	rw_wlock(&peer->p_endpoint_lock);
493 	peer->p_endpoint = *e;
494 	rw_wunlock(&peer->p_endpoint_lock);
495 }
496 
497 static void
wg_peer_clear_src(struct wg_peer * peer)498 wg_peer_clear_src(struct wg_peer *peer)
499 {
500 	rw_wlock(&peer->p_endpoint_lock);
501 	bzero(&peer->p_endpoint.e_local, sizeof(peer->p_endpoint.e_local));
502 	rw_wunlock(&peer->p_endpoint_lock);
503 }
504 
505 static void
wg_peer_get_endpoint(struct wg_peer * peer,struct wg_endpoint * e)506 wg_peer_get_endpoint(struct wg_peer *peer, struct wg_endpoint *e)
507 {
508 	rw_rlock(&peer->p_endpoint_lock);
509 	*e = peer->p_endpoint;
510 	rw_runlock(&peer->p_endpoint_lock);
511 }
512 
513 /* Allowed IP */
514 static int
wg_aip_add(struct wg_softc * sc,struct wg_peer * peer,sa_family_t af,const void * addr,uint8_t cidr)515 wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void *addr, uint8_t cidr)
516 {
517 	struct radix_node_head	*root;
518 	struct radix_node	*node;
519 	struct wg_aip		*aip;
520 	int			 ret = 0;
521 
522 	aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO);
523 	aip->a_peer = peer;
524 	aip->a_af = af;
525 
526 	switch (af) {
527 #ifdef INET
528 	case AF_INET:
529 		if (cidr > 32) cidr = 32;
530 		root = sc->sc_aip4;
531 		aip->a_addr.in = *(const struct in_addr *)addr;
532 		aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff);
533 		aip->a_addr.ip &= aip->a_mask.ip;
534 		aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
535 		break;
536 #endif
537 #ifdef INET6
538 	case AF_INET6:
539 		if (cidr > 128) cidr = 128;
540 		root = sc->sc_aip6;
541 		aip->a_addr.in6 = *(const struct in6_addr *)addr;
542 		in6_prefixlen2mask(&aip->a_mask.in6, cidr);
543 		for (int i = 0; i < 4; i++)
544 			aip->a_addr.ip6[i] &= aip->a_mask.ip6[i];
545 		aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
546 		break;
547 #endif
548 	default:
549 		free(aip, M_WG);
550 		return (EAFNOSUPPORT);
551 	}
552 
553 	RADIX_NODE_HEAD_LOCK(root);
554 	node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes);
555 	if (node == aip->a_nodes) {
556 		LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
557 		peer->p_aips_num++;
558 	} else if (!node)
559 		node = root->rnh_lookup(&aip->a_addr, &aip->a_mask, &root->rh);
560 	if (!node) {
561 		free(aip, M_WG);
562 		ret = ENOMEM;
563 	} else if (node != aip->a_nodes) {
564 		free(aip, M_WG);
565 		aip = (struct wg_aip *)node;
566 		if (aip->a_peer != peer) {
567 			LIST_REMOVE(aip, a_entry);
568 			aip->a_peer->p_aips_num--;
569 			aip->a_peer = peer;
570 			LIST_INSERT_HEAD(&peer->p_aips, aip, a_entry);
571 			aip->a_peer->p_aips_num++;
572 		}
573 	}
574 	RADIX_NODE_HEAD_UNLOCK(root);
575 	return (ret);
576 }
577 
578 static struct wg_peer *
wg_aip_lookup(struct wg_softc * sc,sa_family_t af,void * a)579 wg_aip_lookup(struct wg_softc *sc, sa_family_t af, void *a)
580 {
581 	struct radix_node_head	*root;
582 	struct radix_node	*node;
583 	struct wg_peer		*peer;
584 	struct aip_addr		 addr;
585 	RADIX_NODE_HEAD_RLOCK_TRACKER;
586 
587 	switch (af) {
588 	case AF_INET:
589 		root = sc->sc_aip4;
590 		memcpy(&addr.in, a, sizeof(addr.in));
591 		addr.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr);
592 		break;
593 	case AF_INET6:
594 		root = sc->sc_aip6;
595 		memcpy(&addr.in6, a, sizeof(addr.in6));
596 		addr.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr);
597 		break;
598 	default:
599 		return NULL;
600 	}
601 
602 	RADIX_NODE_HEAD_RLOCK(root);
603 	node = root->rnh_matchaddr(&addr, &root->rh);
604 	if (node != NULL) {
605 		peer = ((struct wg_aip *)node)->a_peer;
606 		noise_remote_ref(peer->p_remote);
607 	} else {
608 		peer = NULL;
609 	}
610 	RADIX_NODE_HEAD_RUNLOCK(root);
611 
612 	return (peer);
613 }
614 
615 static void
wg_aip_remove_all(struct wg_softc * sc,struct wg_peer * peer)616 wg_aip_remove_all(struct wg_softc *sc, struct wg_peer *peer)
617 {
618 	struct wg_aip		*aip, *taip;
619 
620 	RADIX_NODE_HEAD_LOCK(sc->sc_aip4);
621 	LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
622 		if (aip->a_af == AF_INET) {
623 			if (sc->sc_aip4->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip4->rh) == NULL)
624 				panic("failed to delete aip %p", aip);
625 			LIST_REMOVE(aip, a_entry);
626 			peer->p_aips_num--;
627 			free(aip, M_WG);
628 		}
629 	}
630 	RADIX_NODE_HEAD_UNLOCK(sc->sc_aip4);
631 
632 	RADIX_NODE_HEAD_LOCK(sc->sc_aip6);
633 	LIST_FOREACH_SAFE(aip, &peer->p_aips, a_entry, taip) {
634 		if (aip->a_af == AF_INET6) {
635 			if (sc->sc_aip6->rnh_deladdr(&aip->a_addr, &aip->a_mask, &sc->sc_aip6->rh) == NULL)
636 				panic("failed to delete aip %p", aip);
637 			LIST_REMOVE(aip, a_entry);
638 			peer->p_aips_num--;
639 			free(aip, M_WG);
640 		}
641 	}
642 	RADIX_NODE_HEAD_UNLOCK(sc->sc_aip6);
643 
644 	if (!LIST_EMPTY(&peer->p_aips) || peer->p_aips_num != 0)
645 		panic("wg_aip_remove_all could not delete all %p", peer);
646 }
647 
648 static int
wg_socket_init(struct wg_softc * sc,in_port_t port)649 wg_socket_init(struct wg_softc *sc, in_port_t port)
650 {
651 	struct ucred *cred = sc->sc_ucred;
652 	struct socket *so4 = NULL, *so6 = NULL;
653 	int rc;
654 
655 	sx_assert(&sc->sc_lock, SX_XLOCKED);
656 
657 	if (!cred)
658 		return (EBUSY);
659 
660 	/*
661 	 * For socket creation, we use the creds of the thread that created the
662 	 * tunnel rather than the current thread to maintain the semantics that
663 	 * WireGuard has on Linux with network namespaces -- that the sockets
664 	 * are created in their home vnet so that they can be configured and
665 	 * functionally attached to a foreign vnet as the jail's only interface
666 	 * to the network.
667 	 */
668 #ifdef INET
669 	rc = socreate(AF_INET, &so4, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
670 	if (rc)
671 		goto out;
672 
673 	rc = udp_set_kernel_tunneling(so4, wg_input, NULL, sc);
674 	/*
675 	 * udp_set_kernel_tunneling can only fail if there is already a tunneling function set.
676 	 * This should never happen with a new socket.
677 	 */
678 	MPASS(rc == 0);
679 #endif
680 
681 #ifdef INET6
682 	rc = socreate(AF_INET6, &so6, SOCK_DGRAM, IPPROTO_UDP, cred, curthread);
683 	if (rc)
684 		goto out;
685 	rc = udp_set_kernel_tunneling(so6, wg_input, NULL, sc);
686 	MPASS(rc == 0);
687 #endif
688 
689 	if (sc->sc_socket.so_user_cookie) {
690 		rc = wg_socket_set_sockopt(so4, so6, SO_USER_COOKIE, &sc->sc_socket.so_user_cookie, sizeof(sc->sc_socket.so_user_cookie));
691 		if (rc)
692 			goto out;
693 	}
694 	rc = wg_socket_set_sockopt(so4, so6, SO_SETFIB, &sc->sc_socket.so_fibnum, sizeof(sc->sc_socket.so_fibnum));
695 	if (rc)
696 		goto out;
697 
698 	rc = wg_socket_bind(&so4, &so6, &port);
699 	if (!rc) {
700 		sc->sc_socket.so_port = port;
701 		wg_socket_set(sc, so4, so6);
702 	}
703 out:
704 	if (rc) {
705 		if (so4 != NULL)
706 			soclose(so4);
707 		if (so6 != NULL)
708 			soclose(so6);
709 	}
710 	return (rc);
711 }
712 
wg_socket_set_sockopt(struct socket * so4,struct socket * so6,int name,void * val,size_t len)713 static int wg_socket_set_sockopt(struct socket *so4, struct socket *so6, int name, void *val, size_t len)
714 {
715 	int ret4 = 0, ret6 = 0;
716 	struct sockopt sopt = {
717 		.sopt_dir = SOPT_SET,
718 		.sopt_level = SOL_SOCKET,
719 		.sopt_name = name,
720 		.sopt_val = val,
721 		.sopt_valsize = len
722 	};
723 
724 	if (so4)
725 		ret4 = sosetopt(so4, &sopt);
726 	if (so6)
727 		ret6 = sosetopt(so6, &sopt);
728 	return (ret4 ?: ret6);
729 }
730 
wg_socket_set_cookie(struct wg_softc * sc,uint32_t user_cookie)731 static int wg_socket_set_cookie(struct wg_softc *sc, uint32_t user_cookie)
732 {
733 	struct wg_socket *so = &sc->sc_socket;
734 	int ret;
735 
736 	sx_assert(&sc->sc_lock, SX_XLOCKED);
737 	ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_USER_COOKIE, &user_cookie, sizeof(user_cookie));
738 	if (!ret)
739 		so->so_user_cookie = user_cookie;
740 	return (ret);
741 }
742 
wg_socket_set_fibnum(struct wg_softc * sc,int fibnum)743 static int wg_socket_set_fibnum(struct wg_softc *sc, int fibnum)
744 {
745 	struct wg_socket *so = &sc->sc_socket;
746 	int ret;
747 
748 	sx_assert(&sc->sc_lock, SX_XLOCKED);
749 
750 	ret = wg_socket_set_sockopt(so->so_so4, so->so_so6, SO_SETFIB, &fibnum, sizeof(fibnum));
751 	if (!ret)
752 		so->so_fibnum = fibnum;
753 	return (ret);
754 }
755 
756 static void
wg_socket_uninit(struct wg_softc * sc)757 wg_socket_uninit(struct wg_softc *sc)
758 {
759 	wg_socket_set(sc, NULL, NULL);
760 }
761 
762 static void
wg_socket_set(struct wg_softc * sc,struct socket * new_so4,struct socket * new_so6)763 wg_socket_set(struct wg_softc *sc, struct socket *new_so4, struct socket *new_so6)
764 {
765 	struct wg_socket *so = &sc->sc_socket;
766 	struct socket *so4, *so6;
767 
768 	sx_assert(&sc->sc_lock, SX_XLOCKED);
769 
770 	so4 = atomic_load_ptr(&so->so_so4);
771 	so6 = atomic_load_ptr(&so->so_so6);
772 	atomic_store_ptr(&so->so_so4, new_so4);
773 	atomic_store_ptr(&so->so_so6, new_so6);
774 
775 	if (!so4 && !so6)
776 		return;
777 	NET_EPOCH_WAIT();
778 	if (so4)
779 		soclose(so4);
780 	if (so6)
781 		soclose(so6);
782 }
783 
784 static int
wg_socket_bind(struct socket ** in_so4,struct socket ** in_so6,in_port_t * requested_port)785 wg_socket_bind(struct socket **in_so4, struct socket **in_so6, in_port_t *requested_port)
786 {
787 	struct socket *so4 = *in_so4, *so6 = *in_so6;
788 	int ret4 = 0, ret6 = 0;
789 	in_port_t port = *requested_port;
790 	struct sockaddr_in sin = {
791 		.sin_len = sizeof(struct sockaddr_in),
792 		.sin_family = AF_INET,
793 		.sin_port = htons(port)
794 	};
795 	struct sockaddr_in6 sin6 = {
796 		.sin6_len = sizeof(struct sockaddr_in6),
797 		.sin6_family = AF_INET6,
798 		.sin6_port = htons(port)
799 	};
800 
801 	if (so4) {
802 		ret4 = sobind(so4, (struct sockaddr *)&sin, curthread);
803 		if (ret4 && ret4 != EADDRNOTAVAIL)
804 			return (ret4);
805 		if (!ret4 && !sin.sin_port) {
806 			struct sockaddr_in *bound_sin;
807 			int ret = so4->so_proto->pr_sockaddr(so4,
808 			    (struct sockaddr **)&bound_sin);
809 			if (ret)
810 				return (ret);
811 			port = ntohs(bound_sin->sin_port);
812 			sin6.sin6_port = bound_sin->sin_port;
813 			free(bound_sin, M_SONAME);
814 		}
815 	}
816 
817 	if (so6) {
818 		ret6 = sobind(so6, (struct sockaddr *)&sin6, curthread);
819 		if (ret6 && ret6 != EADDRNOTAVAIL)
820 			return (ret6);
821 		if (!ret6 && !sin6.sin6_port) {
822 			struct sockaddr_in6 *bound_sin6;
823 			int ret = so6->so_proto->pr_sockaddr(so6,
824 			    (struct sockaddr **)&bound_sin6);
825 			if (ret)
826 				return (ret);
827 			port = ntohs(bound_sin6->sin6_port);
828 			free(bound_sin6, M_SONAME);
829 		}
830 	}
831 
832 	if (ret4 && ret6)
833 		return (ret4);
834 	*requested_port = port;
835 	if (ret4 && !ret6 && so4) {
836 		soclose(so4);
837 		*in_so4 = NULL;
838 	} else if (ret6 && !ret4 && so6) {
839 		soclose(so6);
840 		*in_so6 = NULL;
841 	}
842 	return (0);
843 }
844 
845 static int
wg_send(struct wg_softc * sc,struct wg_endpoint * e,struct mbuf * m)846 wg_send(struct wg_softc *sc, struct wg_endpoint *e, struct mbuf *m)
847 {
848 	struct epoch_tracker et;
849 	struct sockaddr *sa;
850 	struct wg_socket *so = &sc->sc_socket;
851 	struct socket *so4, *so6;
852 	struct mbuf *control = NULL;
853 	int ret = 0;
854 	size_t len = m->m_pkthdr.len;
855 
856 	/* Get local control address before locking */
857 	if (e->e_remote.r_sa.sa_family == AF_INET) {
858 		if (e->e_local.l_in.s_addr != INADDR_ANY)
859 			control = sbcreatecontrol((caddr_t)&e->e_local.l_in,
860 			    sizeof(struct in_addr), IP_SENDSRCADDR,
861 			    IPPROTO_IP, M_NOWAIT);
862 #ifdef INET6
863 	} else if (e->e_remote.r_sa.sa_family == AF_INET6) {
864 		if (!IN6_IS_ADDR_UNSPECIFIED(&e->e_local.l_in6))
865 			control = sbcreatecontrol((caddr_t)&e->e_local.l_pktinfo6,
866 			    sizeof(struct in6_pktinfo), IPV6_PKTINFO,
867 			    IPPROTO_IPV6, M_NOWAIT);
868 #endif
869 	} else {
870 		m_freem(m);
871 		return (EAFNOSUPPORT);
872 	}
873 
874 	/* Get remote address */
875 	sa = &e->e_remote.r_sa;
876 
877 	NET_EPOCH_ENTER(et);
878 	so4 = atomic_load_ptr(&so->so_so4);
879 	so6 = atomic_load_ptr(&so->so_so6);
880 	if (e->e_remote.r_sa.sa_family == AF_INET && so4 != NULL)
881 		ret = sosend(so4, sa, NULL, m, control, 0, curthread);
882 	else if (e->e_remote.r_sa.sa_family == AF_INET6 && so6 != NULL)
883 		ret = sosend(so6, sa, NULL, m, control, 0, curthread);
884 	else {
885 		ret = ENOTCONN;
886 		m_freem(control);
887 		m_freem(m);
888 	}
889 	NET_EPOCH_EXIT(et);
890 	if (ret == 0) {
891 		if_inc_counter(sc->sc_ifp, IFCOUNTER_OPACKETS, 1);
892 		if_inc_counter(sc->sc_ifp, IFCOUNTER_OBYTES, len);
893 	}
894 	return (ret);
895 }
896 
897 static void
wg_send_buf(struct wg_softc * sc,struct wg_endpoint * e,uint8_t * buf,size_t len)898 wg_send_buf(struct wg_softc *sc, struct wg_endpoint *e, uint8_t *buf, size_t len)
899 {
900 	struct mbuf	*m;
901 	int		 ret = 0;
902 	bool		 retried = false;
903 
904 retry:
905 	m = m_get2(len, M_NOWAIT, MT_DATA, M_PKTHDR);
906 	if (!m) {
907 		ret = ENOMEM;
908 		goto out;
909 	}
910 	m_copyback(m, 0, len, buf);
911 
912 	if (ret == 0) {
913 		ret = wg_send(sc, e, m);
914 		/* Retry if we couldn't bind to e->e_local */
915 		if (ret == EADDRNOTAVAIL && !retried) {
916 			bzero(&e->e_local, sizeof(e->e_local));
917 			retried = true;
918 			goto retry;
919 		}
920 	} else {
921 		ret = wg_send(sc, e, m);
922 	}
923 out:
924 	if (ret)
925 		DPRINTF(sc, "Unable to send packet: %d\n", ret);
926 }
927 
928 /* Timers */
929 static void
wg_timers_enable(struct wg_peer * peer)930 wg_timers_enable(struct wg_peer *peer)
931 {
932 	atomic_store_bool(&peer->p_enabled, true);
933 	wg_timers_run_persistent_keepalive(peer);
934 }
935 
936 static void
wg_timers_disable(struct wg_peer * peer)937 wg_timers_disable(struct wg_peer *peer)
938 {
939 	/* By setting p_enabled = false, then calling NET_EPOCH_WAIT, we can be
940 	 * sure no new handshakes are created after the wait. This is because
941 	 * all callout_resets (scheduling the callout) are guarded by
942 	 * p_enabled. We can be sure all sections that read p_enabled and then
943 	 * optionally call callout_reset are finished as they are surrounded by
944 	 * NET_EPOCH_{ENTER,EXIT}.
945 	 *
946 	 * However, as new callouts may be scheduled during NET_EPOCH_WAIT (but
947 	 * not after), we stop all callouts leaving no callouts active.
948 	 *
949 	 * We should also pull NET_EPOCH_WAIT out of the FOREACH(peer) loops, but the
950 	 * performance impact is acceptable for the time being. */
951 	atomic_store_bool(&peer->p_enabled, false);
952 	NET_EPOCH_WAIT();
953 	atomic_store_bool(&peer->p_need_another_keepalive, false);
954 
955 	callout_stop(&peer->p_new_handshake);
956 	callout_stop(&peer->p_send_keepalive);
957 	callout_stop(&peer->p_retry_handshake);
958 	callout_stop(&peer->p_persistent_keepalive);
959 	callout_stop(&peer->p_zero_key_material);
960 }
961 
962 static void
wg_timers_set_persistent_keepalive(struct wg_peer * peer,uint16_t interval)963 wg_timers_set_persistent_keepalive(struct wg_peer *peer, uint16_t interval)
964 {
965 	struct epoch_tracker et;
966 	if (interval != peer->p_persistent_keepalive_interval) {
967 		atomic_store_16(&peer->p_persistent_keepalive_interval, interval);
968 		NET_EPOCH_ENTER(et);
969 		if (atomic_load_bool(&peer->p_enabled))
970 			wg_timers_run_persistent_keepalive(peer);
971 		NET_EPOCH_EXIT(et);
972 	}
973 }
974 
975 static void
wg_timers_get_last_handshake(struct wg_peer * peer,struct wg_timespec64 * time)976 wg_timers_get_last_handshake(struct wg_peer *peer, struct wg_timespec64 *time)
977 {
978 	mtx_lock(&peer->p_handshake_mtx);
979 	time->tv_sec = peer->p_handshake_complete.tv_sec;
980 	time->tv_nsec = peer->p_handshake_complete.tv_nsec;
981 	mtx_unlock(&peer->p_handshake_mtx);
982 }
983 
984 static void
wg_timers_event_data_sent(struct wg_peer * peer)985 wg_timers_event_data_sent(struct wg_peer *peer)
986 {
987 	struct epoch_tracker et;
988 	NET_EPOCH_ENTER(et);
989 	if (atomic_load_bool(&peer->p_enabled) &&
990 	    !callout_pending(&peer->p_new_handshake))
991 		callout_reset(&peer->p_new_handshake, MSEC_2_TICKS(
992 		    NEW_HANDSHAKE_TIMEOUT * 1000 +
993 		    arc4random_uniform(REKEY_TIMEOUT_JITTER)),
994 		    wg_timers_run_new_handshake, peer);
995 	NET_EPOCH_EXIT(et);
996 }
997 
998 static void
wg_timers_event_data_received(struct wg_peer * peer)999 wg_timers_event_data_received(struct wg_peer *peer)
1000 {
1001 	struct epoch_tracker et;
1002 	NET_EPOCH_ENTER(et);
1003 	if (atomic_load_bool(&peer->p_enabled)) {
1004 		if (!callout_pending(&peer->p_send_keepalive))
1005 			callout_reset(&peer->p_send_keepalive,
1006 			    MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1007 			    wg_timers_run_send_keepalive, peer);
1008 		else
1009 			atomic_store_bool(&peer->p_need_another_keepalive,
1010 			    true);
1011 	}
1012 	NET_EPOCH_EXIT(et);
1013 }
1014 
1015 static void
wg_timers_event_any_authenticated_packet_sent(struct wg_peer * peer)1016 wg_timers_event_any_authenticated_packet_sent(struct wg_peer *peer)
1017 {
1018 	callout_stop(&peer->p_send_keepalive);
1019 }
1020 
1021 static void
wg_timers_event_any_authenticated_packet_received(struct wg_peer * peer)1022 wg_timers_event_any_authenticated_packet_received(struct wg_peer *peer)
1023 {
1024 	callout_stop(&peer->p_new_handshake);
1025 }
1026 
1027 static void
wg_timers_event_any_authenticated_packet_traversal(struct wg_peer * peer)1028 wg_timers_event_any_authenticated_packet_traversal(struct wg_peer *peer)
1029 {
1030 	struct epoch_tracker et;
1031 	uint16_t interval;
1032 	NET_EPOCH_ENTER(et);
1033 	interval = atomic_load_16(&peer->p_persistent_keepalive_interval);
1034 	if (atomic_load_bool(&peer->p_enabled) && interval > 0)
1035 		callout_reset(&peer->p_persistent_keepalive,
1036 		     MSEC_2_TICKS(interval * 1000),
1037 		     wg_timers_run_persistent_keepalive, peer);
1038 	NET_EPOCH_EXIT(et);
1039 }
1040 
1041 static void
wg_timers_event_handshake_initiated(struct wg_peer * peer)1042 wg_timers_event_handshake_initiated(struct wg_peer *peer)
1043 {
1044 	struct epoch_tracker et;
1045 	NET_EPOCH_ENTER(et);
1046 	if (atomic_load_bool(&peer->p_enabled))
1047 		callout_reset(&peer->p_retry_handshake, MSEC_2_TICKS(
1048 		    REKEY_TIMEOUT * 1000 +
1049 		    arc4random_uniform(REKEY_TIMEOUT_JITTER)),
1050 		    wg_timers_run_retry_handshake, peer);
1051 	NET_EPOCH_EXIT(et);
1052 }
1053 
1054 static void
wg_timers_event_handshake_complete(struct wg_peer * peer)1055 wg_timers_event_handshake_complete(struct wg_peer *peer)
1056 {
1057 	struct epoch_tracker et;
1058 	NET_EPOCH_ENTER(et);
1059 	if (atomic_load_bool(&peer->p_enabled)) {
1060 		mtx_lock(&peer->p_handshake_mtx);
1061 		callout_stop(&peer->p_retry_handshake);
1062 		peer->p_handshake_retries = 0;
1063 		getnanotime(&peer->p_handshake_complete);
1064 		mtx_unlock(&peer->p_handshake_mtx);
1065 		wg_timers_run_send_keepalive(peer);
1066 	}
1067 	NET_EPOCH_EXIT(et);
1068 }
1069 
1070 static void
wg_timers_event_session_derived(struct wg_peer * peer)1071 wg_timers_event_session_derived(struct wg_peer *peer)
1072 {
1073 	struct epoch_tracker et;
1074 	NET_EPOCH_ENTER(et);
1075 	if (atomic_load_bool(&peer->p_enabled))
1076 		callout_reset(&peer->p_zero_key_material,
1077 		    MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1078 		    wg_timers_run_zero_key_material, peer);
1079 	NET_EPOCH_EXIT(et);
1080 }
1081 
1082 static void
wg_timers_event_want_initiation(struct wg_peer * peer)1083 wg_timers_event_want_initiation(struct wg_peer *peer)
1084 {
1085 	struct epoch_tracker et;
1086 	NET_EPOCH_ENTER(et);
1087 	if (atomic_load_bool(&peer->p_enabled))
1088 		wg_timers_run_send_initiation(peer, false);
1089 	NET_EPOCH_EXIT(et);
1090 }
1091 
1092 static void
wg_timers_run_send_initiation(struct wg_peer * peer,bool is_retry)1093 wg_timers_run_send_initiation(struct wg_peer *peer, bool is_retry)
1094 {
1095 	if (!is_retry)
1096 		peer->p_handshake_retries = 0;
1097 	if (noise_remote_initiation_expired(peer->p_remote) == ETIMEDOUT)
1098 		wg_send_initiation(peer);
1099 }
1100 
1101 static void
wg_timers_run_retry_handshake(void * _peer)1102 wg_timers_run_retry_handshake(void *_peer)
1103 {
1104 	struct epoch_tracker et;
1105 	struct wg_peer *peer = _peer;
1106 
1107 	mtx_lock(&peer->p_handshake_mtx);
1108 	if (peer->p_handshake_retries <= MAX_TIMER_HANDSHAKES) {
1109 		peer->p_handshake_retries++;
1110 		mtx_unlock(&peer->p_handshake_mtx);
1111 
1112 		DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1113 		    "after %d seconds, retrying (try %d)\n", peer->p_id,
1114 		    REKEY_TIMEOUT, peer->p_handshake_retries + 1);
1115 		wg_peer_clear_src(peer);
1116 		wg_timers_run_send_initiation(peer, true);
1117 	} else {
1118 		mtx_unlock(&peer->p_handshake_mtx);
1119 
1120 		DPRINTF(peer->p_sc, "Handshake for peer %" PRIu64 " did not complete "
1121 		    "after %d retries, giving up\n", peer->p_id,
1122 		    MAX_TIMER_HANDSHAKES + 2);
1123 
1124 		callout_stop(&peer->p_send_keepalive);
1125 		wg_queue_purge(&peer->p_stage_queue);
1126 		NET_EPOCH_ENTER(et);
1127 		if (atomic_load_bool(&peer->p_enabled) &&
1128 		    !callout_pending(&peer->p_zero_key_material))
1129 			callout_reset(&peer->p_zero_key_material,
1130 			    MSEC_2_TICKS(REJECT_AFTER_TIME * 3 * 1000),
1131 			    wg_timers_run_zero_key_material, peer);
1132 		NET_EPOCH_EXIT(et);
1133 	}
1134 }
1135 
1136 static void
wg_timers_run_send_keepalive(void * _peer)1137 wg_timers_run_send_keepalive(void *_peer)
1138 {
1139 	struct epoch_tracker et;
1140 	struct wg_peer *peer = _peer;
1141 
1142 	wg_send_keepalive(peer);
1143 	NET_EPOCH_ENTER(et);
1144 	if (atomic_load_bool(&peer->p_enabled) &&
1145 	    atomic_load_bool(&peer->p_need_another_keepalive)) {
1146 		atomic_store_bool(&peer->p_need_another_keepalive, false);
1147 		callout_reset(&peer->p_send_keepalive,
1148 		    MSEC_2_TICKS(KEEPALIVE_TIMEOUT * 1000),
1149 		    wg_timers_run_send_keepalive, peer);
1150 	}
1151 	NET_EPOCH_EXIT(et);
1152 }
1153 
1154 static void
wg_timers_run_new_handshake(void * _peer)1155 wg_timers_run_new_handshake(void *_peer)
1156 {
1157 	struct wg_peer *peer = _peer;
1158 
1159 	DPRINTF(peer->p_sc, "Retrying handshake with peer %" PRIu64 " because we "
1160 	    "stopped hearing back after %d seconds\n",
1161 	    peer->p_id, NEW_HANDSHAKE_TIMEOUT);
1162 
1163 	wg_peer_clear_src(peer);
1164 	wg_timers_run_send_initiation(peer, false);
1165 }
1166 
1167 static void
wg_timers_run_zero_key_material(void * _peer)1168 wg_timers_run_zero_key_material(void *_peer)
1169 {
1170 	struct wg_peer *peer = _peer;
1171 
1172 	DPRINTF(peer->p_sc, "Zeroing out keys for peer %" PRIu64 ", since we "
1173 	    "haven't received a new one in %d seconds\n",
1174 	    peer->p_id, REJECT_AFTER_TIME * 3);
1175 	noise_remote_keypairs_clear(peer->p_remote);
1176 }
1177 
1178 static void
wg_timers_run_persistent_keepalive(void * _peer)1179 wg_timers_run_persistent_keepalive(void *_peer)
1180 {
1181 	struct wg_peer *peer = _peer;
1182 
1183 	if (atomic_load_16(&peer->p_persistent_keepalive_interval) > 0)
1184 		wg_send_keepalive(peer);
1185 }
1186 
1187 /* TODO Handshake */
1188 static void
wg_peer_send_buf(struct wg_peer * peer,uint8_t * buf,size_t len)1189 wg_peer_send_buf(struct wg_peer *peer, uint8_t *buf, size_t len)
1190 {
1191 	struct wg_endpoint endpoint;
1192 
1193 	counter_u64_add(peer->p_tx_bytes, len);
1194 	wg_timers_event_any_authenticated_packet_traversal(peer);
1195 	wg_timers_event_any_authenticated_packet_sent(peer);
1196 	wg_peer_get_endpoint(peer, &endpoint);
1197 	wg_send_buf(peer->p_sc, &endpoint, buf, len);
1198 }
1199 
1200 static void
wg_send_initiation(struct wg_peer * peer)1201 wg_send_initiation(struct wg_peer *peer)
1202 {
1203 	struct wg_pkt_initiation pkt;
1204 
1205 	if (noise_create_initiation(peer->p_remote, &pkt.s_idx, pkt.ue,
1206 	    pkt.es, pkt.ets) != 0)
1207 		return;
1208 
1209 	DPRINTF(peer->p_sc, "Sending handshake initiation to peer %" PRIu64 "\n", peer->p_id);
1210 
1211 	pkt.t = WG_PKT_INITIATION;
1212 	cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1213 	    sizeof(pkt) - sizeof(pkt.m));
1214 	wg_peer_send_buf(peer, (uint8_t *)&pkt, sizeof(pkt));
1215 	wg_timers_event_handshake_initiated(peer);
1216 }
1217 
1218 static void
wg_send_response(struct wg_peer * peer)1219 wg_send_response(struct wg_peer *peer)
1220 {
1221 	struct wg_pkt_response pkt;
1222 
1223 	if (noise_create_response(peer->p_remote, &pkt.s_idx, &pkt.r_idx,
1224 	    pkt.ue, pkt.en) != 0)
1225 		return;
1226 
1227 	DPRINTF(peer->p_sc, "Sending handshake response to peer %" PRIu64 "\n", peer->p_id);
1228 
1229 	wg_timers_event_session_derived(peer);
1230 	pkt.t = WG_PKT_RESPONSE;
1231 	cookie_maker_mac(&peer->p_cookie, &pkt.m, &pkt,
1232 	     sizeof(pkt)-sizeof(pkt.m));
1233 	wg_peer_send_buf(peer, (uint8_t*)&pkt, sizeof(pkt));
1234 }
1235 
1236 static void
wg_send_cookie(struct wg_softc * sc,struct cookie_macs * cm,uint32_t idx,struct wg_endpoint * e)1237 wg_send_cookie(struct wg_softc *sc, struct cookie_macs *cm, uint32_t idx,
1238     struct wg_endpoint *e)
1239 {
1240 	struct wg_pkt_cookie	pkt;
1241 
1242 	DPRINTF(sc, "Sending cookie response for denied handshake message\n");
1243 
1244 	pkt.t = WG_PKT_COOKIE;
1245 	pkt.r_idx = idx;
1246 
1247 	cookie_checker_create_payload(&sc->sc_cookie, cm, pkt.nonce,
1248 	    pkt.ec, &e->e_remote.r_sa);
1249 	wg_send_buf(sc, e, (uint8_t *)&pkt, sizeof(pkt));
1250 }
1251 
1252 static void
wg_send_keepalive(struct wg_peer * peer)1253 wg_send_keepalive(struct wg_peer *peer)
1254 {
1255 	struct wg_packet *pkt;
1256 	struct mbuf *m;
1257 
1258 	if (wg_queue_len(&peer->p_stage_queue) > 0)
1259 		goto send;
1260 	if ((m = m_gethdr(M_NOWAIT, MT_DATA)) == NULL)
1261 		return;
1262 	if ((pkt = wg_packet_alloc(m)) == NULL) {
1263 		m_freem(m);
1264 		return;
1265 	}
1266 	wg_queue_push_staged(&peer->p_stage_queue, pkt);
1267 	DPRINTF(peer->p_sc, "Sending keepalive packet to peer %" PRIu64 "\n", peer->p_id);
1268 send:
1269 	wg_peer_send_staged(peer);
1270 }
1271 
1272 static void
wg_handshake(struct wg_softc * sc,struct wg_packet * pkt)1273 wg_handshake(struct wg_softc *sc, struct wg_packet *pkt)
1274 {
1275 	struct wg_pkt_initiation	*init;
1276 	struct wg_pkt_response		*resp;
1277 	struct wg_pkt_cookie		*cook;
1278 	struct wg_endpoint		*e;
1279 	struct wg_peer			*peer;
1280 	struct mbuf			*m;
1281 	struct noise_remote		*remote = NULL;
1282 	int				 res;
1283 	bool				 underload = false;
1284 	static sbintime_t		 wg_last_underload; /* sbinuptime */
1285 
1286 	underload = wg_queue_len(&sc->sc_handshake_queue) >= MAX_QUEUED_HANDSHAKES / 8;
1287 	if (underload) {
1288 		wg_last_underload = getsbinuptime();
1289 	} else if (wg_last_underload) {
1290 		underload = wg_last_underload + UNDERLOAD_TIMEOUT * SBT_1S > getsbinuptime();
1291 		if (!underload)
1292 			wg_last_underload = 0;
1293 	}
1294 
1295 	m = pkt->p_mbuf;
1296 	e = &pkt->p_endpoint;
1297 
1298 	if ((pkt->p_mbuf = m = m_pullup(m, m->m_pkthdr.len)) == NULL)
1299 		goto error;
1300 
1301 	switch (*mtod(m, uint32_t *)) {
1302 	case WG_PKT_INITIATION:
1303 		init = mtod(m, struct wg_pkt_initiation *);
1304 
1305 		res = cookie_checker_validate_macs(&sc->sc_cookie, &init->m,
1306 				init, sizeof(*init) - sizeof(init->m),
1307 				underload, &e->e_remote.r_sa,
1308 				if_getvnet(sc->sc_ifp));
1309 
1310 		if (res == EINVAL) {
1311 			DPRINTF(sc, "Invalid initiation MAC\n");
1312 			goto error;
1313 		} else if (res == ECONNREFUSED) {
1314 			DPRINTF(sc, "Handshake ratelimited\n");
1315 			goto error;
1316 		} else if (res == EAGAIN) {
1317 			wg_send_cookie(sc, &init->m, init->s_idx, e);
1318 			goto error;
1319 		} else if (res != 0) {
1320 			panic("unexpected response: %d\n", res);
1321 		}
1322 
1323 		if (noise_consume_initiation(sc->sc_local, &remote,
1324 		    init->s_idx, init->ue, init->es, init->ets) != 0) {
1325 			DPRINTF(sc, "Invalid handshake initiation\n");
1326 			goto error;
1327 		}
1328 
1329 		peer = noise_remote_arg(remote);
1330 
1331 		DPRINTF(sc, "Receiving handshake initiation from peer %" PRIu64 "\n", peer->p_id);
1332 
1333 		wg_peer_set_endpoint(peer, e);
1334 		wg_send_response(peer);
1335 		break;
1336 	case WG_PKT_RESPONSE:
1337 		resp = mtod(m, struct wg_pkt_response *);
1338 
1339 		res = cookie_checker_validate_macs(&sc->sc_cookie, &resp->m,
1340 				resp, sizeof(*resp) - sizeof(resp->m),
1341 				underload, &e->e_remote.r_sa,
1342 				if_getvnet(sc->sc_ifp));
1343 
1344 		if (res == EINVAL) {
1345 			DPRINTF(sc, "Invalid response MAC\n");
1346 			goto error;
1347 		} else if (res == ECONNREFUSED) {
1348 			DPRINTF(sc, "Handshake ratelimited\n");
1349 			goto error;
1350 		} else if (res == EAGAIN) {
1351 			wg_send_cookie(sc, &resp->m, resp->s_idx, e);
1352 			goto error;
1353 		} else if (res != 0) {
1354 			panic("unexpected response: %d\n", res);
1355 		}
1356 
1357 		if (noise_consume_response(sc->sc_local, &remote,
1358 		    resp->s_idx, resp->r_idx, resp->ue, resp->en) != 0) {
1359 			DPRINTF(sc, "Invalid handshake response\n");
1360 			goto error;
1361 		}
1362 
1363 		peer = noise_remote_arg(remote);
1364 		DPRINTF(sc, "Receiving handshake response from peer %" PRIu64 "\n", peer->p_id);
1365 
1366 		wg_peer_set_endpoint(peer, e);
1367 		wg_timers_event_session_derived(peer);
1368 		wg_timers_event_handshake_complete(peer);
1369 		break;
1370 	case WG_PKT_COOKIE:
1371 		cook = mtod(m, struct wg_pkt_cookie *);
1372 
1373 		if ((remote = noise_remote_index(sc->sc_local, cook->r_idx)) == NULL) {
1374 			DPRINTF(sc, "Unknown cookie index\n");
1375 			goto error;
1376 		}
1377 
1378 		peer = noise_remote_arg(remote);
1379 
1380 		if (cookie_maker_consume_payload(&peer->p_cookie,
1381 		    cook->nonce, cook->ec) == 0) {
1382 			DPRINTF(sc, "Receiving cookie response\n");
1383 		} else {
1384 			DPRINTF(sc, "Could not decrypt cookie response\n");
1385 			goto error;
1386 		}
1387 
1388 		goto not_authenticated;
1389 	default:
1390 		panic("invalid packet in handshake queue");
1391 	}
1392 
1393 	wg_timers_event_any_authenticated_packet_received(peer);
1394 	wg_timers_event_any_authenticated_packet_traversal(peer);
1395 
1396 not_authenticated:
1397 	counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len);
1398 	if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1399 	if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len);
1400 error:
1401 	if (remote != NULL)
1402 		noise_remote_put(remote);
1403 	wg_packet_free(pkt);
1404 }
1405 
1406 static void
wg_softc_handshake_receive(struct wg_softc * sc)1407 wg_softc_handshake_receive(struct wg_softc *sc)
1408 {
1409 	struct wg_packet *pkt;
1410 	while ((pkt = wg_queue_dequeue_handshake(&sc->sc_handshake_queue)) != NULL)
1411 		wg_handshake(sc, pkt);
1412 }
1413 
1414 static void
wg_mbuf_reset(struct mbuf * m)1415 wg_mbuf_reset(struct mbuf *m)
1416 {
1417 
1418 	struct m_tag *t, *tmp;
1419 
1420 	/*
1421 	 * We want to reset the mbuf to a newly allocated state, containing
1422 	 * just the packet contents. Unfortunately FreeBSD doesn't seem to
1423 	 * offer this anywhere, so we have to make it up as we go. If we can
1424 	 * get this in kern/kern_mbuf.c, that would be best.
1425 	 *
1426 	 * Notice: this may break things unexpectedly but it is better to fail
1427 	 *         closed in the extreme case than leak informtion in every
1428 	 *         case.
1429 	 *
1430 	 * With that said, all this attempts to do is remove any extraneous
1431 	 * information that could be present.
1432 	 */
1433 
1434 	M_ASSERTPKTHDR(m);
1435 
1436 	m->m_flags &= ~(M_BCAST|M_MCAST|M_VLANTAG|M_PROMISC|M_PROTOFLAGS);
1437 
1438 	M_HASHTYPE_CLEAR(m);
1439 #ifdef NUMA
1440         m->m_pkthdr.numa_domain = M_NODOM;
1441 #endif
1442 	SLIST_FOREACH_SAFE(t, &m->m_pkthdr.tags, m_tag_link, tmp) {
1443 		if ((t->m_tag_id != 0 || t->m_tag_cookie != MTAG_WGLOOP) &&
1444 		    t->m_tag_id != PACKET_TAG_MACLABEL)
1445 			m_tag_delete(m, t);
1446 	}
1447 
1448 	KASSERT((m->m_pkthdr.csum_flags & CSUM_SND_TAG) == 0,
1449 	    ("%s: mbuf %p has a send tag", __func__, m));
1450 
1451 	m->m_pkthdr.csum_flags = 0;
1452 	m->m_pkthdr.PH_per.sixtyfour[0] = 0;
1453 	m->m_pkthdr.PH_loc.sixtyfour[0] = 0;
1454 }
1455 
1456 static inline unsigned int
calculate_padding(struct wg_packet * pkt)1457 calculate_padding(struct wg_packet *pkt)
1458 {
1459 	unsigned int padded_size, last_unit = pkt->p_mbuf->m_pkthdr.len;
1460 
1461 	/* Keepalive packets don't set p_mtu, but also have a length of zero. */
1462 	if (__predict_false(pkt->p_mtu == 0)) {
1463 		padded_size = (last_unit + (WG_PKT_PADDING - 1)) &
1464 		    ~(WG_PKT_PADDING - 1);
1465 		return (padded_size - last_unit);
1466 	}
1467 
1468 	if (__predict_false(last_unit > pkt->p_mtu))
1469 		last_unit %= pkt->p_mtu;
1470 
1471 	padded_size = (last_unit + (WG_PKT_PADDING - 1)) & ~(WG_PKT_PADDING - 1);
1472 	if (pkt->p_mtu < padded_size)
1473 		padded_size = pkt->p_mtu;
1474 	return (padded_size - last_unit);
1475 }
1476 
1477 static void
wg_encrypt(struct wg_softc * sc,struct wg_packet * pkt)1478 wg_encrypt(struct wg_softc *sc, struct wg_packet *pkt)
1479 {
1480 	static const uint8_t	 padding[WG_PKT_PADDING] = { 0 };
1481 	struct wg_pkt_data	*data;
1482 	struct wg_peer		*peer;
1483 	struct noise_remote	*remote;
1484 	struct mbuf		*m;
1485 	uint32_t		 idx;
1486 	unsigned int		 padlen;
1487 	enum wg_ring_state	 state = WG_PACKET_DEAD;
1488 
1489 	remote = noise_keypair_remote(pkt->p_keypair);
1490 	peer = noise_remote_arg(remote);
1491 	m = pkt->p_mbuf;
1492 
1493 	/* Pad the packet */
1494 	padlen = calculate_padding(pkt);
1495 	if (padlen != 0 && !m_append(m, padlen, padding))
1496 		goto out;
1497 
1498 	/* Do encryption */
1499 	if (noise_keypair_encrypt(pkt->p_keypair, &idx, pkt->p_nonce, m) != 0)
1500 		goto out;
1501 
1502 	/* Put header into packet */
1503 	M_PREPEND(m, sizeof(struct wg_pkt_data), M_NOWAIT);
1504 	if (m == NULL)
1505 		goto out;
1506 	data = mtod(m, struct wg_pkt_data *);
1507 	data->t = WG_PKT_DATA;
1508 	data->r_idx = idx;
1509 	data->nonce = htole64(pkt->p_nonce);
1510 
1511 	wg_mbuf_reset(m);
1512 	state = WG_PACKET_CRYPTED;
1513 out:
1514 	pkt->p_mbuf = m;
1515 	atomic_store_rel_int(&pkt->p_state, state);
1516 	GROUPTASK_ENQUEUE(&peer->p_send);
1517 	noise_remote_put(remote);
1518 }
1519 
1520 static void
wg_decrypt(struct wg_softc * sc,struct wg_packet * pkt)1521 wg_decrypt(struct wg_softc *sc, struct wg_packet *pkt)
1522 {
1523 	struct wg_peer		*peer, *allowed_peer;
1524 	struct noise_remote	*remote;
1525 	struct mbuf		*m;
1526 	int			 len;
1527 	enum wg_ring_state	 state = WG_PACKET_DEAD;
1528 
1529 	remote = noise_keypair_remote(pkt->p_keypair);
1530 	peer = noise_remote_arg(remote);
1531 	m = pkt->p_mbuf;
1532 
1533 	/* Read nonce and then adjust to remove the header. */
1534 	pkt->p_nonce = le64toh(mtod(m, struct wg_pkt_data *)->nonce);
1535 	m_adj(m, sizeof(struct wg_pkt_data));
1536 
1537 	if (noise_keypair_decrypt(pkt->p_keypair, pkt->p_nonce, m) != 0)
1538 		goto out;
1539 
1540 	/* A packet with length 0 is a keepalive packet */
1541 	if (__predict_false(m->m_pkthdr.len == 0)) {
1542 		DPRINTF(sc, "Receiving keepalive packet from peer "
1543 		    "%" PRIu64 "\n", peer->p_id);
1544 		state = WG_PACKET_CRYPTED;
1545 		goto out;
1546 	}
1547 
1548 	/*
1549 	 * We can let the network stack handle the intricate validation of the
1550 	 * IP header, we just worry about the sizeof and the version, so we can
1551 	 * read the source address in wg_aip_lookup.
1552 	 */
1553 
1554 	if (determine_af_and_pullup(&m, &pkt->p_af) == 0) {
1555 		if (pkt->p_af == AF_INET) {
1556 			struct ip *ip = mtod(m, struct ip *);
1557 			allowed_peer = wg_aip_lookup(sc, AF_INET, &ip->ip_src);
1558 			len = ntohs(ip->ip_len);
1559 			if (len >= sizeof(struct ip) && len < m->m_pkthdr.len)
1560 				m_adj(m, len - m->m_pkthdr.len);
1561 		} else if (pkt->p_af == AF_INET6) {
1562 			struct ip6_hdr *ip6 = mtod(m, struct ip6_hdr *);
1563 			allowed_peer = wg_aip_lookup(sc, AF_INET6, &ip6->ip6_src);
1564 			len = ntohs(ip6->ip6_plen) + sizeof(struct ip6_hdr);
1565 			if (len < m->m_pkthdr.len)
1566 				m_adj(m, len - m->m_pkthdr.len);
1567 		} else
1568 			panic("determine_af_and_pullup returned unexpected value");
1569 	} else {
1570 		DPRINTF(sc, "Packet is neither ipv4 nor ipv6 from peer %" PRIu64 "\n", peer->p_id);
1571 		goto out;
1572 	}
1573 
1574 	/* We only want to compare the address, not dereference, so drop the ref. */
1575 	if (allowed_peer != NULL)
1576 		noise_remote_put(allowed_peer->p_remote);
1577 
1578 	if (__predict_false(peer != allowed_peer)) {
1579 		DPRINTF(sc, "Packet has unallowed src IP from peer %" PRIu64 "\n", peer->p_id);
1580 		goto out;
1581 	}
1582 
1583 	wg_mbuf_reset(m);
1584 	state = WG_PACKET_CRYPTED;
1585 out:
1586 	pkt->p_mbuf = m;
1587 	atomic_store_rel_int(&pkt->p_state, state);
1588 	GROUPTASK_ENQUEUE(&peer->p_recv);
1589 	noise_remote_put(remote);
1590 }
1591 
1592 static void
wg_softc_decrypt(struct wg_softc * sc)1593 wg_softc_decrypt(struct wg_softc *sc)
1594 {
1595 	struct wg_packet *pkt;
1596 
1597 	while ((pkt = wg_queue_dequeue_parallel(&sc->sc_decrypt_parallel)) != NULL)
1598 		wg_decrypt(sc, pkt);
1599 }
1600 
1601 static void
wg_softc_encrypt(struct wg_softc * sc)1602 wg_softc_encrypt(struct wg_softc *sc)
1603 {
1604 	struct wg_packet *pkt;
1605 
1606 	while ((pkt = wg_queue_dequeue_parallel(&sc->sc_encrypt_parallel)) != NULL)
1607 		wg_encrypt(sc, pkt);
1608 }
1609 
1610 static void
wg_encrypt_dispatch(struct wg_softc * sc)1611 wg_encrypt_dispatch(struct wg_softc *sc)
1612 {
1613 	/*
1614 	 * The update to encrypt_last_cpu is racey such that we may
1615 	 * reschedule the task for the same CPU multiple times, but
1616 	 * the race doesn't really matter.
1617 	 */
1618 	u_int cpu = (sc->sc_encrypt_last_cpu + 1) % mp_ncpus;
1619 	sc->sc_encrypt_last_cpu = cpu;
1620 	GROUPTASK_ENQUEUE(&sc->sc_encrypt[cpu]);
1621 }
1622 
1623 static void
wg_decrypt_dispatch(struct wg_softc * sc)1624 wg_decrypt_dispatch(struct wg_softc *sc)
1625 {
1626 	u_int cpu = (sc->sc_decrypt_last_cpu + 1) % mp_ncpus;
1627 	sc->sc_decrypt_last_cpu = cpu;
1628 	GROUPTASK_ENQUEUE(&sc->sc_decrypt[cpu]);
1629 }
1630 
1631 static void
wg_deliver_out(struct wg_peer * peer)1632 wg_deliver_out(struct wg_peer *peer)
1633 {
1634 	struct wg_endpoint	 endpoint;
1635 	struct wg_softc		*sc = peer->p_sc;
1636 	struct wg_packet	*pkt;
1637 	struct mbuf		*m;
1638 	int			 rc, len;
1639 
1640 	wg_peer_get_endpoint(peer, &endpoint);
1641 
1642 	while ((pkt = wg_queue_dequeue_serial(&peer->p_encrypt_serial)) != NULL) {
1643 		if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1644 			goto error;
1645 
1646 		m = pkt->p_mbuf;
1647 		pkt->p_mbuf = NULL;
1648 
1649 		len = m->m_pkthdr.len;
1650 
1651 		wg_timers_event_any_authenticated_packet_traversal(peer);
1652 		wg_timers_event_any_authenticated_packet_sent(peer);
1653 		rc = wg_send(sc, &endpoint, m);
1654 		if (rc == 0) {
1655 			if (len > (sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN))
1656 				wg_timers_event_data_sent(peer);
1657 			counter_u64_add(peer->p_tx_bytes, len);
1658 		} else if (rc == EADDRNOTAVAIL) {
1659 			wg_peer_clear_src(peer);
1660 			wg_peer_get_endpoint(peer, &endpoint);
1661 			goto error;
1662 		} else {
1663 			goto error;
1664 		}
1665 		wg_packet_free(pkt);
1666 		if (noise_keep_key_fresh_send(peer->p_remote))
1667 			wg_timers_event_want_initiation(peer);
1668 		continue;
1669 error:
1670 		if_inc_counter(sc->sc_ifp, IFCOUNTER_OERRORS, 1);
1671 		wg_packet_free(pkt);
1672 	}
1673 }
1674 
1675 #ifdef DEV_NETMAP
1676 /*
1677  * Hand a packet to the netmap RX ring, via netmap's
1678  * freebsd_generic_rx_handler().
1679  */
1680 static void
wg_deliver_netmap(if_t ifp,struct mbuf * m,int af)1681 wg_deliver_netmap(if_t ifp, struct mbuf *m, int af)
1682 {
1683 	struct ether_header *eh;
1684 
1685 	M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
1686 	if (__predict_false(m == NULL)) {
1687 		if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
1688 		return;
1689 	}
1690 
1691 	eh = mtod(m, struct ether_header *);
1692 	eh->ether_type = af == AF_INET ?
1693 	    htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
1694 	memcpy(eh->ether_shost, "\x02\x02\x02\x02\x02\x02", ETHER_ADDR_LEN);
1695 	memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
1696 	if_input(ifp, m);
1697 }
1698 #endif
1699 
1700 static void
wg_deliver_in(struct wg_peer * peer)1701 wg_deliver_in(struct wg_peer *peer)
1702 {
1703 	struct wg_softc		*sc = peer->p_sc;
1704 	if_t			 ifp = sc->sc_ifp;
1705 	struct wg_packet	*pkt;
1706 	struct mbuf		*m;
1707 	struct epoch_tracker	 et;
1708 	int			 af;
1709 
1710 	while ((pkt = wg_queue_dequeue_serial(&peer->p_decrypt_serial)) != NULL) {
1711 		if (atomic_load_acq_int(&pkt->p_state) != WG_PACKET_CRYPTED)
1712 			goto error;
1713 
1714 		m = pkt->p_mbuf;
1715 		if (noise_keypair_nonce_check(pkt->p_keypair, pkt->p_nonce) != 0)
1716 			goto error;
1717 
1718 		if (noise_keypair_received_with(pkt->p_keypair) == ECONNRESET)
1719 			wg_timers_event_handshake_complete(peer);
1720 
1721 		wg_timers_event_any_authenticated_packet_received(peer);
1722 		wg_timers_event_any_authenticated_packet_traversal(peer);
1723 		wg_peer_set_endpoint(peer, &pkt->p_endpoint);
1724 
1725 		counter_u64_add(peer->p_rx_bytes, m->m_pkthdr.len +
1726 		    sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1727 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1);
1728 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IBYTES, m->m_pkthdr.len +
1729 		    sizeof(struct wg_pkt_data) + NOISE_AUTHTAG_LEN);
1730 
1731 		if (m->m_pkthdr.len == 0)
1732 			goto done;
1733 
1734 		af = pkt->p_af;
1735 		MPASS(af == AF_INET || af == AF_INET6);
1736 		pkt->p_mbuf = NULL;
1737 
1738 		m->m_pkthdr.rcvif = ifp;
1739 
1740 		NET_EPOCH_ENTER(et);
1741 		BPF_MTAP2_AF(ifp, m, af);
1742 
1743 		CURVNET_SET(if_getvnet(ifp));
1744 		M_SETFIB(m, if_getfib(ifp));
1745 #ifdef DEV_NETMAP
1746 		if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
1747 			wg_deliver_netmap(ifp, m, af);
1748 		else
1749 #endif
1750 		if (af == AF_INET)
1751 			netisr_dispatch(NETISR_IP, m);
1752 		else if (af == AF_INET6)
1753 			netisr_dispatch(NETISR_IPV6, m);
1754 		CURVNET_RESTORE();
1755 		NET_EPOCH_EXIT(et);
1756 
1757 		wg_timers_event_data_received(peer);
1758 
1759 done:
1760 		if (noise_keep_key_fresh_recv(peer->p_remote))
1761 			wg_timers_event_want_initiation(peer);
1762 		wg_packet_free(pkt);
1763 		continue;
1764 error:
1765 		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
1766 		wg_packet_free(pkt);
1767 	}
1768 }
1769 
1770 static struct wg_packet *
wg_packet_alloc(struct mbuf * m)1771 wg_packet_alloc(struct mbuf *m)
1772 {
1773 	struct wg_packet *pkt;
1774 
1775 	if ((pkt = uma_zalloc(wg_packet_zone, M_NOWAIT | M_ZERO)) == NULL)
1776 		return (NULL);
1777 	pkt->p_mbuf = m;
1778 	return (pkt);
1779 }
1780 
1781 static void
wg_packet_free(struct wg_packet * pkt)1782 wg_packet_free(struct wg_packet *pkt)
1783 {
1784 	if (pkt->p_keypair != NULL)
1785 		noise_keypair_put(pkt->p_keypair);
1786 	if (pkt->p_mbuf != NULL)
1787 		m_freem(pkt->p_mbuf);
1788 	uma_zfree(wg_packet_zone, pkt);
1789 }
1790 
1791 static void
wg_queue_init(struct wg_queue * queue,const char * name)1792 wg_queue_init(struct wg_queue *queue, const char *name)
1793 {
1794 	mtx_init(&queue->q_mtx, name, NULL, MTX_DEF);
1795 	STAILQ_INIT(&queue->q_queue);
1796 	queue->q_len = 0;
1797 }
1798 
1799 static void
wg_queue_deinit(struct wg_queue * queue)1800 wg_queue_deinit(struct wg_queue *queue)
1801 {
1802 	wg_queue_purge(queue);
1803 	mtx_destroy(&queue->q_mtx);
1804 }
1805 
1806 static size_t
wg_queue_len(struct wg_queue * queue)1807 wg_queue_len(struct wg_queue *queue)
1808 {
1809 	return (queue->q_len);
1810 }
1811 
1812 static int
wg_queue_enqueue_handshake(struct wg_queue * hs,struct wg_packet * pkt)1813 wg_queue_enqueue_handshake(struct wg_queue *hs, struct wg_packet *pkt)
1814 {
1815 	int ret = 0;
1816 	mtx_lock(&hs->q_mtx);
1817 	if (hs->q_len < MAX_QUEUED_HANDSHAKES) {
1818 		STAILQ_INSERT_TAIL(&hs->q_queue, pkt, p_parallel);
1819 		hs->q_len++;
1820 	} else {
1821 		ret = ENOBUFS;
1822 	}
1823 	mtx_unlock(&hs->q_mtx);
1824 	if (ret != 0)
1825 		wg_packet_free(pkt);
1826 	return (ret);
1827 }
1828 
1829 static struct wg_packet *
wg_queue_dequeue_handshake(struct wg_queue * hs)1830 wg_queue_dequeue_handshake(struct wg_queue *hs)
1831 {
1832 	struct wg_packet *pkt;
1833 	mtx_lock(&hs->q_mtx);
1834 	if ((pkt = STAILQ_FIRST(&hs->q_queue)) != NULL) {
1835 		STAILQ_REMOVE_HEAD(&hs->q_queue, p_parallel);
1836 		hs->q_len--;
1837 	}
1838 	mtx_unlock(&hs->q_mtx);
1839 	return (pkt);
1840 }
1841 
1842 static void
wg_queue_push_staged(struct wg_queue * staged,struct wg_packet * pkt)1843 wg_queue_push_staged(struct wg_queue *staged, struct wg_packet *pkt)
1844 {
1845 	struct wg_packet *old = NULL;
1846 
1847 	mtx_lock(&staged->q_mtx);
1848 	if (staged->q_len >= MAX_STAGED_PKT) {
1849 		old = STAILQ_FIRST(&staged->q_queue);
1850 		STAILQ_REMOVE_HEAD(&staged->q_queue, p_parallel);
1851 		staged->q_len--;
1852 	}
1853 	STAILQ_INSERT_TAIL(&staged->q_queue, pkt, p_parallel);
1854 	staged->q_len++;
1855 	mtx_unlock(&staged->q_mtx);
1856 
1857 	if (old != NULL)
1858 		wg_packet_free(old);
1859 }
1860 
1861 static void
wg_queue_enlist_staged(struct wg_queue * staged,struct wg_packet_list * list)1862 wg_queue_enlist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1863 {
1864 	struct wg_packet *pkt, *tpkt;
1865 	STAILQ_FOREACH_SAFE(pkt, list, p_parallel, tpkt)
1866 		wg_queue_push_staged(staged, pkt);
1867 }
1868 
1869 static void
wg_queue_delist_staged(struct wg_queue * staged,struct wg_packet_list * list)1870 wg_queue_delist_staged(struct wg_queue *staged, struct wg_packet_list *list)
1871 {
1872 	STAILQ_INIT(list);
1873 	mtx_lock(&staged->q_mtx);
1874 	STAILQ_CONCAT(list, &staged->q_queue);
1875 	staged->q_len = 0;
1876 	mtx_unlock(&staged->q_mtx);
1877 }
1878 
1879 static void
wg_queue_purge(struct wg_queue * staged)1880 wg_queue_purge(struct wg_queue *staged)
1881 {
1882 	struct wg_packet_list list;
1883 	struct wg_packet *pkt, *tpkt;
1884 	wg_queue_delist_staged(staged, &list);
1885 	STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt)
1886 		wg_packet_free(pkt);
1887 }
1888 
1889 static int
wg_queue_both(struct wg_queue * parallel,struct wg_queue * serial,struct wg_packet * pkt)1890 wg_queue_both(struct wg_queue *parallel, struct wg_queue *serial, struct wg_packet *pkt)
1891 {
1892 	pkt->p_state = WG_PACKET_UNCRYPTED;
1893 
1894 	mtx_lock(&serial->q_mtx);
1895 	if (serial->q_len < MAX_QUEUED_PKT) {
1896 		serial->q_len++;
1897 		STAILQ_INSERT_TAIL(&serial->q_queue, pkt, p_serial);
1898 	} else {
1899 		mtx_unlock(&serial->q_mtx);
1900 		wg_packet_free(pkt);
1901 		return (ENOBUFS);
1902 	}
1903 	mtx_unlock(&serial->q_mtx);
1904 
1905 	mtx_lock(&parallel->q_mtx);
1906 	if (parallel->q_len < MAX_QUEUED_PKT) {
1907 		parallel->q_len++;
1908 		STAILQ_INSERT_TAIL(&parallel->q_queue, pkt, p_parallel);
1909 	} else {
1910 		mtx_unlock(&parallel->q_mtx);
1911 		pkt->p_state = WG_PACKET_DEAD;
1912 		return (ENOBUFS);
1913 	}
1914 	mtx_unlock(&parallel->q_mtx);
1915 
1916 	return (0);
1917 }
1918 
1919 static struct wg_packet *
wg_queue_dequeue_serial(struct wg_queue * serial)1920 wg_queue_dequeue_serial(struct wg_queue *serial)
1921 {
1922 	struct wg_packet *pkt = NULL;
1923 	mtx_lock(&serial->q_mtx);
1924 	if (serial->q_len > 0 && STAILQ_FIRST(&serial->q_queue)->p_state != WG_PACKET_UNCRYPTED) {
1925 		serial->q_len--;
1926 		pkt = STAILQ_FIRST(&serial->q_queue);
1927 		STAILQ_REMOVE_HEAD(&serial->q_queue, p_serial);
1928 	}
1929 	mtx_unlock(&serial->q_mtx);
1930 	return (pkt);
1931 }
1932 
1933 static struct wg_packet *
wg_queue_dequeue_parallel(struct wg_queue * parallel)1934 wg_queue_dequeue_parallel(struct wg_queue *parallel)
1935 {
1936 	struct wg_packet *pkt = NULL;
1937 	mtx_lock(&parallel->q_mtx);
1938 	if (parallel->q_len > 0) {
1939 		parallel->q_len--;
1940 		pkt = STAILQ_FIRST(&parallel->q_queue);
1941 		STAILQ_REMOVE_HEAD(&parallel->q_queue, p_parallel);
1942 	}
1943 	mtx_unlock(&parallel->q_mtx);
1944 	return (pkt);
1945 }
1946 
1947 static bool
wg_input(struct mbuf * m,int offset,struct inpcb * inpcb,const struct sockaddr * sa,void * _sc)1948 wg_input(struct mbuf *m, int offset, struct inpcb *inpcb,
1949     const struct sockaddr *sa, void *_sc)
1950 {
1951 #ifdef INET
1952 	const struct sockaddr_in	*sin;
1953 #endif
1954 #ifdef INET6
1955 	const struct sockaddr_in6	*sin6;
1956 #endif
1957 	struct noise_remote		*remote;
1958 	struct wg_pkt_data		*data;
1959 	struct wg_packet		*pkt;
1960 	struct wg_peer			*peer;
1961 	struct wg_softc			*sc = _sc;
1962 	struct mbuf			*defragged;
1963 
1964 	defragged = m_defrag(m, M_NOWAIT);
1965 	if (defragged)
1966 		m = defragged;
1967 	m = m_unshare(m, M_NOWAIT);
1968 	if (!m) {
1969 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1970 		return true;
1971 	}
1972 
1973 	/* Caller provided us with `sa`, no need for this header. */
1974 	m_adj(m, offset + sizeof(struct udphdr));
1975 
1976 	/* Pullup enough to read packet type */
1977 	if ((m = m_pullup(m, sizeof(uint32_t))) == NULL) {
1978 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1979 		return true;
1980 	}
1981 
1982 	if ((pkt = wg_packet_alloc(m)) == NULL) {
1983 		if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
1984 		m_freem(m);
1985 		return true;
1986 	}
1987 
1988 	/* Save send/recv address and port for later. */
1989 	switch (sa->sa_family) {
1990 #ifdef INET
1991 	case AF_INET:
1992 		sin = (const struct sockaddr_in *)sa;
1993 		pkt->p_endpoint.e_remote.r_sin = sin[0];
1994 		pkt->p_endpoint.e_local.l_in = sin[1].sin_addr;
1995 		break;
1996 #endif
1997 #ifdef INET6
1998 	case AF_INET6:
1999 		sin6 = (const struct sockaddr_in6 *)sa;
2000 		pkt->p_endpoint.e_remote.r_sin6 = sin6[0];
2001 		pkt->p_endpoint.e_local.l_in6 = sin6[1].sin6_addr;
2002 		break;
2003 #endif
2004 	default:
2005 		goto error;
2006 	}
2007 
2008 	if ((m->m_pkthdr.len == sizeof(struct wg_pkt_initiation) &&
2009 		*mtod(m, uint32_t *) == WG_PKT_INITIATION) ||
2010 	    (m->m_pkthdr.len == sizeof(struct wg_pkt_response) &&
2011 		*mtod(m, uint32_t *) == WG_PKT_RESPONSE) ||
2012 	    (m->m_pkthdr.len == sizeof(struct wg_pkt_cookie) &&
2013 		*mtod(m, uint32_t *) == WG_PKT_COOKIE)) {
2014 
2015 		if (wg_queue_enqueue_handshake(&sc->sc_handshake_queue, pkt) != 0) {
2016 			if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2017 			DPRINTF(sc, "Dropping handshake packet\n");
2018 		}
2019 		GROUPTASK_ENQUEUE(&sc->sc_handshake);
2020 	} else if (m->m_pkthdr.len >= sizeof(struct wg_pkt_data) +
2021 	    NOISE_AUTHTAG_LEN && *mtod(m, uint32_t *) == WG_PKT_DATA) {
2022 
2023 		/* Pullup whole header to read r_idx below. */
2024 		if ((pkt->p_mbuf = m_pullup(m, sizeof(struct wg_pkt_data))) == NULL)
2025 			goto error;
2026 
2027 		data = mtod(pkt->p_mbuf, struct wg_pkt_data *);
2028 		if ((pkt->p_keypair = noise_keypair_lookup(sc->sc_local, data->r_idx)) == NULL)
2029 			goto error;
2030 
2031 		remote = noise_keypair_remote(pkt->p_keypair);
2032 		peer = noise_remote_arg(remote);
2033 		if (wg_queue_both(&sc->sc_decrypt_parallel, &peer->p_decrypt_serial, pkt) != 0)
2034 			if_inc_counter(sc->sc_ifp, IFCOUNTER_IQDROPS, 1);
2035 		wg_decrypt_dispatch(sc);
2036 		noise_remote_put(remote);
2037 	} else {
2038 		goto error;
2039 	}
2040 	return true;
2041 error:
2042 	if_inc_counter(sc->sc_ifp, IFCOUNTER_IERRORS, 1);
2043 	wg_packet_free(pkt);
2044 	return true;
2045 }
2046 
2047 static void
wg_peer_send_staged(struct wg_peer * peer)2048 wg_peer_send_staged(struct wg_peer *peer)
2049 {
2050 	struct wg_packet_list	 list;
2051 	struct noise_keypair	*keypair;
2052 	struct wg_packet	*pkt, *tpkt;
2053 	struct wg_softc		*sc = peer->p_sc;
2054 
2055 	wg_queue_delist_staged(&peer->p_stage_queue, &list);
2056 
2057 	if (STAILQ_EMPTY(&list))
2058 		return;
2059 
2060 	if ((keypair = noise_keypair_current(peer->p_remote)) == NULL)
2061 		goto error;
2062 
2063 	STAILQ_FOREACH(pkt, &list, p_parallel) {
2064 		if (noise_keypair_nonce_next(keypair, &pkt->p_nonce) != 0)
2065 			goto error_keypair;
2066 	}
2067 	STAILQ_FOREACH_SAFE(pkt, &list, p_parallel, tpkt) {
2068 		pkt->p_keypair = noise_keypair_ref(keypair);
2069 		if (wg_queue_both(&sc->sc_encrypt_parallel, &peer->p_encrypt_serial, pkt) != 0)
2070 			if_inc_counter(sc->sc_ifp, IFCOUNTER_OQDROPS, 1);
2071 	}
2072 	wg_encrypt_dispatch(sc);
2073 	noise_keypair_put(keypair);
2074 	return;
2075 
2076 error_keypair:
2077 	noise_keypair_put(keypair);
2078 error:
2079 	wg_queue_enlist_staged(&peer->p_stage_queue, &list);
2080 	wg_timers_event_want_initiation(peer);
2081 }
2082 
2083 static inline void
xmit_err(if_t ifp,struct mbuf * m,struct wg_packet * pkt,sa_family_t af)2084 xmit_err(if_t ifp, struct mbuf *m, struct wg_packet *pkt, sa_family_t af)
2085 {
2086 	if_inc_counter(ifp, IFCOUNTER_OERRORS, 1);
2087 	switch (af) {
2088 #ifdef INET
2089 	case AF_INET:
2090 		icmp_error(m, ICMP_UNREACH, ICMP_UNREACH_HOST, 0, 0);
2091 		if (pkt)
2092 			pkt->p_mbuf = NULL;
2093 		m = NULL;
2094 		break;
2095 #endif
2096 #ifdef INET6
2097 	case AF_INET6:
2098 		icmp6_error(m, ICMP6_DST_UNREACH, 0, 0);
2099 		if (pkt)
2100 			pkt->p_mbuf = NULL;
2101 		m = NULL;
2102 		break;
2103 #endif
2104 	}
2105 	if (pkt)
2106 		wg_packet_free(pkt);
2107 	else if (m)
2108 		m_freem(m);
2109 }
2110 
2111 static int
wg_xmit(if_t ifp,struct mbuf * m,sa_family_t af,uint32_t mtu)2112 wg_xmit(if_t ifp, struct mbuf *m, sa_family_t af, uint32_t mtu)
2113 {
2114 	struct wg_packet	*pkt = NULL;
2115 	struct wg_softc		*sc = if_getsoftc(ifp);
2116 	struct wg_peer		*peer;
2117 	int			 rc = 0;
2118 	sa_family_t		 peer_af;
2119 
2120 	/* Work around lifetime issue in the ipv6 mld code. */
2121 	if (__predict_false((if_getflags(ifp) & IFF_DYING) || !sc)) {
2122 		rc = ENXIO;
2123 		goto err_xmit;
2124 	}
2125 
2126 	if ((pkt = wg_packet_alloc(m)) == NULL) {
2127 		rc = ENOBUFS;
2128 		goto err_xmit;
2129 	}
2130 	pkt->p_mtu = mtu;
2131 	pkt->p_af = af;
2132 
2133 	if (af == AF_INET) {
2134 		peer = wg_aip_lookup(sc, AF_INET, &mtod(m, struct ip *)->ip_dst);
2135 	} else if (af == AF_INET6) {
2136 		peer = wg_aip_lookup(sc, AF_INET6, &mtod(m, struct ip6_hdr *)->ip6_dst);
2137 	} else {
2138 		rc = EAFNOSUPPORT;
2139 		goto err_xmit;
2140 	}
2141 
2142 	BPF_MTAP2_AF(ifp, m, pkt->p_af);
2143 
2144 	if (__predict_false(peer == NULL)) {
2145 		rc = ENETUNREACH;
2146 		goto err_xmit;
2147 	}
2148 
2149 	if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP, MAX_LOOPS))) {
2150 		DPRINTF(sc, "Packet looped");
2151 		rc = ELOOP;
2152 		goto err_peer;
2153 	}
2154 
2155 	peer_af = peer->p_endpoint.e_remote.r_sa.sa_family;
2156 	if (__predict_false(peer_af != AF_INET && peer_af != AF_INET6)) {
2157 		DPRINTF(sc, "No valid endpoint has been configured or "
2158 			    "discovered for peer %" PRIu64 "\n", peer->p_id);
2159 		rc = EHOSTUNREACH;
2160 		goto err_peer;
2161 	}
2162 
2163 	wg_queue_push_staged(&peer->p_stage_queue, pkt);
2164 	wg_peer_send_staged(peer);
2165 	noise_remote_put(peer->p_remote);
2166 	return (0);
2167 
2168 err_peer:
2169 	noise_remote_put(peer->p_remote);
2170 err_xmit:
2171 	xmit_err(ifp, m, pkt, af);
2172 	return (rc);
2173 }
2174 
2175 static inline int
determine_af_and_pullup(struct mbuf ** m,sa_family_t * af)2176 determine_af_and_pullup(struct mbuf **m, sa_family_t *af)
2177 {
2178 	u_char ipv;
2179 	if ((*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2180 		*m = m_pullup(*m, sizeof(struct ip6_hdr));
2181 	else if ((*m)->m_pkthdr.len >= sizeof(struct ip))
2182 		*m = m_pullup(*m, sizeof(struct ip));
2183 	else
2184 		return (EAFNOSUPPORT);
2185 	if (*m == NULL)
2186 		return (ENOBUFS);
2187 	ipv = mtod(*m, struct ip *)->ip_v;
2188 	if (ipv == 4)
2189 		*af = AF_INET;
2190 	else if (ipv == 6 && (*m)->m_pkthdr.len >= sizeof(struct ip6_hdr))
2191 		*af = AF_INET6;
2192 	else
2193 		return (EAFNOSUPPORT);
2194 	return (0);
2195 }
2196 
2197 static int
determine_ethertype_and_pullup(struct mbuf ** m,int * etp)2198 determine_ethertype_and_pullup(struct mbuf **m, int *etp)
2199 {
2200 	struct ether_header *eh;
2201 
2202 	*m = m_pullup(*m, sizeof(struct ether_header));
2203 	if (__predict_false(*m == NULL))
2204 		return (ENOBUFS);
2205 	eh = mtod(*m, struct ether_header *);
2206 	*etp = ntohs(eh->ether_type);
2207 	if (*etp != ETHERTYPE_IP && *etp != ETHERTYPE_IPV6)
2208 		return (EAFNOSUPPORT);
2209 	return (0);
2210 }
2211 
2212 /*
2213  * This should only be invoked by netmap, via nm_os_generic_xmit_frame(), to
2214  * transmit packets from the netmap TX ring.
2215  */
2216 static int
wg_transmit(if_t ifp,struct mbuf * m)2217 wg_transmit(if_t ifp, struct mbuf *m)
2218 {
2219 	sa_family_t af;
2220 	int et, ret;
2221 	struct mbuf *defragged;
2222 
2223 	KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2224 	    ("%s: ifp %p is not in netmap mode", __func__, ifp));
2225 
2226 	defragged = m_defrag(m, M_NOWAIT);
2227 	if (defragged)
2228 		m = defragged;
2229 	m = m_unshare(m, M_NOWAIT);
2230 	if (!m) {
2231 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2232 		return (ENOBUFS);
2233 	}
2234 
2235 	ret = determine_ethertype_and_pullup(&m, &et);
2236 	if (ret) {
2237 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2238 		return (ret);
2239 	}
2240 	m_adj(m, sizeof(struct ether_header));
2241 
2242 	ret = determine_af_and_pullup(&m, &af);
2243 	if (ret) {
2244 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2245 		return (ret);
2246 	}
2247 
2248 	/*
2249 	 * netmap only gets to see transient errors, since it handles errors by
2250 	 * refusing to advance the transmit ring and retrying later.
2251 	 */
2252 	ret = wg_xmit(ifp, m, af, if_getmtu(ifp));
2253 	if (ret == ENOBUFS)
2254 		return (ret);
2255 	return (0);
2256 }
2257 
2258 #ifdef DEV_NETMAP
2259 /*
2260  * This should only be invoked by netmap, via nm_os_send_up(), to process
2261  * packets from the host TX ring.
2262  */
2263 static void
wg_if_input(if_t ifp,struct mbuf * m)2264 wg_if_input(if_t ifp, struct mbuf *m)
2265 {
2266 	int et;
2267 
2268 	KASSERT((if_getcapenable(ifp) & IFCAP_NETMAP) != 0,
2269 	    ("%s: ifp %p is not in netmap mode", __func__, ifp));
2270 
2271 	if (determine_ethertype_and_pullup(&m, &et) != 0) {
2272 		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2273 		m_freem(m);
2274 		return;
2275 	}
2276 	CURVNET_SET(if_getvnet(ifp));
2277 	switch (et) {
2278 	case ETHERTYPE_IP:
2279 		m_adj(m, sizeof(struct ether_header));
2280 		netisr_dispatch(NETISR_IP, m);
2281 		break;
2282 	case ETHERTYPE_IPV6:
2283 		m_adj(m, sizeof(struct ether_header));
2284 		netisr_dispatch(NETISR_IPV6, m);
2285 		break;
2286 	default:
2287 		__assert_unreachable();
2288 	}
2289 	CURVNET_RESTORE();
2290 }
2291 
2292 /*
2293  * Deliver a packet to the host RX ring.  Because the interface is in netmap
2294  * mode, the if_transmit() call should pass the packet to netmap_transmit().
2295  */
2296 static int
wg_xmit_netmap(if_t ifp,struct mbuf * m,int af)2297 wg_xmit_netmap(if_t ifp, struct mbuf *m, int af)
2298 {
2299 	struct ether_header *eh;
2300 
2301 	if (__predict_false(if_tunnel_check_nesting(ifp, m, MTAG_WGLOOP,
2302 	    MAX_LOOPS))) {
2303 		printf("%s:%d\n", __func__, __LINE__);
2304 		if_inc_counter(ifp, IFCOUNTER_IERRORS, 1);
2305 		m_freem(m);
2306 		return (ELOOP);
2307 	}
2308 
2309 	M_PREPEND(m, ETHER_HDR_LEN, M_NOWAIT);
2310 	if (__predict_false(m == NULL)) {
2311 		if_inc_counter(ifp, IFCOUNTER_IQDROPS, 1);
2312 		return (ENOBUFS);
2313 	}
2314 
2315 	eh = mtod(m, struct ether_header *);
2316 	eh->ether_type = af == AF_INET ?
2317 	    htons(ETHERTYPE_IP) : htons(ETHERTYPE_IPV6);
2318 	memcpy(eh->ether_shost, "\x06\x06\x06\x06\x06\x06", ETHER_ADDR_LEN);
2319 	memcpy(eh->ether_dhost, "\xff\xff\xff\xff\xff\xff", ETHER_ADDR_LEN);
2320 	return (if_transmit(ifp, m));
2321 }
2322 #endif /* DEV_NETMAP */
2323 
2324 static int
wg_output(if_t ifp,struct mbuf * m,const struct sockaddr * dst,struct route * ro)2325 wg_output(if_t ifp, struct mbuf *m, const struct sockaddr *dst, struct route *ro)
2326 {
2327 	sa_family_t parsed_af;
2328 	uint32_t af, mtu;
2329 	int ret;
2330 	struct mbuf *defragged;
2331 
2332 	/* BPF writes need to be handled specially. */
2333 	if (dst->sa_family == AF_UNSPEC || dst->sa_family == pseudo_AF_HDRCMPLT)
2334 		memcpy(&af, dst->sa_data, sizeof(af));
2335 	else
2336 		af = dst->sa_family;
2337 	if (af == AF_UNSPEC) {
2338 		xmit_err(ifp, m, NULL, af);
2339 		return (EAFNOSUPPORT);
2340 	}
2341 
2342 #ifdef DEV_NETMAP
2343 	if ((if_getcapenable(ifp) & IFCAP_NETMAP) != 0)
2344 		return (wg_xmit_netmap(ifp, m, af));
2345 #endif
2346 
2347 	defragged = m_defrag(m, M_NOWAIT);
2348 	if (defragged)
2349 		m = defragged;
2350 	m = m_unshare(m, M_NOWAIT);
2351 	if (!m) {
2352 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2353 		return (ENOBUFS);
2354 	}
2355 
2356 	ret = determine_af_and_pullup(&m, &parsed_af);
2357 	if (ret) {
2358 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2359 		return (ret);
2360 	}
2361 	if (parsed_af != af) {
2362 		xmit_err(ifp, m, NULL, AF_UNSPEC);
2363 		return (EAFNOSUPPORT);
2364 	}
2365 	mtu = (ro != NULL && ro->ro_mtu > 0) ? ro->ro_mtu : if_getmtu(ifp);
2366 	return (wg_xmit(ifp, m, parsed_af, mtu));
2367 }
2368 
2369 static int
wg_peer_add(struct wg_softc * sc,const nvlist_t * nvl)2370 wg_peer_add(struct wg_softc *sc, const nvlist_t *nvl)
2371 {
2372 	uint8_t			 public[WG_KEY_SIZE];
2373 	const void *pub_key, *preshared_key = NULL;
2374 	const struct sockaddr *endpoint;
2375 	int err;
2376 	size_t size;
2377 	struct noise_remote *remote;
2378 	struct wg_peer *peer = NULL;
2379 	bool need_insert = false;
2380 
2381 	sx_assert(&sc->sc_lock, SX_XLOCKED);
2382 
2383 	if (!nvlist_exists_binary(nvl, "public-key")) {
2384 		return (EINVAL);
2385 	}
2386 	pub_key = nvlist_get_binary(nvl, "public-key", &size);
2387 	if (size != WG_KEY_SIZE) {
2388 		return (EINVAL);
2389 	}
2390 	if (noise_local_keys(sc->sc_local, public, NULL) == 0 &&
2391 	    bcmp(public, pub_key, WG_KEY_SIZE) == 0) {
2392 		return (0); // Silently ignored; not actually a failure.
2393 	}
2394 	if ((remote = noise_remote_lookup(sc->sc_local, pub_key)) != NULL)
2395 		peer = noise_remote_arg(remote);
2396 	if (nvlist_exists_bool(nvl, "remove") &&
2397 		nvlist_get_bool(nvl, "remove")) {
2398 		if (remote != NULL) {
2399 			wg_peer_destroy(peer);
2400 			noise_remote_put(remote);
2401 		}
2402 		return (0);
2403 	}
2404 	if (nvlist_exists_bool(nvl, "replace-allowedips") &&
2405 		nvlist_get_bool(nvl, "replace-allowedips") &&
2406 	    peer != NULL) {
2407 
2408 		wg_aip_remove_all(sc, peer);
2409 	}
2410 	if (peer == NULL) {
2411 		peer = wg_peer_alloc(sc, pub_key);
2412 		need_insert = true;
2413 	}
2414 	if (nvlist_exists_binary(nvl, "endpoint")) {
2415 		endpoint = nvlist_get_binary(nvl, "endpoint", &size);
2416 		if (size > sizeof(peer->p_endpoint.e_remote)) {
2417 			err = EINVAL;
2418 			goto out;
2419 		}
2420 		memcpy(&peer->p_endpoint.e_remote, endpoint, size);
2421 	}
2422 	if (nvlist_exists_binary(nvl, "preshared-key")) {
2423 		preshared_key = nvlist_get_binary(nvl, "preshared-key", &size);
2424 		if (size != WG_KEY_SIZE) {
2425 			err = EINVAL;
2426 			goto out;
2427 		}
2428 		noise_remote_set_psk(peer->p_remote, preshared_key);
2429 	}
2430 	if (nvlist_exists_number(nvl, "persistent-keepalive-interval")) {
2431 		uint64_t pki = nvlist_get_number(nvl, "persistent-keepalive-interval");
2432 		if (pki > UINT16_MAX) {
2433 			err = EINVAL;
2434 			goto out;
2435 		}
2436 		wg_timers_set_persistent_keepalive(peer, pki);
2437 	}
2438 	if (nvlist_exists_nvlist_array(nvl, "allowed-ips")) {
2439 		const void *addr;
2440 		uint64_t cidr;
2441 		const nvlist_t * const * aipl;
2442 		size_t allowedip_count;
2443 
2444 		aipl = nvlist_get_nvlist_array(nvl, "allowed-ips", &allowedip_count);
2445 		for (size_t idx = 0; idx < allowedip_count; idx++) {
2446 			if (!nvlist_exists_number(aipl[idx], "cidr"))
2447 				continue;
2448 			cidr = nvlist_get_number(aipl[idx], "cidr");
2449 			if (nvlist_exists_binary(aipl[idx], "ipv4")) {
2450 				addr = nvlist_get_binary(aipl[idx], "ipv4", &size);
2451 				if (addr == NULL || cidr > 32 || size != sizeof(struct in_addr)) {
2452 					err = EINVAL;
2453 					goto out;
2454 				}
2455 				if ((err = wg_aip_add(sc, peer, AF_INET, addr, cidr)) != 0)
2456 					goto out;
2457 			} else if (nvlist_exists_binary(aipl[idx], "ipv6")) {
2458 				addr = nvlist_get_binary(aipl[idx], "ipv6", &size);
2459 				if (addr == NULL || cidr > 128 || size != sizeof(struct in6_addr)) {
2460 					err = EINVAL;
2461 					goto out;
2462 				}
2463 				if ((err = wg_aip_add(sc, peer, AF_INET6, addr, cidr)) != 0)
2464 					goto out;
2465 			} else {
2466 				continue;
2467 			}
2468 		}
2469 	}
2470 	if (need_insert) {
2471 		if ((err = noise_remote_enable(peer->p_remote)) != 0)
2472 			goto out;
2473 		TAILQ_INSERT_TAIL(&sc->sc_peers, peer, p_entry);
2474 		sc->sc_peers_num++;
2475 		if (if_getlinkstate(sc->sc_ifp) == LINK_STATE_UP)
2476 			wg_timers_enable(peer);
2477 	}
2478 	if (remote != NULL)
2479 		noise_remote_put(remote);
2480 	return (0);
2481 out:
2482 	if (need_insert) /* If we fail, only destroy if it was new. */
2483 		wg_peer_destroy(peer);
2484 	if (remote != NULL)
2485 		noise_remote_put(remote);
2486 	return (err);
2487 }
2488 
2489 static int
wgc_set(struct wg_softc * sc,struct wg_data_io * wgd)2490 wgc_set(struct wg_softc *sc, struct wg_data_io *wgd)
2491 {
2492 	uint8_t public[WG_KEY_SIZE], private[WG_KEY_SIZE];
2493 	if_t ifp;
2494 	void *nvlpacked;
2495 	nvlist_t *nvl;
2496 	ssize_t size;
2497 	int err;
2498 
2499 	ifp = sc->sc_ifp;
2500 	if (wgd->wgd_size == 0 || wgd->wgd_data == NULL)
2501 		return (EFAULT);
2502 
2503 	/* Can nvlists be streamed in? It's not nice to impose arbitrary limits like that but
2504 	 * there needs to be _some_ limitation. */
2505 	if (wgd->wgd_size >= UINT32_MAX / 2)
2506 		return (E2BIG);
2507 
2508 	nvlpacked = malloc(wgd->wgd_size, M_TEMP, M_WAITOK | M_ZERO);
2509 
2510 	err = copyin(wgd->wgd_data, nvlpacked, wgd->wgd_size);
2511 	if (err)
2512 		goto out;
2513 	nvl = nvlist_unpack(nvlpacked, wgd->wgd_size, 0);
2514 	if (nvl == NULL) {
2515 		err = EBADMSG;
2516 		goto out;
2517 	}
2518 	sx_xlock(&sc->sc_lock);
2519 	if (nvlist_exists_bool(nvl, "replace-peers") &&
2520 		nvlist_get_bool(nvl, "replace-peers"))
2521 		wg_peer_destroy_all(sc);
2522 	if (nvlist_exists_number(nvl, "listen-port")) {
2523 		uint64_t new_port = nvlist_get_number(nvl, "listen-port");
2524 		if (new_port > UINT16_MAX) {
2525 			err = EINVAL;
2526 			goto out_locked;
2527 		}
2528 		if (new_port != sc->sc_socket.so_port) {
2529 			if ((if_getdrvflags(ifp) & IFF_DRV_RUNNING) != 0) {
2530 				if ((err = wg_socket_init(sc, new_port)) != 0)
2531 					goto out_locked;
2532 			} else
2533 				sc->sc_socket.so_port = new_port;
2534 		}
2535 	}
2536 	if (nvlist_exists_binary(nvl, "private-key")) {
2537 		const void *key = nvlist_get_binary(nvl, "private-key", &size);
2538 		if (size != WG_KEY_SIZE) {
2539 			err = EINVAL;
2540 			goto out_locked;
2541 		}
2542 
2543 		if (noise_local_keys(sc->sc_local, NULL, private) != 0 ||
2544 		    timingsafe_bcmp(private, key, WG_KEY_SIZE) != 0) {
2545 			struct wg_peer *peer;
2546 
2547 			if (curve25519_generate_public(public, key)) {
2548 				/* Peer conflict: remove conflicting peer. */
2549 				struct noise_remote *remote;
2550 				if ((remote = noise_remote_lookup(sc->sc_local,
2551 				    public)) != NULL) {
2552 					peer = noise_remote_arg(remote);
2553 					wg_peer_destroy(peer);
2554 					noise_remote_put(remote);
2555 				}
2556 			}
2557 
2558 			/*
2559 			 * Set the private key and invalidate all existing
2560 			 * handshakes.
2561 			 */
2562 			/* Note: we might be removing the private key. */
2563 			noise_local_private(sc->sc_local, key);
2564 			if (noise_local_keys(sc->sc_local, NULL, NULL) == 0)
2565 				cookie_checker_update(&sc->sc_cookie, public);
2566 			else
2567 				cookie_checker_update(&sc->sc_cookie, NULL);
2568 		}
2569 	}
2570 	if (nvlist_exists_number(nvl, "user-cookie")) {
2571 		uint64_t user_cookie = nvlist_get_number(nvl, "user-cookie");
2572 		if (user_cookie > UINT32_MAX) {
2573 			err = EINVAL;
2574 			goto out_locked;
2575 		}
2576 		err = wg_socket_set_cookie(sc, user_cookie);
2577 		if (err)
2578 			goto out_locked;
2579 	}
2580 	if (nvlist_exists_nvlist_array(nvl, "peers")) {
2581 		size_t peercount;
2582 		const nvlist_t * const*nvl_peers;
2583 
2584 		nvl_peers = nvlist_get_nvlist_array(nvl, "peers", &peercount);
2585 		for (int i = 0; i < peercount; i++) {
2586 			err = wg_peer_add(sc, nvl_peers[i]);
2587 			if (err != 0)
2588 				goto out_locked;
2589 		}
2590 	}
2591 
2592 out_locked:
2593 	sx_xunlock(&sc->sc_lock);
2594 	nvlist_destroy(nvl);
2595 out:
2596 	zfree(nvlpacked, M_TEMP);
2597 	return (err);
2598 }
2599 
2600 static int
wgc_get(struct wg_softc * sc,struct wg_data_io * wgd)2601 wgc_get(struct wg_softc *sc, struct wg_data_io *wgd)
2602 {
2603 	uint8_t public_key[WG_KEY_SIZE] = { 0 };
2604 	uint8_t private_key[WG_KEY_SIZE] = { 0 };
2605 	uint8_t preshared_key[NOISE_SYMMETRIC_KEY_LEN] = { 0 };
2606 	nvlist_t *nvl, *nvl_peer, *nvl_aip, **nvl_peers, **nvl_aips;
2607 	size_t size, peer_count, aip_count, i, j;
2608 	struct wg_timespec64 ts64;
2609 	struct wg_peer *peer;
2610 	struct wg_aip *aip;
2611 	void *packed;
2612 	int err = 0;
2613 
2614 	nvl = nvlist_create(0);
2615 	if (!nvl)
2616 		return (ENOMEM);
2617 
2618 	sx_slock(&sc->sc_lock);
2619 
2620 	if (sc->sc_socket.so_port != 0)
2621 		nvlist_add_number(nvl, "listen-port", sc->sc_socket.so_port);
2622 	if (sc->sc_socket.so_user_cookie != 0)
2623 		nvlist_add_number(nvl, "user-cookie", sc->sc_socket.so_user_cookie);
2624 	if (noise_local_keys(sc->sc_local, public_key, private_key) == 0) {
2625 		nvlist_add_binary(nvl, "public-key", public_key, WG_KEY_SIZE);
2626 		if (wgc_privileged(sc))
2627 			nvlist_add_binary(nvl, "private-key", private_key, WG_KEY_SIZE);
2628 		explicit_bzero(private_key, sizeof(private_key));
2629 	}
2630 	peer_count = sc->sc_peers_num;
2631 	if (peer_count) {
2632 		nvl_peers = mallocarray(peer_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2633 		i = 0;
2634 		TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2635 			if (i >= peer_count)
2636 				panic("peers changed from under us");
2637 
2638 			nvl_peers[i++] = nvl_peer = nvlist_create(0);
2639 			if (!nvl_peer) {
2640 				err = ENOMEM;
2641 				goto err_peer;
2642 			}
2643 
2644 			(void)noise_remote_keys(peer->p_remote, public_key, preshared_key);
2645 			nvlist_add_binary(nvl_peer, "public-key", public_key, sizeof(public_key));
2646 			if (wgc_privileged(sc))
2647 				nvlist_add_binary(nvl_peer, "preshared-key", preshared_key, sizeof(preshared_key));
2648 			explicit_bzero(preshared_key, sizeof(preshared_key));
2649 			if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET)
2650 				nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in));
2651 			else if (peer->p_endpoint.e_remote.r_sa.sa_family == AF_INET6)
2652 				nvlist_add_binary(nvl_peer, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr_in6));
2653 			wg_timers_get_last_handshake(peer, &ts64);
2654 			nvlist_add_binary(nvl_peer, "last-handshake-time", &ts64, sizeof(ts64));
2655 			nvlist_add_number(nvl_peer, "persistent-keepalive-interval", peer->p_persistent_keepalive_interval);
2656 			nvlist_add_number(nvl_peer, "rx-bytes", counter_u64_fetch(peer->p_rx_bytes));
2657 			nvlist_add_number(nvl_peer, "tx-bytes", counter_u64_fetch(peer->p_tx_bytes));
2658 
2659 			aip_count = peer->p_aips_num;
2660 			if (aip_count) {
2661 				nvl_aips = mallocarray(aip_count, sizeof(void *), M_NVLIST, M_WAITOK | M_ZERO);
2662 				j = 0;
2663 				LIST_FOREACH(aip, &peer->p_aips, a_entry) {
2664 					if (j >= aip_count)
2665 						panic("aips changed from under us");
2666 
2667 					nvl_aips[j++] = nvl_aip = nvlist_create(0);
2668 					if (!nvl_aip) {
2669 						err = ENOMEM;
2670 						goto err_aip;
2671 					}
2672 					if (aip->a_af == AF_INET) {
2673 						nvlist_add_binary(nvl_aip, "ipv4", &aip->a_addr.in, sizeof(aip->a_addr.in));
2674 						nvlist_add_number(nvl_aip, "cidr", bitcount32(aip->a_mask.ip));
2675 					}
2676 #ifdef INET6
2677 					else if (aip->a_af == AF_INET6) {
2678 						nvlist_add_binary(nvl_aip, "ipv6", &aip->a_addr.in6, sizeof(aip->a_addr.in6));
2679 						nvlist_add_number(nvl_aip, "cidr", in6_mask2len(&aip->a_mask.in6, NULL));
2680 					}
2681 #endif
2682 				}
2683 				nvlist_add_nvlist_array(nvl_peer, "allowed-ips", (const nvlist_t *const *)nvl_aips, aip_count);
2684 			err_aip:
2685 				for (j = 0; j < aip_count; ++j)
2686 					nvlist_destroy(nvl_aips[j]);
2687 				free(nvl_aips, M_NVLIST);
2688 				if (err)
2689 					goto err_peer;
2690 			}
2691 		}
2692 		nvlist_add_nvlist_array(nvl, "peers", (const nvlist_t * const *)nvl_peers, peer_count);
2693 	err_peer:
2694 		for (i = 0; i < peer_count; ++i)
2695 			nvlist_destroy(nvl_peers[i]);
2696 		free(nvl_peers, M_NVLIST);
2697 		if (err) {
2698 			sx_sunlock(&sc->sc_lock);
2699 			goto err;
2700 		}
2701 	}
2702 	sx_sunlock(&sc->sc_lock);
2703 	packed = nvlist_pack(nvl, &size);
2704 	if (!packed) {
2705 		err = ENOMEM;
2706 		goto err;
2707 	}
2708 	if (!wgd->wgd_size) {
2709 		wgd->wgd_size = size;
2710 		goto out;
2711 	}
2712 	if (wgd->wgd_size < size) {
2713 		err = ENOSPC;
2714 		goto out;
2715 	}
2716 	err = copyout(packed, wgd->wgd_data, size);
2717 	wgd->wgd_size = size;
2718 
2719 out:
2720 	zfree(packed, M_NVLIST);
2721 err:
2722 	nvlist_destroy(nvl);
2723 	return (err);
2724 }
2725 
2726 static int
wg_ioctl(if_t ifp,u_long cmd,caddr_t data)2727 wg_ioctl(if_t ifp, u_long cmd, caddr_t data)
2728 {
2729 	struct wg_data_io *wgd = (struct wg_data_io *)data;
2730 	struct ifreq *ifr = (struct ifreq *)data;
2731 	struct wg_softc *sc;
2732 	int ret = 0;
2733 
2734 	sx_slock(&wg_sx);
2735 	sc = if_getsoftc(ifp);
2736 	if (!sc) {
2737 		ret = ENXIO;
2738 		goto out;
2739 	}
2740 
2741 	switch (cmd) {
2742 	case SIOCSWG:
2743 		ret = priv_check(curthread, PRIV_NET_WG);
2744 		if (ret == 0)
2745 			ret = wgc_set(sc, wgd);
2746 		break;
2747 	case SIOCGWG:
2748 		ret = wgc_get(sc, wgd);
2749 		break;
2750 	/* Interface IOCTLs */
2751 	case SIOCSIFADDR:
2752 		/*
2753 		 * This differs from *BSD norms, but is more uniform with how
2754 		 * WireGuard behaves elsewhere.
2755 		 */
2756 		break;
2757 	case SIOCSIFFLAGS:
2758 		if (if_getflags(ifp) & IFF_UP)
2759 			ret = wg_up(sc);
2760 		else
2761 			wg_down(sc);
2762 		break;
2763 	case SIOCSIFMTU:
2764 		if (ifr->ifr_mtu <= 0 || ifr->ifr_mtu > MAX_MTU)
2765 			ret = EINVAL;
2766 		else
2767 			if_setmtu(ifp, ifr->ifr_mtu);
2768 		break;
2769 	case SIOCADDMULTI:
2770 	case SIOCDELMULTI:
2771 		break;
2772 	case SIOCGTUNFIB:
2773 		ifr->ifr_fib = sc->sc_socket.so_fibnum;
2774 		break;
2775 	case SIOCSTUNFIB:
2776 		ret = priv_check(curthread, PRIV_NET_WG);
2777 		if (ret)
2778 			break;
2779 		ret = priv_check(curthread, PRIV_NET_SETIFFIB);
2780 		if (ret)
2781 			break;
2782 		sx_xlock(&sc->sc_lock);
2783 		ret = wg_socket_set_fibnum(sc, ifr->ifr_fib);
2784 		sx_xunlock(&sc->sc_lock);
2785 		break;
2786 	default:
2787 		ret = ENOTTY;
2788 	}
2789 
2790 out:
2791 	sx_sunlock(&wg_sx);
2792 	return (ret);
2793 }
2794 
2795 static int
wg_up(struct wg_softc * sc)2796 wg_up(struct wg_softc *sc)
2797 {
2798 	if_t ifp = sc->sc_ifp;
2799 	struct wg_peer *peer;
2800 	int rc = EBUSY;
2801 
2802 	sx_xlock(&sc->sc_lock);
2803 	/* Jail's being removed, no more wg_up(). */
2804 	if ((sc->sc_flags & WGF_DYING) != 0)
2805 		goto out;
2806 
2807 	/* Silent success if we're already running. */
2808 	rc = 0;
2809 	if (if_getdrvflags(ifp) & IFF_DRV_RUNNING)
2810 		goto out;
2811 	if_setdrvflagbits(ifp, IFF_DRV_RUNNING, 0);
2812 
2813 	rc = wg_socket_init(sc, sc->sc_socket.so_port);
2814 	if (rc == 0) {
2815 		TAILQ_FOREACH(peer, &sc->sc_peers, p_entry)
2816 			wg_timers_enable(peer);
2817 		if_link_state_change(sc->sc_ifp, LINK_STATE_UP);
2818 	} else {
2819 		if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2820 		DPRINTF(sc, "Unable to initialize sockets: %d\n", rc);
2821 	}
2822 out:
2823 	sx_xunlock(&sc->sc_lock);
2824 	return (rc);
2825 }
2826 
2827 static void
wg_down(struct wg_softc * sc)2828 wg_down(struct wg_softc *sc)
2829 {
2830 	if_t ifp = sc->sc_ifp;
2831 	struct wg_peer *peer;
2832 
2833 	sx_xlock(&sc->sc_lock);
2834 	if (!(if_getdrvflags(ifp) & IFF_DRV_RUNNING)) {
2835 		sx_xunlock(&sc->sc_lock);
2836 		return;
2837 	}
2838 	if_setdrvflagbits(ifp, 0, IFF_DRV_RUNNING);
2839 
2840 	TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2841 		wg_queue_purge(&peer->p_stage_queue);
2842 		wg_timers_disable(peer);
2843 	}
2844 
2845 	wg_queue_purge(&sc->sc_handshake_queue);
2846 
2847 	TAILQ_FOREACH(peer, &sc->sc_peers, p_entry) {
2848 		noise_remote_handshake_clear(peer->p_remote);
2849 		noise_remote_keypairs_clear(peer->p_remote);
2850 	}
2851 
2852 	if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2853 	wg_socket_uninit(sc);
2854 
2855 	sx_xunlock(&sc->sc_lock);
2856 }
2857 
2858 static int
wg_clone_create(struct if_clone * ifc,char * name,size_t len,struct ifc_data * ifd,struct ifnet ** ifpp)2859 wg_clone_create(struct if_clone *ifc, char *name, size_t len,
2860     struct ifc_data *ifd, struct ifnet **ifpp)
2861 {
2862 	struct wg_softc *sc;
2863 	if_t ifp;
2864 
2865 	sc = malloc(sizeof(*sc), M_WG, M_WAITOK | M_ZERO);
2866 
2867 	sc->sc_local = noise_local_alloc(sc);
2868 
2869 	sc->sc_encrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2870 
2871 	sc->sc_decrypt = mallocarray(sizeof(struct grouptask), mp_ncpus, M_WG, M_WAITOK | M_ZERO);
2872 
2873 	if (!rn_inithead((void **)&sc->sc_aip4, offsetof(struct aip_addr, in) * NBBY))
2874 		goto free_decrypt;
2875 
2876 	if (!rn_inithead((void **)&sc->sc_aip6, offsetof(struct aip_addr, in6) * NBBY))
2877 		goto free_aip4;
2878 
2879 	atomic_add_int(&clone_count, 1);
2880 	ifp = sc->sc_ifp = if_alloc(IFT_WIREGUARD);
2881 
2882 	sc->sc_ucred = crhold(curthread->td_ucred);
2883 	sc->sc_socket.so_fibnum = curthread->td_proc->p_fibnum;
2884 	sc->sc_socket.so_port = 0;
2885 
2886 	TAILQ_INIT(&sc->sc_peers);
2887 	sc->sc_peers_num = 0;
2888 
2889 	cookie_checker_init(&sc->sc_cookie);
2890 
2891 	RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip4);
2892 	RADIX_NODE_HEAD_LOCK_INIT(sc->sc_aip6);
2893 
2894 	GROUPTASK_INIT(&sc->sc_handshake, 0, (gtask_fn_t *)wg_softc_handshake_receive, sc);
2895 	taskqgroup_attach(qgroup_wg_tqg, &sc->sc_handshake, sc, NULL, NULL, "wg tx initiation");
2896 	wg_queue_init(&sc->sc_handshake_queue, "hsq");
2897 
2898 	for (int i = 0; i < mp_ncpus; i++) {
2899 		GROUPTASK_INIT(&sc->sc_encrypt[i], 0,
2900 		     (gtask_fn_t *)wg_softc_encrypt, sc);
2901 		taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_encrypt[i], sc, i, NULL, NULL, "wg encrypt");
2902 		GROUPTASK_INIT(&sc->sc_decrypt[i], 0,
2903 		    (gtask_fn_t *)wg_softc_decrypt, sc);
2904 		taskqgroup_attach_cpu(qgroup_wg_tqg, &sc->sc_decrypt[i], sc, i, NULL, NULL, "wg decrypt");
2905 	}
2906 
2907 	wg_queue_init(&sc->sc_encrypt_parallel, "encp");
2908 	wg_queue_init(&sc->sc_decrypt_parallel, "decp");
2909 
2910 	sx_init(&sc->sc_lock, "wg softc lock");
2911 
2912 	if_setsoftc(ifp, sc);
2913 	if_setcapabilities(ifp, WG_CAPS);
2914 	if_setcapenable(ifp, WG_CAPS);
2915 	if_initname(ifp, wgname, ifd->unit);
2916 
2917 	if_setmtu(ifp, DEFAULT_MTU);
2918 	if_setflags(ifp, IFF_NOARP | IFF_MULTICAST);
2919 	if_setinitfn(ifp, wg_init);
2920 	if_setreassignfn(ifp, wg_reassign);
2921 	if_setqflushfn(ifp, wg_qflush);
2922 	if_settransmitfn(ifp, wg_transmit);
2923 #ifdef DEV_NETMAP
2924 	if_setinputfn(ifp, wg_if_input);
2925 #endif
2926 	if_setoutputfn(ifp, wg_output);
2927 	if_setioctlfn(ifp, wg_ioctl);
2928 	if_attach(ifp);
2929 	bpfattach(ifp, DLT_NULL, sizeof(uint32_t));
2930 #ifdef INET6
2931 	ND_IFINFO(ifp)->flags &= ~ND6_IFF_AUTO_LINKLOCAL;
2932 	ND_IFINFO(ifp)->flags |= ND6_IFF_NO_DAD;
2933 #endif
2934 	sx_xlock(&wg_sx);
2935 	LIST_INSERT_HEAD(&wg_list, sc, sc_entry);
2936 	sx_xunlock(&wg_sx);
2937 	*ifpp = ifp;
2938 	return (0);
2939 free_aip4:
2940 	RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
2941 	free(sc->sc_aip4, M_RTABLE);
2942 free_decrypt:
2943 	free(sc->sc_decrypt, M_WG);
2944 	free(sc->sc_encrypt, M_WG);
2945 	noise_local_free(sc->sc_local, NULL);
2946 	free(sc, M_WG);
2947 	return (ENOMEM);
2948 }
2949 
2950 static void
wg_clone_deferred_free(struct noise_local * l)2951 wg_clone_deferred_free(struct noise_local *l)
2952 {
2953 	struct wg_softc *sc = noise_local_arg(l);
2954 
2955 	free(sc, M_WG);
2956 	atomic_add_int(&clone_count, -1);
2957 }
2958 
2959 static int
wg_clone_destroy(struct if_clone * ifc,if_t ifp,uint32_t flags)2960 wg_clone_destroy(struct if_clone *ifc, if_t ifp, uint32_t flags)
2961 {
2962 	struct wg_softc *sc = if_getsoftc(ifp);
2963 	struct ucred *cred;
2964 
2965 	sx_xlock(&wg_sx);
2966 	if_setsoftc(ifp, NULL);
2967 	sx_xlock(&sc->sc_lock);
2968 	sc->sc_flags |= WGF_DYING;
2969 	cred = sc->sc_ucred;
2970 	sc->sc_ucred = NULL;
2971 	sx_xunlock(&sc->sc_lock);
2972 	LIST_REMOVE(sc, sc_entry);
2973 	sx_xunlock(&wg_sx);
2974 
2975 	if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
2976 	CURVNET_SET(if_getvnet(sc->sc_ifp));
2977 	if_purgeaddrs(sc->sc_ifp);
2978 	CURVNET_RESTORE();
2979 
2980 	sx_xlock(&sc->sc_lock);
2981 	wg_socket_uninit(sc);
2982 	sx_xunlock(&sc->sc_lock);
2983 
2984 	/*
2985 	 * No guarantees that all traffic have passed until the epoch has
2986 	 * elapsed with the socket closed.
2987 	 */
2988 	NET_EPOCH_WAIT();
2989 
2990 	taskqgroup_drain_all(qgroup_wg_tqg);
2991 	sx_xlock(&sc->sc_lock);
2992 	wg_peer_destroy_all(sc);
2993 	NET_EPOCH_DRAIN_CALLBACKS();
2994 	sx_xunlock(&sc->sc_lock);
2995 	sx_destroy(&sc->sc_lock);
2996 	taskqgroup_detach(qgroup_wg_tqg, &sc->sc_handshake);
2997 	for (int i = 0; i < mp_ncpus; i++) {
2998 		taskqgroup_detach(qgroup_wg_tqg, &sc->sc_encrypt[i]);
2999 		taskqgroup_detach(qgroup_wg_tqg, &sc->sc_decrypt[i]);
3000 	}
3001 	free(sc->sc_encrypt, M_WG);
3002 	free(sc->sc_decrypt, M_WG);
3003 	wg_queue_deinit(&sc->sc_handshake_queue);
3004 	wg_queue_deinit(&sc->sc_encrypt_parallel);
3005 	wg_queue_deinit(&sc->sc_decrypt_parallel);
3006 
3007 	RADIX_NODE_HEAD_DESTROY(sc->sc_aip4);
3008 	RADIX_NODE_HEAD_DESTROY(sc->sc_aip6);
3009 	rn_detachhead((void **)&sc->sc_aip4);
3010 	rn_detachhead((void **)&sc->sc_aip6);
3011 
3012 	cookie_checker_free(&sc->sc_cookie);
3013 
3014 	if (cred != NULL)
3015 		crfree(cred);
3016 	bpfdetach(sc->sc_ifp);
3017 	if_detach(sc->sc_ifp);
3018 	if_free(sc->sc_ifp);
3019 
3020 	noise_local_free(sc->sc_local, wg_clone_deferred_free);
3021 
3022 	return (0);
3023 }
3024 
3025 static void
wg_qflush(if_t ifp __unused)3026 wg_qflush(if_t ifp __unused)
3027 {
3028 }
3029 
3030 /*
3031  * Privileged information (private-key, preshared-key) are only exported for
3032  * root and jailed root by default.
3033  */
3034 static bool
wgc_privileged(struct wg_softc * sc)3035 wgc_privileged(struct wg_softc *sc)
3036 {
3037 	struct thread *td;
3038 
3039 	td = curthread;
3040 	return (priv_check(td, PRIV_NET_WG) == 0);
3041 }
3042 
3043 static void
wg_reassign(if_t ifp,struct vnet * new_vnet __unused,char * unused __unused)3044 wg_reassign(if_t ifp, struct vnet *new_vnet __unused,
3045     char *unused __unused)
3046 {
3047 	struct wg_softc *sc;
3048 
3049 	sc = if_getsoftc(ifp);
3050 	wg_down(sc);
3051 }
3052 
3053 static void
wg_init(void * xsc)3054 wg_init(void *xsc)
3055 {
3056 	struct wg_softc *sc;
3057 
3058 	sc = xsc;
3059 	wg_up(sc);
3060 }
3061 
3062 static void
vnet_wg_init(const void * unused __unused)3063 vnet_wg_init(const void *unused __unused)
3064 {
3065 	struct if_clone_addreq req = {
3066 		.create_f = wg_clone_create,
3067 		.destroy_f = wg_clone_destroy,
3068 		.flags = IFC_F_AUTOUNIT,
3069 	};
3070 	V_wg_cloner = ifc_attach_cloner(wgname, &req);
3071 }
3072 VNET_SYSINIT(vnet_wg_init, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3073 	     vnet_wg_init, NULL);
3074 
3075 static void
vnet_wg_uninit(const void * unused __unused)3076 vnet_wg_uninit(const void *unused __unused)
3077 {
3078 	if (V_wg_cloner)
3079 		ifc_detach_cloner(V_wg_cloner);
3080 }
3081 VNET_SYSUNINIT(vnet_wg_uninit, SI_SUB_PROTO_IFATTACHDOMAIN, SI_ORDER_ANY,
3082 	       vnet_wg_uninit, NULL);
3083 
3084 static int
wg_prison_remove(void * obj,void * data __unused)3085 wg_prison_remove(void *obj, void *data __unused)
3086 {
3087 	const struct prison *pr = obj;
3088 	struct wg_softc *sc;
3089 
3090 	/*
3091 	 * Do a pass through all if_wg interfaces and release creds on any from
3092 	 * the jail that are supposed to be going away.  This will, in turn, let
3093 	 * the jail die so that we don't end up with Schrödinger's jail.
3094 	 */
3095 	sx_slock(&wg_sx);
3096 	LIST_FOREACH(sc, &wg_list, sc_entry) {
3097 		sx_xlock(&sc->sc_lock);
3098 		if (!(sc->sc_flags & WGF_DYING) && sc->sc_ucred && sc->sc_ucred->cr_prison == pr) {
3099 			struct ucred *cred = sc->sc_ucred;
3100 			DPRINTF(sc, "Creating jail exiting\n");
3101 			if_link_state_change(sc->sc_ifp, LINK_STATE_DOWN);
3102 			wg_socket_uninit(sc);
3103 			sc->sc_ucred = NULL;
3104 			crfree(cred);
3105 			sc->sc_flags |= WGF_DYING;
3106 		}
3107 		sx_xunlock(&sc->sc_lock);
3108 	}
3109 	sx_sunlock(&wg_sx);
3110 
3111 	return (0);
3112 }
3113 
3114 #ifdef SELFTESTS
3115 #include "selftest/allowedips.c"
wg_run_selftests(void)3116 static bool wg_run_selftests(void)
3117 {
3118 	bool ret = true;
3119 	ret &= wg_allowedips_selftest();
3120 	ret &= noise_counter_selftest();
3121 	ret &= cookie_selftest();
3122 	return ret;
3123 }
3124 #else
wg_run_selftests(void)3125 static inline bool wg_run_selftests(void) { return true; }
3126 #endif
3127 
3128 static int
wg_module_init(void)3129 wg_module_init(void)
3130 {
3131 	int ret;
3132 	osd_method_t methods[PR_MAXMETHOD] = {
3133 		[PR_METHOD_REMOVE] = wg_prison_remove,
3134 	};
3135 
3136 	wg_packet_zone = uma_zcreate("wg packet", sizeof(struct wg_packet),
3137 	     NULL, NULL, NULL, NULL, 0, 0);
3138 
3139 	ret = crypto_init();
3140 	if (ret != 0)
3141 		return (ret);
3142 	ret = cookie_init();
3143 	if (ret != 0)
3144 		return (ret);
3145 
3146 	wg_osd_jail_slot = osd_jail_register(NULL, methods);
3147 
3148 	if (!wg_run_selftests())
3149 		return (ENOTRECOVERABLE);
3150 
3151 	return (0);
3152 }
3153 
3154 static void
wg_module_deinit(void)3155 wg_module_deinit(void)
3156 {
3157 	VNET_ITERATOR_DECL(vnet_iter);
3158 	VNET_LIST_RLOCK();
3159 	VNET_FOREACH(vnet_iter) {
3160 		struct if_clone *clone = VNET_VNET(vnet_iter, wg_cloner);
3161 		if (clone) {
3162 			ifc_detach_cloner(clone);
3163 			VNET_VNET(vnet_iter, wg_cloner) = NULL;
3164 		}
3165 	}
3166 	VNET_LIST_RUNLOCK();
3167 	NET_EPOCH_WAIT();
3168 	MPASS(LIST_EMPTY(&wg_list));
3169 	if (wg_osd_jail_slot != 0)
3170 		osd_jail_deregister(wg_osd_jail_slot);
3171 	cookie_deinit();
3172 	crypto_deinit();
3173 	if (wg_packet_zone != NULL)
3174 		uma_zdestroy(wg_packet_zone);
3175 }
3176 
3177 static int
wg_module_event_handler(module_t mod,int what,void * arg)3178 wg_module_event_handler(module_t mod, int what, void *arg)
3179 {
3180 	switch (what) {
3181 		case MOD_LOAD:
3182 			return wg_module_init();
3183 		case MOD_UNLOAD:
3184 			wg_module_deinit();
3185 			break;
3186 		default:
3187 			return (EOPNOTSUPP);
3188 	}
3189 	return (0);
3190 }
3191 
3192 static moduledata_t wg_moduledata = {
3193 	"if_wg",
3194 	wg_module_event_handler,
3195 	NULL
3196 };
3197 
3198 DECLARE_MODULE(if_wg, wg_moduledata, SI_SUB_PSEUDO, SI_ORDER_ANY);
3199 MODULE_VERSION(if_wg, WIREGUARD_VERSION);
3200 MODULE_DEPEND(if_wg, crypto, 1, 1, 1);
3201