1 use super::HashMap;
2 use crate::frontend::FunctionBuilder;
3 use alloc::vec::Vec;
4 use cranelift_codegen::ir::condcodes::IntCC;
5 use cranelift_codegen::ir::*;
6 
7 type EntryIndex = u128;
8 
9 /// Unlike with `br_table`, `Switch` cases may be sparse or non-0-based.
10 /// They emit efficient code using branches, jump tables, or a combination of both.
11 ///
12 /// # Example
13 ///
14 /// ```rust
15 /// # use cranelift_codegen::ir::types::*;
16 /// # use cranelift_codegen::ir::{UserFuncName, Function, Signature, InstBuilder};
17 /// # use cranelift_codegen::isa::CallConv;
18 /// # use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Switch};
19 /// #
20 /// # let mut sig = Signature::new(CallConv::SystemV);
21 /// # let mut fn_builder_ctx = FunctionBuilderContext::new();
22 /// # let mut func = Function::with_name_signature(UserFuncName::user(0, 0), sig);
23 /// # let mut builder = FunctionBuilder::new(&mut func, &mut fn_builder_ctx);
24 /// #
25 /// # let entry = builder.create_block();
26 /// # builder.switch_to_block(entry);
27 /// #
28 /// let block0 = builder.create_block();
29 /// let block1 = builder.create_block();
30 /// let block2 = builder.create_block();
31 /// let fallback = builder.create_block();
32 ///
33 /// let val = builder.ins().iconst(I32, 1);
34 ///
35 /// let mut switch = Switch::new();
36 /// switch.set_entry(0, block0);
37 /// switch.set_entry(1, block1);
38 /// switch.set_entry(7, block2);
39 /// switch.emit(&mut builder, val, fallback);
40 /// ```
41 #[derive(Debug, Default)]
42 pub struct Switch {
43     cases: HashMap<EntryIndex, Block>,
44 }
45 
46 impl Switch {
47     /// Create a new empty switch
new() -> Self48     pub fn new() -> Self {
49         Self {
50             cases: HashMap::new(),
51         }
52     }
53 
54     /// Set a switch entry
set_entry(&mut self, index: EntryIndex, block: Block)55     pub fn set_entry(&mut self, index: EntryIndex, block: Block) {
56         let prev = self.cases.insert(index, block);
57         assert!(prev.is_none(), "Tried to set the same entry {index} twice");
58     }
59 
60     /// Get a reference to all existing entries
entries(&self) -> &HashMap<EntryIndex, Block>61     pub fn entries(&self) -> &HashMap<EntryIndex, Block> {
62         &self.cases
63     }
64 
65     /// Turn the `cases` `HashMap` into a list of `ContiguousCaseRange`s.
66     ///
67     /// # Postconditions
68     ///
69     /// * Every entry will be represented.
70     /// * The `ContiguousCaseRange`s will not overlap.
71     /// * Between two `ContiguousCaseRange`s there will be at least one entry index.
72     /// * No `ContiguousCaseRange`s will be empty.
collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange>73     fn collect_contiguous_case_ranges(self) -> Vec<ContiguousCaseRange> {
74         log::trace!("build_contiguous_case_ranges before: {:#?}", self.cases);
75         let mut cases = self.cases.into_iter().collect::<Vec<(_, _)>>();
76         cases.sort_by_key(|&(index, _)| index);
77 
78         let mut contiguous_case_ranges: Vec<ContiguousCaseRange> = vec![];
79         let mut last_index = None;
80         for (index, block) in cases {
81             match last_index {
82                 None => contiguous_case_ranges.push(ContiguousCaseRange::new(index)),
83                 Some(last_index) => {
84                     if index > last_index + 1 {
85                         contiguous_case_ranges.push(ContiguousCaseRange::new(index));
86                     }
87                 }
88             }
89             contiguous_case_ranges
90                 .last_mut()
91                 .unwrap()
92                 .blocks
93                 .push(block);
94             last_index = Some(index);
95         }
96 
97         log::trace!("build_contiguous_case_ranges after: {contiguous_case_ranges:#?}");
98 
99         contiguous_case_ranges
100     }
101 
102     /// Binary search for the right `ContiguousCaseRange`.
build_search_tree<'a>( bx: &mut FunctionBuilder, val: Value, otherwise: Block, contiguous_case_ranges: &'a [ContiguousCaseRange], )103     fn build_search_tree<'a>(
104         bx: &mut FunctionBuilder,
105         val: Value,
106         otherwise: Block,
107         contiguous_case_ranges: &'a [ContiguousCaseRange],
108     ) {
109         // If no switch cases were added to begin with, we can just emit `jump otherwise`.
110         if contiguous_case_ranges.is_empty() {
111             bx.ins().jump(otherwise, &[]);
112             return;
113         }
114 
115         // Avoid allocation in the common case
116         if contiguous_case_ranges.len() <= 3 {
117             Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
118             return;
119         }
120 
121         let mut stack = Vec::new();
122         stack.push((None, contiguous_case_ranges));
123 
124         while let Some((block, contiguous_case_ranges)) = stack.pop() {
125             if let Some(block) = block {
126                 bx.switch_to_block(block);
127             }
128 
129             if contiguous_case_ranges.len() <= 3 {
130                 Self::build_search_branches(bx, val, otherwise, contiguous_case_ranges);
131             } else {
132                 let split_point = contiguous_case_ranges.len() / 2;
133                 let (left, right) = contiguous_case_ranges.split_at(split_point);
134 
135                 let left_block = bx.create_block();
136                 let right_block = bx.create_block();
137 
138                 let first_index = right[0].first_index;
139                 let should_take_right_side =
140                     icmp_imm_u128(bx, IntCC::UnsignedGreaterThanOrEqual, val, first_index);
141                 bx.ins()
142                     .brif(should_take_right_side, right_block, &[], left_block, &[]);
143 
144                 bx.seal_block(left_block);
145                 bx.seal_block(right_block);
146 
147                 stack.push((Some(left_block), left));
148                 stack.push((Some(right_block), right));
149             }
150         }
151     }
152 
153     /// Linear search for the right `ContiguousCaseRange`.
build_search_branches<'a>( bx: &mut FunctionBuilder, val: Value, otherwise: Block, contiguous_case_ranges: &'a [ContiguousCaseRange], )154     fn build_search_branches<'a>(
155         bx: &mut FunctionBuilder,
156         val: Value,
157         otherwise: Block,
158         contiguous_case_ranges: &'a [ContiguousCaseRange],
159     ) {
160         for (ix, range) in contiguous_case_ranges.iter().enumerate().rev() {
161             let alternate = if ix == 0 {
162                 otherwise
163             } else {
164                 bx.create_block()
165             };
166 
167             if range.first_index == 0 {
168                 assert_eq!(alternate, otherwise);
169 
170                 if let Some(block) = range.single_block() {
171                     bx.ins().brif(val, otherwise, &[], block, &[]);
172                 } else {
173                     Self::build_jump_table(bx, val, otherwise, 0, &range.blocks);
174                 }
175             } else {
176                 if let Some(block) = range.single_block() {
177                     let is_good_val = icmp_imm_u128(bx, IntCC::Equal, val, range.first_index);
178                     bx.ins().brif(is_good_val, block, &[], alternate, &[]);
179                 } else {
180                     let is_good_val = icmp_imm_u128(
181                         bx,
182                         IntCC::UnsignedGreaterThanOrEqual,
183                         val,
184                         range.first_index,
185                     );
186                     let jt_block = bx.create_block();
187                     bx.ins().brif(is_good_val, jt_block, &[], alternate, &[]);
188                     bx.seal_block(jt_block);
189                     bx.switch_to_block(jt_block);
190                     Self::build_jump_table(bx, val, otherwise, range.first_index, &range.blocks);
191                 }
192             }
193 
194             if alternate != otherwise {
195                 bx.seal_block(alternate);
196                 bx.switch_to_block(alternate);
197             }
198         }
199     }
200 
build_jump_table( bx: &mut FunctionBuilder, val: Value, otherwise: Block, first_index: EntryIndex, blocks: &[Block], )201     fn build_jump_table(
202         bx: &mut FunctionBuilder,
203         val: Value,
204         otherwise: Block,
205         first_index: EntryIndex,
206         blocks: &[Block],
207     ) {
208         // There are currently no 128bit systems supported by rustc, but once we do ensure that
209         // we don't silently ignore a part of the jump table for 128bit integers on 128bit systems.
210         assert!(
211             u32::try_from(blocks.len()).is_ok(),
212             "Jump tables bigger than 2^32-1 are not yet supported"
213         );
214 
215         let jt_data = JumpTableData::new(
216             bx.func.dfg.block_call(otherwise, &[]),
217             &blocks
218                 .iter()
219                 .map(|block| bx.func.dfg.block_call(*block, &[]))
220                 .collect::<Vec<_>>(),
221         );
222         let jump_table = bx.create_jump_table(jt_data);
223 
224         let discr = if first_index == 0 {
225             val
226         } else {
227             if let Ok(first_index) = u64::try_from(first_index) {
228                 bx.ins().iadd_imm(val, (first_index as i64).wrapping_neg())
229             } else {
230                 let (lsb, msb) = (first_index as u64, (first_index >> 64) as u64);
231                 let lsb = bx.ins().iconst(types::I64, lsb as i64);
232                 let msb = bx.ins().iconst(types::I64, msb as i64);
233                 let index = bx.ins().iconcat(lsb, msb);
234                 bx.ins().isub(val, index)
235             }
236         };
237 
238         let discr = match bx.func.dfg.value_type(discr).bits() {
239             bits if bits > 32 => {
240                 // Check for overflow of cast to u32. This is the max supported jump table entries.
241                 let new_block = bx.create_block();
242                 let bigger_than_u32 =
243                     bx.ins()
244                         .icmp_imm(IntCC::UnsignedGreaterThan, discr, u32::MAX as i64);
245                 bx.ins()
246                     .brif(bigger_than_u32, otherwise, &[], new_block, &[]);
247                 bx.seal_block(new_block);
248                 bx.switch_to_block(new_block);
249 
250                 // Cast to i32, as br_table is not implemented for i64/i128
251                 bx.ins().ireduce(types::I32, discr)
252             }
253             bits if bits < 32 => bx.ins().uextend(types::I32, discr),
254             _ => discr,
255         };
256 
257         bx.ins().br_table(discr, jump_table);
258     }
259 
260     /// Build the switch
261     ///
262     /// # Arguments
263     ///
264     /// * The function builder to emit to
265     /// * The value to switch on
266     /// * The default block
emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block)267     pub fn emit(self, bx: &mut FunctionBuilder, val: Value, otherwise: Block) {
268         // Validate that the type of `val` is sufficiently wide to address all cases.
269         let max = self.cases.keys().max().copied().unwrap_or(0);
270         let val_ty = bx.func.dfg.value_type(val);
271         let val_ty_max = val_ty.bounds(false).1;
272         if max > val_ty_max {
273             panic!("The index type {val_ty} does not fit the maximum switch entry of {max}");
274         }
275 
276         let contiguous_case_ranges = self.collect_contiguous_case_ranges();
277         Self::build_search_tree(bx, val, otherwise, &contiguous_case_ranges);
278     }
279 }
280 
icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value281 fn icmp_imm_u128(bx: &mut FunctionBuilder, cond: IntCC, x: Value, y: u128) -> Value {
282     if bx.func.dfg.value_type(x) != types::I128 {
283         assert!(u64::try_from(y).is_ok());
284         bx.ins().icmp_imm(cond, x, y as i64)
285     } else if let Ok(index) = i64::try_from(y) {
286         bx.ins().icmp_imm(cond, x, index)
287     } else {
288         let (lsb, msb) = (y as u64, (y >> 64) as u64);
289         let lsb = bx.ins().iconst(types::I64, lsb as i64);
290         let msb = bx.ins().iconst(types::I64, msb as i64);
291         let index = bx.ins().iconcat(lsb, msb);
292         bx.ins().icmp(cond, x, index)
293     }
294 }
295 
296 /// This represents a contiguous range of cases to switch on.
297 ///
298 /// For example 10 => block1, 11 => block2, 12 => block7 will be represented as:
299 ///
300 /// ```plain
301 /// ContiguousCaseRange {
302 ///     first_index: 10,
303 ///     blocks: vec![Block::from_u32(1), Block::from_u32(2), Block::from_u32(7)]
304 /// }
305 /// ```
306 #[derive(Debug)]
307 struct ContiguousCaseRange {
308     /// The entry index of the first case. Eg. 10 when the entry indexes are 10, 11, 12 and 13.
309     first_index: EntryIndex,
310 
311     /// The blocks to jump to sorted in ascending order of entry index.
312     blocks: Vec<Block>,
313 }
314 
315 impl ContiguousCaseRange {
new(first_index: EntryIndex) -> Self316     fn new(first_index: EntryIndex) -> Self {
317         Self {
318             first_index,
319             blocks: Vec::new(),
320         }
321     }
322 
323     /// Returns `Some` block when there is only a single block in this range.
single_block(&self) -> Option<Block>324     fn single_block(&self) -> Option<Block> {
325         if self.blocks.len() == 1 {
326             Some(self.blocks[0])
327         } else {
328             None
329         }
330     }
331 }
332 
333 #[cfg(test)]
334 mod tests {
335     use super::*;
336     use crate::frontend::FunctionBuilderContext;
337     use alloc::string::ToString;
338 
339     macro_rules! setup {
340         ($default:expr, [$($index:expr,)*]) => {{
341             let mut func = Function::new();
342             let mut func_ctx = FunctionBuilderContext::new();
343             {
344                 let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
345                 let block = bx.create_block();
346                 bx.switch_to_block(block);
347                 let val = bx.ins().iconst(types::I8, 0);
348                 let mut switch = Switch::new();
349                 let _ = &mut switch;
350                 $(
351                     let block = bx.create_block();
352                     switch.set_entry($index, block);
353                 )*
354                 switch.emit(&mut bx, val, Block::with_number($default).unwrap());
355             }
356             func
357                 .to_string()
358                 .trim_start_matches("function u0:0() fast {\n")
359                 .trim_end_matches("\n}\n")
360                 .to_string()
361         }};
362     }
363 
364     #[test]
switch_empty()365     fn switch_empty() {
366         let func = setup!(42, []);
367         assert_eq_output!(
368             func,
369             "block0:
370     v0 = iconst.i8 0
371     jump block42"
372         );
373     }
374 
375     #[test]
switch_zero()376     fn switch_zero() {
377         let func = setup!(0, [0,]);
378         assert_eq_output!(
379             func,
380             "block0:
381     v0 = iconst.i8 0
382     brif v0, block0, block1  ; v0 = 0"
383         );
384     }
385 
386     #[test]
switch_single()387     fn switch_single() {
388         let func = setup!(0, [1,]);
389         assert_eq_output!(
390             func,
391             "block0:
392     v0 = iconst.i8 0
393     v1 = icmp_imm eq v0, 1  ; v0 = 0
394     brif v1, block1, block0"
395         );
396     }
397 
398     #[test]
switch_bool()399     fn switch_bool() {
400         let func = setup!(0, [0, 1,]);
401         assert_eq_output!(
402             func,
403             "block0:
404     v0 = iconst.i8 0
405     v1 = uextend.i32 v0  ; v0 = 0
406     br_table v1, block0, [block1, block2]"
407         );
408     }
409 
410     #[test]
switch_two_gap()411     fn switch_two_gap() {
412         let func = setup!(0, [0, 2,]);
413         assert_eq_output!(
414             func,
415             "block0:
416     v0 = iconst.i8 0
417     v1 = icmp_imm eq v0, 2  ; v0 = 0
418     brif v1, block2, block3
419 
420 block3:
421     brif.i8 v0, block0, block1  ; v0 = 0"
422         );
423     }
424 
425     #[test]
switch_many()426     fn switch_many() {
427         let func = setup!(0, [0, 1, 5, 7, 10, 11, 12,]);
428         assert_eq_output!(
429             func,
430             "block0:
431     v0 = iconst.i8 0
432     v1 = icmp_imm uge v0, 7  ; v0 = 0
433     brif v1, block9, block8
434 
435 block9:
436     v2 = icmp_imm.i8 uge v0, 10  ; v0 = 0
437     brif v2, block11, block10
438 
439 block11:
440     v3 = iadd_imm.i8 v0, -10  ; v0 = 0
441     v4 = uextend.i32 v3
442     br_table v4, block0, [block5, block6, block7]
443 
444 block10:
445     v5 = icmp_imm.i8 eq v0, 7  ; v0 = 0
446     brif v5, block4, block0
447 
448 block8:
449     v6 = icmp_imm.i8 eq v0, 5  ; v0 = 0
450     brif v6, block3, block12
451 
452 block12:
453     v7 = uextend.i32 v0  ; v0 = 0
454     br_table v7, block0, [block1, block2]"
455         );
456     }
457 
458     #[test]
switch_min_index_value()459     fn switch_min_index_value() {
460         let func = setup!(0, [i8::MIN as u8 as u128, 1,]);
461         assert_eq_output!(
462             func,
463             "block0:
464     v0 = iconst.i8 0
465     v1 = icmp_imm eq v0, -128  ; v0 = 0
466     brif v1, block1, block3
467 
468 block3:
469     v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
470     brif v2, block2, block0"
471         );
472     }
473 
474     #[test]
switch_max_index_value()475     fn switch_max_index_value() {
476         let func = setup!(0, [i8::MAX as u8 as u128, 1,]);
477         assert_eq_output!(
478             func,
479             "block0:
480     v0 = iconst.i8 0
481     v1 = icmp_imm eq v0, 127  ; v0 = 0
482     brif v1, block1, block3
483 
484 block3:
485     v2 = icmp_imm.i8 eq v0, 1  ; v0 = 0
486     brif v2, block2, block0"
487         )
488     }
489 
490     #[test]
switch_optimal_codegen()491     fn switch_optimal_codegen() {
492         let func = setup!(0, [-1i8 as u8 as u128, 0, 1,]);
493         assert_eq_output!(
494             func,
495             "block0:
496     v0 = iconst.i8 0
497     v1 = icmp_imm eq v0, -1  ; v0 = 0
498     brif v1, block1, block4
499 
500 block4:
501     v2 = uextend.i32 v0  ; v0 = 0
502     br_table v2, block0, [block2, block3]"
503         );
504     }
505 
506     #[test]
507     #[should_panic(
508         expected = "The index type i8 does not fit the maximum switch entry of 4683743612477887600"
509     )]
switch_rejects_small_inputs()510     fn switch_rejects_small_inputs() {
511         // This is a regression test for a bug that we found where we would emit a cmp
512         // with a type that was not able to fully represent a large index.
513         //
514         // See: https://github.com/bytecodealliance/wasmtime/pull/4502#issuecomment-1191961677
515         setup!(1, [0x4100_0000_00bf_d470,]);
516     }
517 
518     #[test]
switch_seal_generated_blocks()519     fn switch_seal_generated_blocks() {
520         let cases = &[vec![0, 1, 2], vec![0, 1, 2, 10, 11, 12, 20, 30, 40, 50]];
521 
522         for case in cases {
523             for typ in &[types::I8, types::I16, types::I32, types::I64, types::I128] {
524                 eprintln!("Testing {typ:?} with keys: {case:?}");
525                 do_case(case, *typ);
526             }
527         }
528 
529         fn do_case(keys: &[u128], typ: Type) {
530             let mut func = Function::new();
531             let mut builder_ctx = FunctionBuilderContext::new();
532             let mut builder = FunctionBuilder::new(&mut func, &mut builder_ctx);
533 
534             let root_block = builder.create_block();
535             let default_block = builder.create_block();
536             let mut switch = Switch::new();
537 
538             let case_blocks = keys
539                 .iter()
540                 .map(|key| {
541                     let block = builder.create_block();
542                     switch.set_entry(*key, block);
543                     block
544                 })
545                 .collect::<Vec<_>>();
546 
547             builder.seal_block(root_block);
548             builder.switch_to_block(root_block);
549 
550             let val = builder.ins().iconst(typ, 1);
551             switch.emit(&mut builder, val, default_block);
552 
553             for &block in case_blocks.iter().chain(std::iter::once(&default_block)) {
554                 builder.seal_block(block);
555                 builder.switch_to_block(block);
556                 builder.ins().return_(&[]);
557             }
558 
559             builder.finalize(); // Will panic if some blocks are not sealed
560         }
561     }
562 
563     #[test]
switch_64bit()564     fn switch_64bit() {
565         let mut func = Function::new();
566         let mut func_ctx = FunctionBuilderContext::new();
567         {
568             let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
569             let block0 = bx.create_block();
570             bx.switch_to_block(block0);
571             let val = bx.ins().iconst(types::I64, 0);
572             let mut switch = Switch::new();
573             let block1 = bx.create_block();
574             switch.set_entry(1, block1);
575             let block2 = bx.create_block();
576             switch.set_entry(0, block2);
577             let block3 = bx.create_block();
578             switch.emit(&mut bx, val, block3);
579         }
580         let func = func
581             .to_string()
582             .trim_start_matches("function u0:0() fast {\n")
583             .trim_end_matches("\n}\n")
584             .to_string();
585         assert_eq_output!(
586             func,
587             "block0:
588     v0 = iconst.i64 0
589     v1 = icmp_imm ugt v0, 0xffff_ffff  ; v0 = 0
590     brif v1, block3, block4
591 
592 block4:
593     v2 = ireduce.i32 v0  ; v0 = 0
594     br_table v2, block3, [block2, block1]"
595         );
596     }
597 
598     #[test]
switch_128bit()599     fn switch_128bit() {
600         let mut func = Function::new();
601         let mut func_ctx = FunctionBuilderContext::new();
602         {
603             let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
604             let block0 = bx.create_block();
605             bx.switch_to_block(block0);
606             let val = bx.ins().iconst(types::I64, 0);
607             let val = bx.ins().uextend(types::I128, val);
608             let mut switch = Switch::new();
609             let block1 = bx.create_block();
610             switch.set_entry(1, block1);
611             let block2 = bx.create_block();
612             switch.set_entry(0, block2);
613             let block3 = bx.create_block();
614             switch.emit(&mut bx, val, block3);
615         }
616         let func = func
617             .to_string()
618             .trim_start_matches("function u0:0() fast {\n")
619             .trim_end_matches("\n}\n")
620             .to_string();
621         assert_eq_output!(
622             func,
623             "block0:
624     v0 = iconst.i64 0
625     v1 = uextend.i128 v0  ; v0 = 0
626     v2 = icmp_imm ugt v1, 0xffff_ffff
627     brif v2, block3, block4
628 
629 block4:
630     v3 = ireduce.i32 v1
631     br_table v3, block3, [block2, block1]"
632         );
633     }
634 
635     #[test]
switch_128bit_max_u64()636     fn switch_128bit_max_u64() {
637         let mut func = Function::new();
638         let mut func_ctx = FunctionBuilderContext::new();
639         {
640             let mut bx = FunctionBuilder::new(&mut func, &mut func_ctx);
641             let block0 = bx.create_block();
642             bx.switch_to_block(block0);
643             let val = bx.ins().iconst(types::I64, 0);
644             let val = bx.ins().uextend(types::I128, val);
645             let mut switch = Switch::new();
646             let block1 = bx.create_block();
647             switch.set_entry(u64::MAX.into(), block1);
648             let block2 = bx.create_block();
649             switch.set_entry(0, block2);
650             let block3 = bx.create_block();
651             switch.emit(&mut bx, val, block3);
652         }
653         let func = func
654             .to_string()
655             .trim_start_matches("function u0:0() fast {\n")
656             .trim_end_matches("\n}\n")
657             .to_string();
658         assert_eq_output!(
659             func,
660             "block0:
661     v0 = iconst.i64 0
662     v1 = uextend.i128 v0  ; v0 = 0
663     v2 = iconst.i64 -1
664     v3 = iconst.i64 0
665     v4 = iconcat v2, v3  ; v2 = -1, v3 = 0
666     v5 = icmp eq v1, v4
667     brif v5, block1, block4
668 
669 block4:
670     brif.i128 v1, block3, block2"
671         );
672     }
673 }
674