1 //! Cost functions for egraph representation.
2 
3 use crate::ir::Opcode;
4 
5 /// A cost of computing some value in the program.
6 ///
7 /// Costs are measured in an arbitrary union that we represent in a
8 /// `u32`. The ordering is meant to be meaningful, but the value of a
9 /// single unit is arbitrary (and "not to scale"). We use a collection
10 /// of heuristics to try to make this approximation at least usable.
11 ///
12 /// We start by defining costs for each opcode (see `pure_op_cost`
13 /// below). The cost of computing some value, initially, is the cost
14 /// of its opcode, plus the cost of computing its inputs.
15 ///
16 /// We then adjust the cost according to loop nests: for each
17 /// loop-nest level, we multiply by 1024. Because we only have 32
18 /// bits, we limit this scaling to a loop-level of two (i.e., multiply
19 /// by 2^20 ~= 1M).
20 ///
21 /// Arithmetic on costs is always saturating: we don't want to wrap
22 /// around and return to a tiny cost when adding the costs of two very
23 /// expensive operations. It is better to approximate and lose some
24 /// precision than to lose the ordering by wrapping.
25 ///
26 /// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
27 /// that means "infinite". This is separate from the finite costs and
28 /// not reachable by doing arithmetic on them (even when overflowing)
29 /// -- we saturate just *below* infinity. (This is done by the
30 /// `finite()` method.) An infinite cost is used to represent a value
31 /// that cannot be computed, or otherwise serve as a sentinel when
32 /// performing search for the lowest-cost representation of a value.
33 #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
34 pub(crate) struct Cost(u32);
35 
36 impl core::fmt::Debug for Cost {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result37     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
38         if *self == Cost::infinity() {
39             write!(f, "Cost::Infinite")
40         } else {
41             f.debug_tuple("Cost::Finite").field(&self.cost()).finish()
42         }
43     }
44 }
45 
46 impl Cost {
infinity() -> Cost47     pub(crate) fn infinity() -> Cost {
48         // 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
49         // only for heuristics and always saturate so this suffices!)
50         Cost(u32::MAX)
51     }
52 
zero() -> Cost53     pub(crate) fn zero() -> Cost {
54         Cost(0)
55     }
56 
57     /// Construct a new `Cost`.
new(cost: u32) -> Cost58     fn new(cost: u32) -> Cost {
59         Cost(cost)
60     }
61 
cost(&self) -> u3262     fn cost(&self) -> u32 {
63         self.0
64     }
65 
66     /// Return the cost of an opcode.
of_opcode(op: Opcode) -> Cost67     fn of_opcode(op: Opcode) -> Cost {
68         match op {
69             // Constants.
70             Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1),
71 
72             // Extends/reduces.
73             Opcode::Uextend
74             | Opcode::Sextend
75             | Opcode::Ireduce
76             | Opcode::Iconcat
77             | Opcode::Isplit => Cost::new(1),
78 
79             // "Simple" arithmetic.
80             Opcode::Iadd
81             | Opcode::Isub
82             | Opcode::Band
83             | Opcode::Bor
84             | Opcode::Bxor
85             | Opcode::Bnot
86             | Opcode::Ishl
87             | Opcode::Ushr
88             | Opcode::Sshr => Cost::new(3),
89 
90             // "Expensive" arithmetic.
91             Opcode::Imul => Cost::new(10),
92 
93             // Everything else.
94             _ => {
95                 // By default, be slightly more expensive than "simple"
96                 // arithmetic.
97                 let mut c = Cost::new(4);
98 
99                 // And then get more expensive as the opcode does more side
100                 // effects.
101                 if op.can_trap() || op.other_side_effects() {
102                     c = c + Cost::new(10);
103                 }
104                 if op.can_load() {
105                     c = c + Cost::new(20);
106                 }
107                 if op.can_store() {
108                     c = c + Cost::new(50);
109                 }
110 
111                 c
112             }
113         }
114     }
115 
116     /// Compute the cost of the operation and its given operands.
117     ///
118     /// Caller is responsible for checking that the opcode came from an instruction
119     /// that satisfies `inst_predicates::is_pure_for_egraph()`.
of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self120     pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
121         let c = Self::of_opcode(op) + operand_costs.into_iter().sum();
122         Cost::new(c.cost())
123     }
124 
125     /// Compute the cost of an operation in the side-effectful skeleton.
of_skeleton_op(op: Opcode, arity: usize) -> Self126     pub(crate) fn of_skeleton_op(op: Opcode, arity: usize) -> Self {
127         Cost::of_opcode(op) + Cost::new(u32::try_from(arity).unwrap())
128     }
129 }
130 
131 impl core::iter::Sum<Cost> for Cost {
sum<I: Iterator<Item = Cost>>(iter: I) -> Self132     fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
133         iter.fold(Self::zero(), |a, b| a + b)
134     }
135 }
136 
137 impl core::default::Default for Cost {
default() -> Cost138     fn default() -> Cost {
139         Cost::zero()
140     }
141 }
142 
143 impl core::ops::Add<Cost> for Cost {
144     type Output = Cost;
145 
add(self, other: Cost) -> Cost146     fn add(self, other: Cost) -> Cost {
147         Cost::new(self.cost().saturating_add(other.cost()))
148     }
149 }
150 
151 #[cfg(test)]
152 mod tests {
153     use super::*;
154 
155     #[test]
add_cost()156     fn add_cost() {
157         let a = Cost::new(5);
158         let b = Cost::new(37);
159         assert_eq!(a + b, Cost::new(42));
160         assert_eq!(b + a, Cost::new(42));
161     }
162 
163     #[test]
add_infinity()164     fn add_infinity() {
165         let a = Cost::new(5);
166         let b = Cost::infinity();
167         assert_eq!(a + b, Cost::infinity());
168         assert_eq!(b + a, Cost::infinity());
169     }
170 
171     #[test]
op_cost_saturates_to_infinity()172     fn op_cost_saturates_to_infinity() {
173         let a = Cost::new(u32::MAX - 10);
174         let b = Cost::new(11);
175         assert_eq!(a + b, Cost::infinity());
176         assert_eq!(b + a, Cost::infinity());
177     }
178 }
179