1 //! Cost functions for egraph representation. 2 3 use crate::ir::Opcode; 4 5 /// A cost of computing some value in the program. 6 /// 7 /// Costs are measured in an arbitrary union that we represent in a 8 /// `u32`. The ordering is meant to be meaningful, but the value of a 9 /// single unit is arbitrary (and "not to scale"). We use a collection 10 /// of heuristics to try to make this approximation at least usable. 11 /// 12 /// We start by defining costs for each opcode (see `pure_op_cost` 13 /// below). The cost of computing some value, initially, is the cost 14 /// of its opcode, plus the cost of computing its inputs. 15 /// 16 /// We then adjust the cost according to loop nests: for each 17 /// loop-nest level, we multiply by 1024. Because we only have 32 18 /// bits, we limit this scaling to a loop-level of two (i.e., multiply 19 /// by 2^20 ~= 1M). 20 /// 21 /// Arithmetic on costs is always saturating: we don't want to wrap 22 /// around and return to a tiny cost when adding the costs of two very 23 /// expensive operations. It is better to approximate and lose some 24 /// precision than to lose the ordering by wrapping. 25 /// 26 /// Finally, we reserve the highest value, `u32::MAX`, as a sentinel 27 /// that means "infinite". This is separate from the finite costs and 28 /// not reachable by doing arithmetic on them (even when overflowing) 29 /// -- we saturate just *below* infinity. (This is done by the 30 /// `finite()` method.) An infinite cost is used to represent a value 31 /// that cannot be computed, or otherwise serve as a sentinel when 32 /// performing search for the lowest-cost representation of a value. 33 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] 34 pub(crate) struct Cost(u32); 35 36 impl core::fmt::Debug for Cost { fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result37 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 38 if *self == Cost::infinity() { 39 write!(f, "Cost::Infinite") 40 } else { 41 f.debug_tuple("Cost::Finite").field(&self.cost()).finish() 42 } 43 } 44 } 45 46 impl Cost { infinity() -> Cost47 pub(crate) fn infinity() -> Cost { 48 // 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost` 49 // only for heuristics and always saturate so this suffices!) 50 Cost(u32::MAX) 51 } 52 zero() -> Cost53 pub(crate) fn zero() -> Cost { 54 Cost(0) 55 } 56 57 /// Construct a new `Cost`. new(cost: u32) -> Cost58 fn new(cost: u32) -> Cost { 59 Cost(cost) 60 } 61 cost(&self) -> u3262 fn cost(&self) -> u32 { 63 self.0 64 } 65 66 /// Return the cost of an opcode. of_opcode(op: Opcode) -> Cost67 fn of_opcode(op: Opcode) -> Cost { 68 match op { 69 // Constants. 70 Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1), 71 72 // Extends/reduces. 73 Opcode::Uextend 74 | Opcode::Sextend 75 | Opcode::Ireduce 76 | Opcode::Iconcat 77 | Opcode::Isplit => Cost::new(1), 78 79 // "Simple" arithmetic. 80 Opcode::Iadd 81 | Opcode::Isub 82 | Opcode::Band 83 | Opcode::Bor 84 | Opcode::Bxor 85 | Opcode::Bnot 86 | Opcode::Ishl 87 | Opcode::Ushr 88 | Opcode::Sshr => Cost::new(3), 89 90 // "Expensive" arithmetic. 91 Opcode::Imul => Cost::new(10), 92 93 // Everything else. 94 _ => { 95 // By default, be slightly more expensive than "simple" 96 // arithmetic. 97 let mut c = Cost::new(4); 98 99 // And then get more expensive as the opcode does more side 100 // effects. 101 if op.can_trap() || op.other_side_effects() { 102 c = c + Cost::new(10); 103 } 104 if op.can_load() { 105 c = c + Cost::new(20); 106 } 107 if op.can_store() { 108 c = c + Cost::new(50); 109 } 110 111 c 112 } 113 } 114 } 115 116 /// Compute the cost of the operation and its given operands. 117 /// 118 /// Caller is responsible for checking that the opcode came from an instruction 119 /// that satisfies `inst_predicates::is_pure_for_egraph()`. of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self120 pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self { 121 let c = Self::of_opcode(op) + operand_costs.into_iter().sum(); 122 Cost::new(c.cost()) 123 } 124 125 /// Compute the cost of an operation in the side-effectful skeleton. of_skeleton_op(op: Opcode, arity: usize) -> Self126 pub(crate) fn of_skeleton_op(op: Opcode, arity: usize) -> Self { 127 Cost::of_opcode(op) + Cost::new(u32::try_from(arity).unwrap()) 128 } 129 } 130 131 impl core::iter::Sum<Cost> for Cost { sum<I: Iterator<Item = Cost>>(iter: I) -> Self132 fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self { 133 iter.fold(Self::zero(), |a, b| a + b) 134 } 135 } 136 137 impl core::default::Default for Cost { default() -> Cost138 fn default() -> Cost { 139 Cost::zero() 140 } 141 } 142 143 impl core::ops::Add<Cost> for Cost { 144 type Output = Cost; 145 add(self, other: Cost) -> Cost146 fn add(self, other: Cost) -> Cost { 147 Cost::new(self.cost().saturating_add(other.cost())) 148 } 149 } 150 151 #[cfg(test)] 152 mod tests { 153 use super::*; 154 155 #[test] add_cost()156 fn add_cost() { 157 let a = Cost::new(5); 158 let b = Cost::new(37); 159 assert_eq!(a + b, Cost::new(42)); 160 assert_eq!(b + a, Cost::new(42)); 161 } 162 163 #[test] add_infinity()164 fn add_infinity() { 165 let a = Cost::new(5); 166 let b = Cost::infinity(); 167 assert_eq!(a + b, Cost::infinity()); 168 assert_eq!(b + a, Cost::infinity()); 169 } 170 171 #[test] op_cost_saturates_to_infinity()172 fn op_cost_saturates_to_infinity() { 173 let a = Cost::new(u32::MAX - 10); 174 let b = Cost::new(11); 175 assert_eq!(a + b, Cost::infinity()); 176 assert_eq!(b + a, Cost::infinity()); 177 } 178 } 179