1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/init.h> 3 #include <linux/static_call.h> 4 #include <linux/bug.h> 5 #include <linux/smp.h> 6 #include <linux/sort.h> 7 #include <linux/slab.h> 8 #include <linux/module.h> 9 #include <linux/cpu.h> 10 #include <linux/processor.h> 11 #include <asm/sections.h> 12 13 extern struct static_call_site __start_static_call_sites[], 14 __stop_static_call_sites[]; 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], 16 __stop_static_call_tramp_key[]; 17 18 static bool static_call_initialized; 19 20 /* mutex to protect key modules/sites */ 21 static DEFINE_MUTEX(static_call_mutex); 22 23 static void static_call_lock(void) 24 { 25 mutex_lock(&static_call_mutex); 26 } 27 28 static void static_call_unlock(void) 29 { 30 mutex_unlock(&static_call_mutex); 31 } 32 33 static inline void *static_call_addr(struct static_call_site *site) 34 { 35 return (void *)((long)site->addr + (long)&site->addr); 36 } 37 38 static inline unsigned long __static_call_key(const struct static_call_site *site) 39 { 40 return (long)site->key + (long)&site->key; 41 } 42 43 static inline struct static_call_key *static_call_key(const struct static_call_site *site) 44 { 45 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS); 46 } 47 48 /* These assume the key is word-aligned. */ 49 static inline bool static_call_is_init(struct static_call_site *site) 50 { 51 return __static_call_key(site) & STATIC_CALL_SITE_INIT; 52 } 53 54 static inline bool static_call_is_tail(struct static_call_site *site) 55 { 56 return __static_call_key(site) & STATIC_CALL_SITE_TAIL; 57 } 58 59 static inline void static_call_set_init(struct static_call_site *site) 60 { 61 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) - 62 (long)&site->key; 63 } 64 65 static int static_call_site_cmp(const void *_a, const void *_b) 66 { 67 const struct static_call_site *a = _a; 68 const struct static_call_site *b = _b; 69 const struct static_call_key *key_a = static_call_key(a); 70 const struct static_call_key *key_b = static_call_key(b); 71 72 if (key_a < key_b) 73 return -1; 74 75 if (key_a > key_b) 76 return 1; 77 78 return 0; 79 } 80 81 static void static_call_site_swap(void *_a, void *_b, int size) 82 { 83 long delta = (unsigned long)_a - (unsigned long)_b; 84 struct static_call_site *a = _a; 85 struct static_call_site *b = _b; 86 struct static_call_site tmp = *a; 87 88 a->addr = b->addr - delta; 89 a->key = b->key - delta; 90 91 b->addr = tmp.addr + delta; 92 b->key = tmp.key + delta; 93 } 94 95 static inline void static_call_sort_entries(struct static_call_site *start, 96 struct static_call_site *stop) 97 { 98 sort(start, stop - start, sizeof(struct static_call_site), 99 static_call_site_cmp, static_call_site_swap); 100 } 101 102 static inline bool static_call_key_has_mods(struct static_call_key *key) 103 { 104 return !(key->type & 1); 105 } 106 107 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) 108 { 109 if (!static_call_key_has_mods(key)) 110 return NULL; 111 112 return key->mods; 113 } 114 115 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) 116 { 117 if (static_call_key_has_mods(key)) 118 return NULL; 119 120 return (struct static_call_site *)(key->type & ~1); 121 } 122 123 void __static_call_update(struct static_call_key *key, void *tramp, void *func) 124 { 125 struct static_call_site *site, *stop; 126 struct static_call_mod *site_mod, first; 127 128 cpus_read_lock(); 129 static_call_lock(); 130 131 if (key->func == func) 132 goto done; 133 134 key->func = func; 135 136 arch_static_call_transform(NULL, tramp, func, false); 137 138 /* 139 * If uninitialized, we'll not update the callsites, but they still 140 * point to the trampoline and we just patched that. 141 */ 142 if (WARN_ON_ONCE(!static_call_initialized)) 143 goto done; 144 145 first = (struct static_call_mod){ 146 .next = static_call_key_next(key), 147 .mod = NULL, 148 .sites = static_call_key_sites(key), 149 }; 150 151 for (site_mod = &first; site_mod; site_mod = site_mod->next) { 152 struct module *mod = site_mod->mod; 153 154 if (!site_mod->sites) { 155 /* 156 * This can happen if the static call key is defined in 157 * a module which doesn't use it. 158 * 159 * It also happens in the has_mods case, where the 160 * 'first' entry has no sites associated with it. 161 */ 162 continue; 163 } 164 165 stop = __stop_static_call_sites; 166 167 #ifdef CONFIG_MODULES 168 if (mod) { 169 stop = mod->static_call_sites + 170 mod->num_static_call_sites; 171 } 172 #endif 173 174 for (site = site_mod->sites; 175 site < stop && static_call_key(site) == key; site++) { 176 void *site_addr = static_call_addr(site); 177 178 if (static_call_is_init(site)) { 179 /* 180 * Don't write to call sites which were in 181 * initmem and have since been freed. 182 */ 183 if (!mod && system_state >= SYSTEM_RUNNING) 184 continue; 185 if (mod && !within_module_init((unsigned long)site_addr, mod)) 186 continue; 187 } 188 189 if (!kernel_text_address((unsigned long)site_addr)) { 190 WARN_ONCE(1, "can't patch static call site at %pS", 191 site_addr); 192 continue; 193 } 194 195 arch_static_call_transform(site_addr, NULL, func, 196 static_call_is_tail(site)); 197 } 198 } 199 200 done: 201 static_call_unlock(); 202 cpus_read_unlock(); 203 } 204 EXPORT_SYMBOL_GPL(__static_call_update); 205 206 static int __static_call_init(struct module *mod, 207 struct static_call_site *start, 208 struct static_call_site *stop) 209 { 210 struct static_call_site *site; 211 struct static_call_key *key, *prev_key = NULL; 212 struct static_call_mod *site_mod; 213 214 if (start == stop) 215 return 0; 216 217 static_call_sort_entries(start, stop); 218 219 for (site = start; site < stop; site++) { 220 void *site_addr = static_call_addr(site); 221 222 if ((mod && within_module_init((unsigned long)site_addr, mod)) || 223 (!mod && init_section_contains(site_addr, 1))) 224 static_call_set_init(site); 225 226 key = static_call_key(site); 227 if (key != prev_key) { 228 prev_key = key; 229 230 /* 231 * For vmlinux (!mod) avoid the allocation by storing 232 * the sites pointer in the key itself. Also see 233 * __static_call_update()'s @first. 234 * 235 * This allows architectures (eg. x86) to call 236 * static_call_init() before memory allocation works. 237 */ 238 if (!mod) { 239 key->sites = site; 240 key->type |= 1; 241 goto do_transform; 242 } 243 244 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 245 if (!site_mod) 246 return -ENOMEM; 247 248 /* 249 * When the key has a direct sites pointer, extract 250 * that into an explicit struct static_call_mod, so we 251 * can have a list of modules. 252 */ 253 if (static_call_key_sites(key)) { 254 site_mod->mod = NULL; 255 site_mod->next = NULL; 256 site_mod->sites = static_call_key_sites(key); 257 258 key->mods = site_mod; 259 260 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 261 if (!site_mod) 262 return -ENOMEM; 263 } 264 265 site_mod->mod = mod; 266 site_mod->sites = site; 267 site_mod->next = static_call_key_next(key); 268 key->mods = site_mod; 269 } 270 271 do_transform: 272 arch_static_call_transform(site_addr, NULL, key->func, 273 static_call_is_tail(site)); 274 } 275 276 return 0; 277 } 278 279 static int addr_conflict(struct static_call_site *site, void *start, void *end) 280 { 281 unsigned long addr = (unsigned long)static_call_addr(site); 282 283 if (addr <= (unsigned long)end && 284 addr + CALL_INSN_SIZE > (unsigned long)start) 285 return 1; 286 287 return 0; 288 } 289 290 static int __static_call_text_reserved(struct static_call_site *iter_start, 291 struct static_call_site *iter_stop, 292 void *start, void *end) 293 { 294 struct static_call_site *iter = iter_start; 295 296 while (iter < iter_stop) { 297 if (addr_conflict(iter, start, end)) 298 return 1; 299 iter++; 300 } 301 302 return 0; 303 } 304 305 #ifdef CONFIG_MODULES 306 307 static int __static_call_mod_text_reserved(void *start, void *end) 308 { 309 struct module *mod; 310 int ret; 311 312 preempt_disable(); 313 mod = __module_text_address((unsigned long)start); 314 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); 315 if (!try_module_get(mod)) 316 mod = NULL; 317 preempt_enable(); 318 319 if (!mod) 320 return 0; 321 322 ret = __static_call_text_reserved(mod->static_call_sites, 323 mod->static_call_sites + mod->num_static_call_sites, 324 start, end); 325 326 module_put(mod); 327 328 return ret; 329 } 330 331 static unsigned long tramp_key_lookup(unsigned long addr) 332 { 333 struct static_call_tramp_key *start = __start_static_call_tramp_key; 334 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; 335 struct static_call_tramp_key *tramp_key; 336 337 for (tramp_key = start; tramp_key != stop; tramp_key++) { 338 unsigned long tramp; 339 340 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; 341 if (tramp == addr) 342 return (long)tramp_key->key + (long)&tramp_key->key; 343 } 344 345 return 0; 346 } 347 348 static int static_call_add_module(struct module *mod) 349 { 350 struct static_call_site *start = mod->static_call_sites; 351 struct static_call_site *stop = start + mod->num_static_call_sites; 352 struct static_call_site *site; 353 354 for (site = start; site != stop; site++) { 355 unsigned long s_key = __static_call_key(site); 356 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; 357 unsigned long key; 358 359 /* 360 * Is the key is exported, 'addr' points to the key, which 361 * means modules are allowed to call static_call_update() on 362 * it. 363 * 364 * Otherwise, the key isn't exported, and 'addr' points to the 365 * trampoline so we need to lookup the key. 366 * 367 * We go through this dance to prevent crazy modules from 368 * abusing sensitive static calls. 369 */ 370 if (!kernel_text_address(addr)) 371 continue; 372 373 key = tramp_key_lookup(addr); 374 if (!key) { 375 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", 376 static_call_addr(site)); 377 return -EINVAL; 378 } 379 380 key |= s_key & STATIC_CALL_SITE_FLAGS; 381 site->key = key - (long)&site->key; 382 } 383 384 return __static_call_init(mod, start, stop); 385 } 386 387 static void static_call_del_module(struct module *mod) 388 { 389 struct static_call_site *start = mod->static_call_sites; 390 struct static_call_site *stop = mod->static_call_sites + 391 mod->num_static_call_sites; 392 struct static_call_key *key, *prev_key = NULL; 393 struct static_call_mod *site_mod, **prev; 394 struct static_call_site *site; 395 396 for (site = start; site < stop; site++) { 397 key = static_call_key(site); 398 if (key == prev_key) 399 continue; 400 401 prev_key = key; 402 403 for (prev = &key->mods, site_mod = key->mods; 404 site_mod && site_mod->mod != mod; 405 prev = &site_mod->next, site_mod = site_mod->next) 406 ; 407 408 if (!site_mod) 409 continue; 410 411 *prev = site_mod->next; 412 kfree(site_mod); 413 } 414 } 415 416 static int static_call_module_notify(struct notifier_block *nb, 417 unsigned long val, void *data) 418 { 419 struct module *mod = data; 420 int ret = 0; 421 422 cpus_read_lock(); 423 static_call_lock(); 424 425 switch (val) { 426 case MODULE_STATE_COMING: 427 ret = static_call_add_module(mod); 428 if (ret) { 429 WARN(1, "Failed to allocate memory for static calls"); 430 static_call_del_module(mod); 431 } 432 break; 433 case MODULE_STATE_GOING: 434 static_call_del_module(mod); 435 break; 436 } 437 438 static_call_unlock(); 439 cpus_read_unlock(); 440 441 return notifier_from_errno(ret); 442 } 443 444 static struct notifier_block static_call_module_nb = { 445 .notifier_call = static_call_module_notify, 446 }; 447 448 #else 449 450 static inline int __static_call_mod_text_reserved(void *start, void *end) 451 { 452 return 0; 453 } 454 455 #endif /* CONFIG_MODULES */ 456 457 int static_call_text_reserved(void *start, void *end) 458 { 459 int ret = __static_call_text_reserved(__start_static_call_sites, 460 __stop_static_call_sites, start, end); 461 462 if (ret) 463 return ret; 464 465 return __static_call_mod_text_reserved(start, end); 466 } 467 468 int __init static_call_init(void) 469 { 470 int ret; 471 472 if (static_call_initialized) 473 return 0; 474 475 cpus_read_lock(); 476 static_call_lock(); 477 ret = __static_call_init(NULL, __start_static_call_sites, 478 __stop_static_call_sites); 479 static_call_unlock(); 480 cpus_read_unlock(); 481 482 if (ret) { 483 pr_err("Failed to allocate memory for static_call!\n"); 484 BUG(); 485 } 486 487 static_call_initialized = true; 488 489 #ifdef CONFIG_MODULES 490 register_module_notifier(&static_call_module_nb); 491 #endif 492 return 0; 493 } 494 early_initcall(static_call_init); 495 496 long __static_call_return0(void) 497 { 498 return 0; 499 } 500 501 #ifdef CONFIG_STATIC_CALL_SELFTEST 502 503 static int func_a(int x) 504 { 505 return x+1; 506 } 507 508 static int func_b(int x) 509 { 510 return x+2; 511 } 512 513 DEFINE_STATIC_CALL(sc_selftest, func_a); 514 515 static struct static_call_data { 516 int (*func)(int); 517 int val; 518 int expect; 519 } static_call_data [] __initdata = { 520 { NULL, 2, 3 }, 521 { func_b, 2, 4 }, 522 { func_a, 2, 3 } 523 }; 524 525 static int __init test_static_call_init(void) 526 { 527 int i; 528 529 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { 530 struct static_call_data *scd = &static_call_data[i]; 531 532 if (scd->func) 533 static_call_update(sc_selftest, scd->func); 534 535 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); 536 } 537 538 return 0; 539 } 540 early_initcall(test_static_call_init); 541 542 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 543