1 // SPDX-License-Identifier: GPL-2.0 2 #include <linux/init.h> 3 #include <linux/static_call.h> 4 #include <linux/bug.h> 5 #include <linux/smp.h> 6 #include <linux/sort.h> 7 #include <linux/slab.h> 8 #include <linux/module.h> 9 #include <linux/cpu.h> 10 #include <linux/processor.h> 11 #include <asm/sections.h> 12 13 extern struct static_call_site __start_static_call_sites[], 14 __stop_static_call_sites[]; 15 extern struct static_call_tramp_key __start_static_call_tramp_key[], 16 __stop_static_call_tramp_key[]; 17 18 static bool static_call_initialized; 19 20 /* mutex to protect key modules/sites */ 21 static DEFINE_MUTEX(static_call_mutex); 22 23 static void static_call_lock(void) 24 { 25 mutex_lock(&static_call_mutex); 26 } 27 28 static void static_call_unlock(void) 29 { 30 mutex_unlock(&static_call_mutex); 31 } 32 33 static inline void *static_call_addr(struct static_call_site *site) 34 { 35 return (void *)((long)site->addr + (long)&site->addr); 36 } 37 38 static inline unsigned long __static_call_key(const struct static_call_site *site) 39 { 40 return (long)site->key + (long)&site->key; 41 } 42 43 static inline struct static_call_key *static_call_key(const struct static_call_site *site) 44 { 45 return (void *)(__static_call_key(site) & ~STATIC_CALL_SITE_FLAGS); 46 } 47 48 /* These assume the key is word-aligned. */ 49 static inline bool static_call_is_init(struct static_call_site *site) 50 { 51 return __static_call_key(site) & STATIC_CALL_SITE_INIT; 52 } 53 54 static inline bool static_call_is_tail(struct static_call_site *site) 55 { 56 return __static_call_key(site) & STATIC_CALL_SITE_TAIL; 57 } 58 59 static inline void static_call_set_init(struct static_call_site *site) 60 { 61 site->key = (__static_call_key(site) | STATIC_CALL_SITE_INIT) - 62 (long)&site->key; 63 } 64 65 static int static_call_site_cmp(const void *_a, const void *_b) 66 { 67 const struct static_call_site *a = _a; 68 const struct static_call_site *b = _b; 69 const struct static_call_key *key_a = static_call_key(a); 70 const struct static_call_key *key_b = static_call_key(b); 71 72 if (key_a < key_b) 73 return -1; 74 75 if (key_a > key_b) 76 return 1; 77 78 return 0; 79 } 80 81 static void static_call_site_swap(void *_a, void *_b, int size) 82 { 83 long delta = (unsigned long)_a - (unsigned long)_b; 84 struct static_call_site *a = _a; 85 struct static_call_site *b = _b; 86 struct static_call_site tmp = *a; 87 88 a->addr = b->addr - delta; 89 a->key = b->key - delta; 90 91 b->addr = tmp.addr + delta; 92 b->key = tmp.key + delta; 93 } 94 95 static inline void static_call_sort_entries(struct static_call_site *start, 96 struct static_call_site *stop) 97 { 98 sort(start, stop - start, sizeof(struct static_call_site), 99 static_call_site_cmp, static_call_site_swap); 100 } 101 102 static inline bool static_call_key_has_mods(struct static_call_key *key) 103 { 104 return !(key->type & 1); 105 } 106 107 static inline struct static_call_mod *static_call_key_next(struct static_call_key *key) 108 { 109 if (!static_call_key_has_mods(key)) 110 return NULL; 111 112 return key->mods; 113 } 114 115 static inline struct static_call_site *static_call_key_sites(struct static_call_key *key) 116 { 117 if (static_call_key_has_mods(key)) 118 return NULL; 119 120 return (struct static_call_site *)(key->type & ~1); 121 } 122 123 void __static_call_update(struct static_call_key *key, void *tramp, void *func) 124 { 125 struct static_call_site *site, *stop; 126 struct static_call_mod *site_mod, first; 127 128 cpus_read_lock(); 129 static_call_lock(); 130 131 if (key->func == func) 132 goto done; 133 134 key->func = func; 135 136 arch_static_call_transform(NULL, tramp, func, false); 137 138 /* 139 * If uninitialized, we'll not update the callsites, but they still 140 * point to the trampoline and we just patched that. 141 */ 142 if (WARN_ON_ONCE(!static_call_initialized)) 143 goto done; 144 145 first = (struct static_call_mod){ 146 .next = static_call_key_next(key), 147 .mod = NULL, 148 .sites = static_call_key_sites(key), 149 }; 150 151 for (site_mod = &first; site_mod; site_mod = site_mod->next) { 152 bool init = system_state < SYSTEM_RUNNING; 153 struct module *mod = site_mod->mod; 154 155 if (!site_mod->sites) { 156 /* 157 * This can happen if the static call key is defined in 158 * a module which doesn't use it. 159 * 160 * It also happens in the has_mods case, where the 161 * 'first' entry has no sites associated with it. 162 */ 163 continue; 164 } 165 166 stop = __stop_static_call_sites; 167 168 #ifdef CONFIG_MODULES 169 if (mod) { 170 stop = mod->static_call_sites + 171 mod->num_static_call_sites; 172 init = mod->state == MODULE_STATE_COMING; 173 } 174 #endif 175 176 for (site = site_mod->sites; 177 site < stop && static_call_key(site) == key; site++) { 178 void *site_addr = static_call_addr(site); 179 180 if (!init && static_call_is_init(site)) 181 continue; 182 183 if (!kernel_text_address((unsigned long)site_addr)) { 184 WARN_ONCE(1, "can't patch static call site at %pS", 185 site_addr); 186 continue; 187 } 188 189 arch_static_call_transform(site_addr, NULL, func, 190 static_call_is_tail(site)); 191 } 192 } 193 194 done: 195 static_call_unlock(); 196 cpus_read_unlock(); 197 } 198 EXPORT_SYMBOL_GPL(__static_call_update); 199 200 static int __static_call_init(struct module *mod, 201 struct static_call_site *start, 202 struct static_call_site *stop) 203 { 204 struct static_call_site *site; 205 struct static_call_key *key, *prev_key = NULL; 206 struct static_call_mod *site_mod; 207 208 if (start == stop) 209 return 0; 210 211 static_call_sort_entries(start, stop); 212 213 for (site = start; site < stop; site++) { 214 void *site_addr = static_call_addr(site); 215 216 if ((mod && within_module_init((unsigned long)site_addr, mod)) || 217 (!mod && init_section_contains(site_addr, 1))) 218 static_call_set_init(site); 219 220 key = static_call_key(site); 221 if (key != prev_key) { 222 prev_key = key; 223 224 /* 225 * For vmlinux (!mod) avoid the allocation by storing 226 * the sites pointer in the key itself. Also see 227 * __static_call_update()'s @first. 228 * 229 * This allows architectures (eg. x86) to call 230 * static_call_init() before memory allocation works. 231 */ 232 if (!mod) { 233 key->sites = site; 234 key->type |= 1; 235 goto do_transform; 236 } 237 238 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 239 if (!site_mod) 240 return -ENOMEM; 241 242 /* 243 * When the key has a direct sites pointer, extract 244 * that into an explicit struct static_call_mod, so we 245 * can have a list of modules. 246 */ 247 if (static_call_key_sites(key)) { 248 site_mod->mod = NULL; 249 site_mod->next = NULL; 250 site_mod->sites = static_call_key_sites(key); 251 252 key->mods = site_mod; 253 254 site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL); 255 if (!site_mod) 256 return -ENOMEM; 257 } 258 259 site_mod->mod = mod; 260 site_mod->sites = site; 261 site_mod->next = static_call_key_next(key); 262 key->mods = site_mod; 263 } 264 265 do_transform: 266 arch_static_call_transform(site_addr, NULL, key->func, 267 static_call_is_tail(site)); 268 } 269 270 return 0; 271 } 272 273 static int addr_conflict(struct static_call_site *site, void *start, void *end) 274 { 275 unsigned long addr = (unsigned long)static_call_addr(site); 276 277 if (addr <= (unsigned long)end && 278 addr + CALL_INSN_SIZE > (unsigned long)start) 279 return 1; 280 281 return 0; 282 } 283 284 static int __static_call_text_reserved(struct static_call_site *iter_start, 285 struct static_call_site *iter_stop, 286 void *start, void *end) 287 { 288 struct static_call_site *iter = iter_start; 289 290 while (iter < iter_stop) { 291 if (addr_conflict(iter, start, end)) 292 return 1; 293 iter++; 294 } 295 296 return 0; 297 } 298 299 #ifdef CONFIG_MODULES 300 301 static int __static_call_mod_text_reserved(void *start, void *end) 302 { 303 struct module *mod; 304 int ret; 305 306 preempt_disable(); 307 mod = __module_text_address((unsigned long)start); 308 WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod); 309 if (!try_module_get(mod)) 310 mod = NULL; 311 preempt_enable(); 312 313 if (!mod) 314 return 0; 315 316 ret = __static_call_text_reserved(mod->static_call_sites, 317 mod->static_call_sites + mod->num_static_call_sites, 318 start, end); 319 320 module_put(mod); 321 322 return ret; 323 } 324 325 static unsigned long tramp_key_lookup(unsigned long addr) 326 { 327 struct static_call_tramp_key *start = __start_static_call_tramp_key; 328 struct static_call_tramp_key *stop = __stop_static_call_tramp_key; 329 struct static_call_tramp_key *tramp_key; 330 331 for (tramp_key = start; tramp_key != stop; tramp_key++) { 332 unsigned long tramp; 333 334 tramp = (long)tramp_key->tramp + (long)&tramp_key->tramp; 335 if (tramp == addr) 336 return (long)tramp_key->key + (long)&tramp_key->key; 337 } 338 339 return 0; 340 } 341 342 static int static_call_add_module(struct module *mod) 343 { 344 struct static_call_site *start = mod->static_call_sites; 345 struct static_call_site *stop = start + mod->num_static_call_sites; 346 struct static_call_site *site; 347 348 for (site = start; site != stop; site++) { 349 unsigned long s_key = __static_call_key(site); 350 unsigned long addr = s_key & ~STATIC_CALL_SITE_FLAGS; 351 unsigned long key; 352 353 /* 354 * Is the key is exported, 'addr' points to the key, which 355 * means modules are allowed to call static_call_update() on 356 * it. 357 * 358 * Otherwise, the key isn't exported, and 'addr' points to the 359 * trampoline so we need to lookup the key. 360 * 361 * We go through this dance to prevent crazy modules from 362 * abusing sensitive static calls. 363 */ 364 if (!kernel_text_address(addr)) 365 continue; 366 367 key = tramp_key_lookup(addr); 368 if (!key) { 369 pr_warn("Failed to fixup __raw_static_call() usage at: %ps\n", 370 static_call_addr(site)); 371 return -EINVAL; 372 } 373 374 key |= s_key & STATIC_CALL_SITE_FLAGS; 375 site->key = key - (long)&site->key; 376 } 377 378 return __static_call_init(mod, start, stop); 379 } 380 381 static void static_call_del_module(struct module *mod) 382 { 383 struct static_call_site *start = mod->static_call_sites; 384 struct static_call_site *stop = mod->static_call_sites + 385 mod->num_static_call_sites; 386 struct static_call_key *key, *prev_key = NULL; 387 struct static_call_mod *site_mod, **prev; 388 struct static_call_site *site; 389 390 for (site = start; site < stop; site++) { 391 key = static_call_key(site); 392 if (key == prev_key) 393 continue; 394 395 prev_key = key; 396 397 for (prev = &key->mods, site_mod = key->mods; 398 site_mod && site_mod->mod != mod; 399 prev = &site_mod->next, site_mod = site_mod->next) 400 ; 401 402 if (!site_mod) 403 continue; 404 405 *prev = site_mod->next; 406 kfree(site_mod); 407 } 408 } 409 410 static int static_call_module_notify(struct notifier_block *nb, 411 unsigned long val, void *data) 412 { 413 struct module *mod = data; 414 int ret = 0; 415 416 cpus_read_lock(); 417 static_call_lock(); 418 419 switch (val) { 420 case MODULE_STATE_COMING: 421 ret = static_call_add_module(mod); 422 if (ret) { 423 WARN(1, "Failed to allocate memory for static calls"); 424 static_call_del_module(mod); 425 } 426 break; 427 case MODULE_STATE_GOING: 428 static_call_del_module(mod); 429 break; 430 } 431 432 static_call_unlock(); 433 cpus_read_unlock(); 434 435 return notifier_from_errno(ret); 436 } 437 438 static struct notifier_block static_call_module_nb = { 439 .notifier_call = static_call_module_notify, 440 }; 441 442 #else 443 444 static inline int __static_call_mod_text_reserved(void *start, void *end) 445 { 446 return 0; 447 } 448 449 #endif /* CONFIG_MODULES */ 450 451 int static_call_text_reserved(void *start, void *end) 452 { 453 int ret = __static_call_text_reserved(__start_static_call_sites, 454 __stop_static_call_sites, start, end); 455 456 if (ret) 457 return ret; 458 459 return __static_call_mod_text_reserved(start, end); 460 } 461 462 int __init static_call_init(void) 463 { 464 int ret; 465 466 if (static_call_initialized) 467 return 0; 468 469 cpus_read_lock(); 470 static_call_lock(); 471 ret = __static_call_init(NULL, __start_static_call_sites, 472 __stop_static_call_sites); 473 static_call_unlock(); 474 cpus_read_unlock(); 475 476 if (ret) { 477 pr_err("Failed to allocate memory for static_call!\n"); 478 BUG(); 479 } 480 481 static_call_initialized = true; 482 483 #ifdef CONFIG_MODULES 484 register_module_notifier(&static_call_module_nb); 485 #endif 486 return 0; 487 } 488 early_initcall(static_call_init); 489 490 long __static_call_return0(void) 491 { 492 return 0; 493 } 494 495 #ifdef CONFIG_STATIC_CALL_SELFTEST 496 497 static int func_a(int x) 498 { 499 return x+1; 500 } 501 502 static int func_b(int x) 503 { 504 return x+2; 505 } 506 507 DEFINE_STATIC_CALL(sc_selftest, func_a); 508 509 static struct static_call_data { 510 int (*func)(int); 511 int val; 512 int expect; 513 } static_call_data [] __initdata = { 514 { NULL, 2, 3 }, 515 { func_b, 2, 4 }, 516 { func_a, 2, 3 } 517 }; 518 519 static int __init test_static_call_init(void) 520 { 521 int i; 522 523 for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) { 524 struct static_call_data *scd = &static_call_data[i]; 525 526 if (scd->func) 527 static_call_update(sc_selftest, scd->func); 528 529 WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect); 530 } 531 532 return 0; 533 } 534 early_initcall(test_static_call_init); 535 536 #endif /* CONFIG_STATIC_CALL_SELFTEST */ 537