1*84837cf6SBenno Lossin // SPDX-License-Identifier: Apache-2.0 OR MIT
2*84837cf6SBenno Lossin 
3*84837cf6SBenno Lossin #![allow(clippy::undocumented_unsafe_blocks)]
4*84837cf6SBenno Lossin #![cfg_attr(feature = "alloc", feature(allocator_api))]
5*84837cf6SBenno Lossin 
6*84837cf6SBenno Lossin use core::{
7*84837cf6SBenno Lossin     cell::{Cell, UnsafeCell},
8*84837cf6SBenno Lossin     mem::MaybeUninit,
9*84837cf6SBenno Lossin     ops,
10*84837cf6SBenno Lossin     pin::Pin,
11*84837cf6SBenno Lossin     time::Duration,
12*84837cf6SBenno Lossin };
13*84837cf6SBenno Lossin use pin_init::*;
14*84837cf6SBenno Lossin use std::{
15*84837cf6SBenno Lossin     sync::Arc,
16*84837cf6SBenno Lossin     thread::{sleep, Builder},
17*84837cf6SBenno Lossin };
18*84837cf6SBenno Lossin 
19*84837cf6SBenno Lossin #[expect(unused_attributes)]
20*84837cf6SBenno Lossin mod mutex;
21*84837cf6SBenno Lossin use mutex::*;
22*84837cf6SBenno Lossin 
23*84837cf6SBenno Lossin pub struct StaticInit<T, I> {
24*84837cf6SBenno Lossin     cell: UnsafeCell<MaybeUninit<T>>,
25*84837cf6SBenno Lossin     init: Cell<Option<I>>,
26*84837cf6SBenno Lossin     lock: SpinLock,
27*84837cf6SBenno Lossin     present: Cell<bool>,
28*84837cf6SBenno Lossin }
29*84837cf6SBenno Lossin 
30*84837cf6SBenno Lossin unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
31*84837cf6SBenno Lossin unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
32*84837cf6SBenno Lossin 
33*84837cf6SBenno Lossin impl<T, I: PinInit<T>> StaticInit<T, I> {
new(init: I) -> Self34*84837cf6SBenno Lossin     pub const fn new(init: I) -> Self {
35*84837cf6SBenno Lossin         Self {
36*84837cf6SBenno Lossin             cell: UnsafeCell::new(MaybeUninit::uninit()),
37*84837cf6SBenno Lossin             init: Cell::new(Some(init)),
38*84837cf6SBenno Lossin             lock: SpinLock::new(),
39*84837cf6SBenno Lossin             present: Cell::new(false),
40*84837cf6SBenno Lossin         }
41*84837cf6SBenno Lossin     }
42*84837cf6SBenno Lossin }
43*84837cf6SBenno Lossin 
44*84837cf6SBenno Lossin impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
45*84837cf6SBenno Lossin     type Target = T;
deref(&self) -> &Self::Target46*84837cf6SBenno Lossin     fn deref(&self) -> &Self::Target {
47*84837cf6SBenno Lossin         if self.present.get() {
48*84837cf6SBenno Lossin             unsafe { (*self.cell.get()).assume_init_ref() }
49*84837cf6SBenno Lossin         } else {
50*84837cf6SBenno Lossin             println!("acquire spinlock on static init");
51*84837cf6SBenno Lossin             let _guard = self.lock.acquire();
52*84837cf6SBenno Lossin             println!("rechecking present...");
53*84837cf6SBenno Lossin             std::thread::sleep(std::time::Duration::from_millis(200));
54*84837cf6SBenno Lossin             if self.present.get() {
55*84837cf6SBenno Lossin                 return unsafe { (*self.cell.get()).assume_init_ref() };
56*84837cf6SBenno Lossin             }
57*84837cf6SBenno Lossin             println!("doing init");
58*84837cf6SBenno Lossin             let ptr = self.cell.get().cast::<T>();
59*84837cf6SBenno Lossin             match self.init.take() {
60*84837cf6SBenno Lossin                 Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
61*84837cf6SBenno Lossin                 None => unsafe { core::hint::unreachable_unchecked() },
62*84837cf6SBenno Lossin             }
63*84837cf6SBenno Lossin             self.present.set(true);
64*84837cf6SBenno Lossin             unsafe { (*self.cell.get()).assume_init_ref() }
65*84837cf6SBenno Lossin         }
66*84837cf6SBenno Lossin     }
67*84837cf6SBenno Lossin }
68*84837cf6SBenno Lossin 
69*84837cf6SBenno Lossin pub struct CountInit;
70*84837cf6SBenno Lossin 
71*84837cf6SBenno Lossin unsafe impl PinInit<CMutex<usize>> for CountInit {
__pinned_init( self, slot: *mut CMutex<usize>, ) -> Result<(), core::convert::Infallible>72*84837cf6SBenno Lossin     unsafe fn __pinned_init(
73*84837cf6SBenno Lossin         self,
74*84837cf6SBenno Lossin         slot: *mut CMutex<usize>,
75*84837cf6SBenno Lossin     ) -> Result<(), core::convert::Infallible> {
76*84837cf6SBenno Lossin         let init = CMutex::new(0);
77*84837cf6SBenno Lossin         std::thread::sleep(std::time::Duration::from_millis(1000));
78*84837cf6SBenno Lossin         unsafe { init.__pinned_init(slot) }
79*84837cf6SBenno Lossin     }
80*84837cf6SBenno Lossin }
81*84837cf6SBenno Lossin 
82*84837cf6SBenno Lossin pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
83*84837cf6SBenno Lossin 
84*84837cf6SBenno Lossin #[cfg(not(any(feature = "std", feature = "alloc")))]
main()85*84837cf6SBenno Lossin fn main() {}
86*84837cf6SBenno Lossin 
87*84837cf6SBenno Lossin #[cfg(any(feature = "std", feature = "alloc"))]
main()88*84837cf6SBenno Lossin fn main() {
89*84837cf6SBenno Lossin     let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
90*84837cf6SBenno Lossin     let mut handles = vec![];
91*84837cf6SBenno Lossin     let thread_count = 20;
92*84837cf6SBenno Lossin     let workload = 1_000;
93*84837cf6SBenno Lossin     for i in 0..thread_count {
94*84837cf6SBenno Lossin         let mtx = mtx.clone();
95*84837cf6SBenno Lossin         handles.push(
96*84837cf6SBenno Lossin             Builder::new()
97*84837cf6SBenno Lossin                 .name(format!("worker #{i}"))
98*84837cf6SBenno Lossin                 .spawn(move || {
99*84837cf6SBenno Lossin                     for _ in 0..workload {
100*84837cf6SBenno Lossin                         *COUNT.lock() += 1;
101*84837cf6SBenno Lossin                         std::thread::sleep(std::time::Duration::from_millis(10));
102*84837cf6SBenno Lossin                         *mtx.lock() += 1;
103*84837cf6SBenno Lossin                         std::thread::sleep(std::time::Duration::from_millis(10));
104*84837cf6SBenno Lossin                         *COUNT.lock() += 1;
105*84837cf6SBenno Lossin                     }
106*84837cf6SBenno Lossin                     println!("{i} halfway");
107*84837cf6SBenno Lossin                     sleep(Duration::from_millis((i as u64) * 10));
108*84837cf6SBenno Lossin                     for _ in 0..workload {
109*84837cf6SBenno Lossin                         std::thread::sleep(std::time::Duration::from_millis(10));
110*84837cf6SBenno Lossin                         *mtx.lock() += 1;
111*84837cf6SBenno Lossin                     }
112*84837cf6SBenno Lossin                     println!("{i} finished");
113*84837cf6SBenno Lossin                 })
114*84837cf6SBenno Lossin                 .expect("should not fail"),
115*84837cf6SBenno Lossin         );
116*84837cf6SBenno Lossin     }
117*84837cf6SBenno Lossin     for h in handles {
118*84837cf6SBenno Lossin         h.join().expect("thread panicked");
119*84837cf6SBenno Lossin     }
120*84837cf6SBenno Lossin     println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
121*84837cf6SBenno Lossin     assert_eq!(*mtx.lock(), workload * thread_count * 2);
122*84837cf6SBenno Lossin }
123