xref: /wasmtime-44.0.1/winch/codegen/src/stack.rs (revision ba950f2c)
1 use crate::{codegen::CodeGenError, isa::reg::Reg, masm::StackSlot};
2 use anyhow::{anyhow, Result};
3 use smallvec::SmallVec;
4 use wasmparser::{Ieee32, Ieee64};
5 use wasmtime_environ::WasmValType;
6 
7 /// A typed register value used to track register values in the value
8 /// stack.
9 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
10 pub struct TypedReg {
11     /// The physical register.
12     pub reg: Reg,
13     /// The type associated to the physical register.
14     pub ty: WasmValType,
15 }
16 
17 impl TypedReg {
18     /// Create a new [`TypedReg`].
19     pub fn new(ty: WasmValType, reg: Reg) -> Self {
20         Self { ty, reg }
21     }
22 
23     /// Create an i64 [`TypedReg`].
24     pub fn i64(reg: Reg) -> Self {
25         Self {
26             ty: WasmValType::I64,
27             reg,
28         }
29     }
30 
31     /// Create an i32 [`TypedReg`].
32     pub fn i32(reg: Reg) -> Self {
33         Self {
34             ty: WasmValType::I32,
35             reg,
36         }
37     }
38 
39     /// Create an f64 [`TypedReg`].
40     pub fn f64(reg: Reg) -> Self {
41         Self {
42             ty: WasmValType::F64,
43             reg,
44         }
45     }
46 
47     /// Create an f32 [`TypedReg`].
48     pub fn f32(reg: Reg) -> Self {
49         Self {
50             ty: WasmValType::F32,
51             reg,
52         }
53     }
54 
55     /// Create a v128 [`TypedReg`].
56     pub fn v128(reg: Reg) -> Self {
57         Self {
58             ty: WasmValType::V128,
59             reg,
60         }
61     }
62 }
63 
64 impl From<TypedReg> for Reg {
65     fn from(tr: TypedReg) -> Self {
66         tr.reg
67     }
68 }
69 
70 /// A local value.
71 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
72 pub struct Local {
73     /// The index of the local.
74     pub index: u32,
75     /// The type of the local.
76     pub ty: WasmValType,
77 }
78 
79 /// A memory value.
80 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
81 pub struct Memory {
82     /// The type associated with the memory offset.
83     pub ty: WasmValType,
84     /// The stack slot corresponding to the memory value.
85     pub slot: StackSlot,
86 }
87 
88 /// Value definition to be used within the shadow stack.
89 #[derive(Debug, Eq, PartialEq, Copy, Clone)]
90 pub(crate) enum Val {
91     /// I32 Constant.
92     I32(i32),
93     /// I64 Constant.
94     I64(i64),
95     /// F32 Constant.
96     F32(Ieee32),
97     /// F64 Constant.
98     F64(Ieee64),
99     /// V128 Constant.
100     V128(i128),
101     /// A register value.
102     Reg(TypedReg),
103     /// A local slot.
104     Local(Local),
105     /// Offset to a memory location.
106     Memory(Memory),
107 }
108 
109 impl From<TypedReg> for Val {
110     fn from(tr: TypedReg) -> Self {
111         Val::Reg(tr)
112     }
113 }
114 
115 impl From<Local> for Val {
116     fn from(local: Local) -> Self {
117         Val::Local(local)
118     }
119 }
120 
121 impl From<Memory> for Val {
122     fn from(mem: Memory) -> Self {
123         Val::Memory(mem)
124     }
125 }
126 
127 impl TryFrom<u32> for Val {
128     type Error = anyhow::Error;
129     fn try_from(value: u32) -> Result<Self, Self::Error> {
130         i32::try_from(value).map(Val::i32).map_err(Into::into)
131     }
132 }
133 
134 impl Val {
135     /// Create a new I32 constant value.
136     pub fn i32(v: i32) -> Self {
137         Self::I32(v)
138     }
139 
140     /// Create a new I64 constant value.
141     pub fn i64(v: i64) -> Self {
142         Self::I64(v)
143     }
144 
145     /// Create a new F32 constant value.
146     pub fn f32(v: Ieee32) -> Self {
147         Self::F32(v)
148     }
149 
150     pub fn f64(v: Ieee64) -> Self {
151         Self::F64(v)
152     }
153 
154     /// Create a new V128 constant value.
155     pub fn v128(v: i128) -> Self {
156         Self::V128(v)
157     }
158 
159     /// Create a new Reg value.
160     pub fn reg(reg: Reg, ty: WasmValType) -> Self {
161         Self::Reg(TypedReg { reg, ty })
162     }
163 
164     /// Create a new Local value.
165     pub fn local(index: u32, ty: WasmValType) -> Self {
166         Self::Local(Local { index, ty })
167     }
168 
169     /// Create a Memory value.
170     pub fn mem(ty: WasmValType, slot: StackSlot) -> Self {
171         Self::Memory(Memory { ty, slot })
172     }
173 
174     /// Check whether the value is a register.
175     pub fn is_reg(&self) -> bool {
176         match *self {
177             Self::Reg(_) => true,
178             _ => false,
179         }
180     }
181 
182     /// Check whether the value is a memory offset.
183     pub fn is_mem(&self) -> bool {
184         match *self {
185             Self::Memory(_) => true,
186             _ => false,
187         }
188     }
189 
190     /// Check whether the value is a constant.
191     pub fn is_const(&self) -> bool {
192         match *self {
193             Val::I32(_) | Val::I64(_) | Val::F32(_) | Val::F64(_) | Val::V128(_) => true,
194             _ => false,
195         }
196     }
197 
198     /// Check whether the value is local with a particular index.
199     pub fn is_local_at_index(&self, index: u32) -> bool {
200         match *self {
201             Self::Local(Local { index: i, .. }) if i == index => true,
202             _ => false,
203         }
204     }
205 
206     /// Get the register representation of the value.
207     ///
208     /// # Panics
209     /// This method will panic if the value is not a register.
210     pub fn unwrap_reg(&self) -> TypedReg {
211         match self {
212             Self::Reg(tr) => *tr,
213             v => panic!("expected value {v:?} to be a register"),
214         }
215     }
216 
217     /// Get the integer representation of the value.
218     ///
219     /// # Panics
220     /// This method will panic if the value is not an i32.
221     pub fn unwrap_i32(&self) -> i32 {
222         match self {
223             Self::I32(v) => *v,
224             v => panic!("expected value {v:?} to be i32"),
225         }
226     }
227 
228     /// Get the integer representation of the value.
229     ///
230     /// # Panics
231     /// This method will panic if the value is not an i64.
232     pub fn unwrap_i64(&self) -> i64 {
233         match self {
234             Self::I64(v) => *v,
235             v => panic!("expected value {v:?} to be i64"),
236         }
237     }
238 
239     /// Returns the underlying memory value if it is one, panics otherwise.
240     pub fn unwrap_mem(&self) -> Memory {
241         match self {
242             Self::Memory(m) => *m,
243             v => panic!("expected value {v:?} to be a Memory"),
244         }
245     }
246 
247     /// Check whether the value is an i32 constant.
248     pub fn is_i32_const(&self) -> bool {
249         match *self {
250             Self::I32(_) => true,
251             _ => false,
252         }
253     }
254 
255     /// Check whether the value is an i64 constant.
256     pub fn is_i64_const(&self) -> bool {
257         match *self {
258             Self::I64(_) => true,
259             _ => false,
260         }
261     }
262 
263     /// Get the type of the value.
264     pub fn ty(&self) -> WasmValType {
265         match self {
266             Val::I32(_) => WasmValType::I32,
267             Val::I64(_) => WasmValType::I64,
268             Val::F32(_) => WasmValType::F32,
269             Val::F64(_) => WasmValType::F64,
270             Val::V128(_) => WasmValType::V128,
271             Val::Reg(r) => r.ty,
272             Val::Memory(m) => m.ty,
273             Val::Local(l) => l.ty,
274         }
275     }
276 }
277 
278 /// The shadow stack used for compilation.
279 #[derive(Default, Debug)]
280 pub(crate) struct Stack {
281     // NB: The 64 is chosen arbitrarily. We can adjust as we see fit.
282     inner: SmallVec<[Val; 64]>,
283 }
284 
285 impl Stack {
286     /// Allocate a new stack.
287     pub fn new() -> Self {
288         Self {
289             inner: Default::default(),
290         }
291     }
292 
293     /// Ensures that there are at least `n` elements in the value stack,
294     /// and returns the index calculated by: stack length minus `n`.
295     pub fn ensure_index_at(&self, n: usize) -> Result<usize> {
296         if self.len() >= n {
297             Ok(self.len() - n)
298         } else {
299             Err(anyhow!(CodeGenError::missing_values_in_stack()))
300         }
301     }
302 
303     /// Returns true if the stack contains a local with the provided index
304     /// except if the only time the local appears is the top element.
305     pub fn contains_latent_local(&self, index: u32) -> bool {
306         self.inner
307             .iter()
308             // Iterate top-to-bottom so we can skip the top element and stop
309             // when we see a memory element.
310             .rev()
311             // The local is not latent if it's the top element because the top
312             // element will be popped next which materializes the local.
313             .skip(1)
314             // Stop when we see a memory element because that marks where we
315             // spilled up to so there will not be any locals past this point.
316             .take_while(|v| !v.is_mem())
317             .any(|v| v.is_local_at_index(index))
318     }
319 
320     /// Extend the stack with the given elements.
321     pub fn extend(&mut self, values: impl IntoIterator<Item = Val>) {
322         self.inner.extend(values);
323     }
324 
325     /// Inserts many values at the given index.
326     pub fn insert_many(&mut self, at: usize, values: &[Val]) {
327         debug_assert!(at <= self.len());
328 
329         if at == self.len() {
330             self.inner.extend_from_slice(values);
331         } else {
332             self.inner.insert_from_slice(at, values);
333         }
334     }
335 
336     /// Get the length of the stack.
337     pub fn len(&self) -> usize {
338         self.inner.len()
339     }
340 
341     /// Push a value to the stack.
342     pub fn push(&mut self, val: Val) {
343         self.inner.push(val);
344     }
345 
346     /// Peek into the top in the stack.
347     pub fn peek(&self) -> Option<&Val> {
348         self.inner.last()
349     }
350 
351     /// Returns an iterator referencing the last n items of the stack,
352     /// in bottom-most to top-most order.
353     pub fn peekn(&self, n: usize) -> impl Iterator<Item = &Val> + '_ {
354         let len = self.len();
355         assert!(n <= len);
356 
357         let partition = len - n;
358         self.inner[partition..].into_iter()
359     }
360 
361     /// Pops the top element of the stack, if any.
362     pub fn pop(&mut self) -> Option<Val> {
363         self.inner.pop()
364     }
365 
366     /// Pops the element at the top of the stack if it is an i32 const;
367     /// returns `None` otherwise.
368     pub fn pop_i32_const(&mut self) -> Option<i32> {
369         match self.peek() {
370             Some(v) => v.is_i32_const().then(|| self.pop().unwrap().unwrap_i32()),
371             _ => None,
372         }
373     }
374 
375     /// Pops the element at the top of the stack if it is an i64 const;
376     /// returns `None` otherwise.
377     pub fn pop_i64_const(&mut self) -> Option<i64> {
378         match self.peek() {
379             Some(v) => v.is_i64_const().then(|| self.pop().unwrap().unwrap_i64()),
380             _ => None,
381         }
382     }
383 
384     /// Pops the element at the top of the stack if it is a register;
385     /// returns `None` otherwise.
386     pub fn pop_reg(&mut self) -> Option<TypedReg> {
387         match self.peek() {
388             Some(v) => v.is_reg().then(|| self.pop().unwrap().unwrap_reg()),
389             _ => None,
390         }
391     }
392 
393     /// Pops the given register if it is at the top of the stack;
394     /// returns `None` otherwise.
395     pub fn pop_named_reg(&mut self, reg: Reg) -> Option<TypedReg> {
396         match self.peek() {
397             Some(v) => {
398                 (v.is_reg() && v.unwrap_reg().reg == reg).then(|| self.pop().unwrap().unwrap_reg())
399             }
400             _ => None,
401         }
402     }
403 
404     /// Get a mutable reference to the inner stack representation.
405     pub fn inner_mut(&mut self) -> &mut SmallVec<[Val; 64]> {
406         &mut self.inner
407     }
408 
409     /// Get a reference to the inner stack representation.
410     pub fn inner(&self) -> &SmallVec<[Val; 64]> {
411         &self.inner
412     }
413 
414     /// Calculates the size of, in bytes, of the top n [Memory] entries
415     /// in the value stack.
416     pub fn sizeof(&self, top: usize) -> u32 {
417         self.peekn(top).fold(0, |acc, v| {
418             if v.is_mem() {
419                 acc + v.unwrap_mem().slot.size
420             } else {
421                 acc
422             }
423         })
424     }
425 }
426 
427 #[cfg(test)]
428 mod tests {
429     use super::{Stack, Val};
430     use crate::isa::reg::Reg;
431     use wasmtime_environ::WasmValType;
432 
433     #[test]
434     fn test_pop_i32_const() {
435         let mut stack = Stack::new();
436         stack.push(Val::i32(33i32));
437         assert_eq!(33, stack.pop_i32_const().unwrap());
438 
439         stack.push(Val::local(10, WasmValType::I32));
440         assert!(stack.pop_i32_const().is_none());
441     }
442 
443     #[test]
444     fn test_pop_reg() {
445         let mut stack = Stack::new();
446         let reg = Reg::int(2usize);
447         stack.push(Val::reg(reg, WasmValType::I32));
448         stack.push(Val::i32(4));
449 
450         assert_eq!(None, stack.pop_reg());
451         let _ = stack.pop().unwrap();
452         assert_eq!(reg, stack.pop_reg().unwrap().reg);
453     }
454 
455     #[test]
456     fn test_pop_named_reg() {
457         let mut stack = Stack::new();
458         let reg = Reg::int(2usize);
459         stack.push(Val::reg(reg, WasmValType::I32));
460         stack.push(Val::reg(Reg::int(4), WasmValType::I32));
461 
462         assert_eq!(None, stack.pop_named_reg(reg));
463         let _ = stack.pop().unwrap();
464         assert_eq!(reg, stack.pop_named_reg(reg).unwrap().reg);
465     }
466 }
467