1 /* SPDX-License-Identifier: GPL-2.0 */ 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #ifndef _LINUX_SKMSG_H 5 #define _LINUX_SKMSG_H 6 7 #include <linux/bpf.h> 8 #include <linux/filter.h> 9 #include <linux/scatterlist.h> 10 #include <linux/skbuff.h> 11 12 #include <net/sock.h> 13 #include <net/tcp.h> 14 #include <net/strparser.h> 15 16 #define MAX_MSG_FRAGS MAX_SKB_FRAGS 17 18 enum __sk_action { 19 __SK_DROP = 0, 20 __SK_PASS, 21 __SK_REDIRECT, 22 __SK_NONE, 23 }; 24 25 struct sk_msg_sg { 26 u32 start; 27 u32 curr; 28 u32 end; 29 u32 size; 30 u32 copybreak; 31 bool copy[MAX_MSG_FRAGS]; 32 struct scatterlist data[MAX_MSG_FRAGS]; 33 }; 34 35 struct sk_msg { 36 struct sk_msg_sg sg; 37 void *data; 38 void *data_end; 39 u32 apply_bytes; 40 u32 cork_bytes; 41 u32 flags; 42 struct sk_buff *skb; 43 struct sock *sk_redir; 44 struct sock *sk; 45 struct list_head list; 46 }; 47 48 struct sk_psock_progs { 49 struct bpf_prog *msg_parser; 50 struct bpf_prog *skb_parser; 51 struct bpf_prog *skb_verdict; 52 }; 53 54 enum sk_psock_state_bits { 55 SK_PSOCK_TX_ENABLED, 56 }; 57 58 struct sk_psock_link { 59 struct list_head list; 60 struct bpf_map *map; 61 void *link_raw; 62 }; 63 64 struct sk_psock_parser { 65 struct strparser strp; 66 bool enabled; 67 void (*saved_data_ready)(struct sock *sk); 68 }; 69 70 struct sk_psock_work_state { 71 struct sk_buff *skb; 72 u32 len; 73 u32 off; 74 }; 75 76 struct sk_psock { 77 struct sock *sk; 78 struct sock *sk_redir; 79 u32 apply_bytes; 80 u32 cork_bytes; 81 u32 eval; 82 struct sk_msg *cork; 83 struct sk_psock_progs progs; 84 struct sk_psock_parser parser; 85 struct sk_buff_head ingress_skb; 86 struct list_head ingress_msg; 87 unsigned long state; 88 struct list_head link; 89 spinlock_t link_lock; 90 refcount_t refcnt; 91 void (*saved_unhash)(struct sock *sk); 92 void (*saved_close)(struct sock *sk, long timeout); 93 void (*saved_write_space)(struct sock *sk); 94 struct proto *sk_proto; 95 struct sk_psock_work_state work_state; 96 struct work_struct work; 97 union { 98 struct rcu_head rcu; 99 struct work_struct gc; 100 }; 101 }; 102 103 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 104 int elem_first_coalesce); 105 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len); 106 int sk_msg_free(struct sock *sk, struct sk_msg *msg); 107 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg); 108 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes); 109 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 110 u32 bytes); 111 112 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes); 113 114 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 115 struct sk_msg *msg, u32 bytes); 116 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 117 struct sk_msg *msg, u32 bytes); 118 119 static inline void sk_msg_check_to_free(struct sk_msg *msg, u32 i, u32 bytes) 120 { 121 WARN_ON(i == msg->sg.end && bytes); 122 } 123 124 static inline void sk_msg_apply_bytes(struct sk_psock *psock, u32 bytes) 125 { 126 if (psock->apply_bytes) { 127 if (psock->apply_bytes < bytes) 128 psock->apply_bytes = 0; 129 else 130 psock->apply_bytes -= bytes; 131 } 132 } 133 134 #define sk_msg_iter_var_prev(var) \ 135 do { \ 136 if (var == 0) \ 137 var = MAX_MSG_FRAGS - 1; \ 138 else \ 139 var--; \ 140 } while (0) 141 142 #define sk_msg_iter_var_next(var) \ 143 do { \ 144 var++; \ 145 if (var == MAX_MSG_FRAGS) \ 146 var = 0; \ 147 } while (0) 148 149 #define sk_msg_iter_prev(msg, which) \ 150 sk_msg_iter_var_prev(msg->sg.which) 151 152 #define sk_msg_iter_next(msg, which) \ 153 sk_msg_iter_var_next(msg->sg.which) 154 155 static inline void sk_msg_clear_meta(struct sk_msg *msg) 156 { 157 memset(&msg->sg, 0, offsetofend(struct sk_msg_sg, copy)); 158 } 159 160 static inline void sk_msg_init(struct sk_msg *msg) 161 { 162 memset(msg, 0, sizeof(*msg)); 163 sg_init_marker(msg->sg.data, ARRAY_SIZE(msg->sg.data)); 164 } 165 166 static inline void sk_msg_xfer(struct sk_msg *dst, struct sk_msg *src, 167 int which, u32 size) 168 { 169 dst->sg.data[which] = src->sg.data[which]; 170 dst->sg.data[which].length = size; 171 src->sg.data[which].length -= size; 172 src->sg.data[which].offset += size; 173 } 174 175 static inline u32 sk_msg_elem_used(const struct sk_msg *msg) 176 { 177 return msg->sg.end >= msg->sg.start ? 178 msg->sg.end - msg->sg.start : 179 msg->sg.end + (MAX_MSG_FRAGS - msg->sg.start); 180 } 181 182 static inline bool sk_msg_full(const struct sk_msg *msg) 183 { 184 return (msg->sg.end == msg->sg.start) && msg->sg.size; 185 } 186 187 static inline struct scatterlist *sk_msg_elem(struct sk_msg *msg, int which) 188 { 189 return &msg->sg.data[which]; 190 } 191 192 static inline struct page *sk_msg_page(struct sk_msg *msg, int which) 193 { 194 return sg_page(sk_msg_elem(msg, which)); 195 } 196 197 static inline bool sk_msg_to_ingress(const struct sk_msg *msg) 198 { 199 return msg->flags & BPF_F_INGRESS; 200 } 201 202 static inline void sk_msg_compute_data_pointers(struct sk_msg *msg) 203 { 204 struct scatterlist *sge = sk_msg_elem(msg, msg->sg.start); 205 206 if (msg->sg.copy[msg->sg.start]) { 207 msg->data = NULL; 208 msg->data_end = NULL; 209 } else { 210 msg->data = sg_virt(sge); 211 msg->data_end = msg->data + sge->length; 212 } 213 } 214 215 static inline void sk_msg_page_add(struct sk_msg *msg, struct page *page, 216 u32 len, u32 offset) 217 { 218 struct scatterlist *sge; 219 220 get_page(page); 221 sge = sk_msg_elem(msg, msg->sg.end); 222 sg_set_page(sge, page, len, offset); 223 sg_unmark_end(sge); 224 225 msg->sg.copy[msg->sg.end] = true; 226 msg->sg.size += len; 227 sk_msg_iter_next(msg, end); 228 } 229 230 static inline struct sk_psock *sk_psock(const struct sock *sk) 231 { 232 return rcu_dereference_sk_user_data(sk); 233 } 234 235 static inline bool sk_has_psock(struct sock *sk) 236 { 237 return sk_psock(sk) != NULL && sk->sk_prot->recvmsg == tcp_bpf_recvmsg; 238 } 239 240 static inline void sk_psock_queue_msg(struct sk_psock *psock, 241 struct sk_msg *msg) 242 { 243 list_add_tail(&msg->list, &psock->ingress_msg); 244 } 245 246 static inline void sk_psock_report_error(struct sk_psock *psock, int err) 247 { 248 struct sock *sk = psock->sk; 249 250 sk->sk_err = err; 251 sk->sk_error_report(sk); 252 } 253 254 struct sk_psock *sk_psock_init(struct sock *sk, int node); 255 256 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock); 257 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock); 258 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock); 259 260 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 261 struct sk_msg *msg); 262 263 static inline struct sk_psock_link *sk_psock_init_link(void) 264 { 265 return kzalloc(sizeof(struct sk_psock_link), 266 GFP_ATOMIC | __GFP_NOWARN); 267 } 268 269 static inline void sk_psock_free_link(struct sk_psock_link *link) 270 { 271 kfree(link); 272 } 273 274 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock); 275 #if defined(CONFIG_BPF_STREAM_PARSER) 276 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link); 277 #else 278 static inline void sk_psock_unlink(struct sock *sk, 279 struct sk_psock_link *link) 280 { 281 } 282 #endif 283 284 void __sk_psock_purge_ingress_msg(struct sk_psock *psock); 285 286 static inline void sk_psock_cork_free(struct sk_psock *psock) 287 { 288 if (psock->cork) { 289 sk_msg_free(psock->sk, psock->cork); 290 kfree(psock->cork); 291 psock->cork = NULL; 292 } 293 } 294 295 static inline void sk_psock_update_proto(struct sock *sk, 296 struct sk_psock *psock, 297 struct proto *ops) 298 { 299 psock->saved_unhash = sk->sk_prot->unhash; 300 psock->saved_close = sk->sk_prot->close; 301 psock->saved_write_space = sk->sk_write_space; 302 303 psock->sk_proto = sk->sk_prot; 304 sk->sk_prot = ops; 305 } 306 307 static inline void sk_psock_restore_proto(struct sock *sk, 308 struct sk_psock *psock) 309 { 310 if (psock->sk_proto) { 311 sk->sk_prot = psock->sk_proto; 312 psock->sk_proto = NULL; 313 } 314 } 315 316 static inline void sk_psock_set_state(struct sk_psock *psock, 317 enum sk_psock_state_bits bit) 318 { 319 set_bit(bit, &psock->state); 320 } 321 322 static inline void sk_psock_clear_state(struct sk_psock *psock, 323 enum sk_psock_state_bits bit) 324 { 325 clear_bit(bit, &psock->state); 326 } 327 328 static inline bool sk_psock_test_state(const struct sk_psock *psock, 329 enum sk_psock_state_bits bit) 330 { 331 return test_bit(bit, &psock->state); 332 } 333 334 static inline struct sk_psock *sk_psock_get(struct sock *sk) 335 { 336 struct sk_psock *psock; 337 338 rcu_read_lock(); 339 psock = sk_psock(sk); 340 if (psock && !refcount_inc_not_zero(&psock->refcnt)) 341 psock = NULL; 342 rcu_read_unlock(); 343 return psock; 344 } 345 346 void sk_psock_stop(struct sock *sk, struct sk_psock *psock); 347 void sk_psock_destroy(struct rcu_head *rcu); 348 void sk_psock_drop(struct sock *sk, struct sk_psock *psock); 349 350 static inline void sk_psock_put(struct sock *sk, struct sk_psock *psock) 351 { 352 if (refcount_dec_and_test(&psock->refcnt)) 353 sk_psock_drop(sk, psock); 354 } 355 356 static inline void psock_set_prog(struct bpf_prog **pprog, 357 struct bpf_prog *prog) 358 { 359 prog = xchg(pprog, prog); 360 if (prog) 361 bpf_prog_put(prog); 362 } 363 364 static inline void psock_progs_drop(struct sk_psock_progs *progs) 365 { 366 psock_set_prog(&progs->msg_parser, NULL); 367 psock_set_prog(&progs->skb_parser, NULL); 368 psock_set_prog(&progs->skb_verdict, NULL); 369 } 370 371 #endif /* _LINUX_SKMSG_H */ 372