1 #![cfg_attr(asan, allow(dead_code))]
2 
3 use super::index_allocator::{SimpleIndexAllocator, SlotId};
4 use crate::prelude::*;
5 use crate::runtime::vm::sys::vm::commit_pages;
6 use crate::runtime::vm::{
7     HostAlignedByteCount, Mmap, PoolingInstanceAllocatorConfig, mmap::AlignedLength,
8 };
9 
10 /// Represents a pool of execution stacks (used for the async fiber implementation).
11 ///
12 /// Each index into the pool represents a single execution stack. The maximum number of
13 /// stacks is the same as the maximum number of instances.
14 ///
15 /// As stacks grow downwards, each stack starts (lowest address) with a guard page
16 /// that can be used to detect stack overflow.
17 ///
18 /// The top of the stack (starting stack pointer) is returned when a stack is allocated
19 /// from the pool.
20 #[derive(Debug)]
21 pub struct StackPool {
22     mapping: Mmap<AlignedLength>,
23     stack_size: HostAlignedByteCount,
24     max_stacks: usize,
25     page_size: HostAlignedByteCount,
26     index_allocator: SimpleIndexAllocator,
27     async_stack_zeroing: bool,
28     async_stack_keep_resident: HostAlignedByteCount,
29 }
30 
31 impl StackPool {
32     #[cfg(test)]
enabled() -> bool33     pub fn enabled() -> bool {
34         true
35     }
36 
new(config: &PoolingInstanceAllocatorConfig) -> Result<Self>37     pub fn new(config: &PoolingInstanceAllocatorConfig) -> Result<Self> {
38         use rustix::mm::{MprotectFlags, mprotect};
39 
40         let page_size = HostAlignedByteCount::host_page_size();
41 
42         // Add a page to the stack size for the guard page when using fiber stacks
43         let stack_size = if config.stack_size == 0 {
44             HostAlignedByteCount::ZERO
45         } else {
46             HostAlignedByteCount::new_rounded_up(config.stack_size)
47                 .and_then(|size| size.checked_add(HostAlignedByteCount::host_page_size()))
48                 .context("stack size exceeds addressable memory")?
49         };
50 
51         let max_stacks = usize::try_from(config.limits.total_stacks).unwrap();
52 
53         let allocation_size = stack_size
54             .checked_mul(max_stacks)
55             .context("total size of execution stacks exceeds addressable memory")?;
56 
57         let mapping = Mmap::accessible_reserved(allocation_size, allocation_size)
58             .context("failed to create stack pool mapping")?;
59 
60         // Set up the stack guard pages.
61         if !allocation_size.is_zero() {
62             unsafe {
63                 for i in 0..max_stacks {
64                     // Safety: i < max_stacks and we've already checked that
65                     // stack_size * max_stacks is valid.
66                     let offset = stack_size.unchecked_mul(i);
67                     // Make the stack guard page inaccessible.
68                     let bottom_of_stack = mapping.as_ptr().add(offset.byte_count()).cast_mut();
69                     mprotect(
70                         bottom_of_stack.cast(),
71                         page_size.byte_count(),
72                         MprotectFlags::empty(),
73                     )
74                     .context("failed to protect stack guard page")?;
75                 }
76             }
77         }
78 
79         Ok(Self {
80             mapping,
81             stack_size,
82             max_stacks,
83             page_size,
84             async_stack_zeroing: config.async_stack_zeroing,
85             async_stack_keep_resident: HostAlignedByteCount::new_rounded_up(
86                 config.async_stack_keep_resident,
87             )?,
88             index_allocator: SimpleIndexAllocator::new(config.limits.total_stacks),
89         })
90     }
91 
92     /// Are there zero slots in use right now?
is_empty(&self) -> bool93     pub fn is_empty(&self) -> bool {
94         self.index_allocator.is_empty()
95     }
96 
97     /// Allocate a new fiber.
allocate(&self) -> Result<wasmtime_fiber::FiberStack>98     pub fn allocate(&self) -> Result<wasmtime_fiber::FiberStack> {
99         if self.stack_size.is_zero() {
100             bail!("pooling allocator not configured to enable fiber stack allocation");
101         }
102 
103         let index = self
104             .index_allocator
105             .alloc()
106             .ok_or_else(|| super::PoolConcurrencyLimitError::new(self.max_stacks, "fibers"))?
107             .index();
108 
109         assert!(index < self.max_stacks);
110 
111         unsafe {
112             // Remove the guard page from the size
113             let size_without_guard = self.stack_size.checked_sub(self.page_size).expect(
114                 "self.stack_size is host-page-aligned and is > 0,\
115                  so it must be >= self.page_size",
116             );
117 
118             let bottom_of_stack = self
119                 .mapping
120                 .as_ptr()
121                 .add(self.stack_size.unchecked_mul(index).byte_count())
122                 .cast_mut();
123 
124             commit_pages(bottom_of_stack, size_without_guard.byte_count())?;
125 
126             let stack = wasmtime_fiber::FiberStack::from_raw_parts(
127                 bottom_of_stack,
128                 self.page_size.byte_count(),
129                 size_without_guard.byte_count(),
130             )?;
131             Ok(stack)
132         }
133     }
134 
135     /// Zero the given stack, if we are configured to do so.
136     ///
137     /// This will call the given `decommit` function for each region of memory
138     /// that should be decommitted. It is the caller's responsibility to ensure
139     /// that those decommits happen before this stack is reused.
140     ///
141     /// # Panics
142     ///
143     /// `zero_stack` panics if the passed in `stack` was not created by
144     /// [`Self::allocate`].
145     ///
146     /// # Safety
147     ///
148     /// The stack must no longer be in use, and ready for returning to the pool
149     /// after it is zeroed and decommitted.
zero_stack( &self, stack: &mut wasmtime_fiber::FiberStack, mut decommit: impl FnMut(*mut u8, usize), ) -> usize150     pub unsafe fn zero_stack(
151         &self,
152         stack: &mut wasmtime_fiber::FiberStack,
153         mut decommit: impl FnMut(*mut u8, usize),
154     ) -> usize {
155         assert!(stack.is_from_raw_parts());
156         assert!(
157             !self.stack_size.is_zero(),
158             "pooling allocator not configured to enable fiber stack allocation \
159              (Self::allocate should have returned an error)"
160         );
161 
162         if !self.async_stack_zeroing {
163             return 0;
164         }
165 
166         let top = stack
167             .top()
168             .expect("fiber stack not allocated from the pool") as usize;
169 
170         let base = self.mapping.as_ptr() as usize;
171         let len = self.mapping.len();
172         assert!(
173             top > base && top <= (base + len),
174             "fiber stack top pointer not in range"
175         );
176 
177         // Remove the guard page from the size.
178         let stack_size = self.stack_size.checked_sub(self.page_size).expect(
179             "self.stack_size is host-page-aligned and is > 0,\
180              so it must be >= self.page_size",
181         );
182         let bottom_of_stack = top - stack_size.byte_count();
183         let start_of_stack = bottom_of_stack - self.page_size.byte_count();
184         assert!(start_of_stack >= base && start_of_stack < (base + len));
185         assert!((start_of_stack - base) % self.stack_size.byte_count() == 0);
186 
187         // Manually zero the top of the stack to keep the pages resident in
188         // memory and avoid future page faults. Use the system to deallocate
189         // pages past this. This hopefully strikes a reasonable balance between:
190         //
191         // * memset for the whole range is probably expensive
192         // * madvise for the whole range incurs expensive future page faults
193         // * most threads probably don't use most of the stack anyway
194         let size_to_memset = stack_size.min(self.async_stack_keep_resident);
195         let rest = stack_size
196             .checked_sub(size_to_memset)
197             .expect("stack_size >= size_to_memset");
198 
199         // SAFETY: this function's own contract requires that the stack is not
200         // in use so it's safe to pave over part of it with zero.
201         unsafe {
202             std::ptr::write_bytes(
203                 (bottom_of_stack + rest.byte_count()) as *mut u8,
204                 0,
205                 size_to_memset.byte_count(),
206             );
207         }
208 
209         // Use the system to reset remaining stack pages to zero.
210         decommit(bottom_of_stack as _, rest.byte_count());
211 
212         size_to_memset.byte_count()
213     }
214 
215     /// Deallocate a previously-allocated fiber.
216     ///
217     /// # Safety
218     ///
219     /// The fiber must have been allocated by this pool, must be in an allocated
220     /// state, and must never be used again.
221     ///
222     /// The caller must have already called `zero_stack` on the fiber stack and
223     /// flushed any enqueued decommits for this stack's memory.
deallocate(&self, stack: wasmtime_fiber::FiberStack, bytes_resident: usize)224     pub unsafe fn deallocate(&self, stack: wasmtime_fiber::FiberStack, bytes_resident: usize) {
225         assert!(stack.is_from_raw_parts());
226 
227         let top = stack
228             .top()
229             .expect("fiber stack not allocated from the pool") as usize;
230 
231         let base = self.mapping.as_ptr() as usize;
232         let len = self.mapping.len();
233         assert!(
234             top > base && top <= (base + len),
235             "fiber stack top pointer not in range"
236         );
237 
238         // Remove the guard page from the size
239         let stack_size = self.stack_size.byte_count() - self.page_size.byte_count();
240         let bottom_of_stack = top - stack_size;
241         let start_of_stack = bottom_of_stack - self.page_size.byte_count();
242         assert!(start_of_stack >= base && start_of_stack < (base + len));
243         assert!((start_of_stack - base) % self.stack_size.byte_count() == 0);
244 
245         let index = (start_of_stack - base) / self.stack_size.byte_count();
246         assert!(index < self.max_stacks);
247         let index = u32::try_from(index).unwrap();
248 
249         self.index_allocator.free(SlotId(index), bytes_resident);
250     }
251 
unused_warm_slots(&self) -> u32252     pub fn unused_warm_slots(&self) -> u32 {
253         self.index_allocator.unused_warm_slots()
254     }
255 
unused_bytes_resident(&self) -> Option<usize>256     pub fn unused_bytes_resident(&self) -> Option<usize> {
257         if self.async_stack_zeroing {
258             Some(self.index_allocator.unused_bytes_resident())
259         } else {
260             None
261         }
262     }
263 }
264 
265 #[cfg(all(test, unix, feature = "async", not(miri), not(asan)))]
266 mod tests {
267     use super::*;
268     use crate::runtime::vm::InstanceLimits;
269 
270     #[test]
test_stack_pool() -> Result<()>271     fn test_stack_pool() -> Result<()> {
272         let config = PoolingInstanceAllocatorConfig {
273             limits: InstanceLimits {
274                 total_stacks: 10,
275                 ..Default::default()
276             },
277             stack_size: 1,
278             async_stack_zeroing: true,
279             ..PoolingInstanceAllocatorConfig::default()
280         };
281         let pool = StackPool::new(&config)?;
282 
283         let native_page_size = crate::runtime::vm::host_page_size();
284         assert_eq!(pool.stack_size, 2 * native_page_size);
285         assert_eq!(pool.max_stacks, 10);
286         assert_eq!(pool.page_size, native_page_size);
287 
288         assert_eq!(pool.index_allocator.testing_freelist(), []);
289 
290         let base = pool.mapping.as_ptr() as usize;
291 
292         let mut stacks = Vec::new();
293         for i in 0..10 {
294             let stack = pool.allocate().expect("allocation should succeed");
295             assert_eq!(
296                 ((stack.top().unwrap() as usize - base) / pool.stack_size.byte_count()) - 1,
297                 i
298             );
299             stacks.push(stack);
300         }
301 
302         assert_eq!(pool.index_allocator.testing_freelist(), []);
303 
304         assert!(pool.allocate().is_err(), "allocation should fail");
305 
306         for stack in stacks {
307             unsafe {
308                 pool.deallocate(stack, 0);
309             }
310         }
311 
312         assert_eq!(
313             pool.index_allocator.testing_freelist(),
314             [
315                 SlotId(0),
316                 SlotId(1),
317                 SlotId(2),
318                 SlotId(3),
319                 SlotId(4),
320                 SlotId(5),
321                 SlotId(6),
322                 SlotId(7),
323                 SlotId(8),
324                 SlotId(9)
325             ],
326         );
327 
328         Ok(())
329     }
330 }
331