1 //! A NaN-canonicalizing rewriting pass. Patch floating point arithmetic
2 //! instructions that may return a NaN result with a sequence of operations
3 //! that will replace nondeterministic NaN's with a single canonical NaN value.
4 
5 use crate::cursor::{Cursor, FuncCursor};
6 use crate::ir::condcodes::FloatCC;
7 use crate::ir::immediates::{Ieee16, Ieee32, Ieee64, Ieee128};
8 use crate::ir::types::{self};
9 use crate::ir::{Function, Inst, InstBuilder, InstructionData, Opcode, Value};
10 use crate::opts::MemFlags;
11 use crate::timing;
12 
13 /// Perform the NaN canonicalization pass.
do_nan_canonicalization(func: &mut Function, has_vector_support: bool)14 pub fn do_nan_canonicalization(func: &mut Function, has_vector_support: bool) {
15     let _tt = timing::canonicalize_nans();
16     let mut pos = FuncCursor::new(func);
17     while let Some(_block) = pos.next_block() {
18         while let Some(inst) = pos.next_inst() {
19             if is_fp_arith(&mut pos, inst) {
20                 add_nan_canon_seq(&mut pos, inst, has_vector_support);
21             }
22         }
23     }
24 }
25 
26 /// Returns true/false based on whether the instruction is a floating-point
27 /// arithmetic operation. This ignores operations like `fneg`, `fabs`, or
28 /// `fcopysign` that only operate on the sign bit of a floating point value.
is_fp_arith(pos: &mut FuncCursor, inst: Inst) -> bool29 fn is_fp_arith(pos: &mut FuncCursor, inst: Inst) -> bool {
30     match pos.func.dfg.insts[inst] {
31         InstructionData::Unary { opcode, .. } => {
32             opcode == Opcode::Ceil
33                 || opcode == Opcode::Floor
34                 || opcode == Opcode::Nearest
35                 || opcode == Opcode::Sqrt
36                 || opcode == Opcode::Trunc
37                 || opcode == Opcode::Fdemote
38                 || opcode == Opcode::Fpromote
39                 || opcode == Opcode::FvpromoteLow
40                 || opcode == Opcode::Fvdemote
41         }
42         InstructionData::Binary { opcode, .. } => {
43             opcode == Opcode::Fadd
44                 || opcode == Opcode::Fdiv
45                 || opcode == Opcode::Fmax
46                 || opcode == Opcode::Fmin
47                 || opcode == Opcode::Fmul
48                 || opcode == Opcode::Fsub
49         }
50         InstructionData::Ternary { opcode, .. } => opcode == Opcode::Fma,
51         _ => false,
52     }
53 }
54 
55 /// Append a sequence of canonicalizing instructions after the given instruction.
add_nan_canon_seq(pos: &mut FuncCursor, inst: Inst, has_vector_support: bool)56 fn add_nan_canon_seq(pos: &mut FuncCursor, inst: Inst, has_vector_support: bool) {
57     // Select the instruction result, result type. Replace the instruction
58     // result and step forward before inserting the canonicalization sequence.
59     let val = pos.func.dfg.first_result(inst);
60     let val_type = pos.func.dfg.value_type(val);
61     let new_res = pos.func.dfg.replace_result(val, val_type);
62     let _next_inst = pos.next_inst().expect("block missing terminator!");
63 
64     // Insert a comparison instruction, to check if `inst_res` is NaN (comparing
65     // against NaN is always unordered). Select the canonical NaN value if `val`
66     // is NaN, assign the result to `inst`.
67     let comparison = FloatCC::Unordered;
68 
69     let vectorized_scalar_select = |pos: &mut FuncCursor, canon_nan: Value, ty: types::Type| {
70         let canon_nan = pos.ins().scalar_to_vector(ty, canon_nan);
71         let new_res = pos.ins().scalar_to_vector(ty, new_res);
72         let is_nan = pos.ins().fcmp(comparison, new_res, new_res);
73         let is_nan = pos.ins().bitcast(ty, MemFlags::new(), is_nan);
74         let simd_result = pos.ins().bitselect(is_nan, canon_nan, new_res);
75         pos.ins().with_result(val).extractlane(simd_result, 0);
76     };
77     let scalar_select = |pos: &mut FuncCursor, canon_nan: Value| {
78         let is_nan = pos.ins().fcmp(comparison, new_res, new_res);
79         pos.ins()
80             .with_result(val)
81             .select(is_nan, canon_nan, new_res);
82     };
83 
84     let vector_select = |pos: &mut FuncCursor, canon_nan: Value| {
85         let is_nan = pos.ins().fcmp(comparison, new_res, new_res);
86         let is_nan = pos.ins().bitcast(val_type, MemFlags::new(), is_nan);
87         pos.ins()
88             .with_result(val)
89             .bitselect(is_nan, canon_nan, new_res);
90     };
91 
92     match val_type {
93         types::F16 => {
94             let canon_nan = pos.ins().f16const(Ieee16::NAN);
95             scalar_select(pos, canon_nan);
96         }
97         types::F32 => {
98             let canon_nan = pos.ins().f32const(Ieee32::NAN);
99             if has_vector_support {
100                 vectorized_scalar_select(pos, canon_nan, types::F32X4);
101             } else {
102                 scalar_select(pos, canon_nan);
103             }
104         }
105         types::F64 => {
106             let canon_nan = pos.ins().f64const(Ieee64::NAN);
107             if has_vector_support {
108                 vectorized_scalar_select(pos, canon_nan, types::F64X2);
109             } else {
110                 scalar_select(pos, canon_nan);
111             }
112         }
113         types::F32X4 => {
114             let canon_nan = pos.ins().f32const(Ieee32::NAN);
115             let canon_nan = pos.ins().splat(types::F32X4, canon_nan);
116             vector_select(pos, canon_nan);
117         }
118         types::F64X2 => {
119             let canon_nan = pos.ins().f64const(Ieee64::NAN);
120             let canon_nan = pos.ins().splat(types::F64X2, canon_nan);
121             vector_select(pos, canon_nan);
122         }
123         types::F128 => {
124             let nan_const = pos.func.dfg.constants.insert(Ieee128::NAN.into());
125             let canon_nan = pos.ins().f128const(nan_const);
126             scalar_select(pos, canon_nan);
127         }
128         _ => {
129             // Panic if the type given was not an IEEE floating point type.
130             panic!("Could not canonicalize NaN: Unexpected result type found.");
131         }
132     }
133 
134     pos.prev_inst(); // Step backwards so the pass does not skip instructions.
135 }
136