xref: /linux-6.15/kernel/static_call.c (revision 5b06fd3b)
1 // SPDX-License-Identifier: GPL-2.0
2 #include <linux/init.h>
3 #include <linux/static_call.h>
4 #include <linux/bug.h>
5 #include <linux/smp.h>
6 #include <linux/sort.h>
7 #include <linux/slab.h>
8 #include <linux/module.h>
9 #include <linux/cpu.h>
10 #include <linux/processor.h>
11 #include <asm/sections.h>
12 
13 extern struct static_call_site __start_static_call_sites[],
14 			       __stop_static_call_sites[];
15 
16 static bool static_call_initialized;
17 
18 /* mutex to protect key modules/sites */
19 static DEFINE_MUTEX(static_call_mutex);
20 
21 static void static_call_lock(void)
22 {
23 	mutex_lock(&static_call_mutex);
24 }
25 
26 static void static_call_unlock(void)
27 {
28 	mutex_unlock(&static_call_mutex);
29 }
30 
31 static inline void *static_call_addr(struct static_call_site *site)
32 {
33 	return (void *)((long)site->addr + (long)&site->addr);
34 }
35 
36 
37 static inline struct static_call_key *static_call_key(const struct static_call_site *site)
38 {
39 	return (struct static_call_key *)
40 		(((long)site->key + (long)&site->key) & ~STATIC_CALL_SITE_FLAGS);
41 }
42 
43 /* These assume the key is word-aligned. */
44 static inline bool static_call_is_init(struct static_call_site *site)
45 {
46 	return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_INIT;
47 }
48 
49 static inline bool static_call_is_tail(struct static_call_site *site)
50 {
51 	return ((long)site->key + (long)&site->key) & STATIC_CALL_SITE_TAIL;
52 }
53 
54 static inline void static_call_set_init(struct static_call_site *site)
55 {
56 	site->key = ((long)static_call_key(site) | STATIC_CALL_SITE_INIT) -
57 		    (long)&site->key;
58 }
59 
60 static int static_call_site_cmp(const void *_a, const void *_b)
61 {
62 	const struct static_call_site *a = _a;
63 	const struct static_call_site *b = _b;
64 	const struct static_call_key *key_a = static_call_key(a);
65 	const struct static_call_key *key_b = static_call_key(b);
66 
67 	if (key_a < key_b)
68 		return -1;
69 
70 	if (key_a > key_b)
71 		return 1;
72 
73 	return 0;
74 }
75 
76 static void static_call_site_swap(void *_a, void *_b, int size)
77 {
78 	long delta = (unsigned long)_a - (unsigned long)_b;
79 	struct static_call_site *a = _a;
80 	struct static_call_site *b = _b;
81 	struct static_call_site tmp = *a;
82 
83 	a->addr = b->addr  - delta;
84 	a->key  = b->key   - delta;
85 
86 	b->addr = tmp.addr + delta;
87 	b->key  = tmp.key  + delta;
88 }
89 
90 static inline void static_call_sort_entries(struct static_call_site *start,
91 					    struct static_call_site *stop)
92 {
93 	sort(start, stop - start, sizeof(struct static_call_site),
94 	     static_call_site_cmp, static_call_site_swap);
95 }
96 
97 void __static_call_update(struct static_call_key *key, void *tramp, void *func)
98 {
99 	struct static_call_site *site, *stop;
100 	struct static_call_mod *site_mod;
101 
102 	cpus_read_lock();
103 	static_call_lock();
104 
105 	if (key->func == func)
106 		goto done;
107 
108 	key->func = func;
109 
110 	arch_static_call_transform(NULL, tramp, func, false);
111 
112 	/*
113 	 * If uninitialized, we'll not update the callsites, but they still
114 	 * point to the trampoline and we just patched that.
115 	 */
116 	if (WARN_ON_ONCE(!static_call_initialized))
117 		goto done;
118 
119 	for (site_mod = key->mods; site_mod; site_mod = site_mod->next) {
120 		struct module *mod = site_mod->mod;
121 
122 		if (!site_mod->sites) {
123 			/*
124 			 * This can happen if the static call key is defined in
125 			 * a module which doesn't use it.
126 			 */
127 			continue;
128 		}
129 
130 		stop = __stop_static_call_sites;
131 
132 #ifdef CONFIG_MODULES
133 		if (mod) {
134 			stop = mod->static_call_sites +
135 			       mod->num_static_call_sites;
136 		}
137 #endif
138 
139 		for (site = site_mod->sites;
140 		     site < stop && static_call_key(site) == key; site++) {
141 			void *site_addr = static_call_addr(site);
142 
143 			if (static_call_is_init(site)) {
144 				/*
145 				 * Don't write to call sites which were in
146 				 * initmem and have since been freed.
147 				 */
148 				if (!mod && system_state >= SYSTEM_RUNNING)
149 					continue;
150 				if (mod && !within_module_init((unsigned long)site_addr, mod))
151 					continue;
152 			}
153 
154 			if (!kernel_text_address((unsigned long)site_addr)) {
155 				WARN_ONCE(1, "can't patch static call site at %pS",
156 					  site_addr);
157 				continue;
158 			}
159 
160 			arch_static_call_transform(site_addr, NULL, func,
161 				static_call_is_tail(site));
162 		}
163 	}
164 
165 done:
166 	static_call_unlock();
167 	cpus_read_unlock();
168 }
169 EXPORT_SYMBOL_GPL(__static_call_update);
170 
171 static int __static_call_init(struct module *mod,
172 			      struct static_call_site *start,
173 			      struct static_call_site *stop)
174 {
175 	struct static_call_site *site;
176 	struct static_call_key *key, *prev_key = NULL;
177 	struct static_call_mod *site_mod;
178 
179 	if (start == stop)
180 		return 0;
181 
182 	static_call_sort_entries(start, stop);
183 
184 	for (site = start; site < stop; site++) {
185 		void *site_addr = static_call_addr(site);
186 
187 		if ((mod && within_module_init((unsigned long)site_addr, mod)) ||
188 		    (!mod && init_section_contains(site_addr, 1)))
189 			static_call_set_init(site);
190 
191 		key = static_call_key(site);
192 		if (key != prev_key) {
193 			prev_key = key;
194 
195 			site_mod = kzalloc(sizeof(*site_mod), GFP_KERNEL);
196 			if (!site_mod)
197 				return -ENOMEM;
198 
199 			site_mod->mod = mod;
200 			site_mod->sites = site;
201 			site_mod->next = key->mods;
202 			key->mods = site_mod;
203 		}
204 
205 		arch_static_call_transform(site_addr, NULL, key->func,
206 				static_call_is_tail(site));
207 	}
208 
209 	return 0;
210 }
211 
212 static int addr_conflict(struct static_call_site *site, void *start, void *end)
213 {
214 	unsigned long addr = (unsigned long)static_call_addr(site);
215 
216 	if (addr <= (unsigned long)end &&
217 	    addr + CALL_INSN_SIZE > (unsigned long)start)
218 		return 1;
219 
220 	return 0;
221 }
222 
223 static int __static_call_text_reserved(struct static_call_site *iter_start,
224 				       struct static_call_site *iter_stop,
225 				       void *start, void *end)
226 {
227 	struct static_call_site *iter = iter_start;
228 
229 	while (iter < iter_stop) {
230 		if (addr_conflict(iter, start, end))
231 			return 1;
232 		iter++;
233 	}
234 
235 	return 0;
236 }
237 
238 #ifdef CONFIG_MODULES
239 
240 static int __static_call_mod_text_reserved(void *start, void *end)
241 {
242 	struct module *mod;
243 	int ret;
244 
245 	preempt_disable();
246 	mod = __module_text_address((unsigned long)start);
247 	WARN_ON_ONCE(__module_text_address((unsigned long)end) != mod);
248 	if (!try_module_get(mod))
249 		mod = NULL;
250 	preempt_enable();
251 
252 	if (!mod)
253 		return 0;
254 
255 	ret = __static_call_text_reserved(mod->static_call_sites,
256 			mod->static_call_sites + mod->num_static_call_sites,
257 			start, end);
258 
259 	module_put(mod);
260 
261 	return ret;
262 }
263 
264 static int static_call_add_module(struct module *mod)
265 {
266 	return __static_call_init(mod, mod->static_call_sites,
267 				  mod->static_call_sites + mod->num_static_call_sites);
268 }
269 
270 static void static_call_del_module(struct module *mod)
271 {
272 	struct static_call_site *start = mod->static_call_sites;
273 	struct static_call_site *stop = mod->static_call_sites +
274 					mod->num_static_call_sites;
275 	struct static_call_key *key, *prev_key = NULL;
276 	struct static_call_mod *site_mod, **prev;
277 	struct static_call_site *site;
278 
279 	for (site = start; site < stop; site++) {
280 		key = static_call_key(site);
281 		if (key == prev_key)
282 			continue;
283 
284 		prev_key = key;
285 
286 		for (prev = &key->mods, site_mod = key->mods;
287 		     site_mod && site_mod->mod != mod;
288 		     prev = &site_mod->next, site_mod = site_mod->next)
289 			;
290 
291 		if (!site_mod)
292 			continue;
293 
294 		*prev = site_mod->next;
295 		kfree(site_mod);
296 	}
297 }
298 
299 static int static_call_module_notify(struct notifier_block *nb,
300 				     unsigned long val, void *data)
301 {
302 	struct module *mod = data;
303 	int ret = 0;
304 
305 	cpus_read_lock();
306 	static_call_lock();
307 
308 	switch (val) {
309 	case MODULE_STATE_COMING:
310 		ret = static_call_add_module(mod);
311 		if (ret) {
312 			WARN(1, "Failed to allocate memory for static calls");
313 			static_call_del_module(mod);
314 		}
315 		break;
316 	case MODULE_STATE_GOING:
317 		static_call_del_module(mod);
318 		break;
319 	}
320 
321 	static_call_unlock();
322 	cpus_read_unlock();
323 
324 	return notifier_from_errno(ret);
325 }
326 
327 static struct notifier_block static_call_module_nb = {
328 	.notifier_call = static_call_module_notify,
329 };
330 
331 #else
332 
333 static inline int __static_call_mod_text_reserved(void *start, void *end)
334 {
335 	return 0;
336 }
337 
338 #endif /* CONFIG_MODULES */
339 
340 int static_call_text_reserved(void *start, void *end)
341 {
342 	int ret = __static_call_text_reserved(__start_static_call_sites,
343 			__stop_static_call_sites, start, end);
344 
345 	if (ret)
346 		return ret;
347 
348 	return __static_call_mod_text_reserved(start, end);
349 }
350 
351 static void __init static_call_init(void)
352 {
353 	int ret;
354 
355 	if (static_call_initialized)
356 		return;
357 
358 	cpus_read_lock();
359 	static_call_lock();
360 	ret = __static_call_init(NULL, __start_static_call_sites,
361 				 __stop_static_call_sites);
362 	static_call_unlock();
363 	cpus_read_unlock();
364 
365 	if (ret) {
366 		pr_err("Failed to allocate memory for static_call!\n");
367 		BUG();
368 	}
369 
370 	static_call_initialized = true;
371 
372 #ifdef CONFIG_MODULES
373 	register_module_notifier(&static_call_module_nb);
374 #endif
375 }
376 early_initcall(static_call_init);
377 
378 #ifdef CONFIG_STATIC_CALL_SELFTEST
379 
380 static int func_a(int x)
381 {
382 	return x+1;
383 }
384 
385 static int func_b(int x)
386 {
387 	return x+2;
388 }
389 
390 DEFINE_STATIC_CALL(sc_selftest, func_a);
391 
392 static struct static_call_data {
393       int (*func)(int);
394       int val;
395       int expect;
396 } static_call_data [] __initdata = {
397       { NULL,   2, 3 },
398       { func_b, 2, 4 },
399       { func_a, 2, 3 }
400 };
401 
402 static int __init test_static_call_init(void)
403 {
404       int i;
405 
406       for (i = 0; i < ARRAY_SIZE(static_call_data); i++ ) {
407 	      struct static_call_data *scd = &static_call_data[i];
408 
409               if (scd->func)
410                       static_call_update(sc_selftest, scd->func);
411 
412               WARN_ON(static_call(sc_selftest)(scd->val) != scd->expect);
413       }
414 
415       return 0;
416 }
417 early_initcall(test_static_call_init);
418 
419 #endif /* CONFIG_STATIC_CALL_SELFTEST */
420