1 //! Run tests concurrently.
2 //!
3 //! This module provides the `ConcurrentRunner` struct which uses a pool of threads to run tests
4 //! concurrently.
5 
6 use crate::runone;
7 use cranelift_codegen::dbg::LOG_FILENAME_PREFIX;
8 use cranelift_codegen::timing;
9 use log::error;
10 use std::panic::catch_unwind;
11 use std::path::{Path, PathBuf};
12 use std::sync::mpsc::{Receiver, Sender, channel};
13 use std::sync::{Arc, Mutex};
14 use std::thread;
15 use std::time::Duration;
16 
17 /// Request sent to worker threads contains jobid and path.
18 struct Request(usize, PathBuf);
19 
20 /// Reply from worker thread,
21 pub enum Reply {
22     Starting {
23         jobid: usize,
24     },
25     Done {
26         jobid: usize,
27         result: anyhow::Result<Duration>,
28     },
29     Tick,
30 }
31 
32 /// Manage threads that run test jobs concurrently.
33 pub struct ConcurrentRunner {
34     /// Channel for sending requests to the worker threads.
35     /// The workers are sharing the receiver with an `Arc<Mutex<Receiver>>`.
36     /// This is `None` when shutting down.
37     request_tx: Option<Sender<Request>>,
38 
39     /// Channel for receiving replies from the workers.
40     /// Workers have their own `Sender`.
41     reply_rx: Receiver<Reply>,
42 
43     handles: Vec<thread::JoinHandle<timing::PassTimes>>,
44 }
45 
46 impl ConcurrentRunner {
47     /// Create a new `ConcurrentRunner` with threads spun up.
new() -> Self48     pub fn new() -> Self {
49         let (request_tx, request_rx) = channel();
50         let request_mutex = Arc::new(Mutex::new(request_rx));
51         let (reply_tx, reply_rx) = channel();
52 
53         heartbeat_thread(reply_tx.clone());
54 
55         let num_threads = std::env::var("CRANELIFT_FILETESTS_THREADS")
56             .ok()
57             .map(|s| {
58                 use std::str::FromStr;
59                 let n = usize::from_str(&s).unwrap();
60                 assert!(n > 0);
61                 n
62             })
63             .unwrap_or_else(|| num_cpus::get());
64         let handles = (0..num_threads)
65             .map(|num| worker_thread(num, request_mutex.clone(), reply_tx.clone()))
66             .collect();
67 
68         Self {
69             request_tx: Some(request_tx),
70             reply_rx,
71             handles,
72         }
73     }
74 
75     /// Shut down worker threads orderly. They will finish any queued jobs first.
shutdown(&mut self)76     pub fn shutdown(&mut self) {
77         self.request_tx = None;
78     }
79 
80     /// Join all the worker threads.
81     /// Transfer pass timings from the worker threads to the current thread.
join(&mut self) -> timing::PassTimes82     pub fn join(&mut self) -> timing::PassTimes {
83         assert!(self.request_tx.is_none(), "must shutdown before join");
84         let mut pass_times = timing::PassTimes::default();
85         for h in self.handles.drain(..) {
86             match h.join() {
87                 Ok(t) => pass_times.add(&t),
88                 Err(e) => println!("worker panicked: {e:?}"),
89             }
90         }
91         pass_times
92     }
93 
94     /// Add a new job to the queues.
put(&mut self, jobid: usize, path: &Path)95     pub fn put(&mut self, jobid: usize, path: &Path) {
96         self.request_tx
97             .as_ref()
98             .expect("cannot push after shutdown")
99             .send(Request(jobid, path.to_owned()))
100             .expect("all the worker threads are gone");
101     }
102 
103     /// Get a job reply without blocking.
try_get(&mut self) -> Option<Reply>104     pub fn try_get(&mut self) -> Option<Reply> {
105         self.reply_rx.try_recv().ok()
106     }
107 
108     /// Get a job reply, blocking until one is available.
get(&mut self) -> Option<Reply>109     pub fn get(&mut self) -> Option<Reply> {
110         self.reply_rx.recv().ok()
111     }
112 }
113 
114 /// Spawn a heartbeat thread which sends ticks down the reply channel every second.
115 /// This lets us implement timeouts without the not yet stable `recv_timeout`.
heartbeat_thread(replies: Sender<Reply>) -> thread::JoinHandle<()>116 fn heartbeat_thread(replies: Sender<Reply>) -> thread::JoinHandle<()> {
117     thread::Builder::new()
118         .name("heartbeat".to_string())
119         .spawn(move || {
120             file_per_thread_logger::initialize(LOG_FILENAME_PREFIX);
121             while replies.send(Reply::Tick).is_ok() {
122                 thread::sleep(Duration::from_secs(1));
123             }
124         })
125         .unwrap()
126 }
127 
128 /// Spawn a worker thread running tests.
worker_thread( thread_num: usize, requests: Arc<Mutex<Receiver<Request>>>, replies: Sender<Reply>, ) -> thread::JoinHandle<timing::PassTimes>129 fn worker_thread(
130     thread_num: usize,
131     requests: Arc<Mutex<Receiver<Request>>>,
132     replies: Sender<Reply>,
133 ) -> thread::JoinHandle<timing::PassTimes> {
134     thread::Builder::new()
135         .name(format!("worker #{thread_num}"))
136         .spawn(move || {
137             file_per_thread_logger::initialize(LOG_FILENAME_PREFIX);
138             loop {
139                 // Lock the mutex only long enough to extract a request.
140                 let Request(jobid, path) = match requests.lock().unwrap().recv() {
141                     Err(..) => break, // TX end shut down. exit thread.
142                     Ok(req) => req,
143                 };
144 
145                 // Tell them we're starting this job.
146                 // The receiver should always be present for this as long as we have jobs.
147                 replies.send(Reply::Starting { jobid }).unwrap();
148 
149                 let result = catch_unwind(|| runone::run(path.as_path(), None, None))
150                     .unwrap_or_else(|e| {
151                         // The test panicked, leaving us a `Box<Any>`.
152                         // Panics are usually strings.
153                         if let Some(msg) = e.downcast_ref::<String>() {
154                             anyhow::bail!("panicked in worker #{thread_num}: {msg}")
155                         } else if let Some(msg) = e.downcast_ref::<&'static str>() {
156                             anyhow::bail!("panicked in worker #{thread_num}: {msg}")
157                         } else {
158                             anyhow::bail!("panicked in worker #{thread_num}")
159                         }
160                     });
161 
162                 if let Err(ref msg) = result {
163                     error!("FAIL: {msg}");
164                 }
165 
166                 replies.send(Reply::Done { jobid, result }).unwrap();
167             }
168 
169             // Timing is accumulated independently per thread.
170             // Timings from this worker thread will be aggregated by `ConcurrentRunner::join()`.
171             timing::take_current()
172         })
173         .unwrap()
174 }
175