1 //! Stratification of call graphs for parallel bottom-up inlining.
2 //!
3 //! This module takes a call graph and constructs a strata, which is essentially
4 //! a parallel execution plan. A strata consists of an ordered sequence of
5 //! layers, and a layer of an unordered set of functions. The `i`th layer must
6 //! be processed before the `i + 1`th layer, but functions within the same layer
7 //! may be processed in any order (and in parallel).
8 //!
9 //! For example, when given the following tree-like call graph:
10 //!
11 //! ```text
12 //! +---+   +---+   +---+
13 //! | a |-->| b |-->| c |
14 //! +---+   +---+   +---+
15 //!   |       |
16 //!   |       |     +---+
17 //!   |       '---->| d |
18 //!   |             +---+
19 //!   |
20 //!   |     +---+   +---+
21 //!   '---->| e |-->| f |
22 //!         +---+   +---+
23 //!           |
24 //!           |     +---+
25 //!           '---->| g |
26 //!                 +---+
27 //! ```
28 //!
29 //! then stratification will produce these layers:
30 //!
31 //! ```text
32 //! [
33 //!     {c, d, f, g},
34 //!     {b, e},
35 //!     {a},
36 //! ]
37 //! ```
38 //!
39 //! Our goal in constructing the layers is to maximize potential parallelism at
40 //! each layer. Logically, we do this by finding the strongly-connected
41 //! components of the input call graph and peeling off all of the leaves of
42 //! SCCs' condensation (i.e. the DAG that the SCCs form; see the documentation
43 //! for the `StronglyConnectedComponents::evaporation` method for
44 //! details). These leaves become the strata's first layer. The layer's
45 //! components are removed from the condensation graph, and we repeat the
46 //! process, so that the condensation's new leaves become the strata's second
47 //! layer, and etc... until the condensation graph is empty and all components
48 //! have been processed. In practice we don't actually mutate the condensation
49 //! graph or remove its nodes but instead count how many unprocessed
50 //! dependencies each component has, and a component is ready for inclusion in a
51 //! layer once its unprocessed-dependencies count reaches zero.
52 
53 use super::*;
54 use std::{fmt::Debug, ops::Range};
55 use wasmtime_environ::{
56     EntityRef, SecondaryMap,
57     graphs::{Graph, Scc, StronglyConnectedComponents},
58 };
59 
60 /// A stratified call graph; essentially a parallel-execution plan for bottom-up
61 /// inlining.
62 ///
63 /// See the module doc comment for more details.
64 pub struct Strata<Node> {
65     layers: Vec<Range<u32>>,
66     layer_elems: Vec<Node>,
67 }
68 
69 impl<Node: Debug> Debug for Strata<Node> {
fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result70     fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
71         struct Layers<'a, Node>(&'a Strata<Node>);
72 
73         impl<'a, Node: Debug> Debug for Layers<'a, Node> {
74             fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
75                 let mut f = f.debug_list();
76                 for layer in self.0.layers() {
77                     f.entry(&layer);
78                 }
79                 f.finish()
80             }
81         }
82 
83         f.debug_struct("Strata")
84             .field("layers", &Layers(self))
85             .finish()
86     }
87 }
88 
89 impl<Node> Strata<Node> {
90     /// Stratify the given call graph, yielding a `Strata` parallel-execution
91     /// plan.
new<G>(call_graph: &G) -> Self where Node: EntityRef + Debug, G: Debug + Graph<Node>,92     pub fn new<G>(call_graph: &G) -> Self
93     where
94         Node: EntityRef + Debug,
95         G: Debug + Graph<Node>,
96     {
97         log::trace!("Stratifying {call_graph:#?}");
98 
99         let components = StronglyConnectedComponents::new(call_graph);
100         let evaporation = components.evaporation(call_graph);
101 
102         // A map from each component to the count of how many call-graph
103         // dependencies to other components it has that have not been fulfilled
104         // yet. These counts are decremented as we assign a component's dependencies
105         // to layers.
106         let mut unfulfilled_deps_count = SecondaryMap::<Scc, u32>::with_capacity(components.len());
107         for to_component in components.keys() {
108             for from_component in evaporation.reverse_edges(to_component) {
109                 unfulfilled_deps_count[*from_component] += 1;
110             }
111         }
112 
113         // Build the strata.
114         //
115         // The first layer is formed by searching through all components for those
116         // that have a zero unfulfilled-deps count. When we finish a layer, we
117         // iterate over each of component in that layer and decrement the
118         // unfulfilled-deps count of every other component that depends on the
119         // newly-assigned-to-a-layer component. Any component that then reaches a
120         // zero unfulfilled-dep count is added to the next layer. This proceeds to a
121         // fixed point, similarly to GC tracing and ref-count decrementing.
122 
123         let mut layers: Vec<Range<u32>> = vec![];
124         let (min, max) = call_graph.nodes().size_hint();
125         let cap = max.unwrap_or(min);
126         let mut layer_elems: Vec<Node> = Vec::with_capacity(cap);
127 
128         let mut current_layer: Vec<Scc> = components
129             .keys()
130             .filter(|scc| unfulfilled_deps_count[*scc] == 0)
131             .collect();
132         debug_assert!(
133             !current_layer.is_empty() || call_graph.nodes().next().is_none(),
134             "the first layer can only be empty when the call graph itself is empty"
135         );
136 
137         let mut next_layer = vec![];
138 
139         while !current_layer.is_empty() {
140             debug_assert!(next_layer.is_empty());
141 
142             for dependee in &current_layer {
143                 for depender in evaporation.reverse_edges(*dependee) {
144                     debug_assert!(unfulfilled_deps_count[*depender] > 0);
145                     unfulfilled_deps_count[*depender] -= 1;
146                     if unfulfilled_deps_count[*depender] == 0 {
147                         next_layer.push(*depender);
148                     }
149                 }
150             }
151 
152             layers.push(extend_with_range(
153                 &mut layer_elems,
154                 current_layer
155                     .drain(..)
156                     .flat_map(|scc| components.nodes(scc).iter().copied()),
157             ));
158 
159             std::mem::swap(&mut next_layer, &mut current_layer);
160         }
161 
162         debug_assert!(
163             unfulfilled_deps_count.values().all(|c| *c == 0),
164             "after every component is assigned to a layer, all dependencies should be fulfilled"
165         );
166 
167         let result = Strata {
168             layers,
169             layer_elems,
170         };
171         log::trace!("  -> {result:#?}");
172         result
173     }
174 
175     /// Iterate over the layers of this `Strata`.
176     ///
177     /// The `i`th layer must be processed before the `i + 1`th layer, but the
178     /// functions within a layer may be processed in any order and in parallel.
layers(&self) -> impl ExactSizeIterator<Item = &[Node]>179     pub fn layers(&self) -> impl ExactSizeIterator<Item = &[Node]> {
180         self.layers.iter().map(|range| {
181             let start = usize::try_from(range.start).unwrap();
182             let end = usize::try_from(range.end).unwrap();
183             &self.layer_elems[start..end]
184         })
185     }
186 }
187 
188 #[cfg(test)]
189 mod tests {
190     use super::*;
191     use wasmtime_environ::graphs::Graph;
192 
193     #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
194     struct Function(u32);
195     wasmtime_environ::entity_impl!(Function);
196 
197     #[derive(Debug)]
198     struct Functions {
199         calls: SecondaryMap<Function, Vec<Function>>,
200     }
201 
202     impl Default for Functions {
default() -> Self203         fn default() -> Self {
204             let _ = env_logger::try_init();
205             Self {
206                 calls: Default::default(),
207             }
208         }
209     }
210 
211     impl Graph<Function> for Functions {
212         type NodesIter<'a>
213             = wasmtime_environ::Keys<Function>
214         where
215             Self: 'a;
216 
nodes(&self) -> Self::NodesIter<'_>217         fn nodes(&self) -> Self::NodesIter<'_> {
218             self.calls.keys()
219         }
220 
221         type SuccessorsIter<'a>
222             = core::iter::Copied<core::slice::Iter<'a, Function>>
223         where
224             Self: 'a;
225 
successors(&self, f: Function) -> Self::SuccessorsIter<'_>226         fn successors(&self, f: Function) -> Self::SuccessorsIter<'_> {
227             self.calls[f].iter().copied()
228         }
229     }
230 
231     impl Functions {
define_func(&mut self, f: u32) -> &mut Self232         fn define_func(&mut self, f: u32) -> &mut Self {
233             let f = Function::from_u32(f);
234             if self.calls.get(f).is_none() {
235                 self.calls[f] = vec![];
236             }
237             self
238         }
239 
define_call(&mut self, caller: u32, callee: u32) -> &mut Self240         fn define_call(&mut self, caller: u32, callee: u32) -> &mut Self {
241             self.define_func(caller);
242             self.define_func(callee);
243             let caller = Function::from_u32(caller);
244             let callee = Function::from_u32(callee);
245             self.calls[caller].push(callee);
246             self
247         }
248 
define_calls( &mut self, caller: u32, callees: impl IntoIterator<Item = u32>, ) -> &mut Self249         fn define_calls(
250             &mut self,
251             caller: u32,
252             callees: impl IntoIterator<Item = u32>,
253         ) -> &mut Self {
254             for callee in callees {
255                 self.define_call(caller, callee);
256             }
257             self
258         }
259 
stratify(&self) -> Strata<Function>260         fn stratify(&self) -> Strata<Function> {
261             Strata::new(self)
262         }
263 
assert_stratification(&self, mut expected: Vec<Vec<u32>>)264         fn assert_stratification(&self, mut expected: Vec<Vec<u32>>) {
265             for layer in &mut expected {
266                 layer.sort();
267             }
268             log::trace!("expected stratification = {expected:?}");
269 
270             let actual = self
271                 .stratify()
272                 .layers()
273                 .map(|layer| {
274                     let mut layer = layer.iter().map(|f| f.as_u32()).collect::<Vec<_>>();
275                     layer.sort();
276                     layer
277                 })
278                 .collect::<Vec<_>>();
279             log::trace!("actual stratification = {actual:?}");
280 
281             assert_eq!(expected.len(), actual.iter().len());
282             for (expected, actual) in expected.into_iter().zip(actual) {
283                 log::trace!("expected layer = {expected:?}");
284                 log::trace!("  actual layer = {expected:?}");
285 
286                 assert_eq!(expected.len(), actual.len());
287                 for (expected, actual) in expected.into_iter().zip(actual) {
288                     assert_eq!(expected, actual);
289                 }
290             }
291         }
292     }
293 
294     #[test]
test_disconnected_functions()295     fn test_disconnected_functions() {
296         // +---+   +---+   +---+
297         // | 0 |   | 1 |   | 2 |
298         // +---+   +---+   +---+
299         Functions::default()
300             .define_func(0)
301             .define_func(1)
302             .define_func(2)
303             .assert_stratification(vec![vec![0, 1, 2]]);
304     }
305 
306     #[test]
test_chained_functions()307     fn test_chained_functions() {
308         // +---+   +---+   +---+
309         // | 0 |-->| 1 |-->| 2 |
310         // +---+   +---+   +---+
311         Functions::default()
312             .define_call(0, 1)
313             .define_call(1, 2)
314             .assert_stratification(vec![vec![2], vec![1], vec![0]]);
315     }
316 
317     #[test]
test_cycle()318     fn test_cycle() {
319         //   ,---------------.
320         //   V               |
321         // +---+   +---+   +---+
322         // | 0 |-->| 1 |-->| 2 |
323         // +---+   +---+   +---+
324         Functions::default()
325             .define_call(0, 1)
326             .define_call(1, 2)
327             .define_call(2, 0)
328             .assert_stratification(vec![vec![0, 1, 2]]);
329     }
330 
331     #[test]
test_tree()332     fn test_tree() {
333         // +---+   +---+   +---+
334         // | 0 |-->| 1 |-->| 2 |
335         // +---+   +---+   +---+
336         //   |       |
337         //   |       |     +---+
338         //   |       '---->| 3 |
339         //   |             +---+
340         //   |
341         //   |     +---+   +---+
342         //   '---->| 4 |-->| 5 |
343         //         +---+   +---+
344         //           |
345         //           |     +---+
346         //           '---->| 6 |
347         //                 +---+
348         Functions::default()
349             .define_calls(0, [1, 4])
350             .define_calls(1, [2, 3])
351             .define_calls(4, [5, 6])
352             .assert_stratification(vec![vec![2, 3, 5, 6], vec![1, 4], vec![0]]);
353     }
354 
355     #[test]
test_chain_of_cycles()356     fn test_chain_of_cycles() {
357         //   ,-----.
358         //   |     |
359         //   V     |
360         // +---+   |
361         // | 0 |---'
362         // +---+
363         //   |
364         //   V
365         // +---+    +---+
366         // | 1 |<-->| 2 |
367         // +---+    +---+
368         //  |
369         //  | ,----------------.
370         //  | |                |
371         //  V |                V
372         // +---+    +---+    +---+
373         // | 3 |<---| 4 |<---| 5 |
374         // +---+    +---+    +---+
375         Functions::default()
376             .define_calls(0, [0, 1])
377             .define_calls(1, [2, 3])
378             .define_calls(2, [1])
379             .define_calls(3, [5])
380             .define_calls(4, [3])
381             .define_calls(5, [4])
382             .assert_stratification(vec![vec![3, 4, 5], vec![1, 2], vec![0]]);
383     }
384 
385     #[test]
test_multiple_edges_to_same_component()386     fn test_multiple_edges_to_same_component() {
387         // +---+           +---+
388         // | 0 |           | 1 |
389         // +---+           +---+
390         //   ^               ^
391         //   |               |
392         //   V               V
393         // +---+           +---+
394         // | 2 |           | 3 |
395         // +---+           +---+
396         //   |               |
397         //   `------. ,------'
398         //          | |
399         //          V V
400         //         +---+
401         //         | 4 |
402         //         +---+
403         //           ^
404         //           |
405         //           V
406         //         +---+
407         //         | 5 |
408         //         +---+
409         Functions::default()
410             .define_calls(0, [2])
411             .define_calls(1, [3])
412             .define_calls(2, [0, 4])
413             .define_calls(3, [1, 4])
414             .define_calls(4, [5])
415             .define_calls(5, [4])
416             .assert_stratification(vec![vec![4, 5], vec![0, 1, 2, 3]]);
417     }
418 }
419