xref: /wasmtime-44.0.1/winch/codegen/src/stack.rs (revision b93e1bc0)
1 use crate::{codegen::CodeGenError, isa::reg::Reg, masm::StackSlot};
2 use anyhow::{anyhow, Result};
3 use smallvec::SmallVec;
4 use wasmparser::{Ieee32, Ieee64};
5 use wasmtime_environ::WasmValType;
6 
7 /// A typed register value used to track register values in the value
8 /// stack.
9 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
10 pub struct TypedReg {
11     /// The physical register.
12     pub reg: Reg,
13     /// The type associated to the physical register.
14     pub ty: WasmValType,
15 }
16 
17 impl TypedReg {
18     /// Create a new [`TypedReg`].
19     pub fn new(ty: WasmValType, reg: Reg) -> Self {
20         Self { ty, reg }
21     }
22 
23     /// Create an i64 [`TypedReg`].
24     pub fn i64(reg: Reg) -> Self {
25         Self {
26             ty: WasmValType::I64,
27             reg,
28         }
29     }
30 
31     /// Create an i32 [`TypedReg`].
32     pub fn i32(reg: Reg) -> Self {
33         Self {
34             ty: WasmValType::I32,
35             reg,
36         }
37     }
38 
39     /// Create an f64 [`TypedReg`].
40     pub fn f64(reg: Reg) -> Self {
41         Self {
42             ty: WasmValType::F64,
43             reg,
44         }
45     }
46 
47     /// Create an f32 [`TypedReg`].
48     pub fn f32(reg: Reg) -> Self {
49         Self {
50             ty: WasmValType::F32,
51             reg,
52         }
53     }
54 }
55 
56 impl From<TypedReg> for Reg {
57     fn from(tr: TypedReg) -> Self {
58         tr.reg
59     }
60 }
61 
62 /// A local value.
63 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
64 pub struct Local {
65     /// The index of the local.
66     pub index: u32,
67     /// The type of the local.
68     pub ty: WasmValType,
69 }
70 
71 /// A memory value.
72 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
73 pub struct Memory {
74     /// The type associated with the memory offset.
75     pub ty: WasmValType,
76     /// The stack slot corresponding to the memory value.
77     pub slot: StackSlot,
78 }
79 
80 /// Value definition to be used within the shadow stack.
81 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
82 pub(crate) enum Val {
83     /// I32 Constant.
84     I32(i32),
85     /// I64 Constant.
86     I64(i64),
87     /// F32 Constant.
88     F32(Ieee32),
89     /// F64 Constant.
90     F64(Ieee64),
91     /// V128 Constant.
92     V128(i128),
93     /// A register value.
94     Reg(TypedReg),
95     /// A local slot.
96     Local(Local),
97     /// Offset to a memory location.
98     Memory(Memory),
99 }
100 
101 impl From<TypedReg> for Val {
102     fn from(tr: TypedReg) -> Self {
103         Val::Reg(tr)
104     }
105 }
106 
107 impl From<Local> for Val {
108     fn from(local: Local) -> Self {
109         Val::Local(local)
110     }
111 }
112 
113 impl From<Memory> for Val {
114     fn from(mem: Memory) -> Self {
115         Val::Memory(mem)
116     }
117 }
118 
119 impl TryFrom<u32> for Val {
120     type Error = anyhow::Error;
121     fn try_from(value: u32) -> Result<Self, Self::Error> {
122         i32::try_from(value).map(Val::i32).map_err(Into::into)
123     }
124 }
125 
126 impl Val {
127     /// Create a new I32 constant value.
128     pub fn i32(v: i32) -> Self {
129         Self::I32(v)
130     }
131 
132     /// Create a new I64 constant value.
133     pub fn i64(v: i64) -> Self {
134         Self::I64(v)
135     }
136 
137     /// Create a new F32 constant value.
138     pub fn f32(v: Ieee32) -> Self {
139         Self::F32(v)
140     }
141 
142     pub fn f64(v: Ieee64) -> Self {
143         Self::F64(v)
144     }
145 
146     /// Create a new V128 constant value.
147     pub fn v128(v: i128) -> Self {
148         Self::V128(v)
149     }
150 
151     /// Create a new Reg value.
152     pub fn reg(reg: Reg, ty: WasmValType) -> Self {
153         Self::Reg(TypedReg { reg, ty })
154     }
155 
156     /// Create a new Local value.
157     pub fn local(index: u32, ty: WasmValType) -> Self {
158         Self::Local(Local { index, ty })
159     }
160 
161     /// Create a Memory value.
162     pub fn mem(ty: WasmValType, slot: StackSlot) -> Self {
163         Self::Memory(Memory { ty, slot })
164     }
165 
166     /// Check whether the value is a register.
167     pub fn is_reg(&self) -> bool {
168         match *self {
169             Self::Reg(_) => true,
170             _ => false,
171         }
172     }
173 
174     /// Check whether the value is a memory offset.
175     pub fn is_mem(&self) -> bool {
176         match *self {
177             Self::Memory(_) => true,
178             _ => false,
179         }
180     }
181 
182     /// Check whether the value is a constant.
183     pub fn is_const(&self) -> bool {
184         match *self {
185             Val::I32(_) | Val::I64(_) | Val::F32(_) | Val::F64(_) | Val::V128(_) => true,
186             _ => false,
187         }
188     }
189 
190     /// Check whether the value is local with a particular index.
191     pub fn is_local_at_index(&self, index: u32) -> bool {
192         match *self {
193             Self::Local(Local { index: i, .. }) if i == index => true,
194             _ => false,
195         }
196     }
197 
198     /// Get the register representation of the value.
199     ///
200     /// # Panics
201     /// This method will panic if the value is not a register.
202     pub fn unwrap_reg(&self) -> TypedReg {
203         match self {
204             Self::Reg(tr) => *tr,
205             v => panic!("expected value {v:?} to be a register"),
206         }
207     }
208 
209     /// Get the integer representation of the value.
210     ///
211     /// # Panics
212     /// This method will panic if the value is not an i32.
213     pub fn unwrap_i32(&self) -> i32 {
214         match self {
215             Self::I32(v) => *v,
216             v => panic!("expected value {v:?} to be i32"),
217         }
218     }
219 
220     /// Get the integer representation of the value.
221     ///
222     /// # Panics
223     /// This method will panic if the value is not an i64.
224     pub fn unwrap_i64(&self) -> i64 {
225         match self {
226             Self::I64(v) => *v,
227             v => panic!("expected value {v:?} to be i64"),
228         }
229     }
230 
231     /// Returns the underlying memory value if it is one, panics otherwise.
232     pub fn unwrap_mem(&self) -> Memory {
233         match self {
234             Self::Memory(m) => *m,
235             v => panic!("expected value {v:?} to be a Memory"),
236         }
237     }
238 
239     /// Check whether the value is an i32 constant.
240     pub fn is_i32_const(&self) -> bool {
241         match *self {
242             Self::I32(_) => true,
243             _ => false,
244         }
245     }
246 
247     /// Check whether the value is an i64 constant.
248     pub fn is_i64_const(&self) -> bool {
249         match *self {
250             Self::I64(_) => true,
251             _ => false,
252         }
253     }
254 
255     /// Get the type of the value.
256     pub fn ty(&self) -> WasmValType {
257         match self {
258             Val::I32(_) => WasmValType::I32,
259             Val::I64(_) => WasmValType::I64,
260             Val::F32(_) => WasmValType::F32,
261             Val::F64(_) => WasmValType::F64,
262             Val::V128(_) => WasmValType::V128,
263             Val::Reg(r) => r.ty,
264             Val::Memory(m) => m.ty,
265             Val::Local(l) => l.ty,
266         }
267     }
268 }
269 
270 /// The shadow stack used for compilation.
271 #[derive(Default, Debug)]
272 pub(crate) struct Stack {
273     // NB: The 64 is chosen arbitrarily. We can adjust as we see fit.
274     inner: SmallVec<[Val; 64]>,
275 }
276 
277 impl Stack {
278     /// Allocate a new stack.
279     pub fn new() -> Self {
280         Self {
281             inner: Default::default(),
282         }
283     }
284 
285     /// Ensures that there are at least `n` elements in the value stack,
286     /// and returns the index calculated by: stack length minus `n`.
287     pub fn ensure_index_at(&self, n: usize) -> Result<usize> {
288         if self.len() >= n {
289             Ok(self.len() - n)
290         } else {
291             Err(anyhow!(CodeGenError::missing_values_in_stack()))
292         }
293     }
294 
295     /// Returns true if the stack contains a local with the provided index
296     /// except if the only time the local appears is the top element.
297     pub fn contains_latent_local(&self, index: u32) -> bool {
298         self.inner
299             .iter()
300             // Iterate top-to-bottom so we can skip the top element and stop
301             // when we see a memory element.
302             .rev()
303             // The local is not latent if it's the top element because the top
304             // element will be popped next which materializes the local.
305             .skip(1)
306             // Stop when we see a memory element because that marks where we
307             // spilled up to so there will not be any locals past this point.
308             .take_while(|v| !v.is_mem())
309             .any(|v| v.is_local_at_index(index))
310     }
311 
312     /// Extend the stack with the given elements.
313     pub fn extend(&mut self, values: impl IntoIterator<Item = Val>) {
314         self.inner.extend(values);
315     }
316 
317     /// Inserts many values at the given index.
318     pub fn insert_many(&mut self, at: usize, values: &[Val]) {
319         debug_assert!(at <= self.len());
320 
321         if at == self.len() {
322             self.inner.extend_from_slice(values);
323         } else {
324             self.inner.insert_from_slice(at, values);
325         }
326     }
327 
328     /// Get the length of the stack.
329     pub fn len(&self) -> usize {
330         self.inner.len()
331     }
332 
333     /// Push a value to the stack.
334     pub fn push(&mut self, val: Val) {
335         self.inner.push(val);
336     }
337 
338     /// Peek into the top in the stack.
339     pub fn peek(&self) -> Option<&Val> {
340         self.inner.last()
341     }
342 
343     /// Returns an iterator referencing the last n items of the stack,
344     /// in bottom-most to top-most order.
345     pub fn peekn(&self, n: usize) -> impl Iterator<Item = &Val> + '_ {
346         let len = self.len();
347         assert!(n <= len);
348 
349         let partition = len - n;
350         self.inner[partition..].into_iter()
351     }
352 
353     /// Pops the top element of the stack, if any.
354     pub fn pop(&mut self) -> Option<Val> {
355         self.inner.pop()
356     }
357 
358     /// Pops the element at the top of the stack if it is an i32 const;
359     /// returns `None` otherwise.
360     pub fn pop_i32_const(&mut self) -> Option<i32> {
361         match self.peek() {
362             Some(v) => v.is_i32_const().then(|| self.pop().unwrap().unwrap_i32()),
363             _ => None,
364         }
365     }
366 
367     /// Pops the element at the top of the stack if it is an i64 const;
368     /// returns `None` otherwise.
369     pub fn pop_i64_const(&mut self) -> Option<i64> {
370         match self.peek() {
371             Some(v) => v.is_i64_const().then(|| self.pop().unwrap().unwrap_i64()),
372             _ => None,
373         }
374     }
375 
376     /// Pops the element at the top of the stack if it is a register;
377     /// returns `None` otherwise.
378     pub fn pop_reg(&mut self) -> Option<TypedReg> {
379         match self.peek() {
380             Some(v) => v.is_reg().then(|| self.pop().unwrap().unwrap_reg()),
381             _ => None,
382         }
383     }
384 
385     /// Pops the given register if it is at the top of the stack;
386     /// returns `None` otherwise.
387     pub fn pop_named_reg(&mut self, reg: Reg) -> Option<TypedReg> {
388         match self.peek() {
389             Some(v) => {
390                 (v.is_reg() && v.unwrap_reg().reg == reg).then(|| self.pop().unwrap().unwrap_reg())
391             }
392             _ => None,
393         }
394     }
395 
396     /// Get a mutable reference to the inner stack representation.
397     pub fn inner_mut(&mut self) -> &mut SmallVec<[Val; 64]> {
398         &mut self.inner
399     }
400 
401     /// Get a reference to the inner stack representation.
402     pub fn inner(&self) -> &SmallVec<[Val; 64]> {
403         &self.inner
404     }
405 
406     /// Calculates the size of, in bytes, of the top n [Memory] entries
407     /// in the value stack.
408     pub fn sizeof(&self, top: usize) -> u32 {
409         self.peekn(top).fold(0, |acc, v| {
410             if v.is_mem() {
411                 acc + v.unwrap_mem().slot.size
412             } else {
413                 acc
414             }
415         })
416     }
417 }
418 
419 #[cfg(test)]
420 mod tests {
421     use super::{Stack, Val};
422     use crate::isa::reg::Reg;
423     use wasmtime_environ::WasmValType;
424 
425     #[test]
426     fn test_pop_i32_const() {
427         let mut stack = Stack::new();
428         stack.push(Val::i32(33i32));
429         assert_eq!(33, stack.pop_i32_const().unwrap());
430 
431         stack.push(Val::local(10, WasmValType::I32));
432         assert!(stack.pop_i32_const().is_none());
433     }
434 
435     #[test]
436     fn test_pop_reg() {
437         let mut stack = Stack::new();
438         let reg = Reg::int(2usize);
439         stack.push(Val::reg(reg, WasmValType::I32));
440         stack.push(Val::i32(4));
441 
442         assert_eq!(None, stack.pop_reg());
443         let _ = stack.pop().unwrap();
444         assert_eq!(reg, stack.pop_reg().unwrap().reg);
445     }
446 
447     #[test]
448     fn test_pop_named_reg() {
449         let mut stack = Stack::new();
450         let reg = Reg::int(2usize);
451         stack.push(Val::reg(reg, WasmValType::I32));
452         stack.push(Val::reg(Reg::int(4), WasmValType::I32));
453 
454         assert_eq!(None, stack.pop_named_reg(reg));
455         let _ = stack.pop().unwrap();
456         assert_eq!(reg, stack.pop_named_reg(reg).unwrap().reg);
457     }
458 }
459