xref: /linux-6.15/drivers/misc/ntsync.c (revision ecc2ee36)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * ntsync.c - Kernel driver for NT synchronization primitives
4  *
5  * Copyright (C) 2024 Elizabeth Figura <[email protected]>
6  */
7 
8 #include <linux/anon_inodes.h>
9 #include <linux/atomic.h>
10 #include <linux/file.h>
11 #include <linux/fs.h>
12 #include <linux/hrtimer.h>
13 #include <linux/ktime.h>
14 #include <linux/miscdevice.h>
15 #include <linux/module.h>
16 #include <linux/mutex.h>
17 #include <linux/overflow.h>
18 #include <linux/sched.h>
19 #include <linux/sched/signal.h>
20 #include <linux/slab.h>
21 #include <linux/spinlock.h>
22 #include <uapi/linux/ntsync.h>
23 
24 #define NTSYNC_NAME	"ntsync"
25 
26 enum ntsync_type {
27 	NTSYNC_TYPE_SEM,
28 	NTSYNC_TYPE_MUTEX,
29 };
30 
31 /*
32  * Individual synchronization primitives are represented by
33  * struct ntsync_obj, and each primitive is backed by a file.
34  *
35  * The whole namespace is represented by a struct ntsync_device also
36  * backed by a file.
37  *
38  * Both rely on struct file for reference counting. Individual
39  * ntsync_obj objects take a reference to the device when created.
40  * Wait operations take a reference to each object being waited on for
41  * the duration of the wait.
42  */
43 
44 struct ntsync_obj {
45 	spinlock_t lock;
46 	int dev_locked;
47 
48 	enum ntsync_type type;
49 
50 	struct file *file;
51 	struct ntsync_device *dev;
52 
53 	/* The following fields are protected by the object lock. */
54 	union {
55 		struct {
56 			__u32 count;
57 			__u32 max;
58 		} sem;
59 		struct {
60 			__u32 count;
61 			pid_t owner;
62 			bool ownerdead;
63 		} mutex;
64 	} u;
65 
66 	/*
67 	 * any_waiters is protected by the object lock, but all_waiters is
68 	 * protected by the device wait_all_lock.
69 	 */
70 	struct list_head any_waiters;
71 	struct list_head all_waiters;
72 
73 	/*
74 	 * Hint describing how many tasks are queued on this object in a
75 	 * wait-all operation.
76 	 *
77 	 * Any time we do a wake, we may need to wake "all" waiters as well as
78 	 * "any" waiters. In order to atomically wake "all" waiters, we must
79 	 * lock all of the objects, and that means grabbing the wait_all_lock
80 	 * below (and, due to lock ordering rules, before locking this object).
81 	 * However, wait-all is a rare operation, and grabbing the wait-all
82 	 * lock for every wake would create unnecessary contention.
83 	 * Therefore we first check whether all_hint is zero, and, if it is,
84 	 * we skip trying to wake "all" waiters.
85 	 *
86 	 * Since wait requests must originate from user-space threads, we're
87 	 * limited here by PID_MAX_LIMIT, so there's no risk of overflow.
88 	 */
89 	atomic_t all_hint;
90 };
91 
92 struct ntsync_q_entry {
93 	struct list_head node;
94 	struct ntsync_q *q;
95 	struct ntsync_obj *obj;
96 	__u32 index;
97 };
98 
99 struct ntsync_q {
100 	struct task_struct *task;
101 	__u32 owner;
102 
103 	/*
104 	 * Protected via atomic_try_cmpxchg(). Only the thread that wins the
105 	 * compare-and-swap may actually change object states and wake this
106 	 * task.
107 	 */
108 	atomic_t signaled;
109 
110 	bool all;
111 	bool ownerdead;
112 	__u32 count;
113 	struct ntsync_q_entry entries[];
114 };
115 
116 struct ntsync_device {
117 	/*
118 	 * Wait-all operations must atomically grab all objects, and be totally
119 	 * ordered with respect to each other and wait-any operations.
120 	 * If one thread is trying to acquire several objects, another thread
121 	 * cannot touch the object at the same time.
122 	 *
123 	 * This device-wide lock is used to serialize wait-for-all
124 	 * operations, and operations on an object that is involved in a
125 	 * wait-for-all.
126 	 */
127 	struct mutex wait_all_lock;
128 
129 	struct file *file;
130 };
131 
132 /*
133  * Single objects are locked using obj->lock.
134  *
135  * Multiple objects are 'locked' while holding dev->wait_all_lock.
136  * In this case however, individual objects are not locked by holding
137  * obj->lock, but by setting obj->dev_locked.
138  *
139  * This means that in order to lock a single object, the sequence is slightly
140  * more complicated than usual. Specifically it needs to check obj->dev_locked
141  * after acquiring obj->lock, if set, it needs to drop the lock and acquire
142  * dev->wait_all_lock in order to serialize against the multi-object operation.
143  */
144 
145 static void dev_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
146 {
147 	lockdep_assert_held(&dev->wait_all_lock);
148 	lockdep_assert(obj->dev == dev);
149 	spin_lock(&obj->lock);
150 	/*
151 	 * By setting obj->dev_locked inside obj->lock, it is ensured that
152 	 * anyone holding obj->lock must see the value.
153 	 */
154 	obj->dev_locked = 1;
155 	spin_unlock(&obj->lock);
156 }
157 
158 static void dev_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
159 {
160 	lockdep_assert_held(&dev->wait_all_lock);
161 	lockdep_assert(obj->dev == dev);
162 	spin_lock(&obj->lock);
163 	obj->dev_locked = 0;
164 	spin_unlock(&obj->lock);
165 }
166 
167 static void obj_lock(struct ntsync_obj *obj)
168 {
169 	struct ntsync_device *dev = obj->dev;
170 
171 	for (;;) {
172 		spin_lock(&obj->lock);
173 		if (likely(!obj->dev_locked))
174 			break;
175 
176 		spin_unlock(&obj->lock);
177 		mutex_lock(&dev->wait_all_lock);
178 		spin_lock(&obj->lock);
179 		/*
180 		 * obj->dev_locked should be set and released under the same
181 		 * wait_all_lock section, since we now own this lock, it should
182 		 * be clear.
183 		 */
184 		lockdep_assert(!obj->dev_locked);
185 		spin_unlock(&obj->lock);
186 		mutex_unlock(&dev->wait_all_lock);
187 	}
188 }
189 
190 static void obj_unlock(struct ntsync_obj *obj)
191 {
192 	spin_unlock(&obj->lock);
193 }
194 
195 static bool ntsync_lock_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
196 {
197 	bool all;
198 
199 	obj_lock(obj);
200 	all = atomic_read(&obj->all_hint);
201 	if (unlikely(all)) {
202 		obj_unlock(obj);
203 		mutex_lock(&dev->wait_all_lock);
204 		dev_lock_obj(dev, obj);
205 	}
206 
207 	return all;
208 }
209 
210 static void ntsync_unlock_obj(struct ntsync_device *dev, struct ntsync_obj *obj, bool all)
211 {
212 	if (all) {
213 		dev_unlock_obj(dev, obj);
214 		mutex_unlock(&dev->wait_all_lock);
215 	} else {
216 		obj_unlock(obj);
217 	}
218 }
219 
220 #define ntsync_assert_held(obj) \
221 	lockdep_assert((lockdep_is_held(&(obj)->lock) != LOCK_STATE_NOT_HELD) || \
222 		       ((lockdep_is_held(&(obj)->dev->wait_all_lock) != LOCK_STATE_NOT_HELD) && \
223 			(obj)->dev_locked))
224 
225 static bool is_signaled(struct ntsync_obj *obj, __u32 owner)
226 {
227 	ntsync_assert_held(obj);
228 
229 	switch (obj->type) {
230 	case NTSYNC_TYPE_SEM:
231 		return !!obj->u.sem.count;
232 	case NTSYNC_TYPE_MUTEX:
233 		if (obj->u.mutex.owner && obj->u.mutex.owner != owner)
234 			return false;
235 		return obj->u.mutex.count < UINT_MAX;
236 	}
237 
238 	WARN(1, "bad object type %#x\n", obj->type);
239 	return false;
240 }
241 
242 /*
243  * "locked_obj" is an optional pointer to an object which is already locked and
244  * should not be locked again. This is necessary so that changing an object's
245  * state and waking it can be a single atomic operation.
246  */
247 static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q,
248 			 struct ntsync_obj *locked_obj)
249 {
250 	__u32 count = q->count;
251 	bool can_wake = true;
252 	int signaled = -1;
253 	__u32 i;
254 
255 	lockdep_assert_held(&dev->wait_all_lock);
256 	if (locked_obj)
257 		lockdep_assert(locked_obj->dev_locked);
258 
259 	for (i = 0; i < count; i++) {
260 		if (q->entries[i].obj != locked_obj)
261 			dev_lock_obj(dev, q->entries[i].obj);
262 	}
263 
264 	for (i = 0; i < count; i++) {
265 		if (!is_signaled(q->entries[i].obj, q->owner)) {
266 			can_wake = false;
267 			break;
268 		}
269 	}
270 
271 	if (can_wake && atomic_try_cmpxchg(&q->signaled, &signaled, 0)) {
272 		for (i = 0; i < count; i++) {
273 			struct ntsync_obj *obj = q->entries[i].obj;
274 
275 			switch (obj->type) {
276 			case NTSYNC_TYPE_SEM:
277 				obj->u.sem.count--;
278 				break;
279 			case NTSYNC_TYPE_MUTEX:
280 				if (obj->u.mutex.ownerdead)
281 					q->ownerdead = true;
282 				obj->u.mutex.ownerdead = false;
283 				obj->u.mutex.count++;
284 				obj->u.mutex.owner = q->owner;
285 				break;
286 			}
287 		}
288 		wake_up_process(q->task);
289 	}
290 
291 	for (i = 0; i < count; i++) {
292 		if (q->entries[i].obj != locked_obj)
293 			dev_unlock_obj(dev, q->entries[i].obj);
294 	}
295 }
296 
297 static void try_wake_all_obj(struct ntsync_device *dev, struct ntsync_obj *obj)
298 {
299 	struct ntsync_q_entry *entry;
300 
301 	lockdep_assert_held(&dev->wait_all_lock);
302 	lockdep_assert(obj->dev_locked);
303 
304 	list_for_each_entry(entry, &obj->all_waiters, node)
305 		try_wake_all(dev, entry->q, obj);
306 }
307 
308 static void try_wake_any_sem(struct ntsync_obj *sem)
309 {
310 	struct ntsync_q_entry *entry;
311 
312 	ntsync_assert_held(sem);
313 	lockdep_assert(sem->type == NTSYNC_TYPE_SEM);
314 
315 	list_for_each_entry(entry, &sem->any_waiters, node) {
316 		struct ntsync_q *q = entry->q;
317 		int signaled = -1;
318 
319 		if (!sem->u.sem.count)
320 			break;
321 
322 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
323 			sem->u.sem.count--;
324 			wake_up_process(q->task);
325 		}
326 	}
327 }
328 
329 static void try_wake_any_mutex(struct ntsync_obj *mutex)
330 {
331 	struct ntsync_q_entry *entry;
332 
333 	ntsync_assert_held(mutex);
334 	lockdep_assert(mutex->type == NTSYNC_TYPE_MUTEX);
335 
336 	list_for_each_entry(entry, &mutex->any_waiters, node) {
337 		struct ntsync_q *q = entry->q;
338 		int signaled = -1;
339 
340 		if (mutex->u.mutex.count == UINT_MAX)
341 			break;
342 		if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner)
343 			continue;
344 
345 		if (atomic_try_cmpxchg(&q->signaled, &signaled, entry->index)) {
346 			if (mutex->u.mutex.ownerdead)
347 				q->ownerdead = true;
348 			mutex->u.mutex.ownerdead = false;
349 			mutex->u.mutex.count++;
350 			mutex->u.mutex.owner = q->owner;
351 			wake_up_process(q->task);
352 		}
353 	}
354 }
355 
356 /*
357  * Actually change the semaphore state, returning -EOVERFLOW if it is made
358  * invalid.
359  */
360 static int release_sem_state(struct ntsync_obj *sem, __u32 count)
361 {
362 	__u32 sum;
363 
364 	ntsync_assert_held(sem);
365 
366 	if (check_add_overflow(sem->u.sem.count, count, &sum) ||
367 	    sum > sem->u.sem.max)
368 		return -EOVERFLOW;
369 
370 	sem->u.sem.count = sum;
371 	return 0;
372 }
373 
374 static int ntsync_sem_release(struct ntsync_obj *sem, void __user *argp)
375 {
376 	struct ntsync_device *dev = sem->dev;
377 	__u32 __user *user_args = argp;
378 	__u32 prev_count;
379 	__u32 args;
380 	bool all;
381 	int ret;
382 
383 	if (copy_from_user(&args, argp, sizeof(args)))
384 		return -EFAULT;
385 
386 	if (sem->type != NTSYNC_TYPE_SEM)
387 		return -EINVAL;
388 
389 	all = ntsync_lock_obj(dev, sem);
390 
391 	prev_count = sem->u.sem.count;
392 	ret = release_sem_state(sem, args);
393 	if (!ret) {
394 		if (all)
395 			try_wake_all_obj(dev, sem);
396 		try_wake_any_sem(sem);
397 	}
398 
399 	ntsync_unlock_obj(dev, sem, all);
400 
401 	if (!ret && put_user(prev_count, user_args))
402 		ret = -EFAULT;
403 
404 	return ret;
405 }
406 
407 /*
408  * Actually change the mutex state, returning -EPERM if not the owner.
409  */
410 static int unlock_mutex_state(struct ntsync_obj *mutex,
411 			      const struct ntsync_mutex_args *args)
412 {
413 	ntsync_assert_held(mutex);
414 
415 	if (mutex->u.mutex.owner != args->owner)
416 		return -EPERM;
417 
418 	if (!--mutex->u.mutex.count)
419 		mutex->u.mutex.owner = 0;
420 	return 0;
421 }
422 
423 static int ntsync_mutex_unlock(struct ntsync_obj *mutex, void __user *argp)
424 {
425 	struct ntsync_mutex_args __user *user_args = argp;
426 	struct ntsync_device *dev = mutex->dev;
427 	struct ntsync_mutex_args args;
428 	__u32 prev_count;
429 	bool all;
430 	int ret;
431 
432 	if (copy_from_user(&args, argp, sizeof(args)))
433 		return -EFAULT;
434 	if (!args.owner)
435 		return -EINVAL;
436 
437 	if (mutex->type != NTSYNC_TYPE_MUTEX)
438 		return -EINVAL;
439 
440 	all = ntsync_lock_obj(dev, mutex);
441 
442 	prev_count = mutex->u.mutex.count;
443 	ret = unlock_mutex_state(mutex, &args);
444 	if (!ret) {
445 		if (all)
446 			try_wake_all_obj(dev, mutex);
447 		try_wake_any_mutex(mutex);
448 	}
449 
450 	ntsync_unlock_obj(dev, mutex, all);
451 
452 	if (!ret && put_user(prev_count, &user_args->count))
453 		ret = -EFAULT;
454 
455 	return ret;
456 }
457 
458 /*
459  * Actually change the mutex state to mark its owner as dead,
460  * returning -EPERM if not the owner.
461  */
462 static int kill_mutex_state(struct ntsync_obj *mutex, __u32 owner)
463 {
464 	ntsync_assert_held(mutex);
465 
466 	if (mutex->u.mutex.owner != owner)
467 		return -EPERM;
468 
469 	mutex->u.mutex.ownerdead = true;
470 	mutex->u.mutex.owner = 0;
471 	mutex->u.mutex.count = 0;
472 	return 0;
473 }
474 
475 static int ntsync_mutex_kill(struct ntsync_obj *mutex, void __user *argp)
476 {
477 	struct ntsync_device *dev = mutex->dev;
478 	__u32 owner;
479 	bool all;
480 	int ret;
481 
482 	if (get_user(owner, (__u32 __user *)argp))
483 		return -EFAULT;
484 	if (!owner)
485 		return -EINVAL;
486 
487 	if (mutex->type != NTSYNC_TYPE_MUTEX)
488 		return -EINVAL;
489 
490 	all = ntsync_lock_obj(dev, mutex);
491 
492 	ret = kill_mutex_state(mutex, owner);
493 	if (!ret) {
494 		if (all)
495 			try_wake_all_obj(dev, mutex);
496 		try_wake_any_mutex(mutex);
497 	}
498 
499 	ntsync_unlock_obj(dev, mutex, all);
500 
501 	return ret;
502 }
503 
504 static int ntsync_obj_release(struct inode *inode, struct file *file)
505 {
506 	struct ntsync_obj *obj = file->private_data;
507 
508 	fput(obj->dev->file);
509 	kfree(obj);
510 
511 	return 0;
512 }
513 
514 static long ntsync_obj_ioctl(struct file *file, unsigned int cmd,
515 			     unsigned long parm)
516 {
517 	struct ntsync_obj *obj = file->private_data;
518 	void __user *argp = (void __user *)parm;
519 
520 	switch (cmd) {
521 	case NTSYNC_IOC_SEM_RELEASE:
522 		return ntsync_sem_release(obj, argp);
523 	case NTSYNC_IOC_MUTEX_UNLOCK:
524 		return ntsync_mutex_unlock(obj, argp);
525 	case NTSYNC_IOC_MUTEX_KILL:
526 		return ntsync_mutex_kill(obj, argp);
527 	default:
528 		return -ENOIOCTLCMD;
529 	}
530 }
531 
532 static const struct file_operations ntsync_obj_fops = {
533 	.owner		= THIS_MODULE,
534 	.release	= ntsync_obj_release,
535 	.unlocked_ioctl	= ntsync_obj_ioctl,
536 	.compat_ioctl	= compat_ptr_ioctl,
537 };
538 
539 static struct ntsync_obj *ntsync_alloc_obj(struct ntsync_device *dev,
540 					   enum ntsync_type type)
541 {
542 	struct ntsync_obj *obj;
543 
544 	obj = kzalloc(sizeof(*obj), GFP_KERNEL);
545 	if (!obj)
546 		return NULL;
547 	obj->type = type;
548 	obj->dev = dev;
549 	get_file(dev->file);
550 	spin_lock_init(&obj->lock);
551 	INIT_LIST_HEAD(&obj->any_waiters);
552 	INIT_LIST_HEAD(&obj->all_waiters);
553 	atomic_set(&obj->all_hint, 0);
554 
555 	return obj;
556 }
557 
558 static int ntsync_obj_get_fd(struct ntsync_obj *obj)
559 {
560 	struct file *file;
561 	int fd;
562 
563 	fd = get_unused_fd_flags(O_CLOEXEC);
564 	if (fd < 0)
565 		return fd;
566 	file = anon_inode_getfile("ntsync", &ntsync_obj_fops, obj, O_RDWR);
567 	if (IS_ERR(file)) {
568 		put_unused_fd(fd);
569 		return PTR_ERR(file);
570 	}
571 	obj->file = file;
572 	fd_install(fd, file);
573 
574 	return fd;
575 }
576 
577 static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
578 {
579 	struct ntsync_sem_args args;
580 	struct ntsync_obj *sem;
581 	int fd;
582 
583 	if (copy_from_user(&args, argp, sizeof(args)))
584 		return -EFAULT;
585 
586 	if (args.count > args.max)
587 		return -EINVAL;
588 
589 	sem = ntsync_alloc_obj(dev, NTSYNC_TYPE_SEM);
590 	if (!sem)
591 		return -ENOMEM;
592 	sem->u.sem.count = args.count;
593 	sem->u.sem.max = args.max;
594 	fd = ntsync_obj_get_fd(sem);
595 	if (fd < 0)
596 		kfree(sem);
597 
598 	return fd;
599 }
600 
601 static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp)
602 {
603 	struct ntsync_mutex_args args;
604 	struct ntsync_obj *mutex;
605 	int fd;
606 
607 	if (copy_from_user(&args, argp, sizeof(args)))
608 		return -EFAULT;
609 
610 	if (!args.owner != !args.count)
611 		return -EINVAL;
612 
613 	mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX);
614 	if (!mutex)
615 		return -ENOMEM;
616 	mutex->u.mutex.count = args.count;
617 	mutex->u.mutex.owner = args.owner;
618 	fd = ntsync_obj_get_fd(mutex);
619 	if (fd < 0)
620 		kfree(mutex);
621 
622 	return fd;
623 }
624 
625 static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd)
626 {
627 	struct file *file = fget(fd);
628 	struct ntsync_obj *obj;
629 
630 	if (!file)
631 		return NULL;
632 
633 	if (file->f_op != &ntsync_obj_fops) {
634 		fput(file);
635 		return NULL;
636 	}
637 
638 	obj = file->private_data;
639 	if (obj->dev != dev) {
640 		fput(file);
641 		return NULL;
642 	}
643 
644 	return obj;
645 }
646 
647 static void put_obj(struct ntsync_obj *obj)
648 {
649 	fput(obj->file);
650 }
651 
652 static int ntsync_schedule(const struct ntsync_q *q, const struct ntsync_wait_args *args)
653 {
654 	ktime_t timeout = ns_to_ktime(args->timeout);
655 	clockid_t clock = CLOCK_MONOTONIC;
656 	ktime_t *timeout_ptr;
657 	int ret = 0;
658 
659 	timeout_ptr = (args->timeout == U64_MAX ? NULL : &timeout);
660 
661 	if (args->flags & NTSYNC_WAIT_REALTIME)
662 		clock = CLOCK_REALTIME;
663 
664 	do {
665 		if (signal_pending(current)) {
666 			ret = -ERESTARTSYS;
667 			break;
668 		}
669 
670 		set_current_state(TASK_INTERRUPTIBLE);
671 		if (atomic_read(&q->signaled) != -1) {
672 			ret = 0;
673 			break;
674 		}
675 		ret = schedule_hrtimeout_range_clock(timeout_ptr, 0, HRTIMER_MODE_ABS, clock);
676 	} while (ret < 0);
677 	__set_current_state(TASK_RUNNING);
678 
679 	return ret;
680 }
681 
682 /*
683  * Allocate and initialize the ntsync_q structure, but do not queue us yet.
684  */
685 static int setup_wait(struct ntsync_device *dev,
686 		      const struct ntsync_wait_args *args, bool all,
687 		      struct ntsync_q **ret_q)
688 {
689 	const __u32 count = args->count;
690 	int fds[NTSYNC_MAX_WAIT_COUNT];
691 	struct ntsync_q *q;
692 	__u32 i, j;
693 
694 	if (args->pad[0] || args->pad[1] || (args->flags & ~NTSYNC_WAIT_REALTIME))
695 		return -EINVAL;
696 
697 	if (args->count > NTSYNC_MAX_WAIT_COUNT)
698 		return -EINVAL;
699 
700 	if (copy_from_user(fds, u64_to_user_ptr(args->objs),
701 			   array_size(count, sizeof(*fds))))
702 		return -EFAULT;
703 
704 	q = kmalloc(struct_size(q, entries, count), GFP_KERNEL);
705 	if (!q)
706 		return -ENOMEM;
707 	q->task = current;
708 	q->owner = args->owner;
709 	atomic_set(&q->signaled, -1);
710 	q->all = all;
711 	q->ownerdead = false;
712 	q->count = count;
713 
714 	for (i = 0; i < count; i++) {
715 		struct ntsync_q_entry *entry = &q->entries[i];
716 		struct ntsync_obj *obj = get_obj(dev, fds[i]);
717 
718 		if (!obj)
719 			goto err;
720 
721 		if (all) {
722 			/* Check that the objects are all distinct. */
723 			for (j = 0; j < i; j++) {
724 				if (obj == q->entries[j].obj) {
725 					put_obj(obj);
726 					goto err;
727 				}
728 			}
729 		}
730 
731 		entry->obj = obj;
732 		entry->q = q;
733 		entry->index = i;
734 	}
735 
736 	*ret_q = q;
737 	return 0;
738 
739 err:
740 	for (j = 0; j < i; j++)
741 		put_obj(q->entries[j].obj);
742 	kfree(q);
743 	return -EINVAL;
744 }
745 
746 static void try_wake_any_obj(struct ntsync_obj *obj)
747 {
748 	switch (obj->type) {
749 	case NTSYNC_TYPE_SEM:
750 		try_wake_any_sem(obj);
751 		break;
752 	case NTSYNC_TYPE_MUTEX:
753 		try_wake_any_mutex(obj);
754 		break;
755 	}
756 }
757 
758 static int ntsync_wait_any(struct ntsync_device *dev, void __user *argp)
759 {
760 	struct ntsync_wait_args args;
761 	struct ntsync_q *q;
762 	int signaled;
763 	bool all;
764 	__u32 i;
765 	int ret;
766 
767 	if (copy_from_user(&args, argp, sizeof(args)))
768 		return -EFAULT;
769 
770 	ret = setup_wait(dev, &args, false, &q);
771 	if (ret < 0)
772 		return ret;
773 
774 	/* queue ourselves */
775 
776 	for (i = 0; i < args.count; i++) {
777 		struct ntsync_q_entry *entry = &q->entries[i];
778 		struct ntsync_obj *obj = entry->obj;
779 
780 		all = ntsync_lock_obj(dev, obj);
781 		list_add_tail(&entry->node, &obj->any_waiters);
782 		ntsync_unlock_obj(dev, obj, all);
783 	}
784 
785 	/* check if we are already signaled */
786 
787 	for (i = 0; i < args.count; i++) {
788 		struct ntsync_obj *obj = q->entries[i].obj;
789 
790 		if (atomic_read(&q->signaled) != -1)
791 			break;
792 
793 		all = ntsync_lock_obj(dev, obj);
794 		try_wake_any_obj(obj);
795 		ntsync_unlock_obj(dev, obj, all);
796 	}
797 
798 	/* sleep */
799 
800 	ret = ntsync_schedule(q, &args);
801 
802 	/* and finally, unqueue */
803 
804 	for (i = 0; i < args.count; i++) {
805 		struct ntsync_q_entry *entry = &q->entries[i];
806 		struct ntsync_obj *obj = entry->obj;
807 
808 		all = ntsync_lock_obj(dev, obj);
809 		list_del(&entry->node);
810 		ntsync_unlock_obj(dev, obj, all);
811 
812 		put_obj(obj);
813 	}
814 
815 	signaled = atomic_read(&q->signaled);
816 	if (signaled != -1) {
817 		struct ntsync_wait_args __user *user_args = argp;
818 
819 		/* even if we caught a signal, we need to communicate success */
820 		ret = q->ownerdead ? -EOWNERDEAD : 0;
821 
822 		if (put_user(signaled, &user_args->index))
823 			ret = -EFAULT;
824 	} else if (!ret) {
825 		ret = -ETIMEDOUT;
826 	}
827 
828 	kfree(q);
829 	return ret;
830 }
831 
832 static int ntsync_wait_all(struct ntsync_device *dev, void __user *argp)
833 {
834 	struct ntsync_wait_args args;
835 	struct ntsync_q *q;
836 	int signaled;
837 	__u32 i;
838 	int ret;
839 
840 	if (copy_from_user(&args, argp, sizeof(args)))
841 		return -EFAULT;
842 
843 	ret = setup_wait(dev, &args, true, &q);
844 	if (ret < 0)
845 		return ret;
846 
847 	/* queue ourselves */
848 
849 	mutex_lock(&dev->wait_all_lock);
850 
851 	for (i = 0; i < args.count; i++) {
852 		struct ntsync_q_entry *entry = &q->entries[i];
853 		struct ntsync_obj *obj = entry->obj;
854 
855 		atomic_inc(&obj->all_hint);
856 
857 		/*
858 		 * obj->all_waiters is protected by dev->wait_all_lock rather
859 		 * than obj->lock, so there is no need to acquire obj->lock
860 		 * here.
861 		 */
862 		list_add_tail(&entry->node, &obj->all_waiters);
863 	}
864 
865 	/* check if we are already signaled */
866 
867 	try_wake_all(dev, q, NULL);
868 
869 	mutex_unlock(&dev->wait_all_lock);
870 
871 	/* sleep */
872 
873 	ret = ntsync_schedule(q, &args);
874 
875 	/* and finally, unqueue */
876 
877 	mutex_lock(&dev->wait_all_lock);
878 
879 	for (i = 0; i < args.count; i++) {
880 		struct ntsync_q_entry *entry = &q->entries[i];
881 		struct ntsync_obj *obj = entry->obj;
882 
883 		/*
884 		 * obj->all_waiters is protected by dev->wait_all_lock rather
885 		 * than obj->lock, so there is no need to acquire it here.
886 		 */
887 		list_del(&entry->node);
888 
889 		atomic_dec(&obj->all_hint);
890 
891 		put_obj(obj);
892 	}
893 
894 	mutex_unlock(&dev->wait_all_lock);
895 
896 	signaled = atomic_read(&q->signaled);
897 	if (signaled != -1) {
898 		struct ntsync_wait_args __user *user_args = argp;
899 
900 		/* even if we caught a signal, we need to communicate success */
901 		ret = q->ownerdead ? -EOWNERDEAD : 0;
902 
903 		if (put_user(signaled, &user_args->index))
904 			ret = -EFAULT;
905 	} else if (!ret) {
906 		ret = -ETIMEDOUT;
907 	}
908 
909 	kfree(q);
910 	return ret;
911 }
912 
913 static int ntsync_char_open(struct inode *inode, struct file *file)
914 {
915 	struct ntsync_device *dev;
916 
917 	dev = kzalloc(sizeof(*dev), GFP_KERNEL);
918 	if (!dev)
919 		return -ENOMEM;
920 
921 	mutex_init(&dev->wait_all_lock);
922 
923 	file->private_data = dev;
924 	dev->file = file;
925 	return nonseekable_open(inode, file);
926 }
927 
928 static int ntsync_char_release(struct inode *inode, struct file *file)
929 {
930 	struct ntsync_device *dev = file->private_data;
931 
932 	kfree(dev);
933 
934 	return 0;
935 }
936 
937 static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
938 			      unsigned long parm)
939 {
940 	struct ntsync_device *dev = file->private_data;
941 	void __user *argp = (void __user *)parm;
942 
943 	switch (cmd) {
944 	case NTSYNC_IOC_CREATE_MUTEX:
945 		return ntsync_create_mutex(dev, argp);
946 	case NTSYNC_IOC_CREATE_SEM:
947 		return ntsync_create_sem(dev, argp);
948 	case NTSYNC_IOC_WAIT_ALL:
949 		return ntsync_wait_all(dev, argp);
950 	case NTSYNC_IOC_WAIT_ANY:
951 		return ntsync_wait_any(dev, argp);
952 	default:
953 		return -ENOIOCTLCMD;
954 	}
955 }
956 
957 static const struct file_operations ntsync_fops = {
958 	.owner		= THIS_MODULE,
959 	.open		= ntsync_char_open,
960 	.release	= ntsync_char_release,
961 	.unlocked_ioctl	= ntsync_char_ioctl,
962 	.compat_ioctl	= compat_ptr_ioctl,
963 };
964 
965 static struct miscdevice ntsync_misc = {
966 	.minor		= MISC_DYNAMIC_MINOR,
967 	.name		= NTSYNC_NAME,
968 	.fops		= &ntsync_fops,
969 };
970 
971 module_misc_device(ntsync_misc);
972 
973 MODULE_AUTHOR("Elizabeth Figura <[email protected]>");
974 MODULE_DESCRIPTION("Kernel driver for NT synchronization primitives");
975 MODULE_LICENSE("GPL");
976