1 //! Run the wasi-nn tests in `crates/test-programs`.
2 //!
3 //! It may be difficult to run to run all tests on all platforms; we check the
4 //! pre-requisites for each test dynamically (see [`check`]). Using
5 //! `libtest-mimic` allows us then to dynamically ignore tests that cannot run
6 //! on the current machine.
7 //!
8 //! There are two modes these tests run in:
9 //! - "ignore if unavailable" mode: if the checks for a test fail (e.g., the
10 //!   backend is not installed, test artifacts cannot download, we're on the
11 //!   wrong platform), the test is ignored.
12 //! - "fail if unavailable" mode: when the `CI` or `FORCE_WASINN_TEST_CHECK`
13 //!   environment variables are set, any checks that fail cause the test to fail
14 //!   early.
15 
16 mod check;
17 mod exec;
18 
19 use libtest_mimic::{Arguments, Trial};
20 use std::{borrow::Cow, env};
21 use test_programs_artifacts::*;
22 use wasmtime::Result;
23 use wasmtime_wasi_nn::{Backend, backend};
24 
main() -> Result<()>25 fn main() -> Result<()> {
26     tracing_subscriber::fmt::init();
27 
28     if cfg!(miri) {
29         return Ok(());
30     }
31 
32     // Gather a list of the test-program names.
33     let mut programs = Vec::new();
34     macro_rules! add_to_list {
35         ($name:ident) => {
36             programs.push(stringify!($name));
37         };
38     }
39     foreach_nn!(add_to_list);
40 
41     // Make ignored tests turn into failures.
42     let error_on_failed_check =
43         env::var_os("CI").is_some() || env::var_os("FORCE_WASINN_TEST_CHECK").is_some();
44 
45     // Inform `libtest-mimic` how to run each test program.
46     let arguments = Arguments::from_args();
47     let mut trials = Vec::new();
48     for program in programs {
49         // Either ignore the test if it cannot run (i.e., downgrade `Fail` to
50         // `Ignore`) or preemptively fail it if `error_on_failed_check` is set.
51         let (run_test, mut check) = check_test_program(program);
52         if !error_on_failed_check {
53             check = check.downgrade_failure(); // Downgrade `Fail` to `Ignore`.
54         }
55         let should_ignore = check.is_ignore();
56         if arguments.nocapture && should_ignore {
57             println!("> ignoring {program}: {}", check.reason());
58         }
59         let trial = Trial::test(program, move || {
60             run_test().map_err(|e| format!("{e:?}").into())
61         })
62         .with_ignored_flag(should_ignore);
63         trials.push(trial);
64     }
65 
66     // Run the tests.
67     libtest_mimic::run(&arguments, trials).exit()
68 }
69 
70 /// Return the test program to run and a check that must pass for the test to
71 /// run.
check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck)72 fn check_test_program(name: &str) -> (fn() -> Result<()>, IgnoreCheck) {
73     match name {
74         // Legacy WITX-based tests:
75         "nn_witx_image_classification_openvino" => (
76             nn_witx_image_classification_openvino,
77             IgnoreCheck::for_openvino(),
78         ),
79         "nn_witx_image_classification_openvino_named" => (
80             nn_witx_image_classification_openvino_named,
81             IgnoreCheck::for_openvino(),
82         ),
83         "nn_witx_image_classification_onnx" => {
84             (nn_witx_image_classification_onnx, IgnoreCheck::for_onnx())
85         }
86         "nn_witx_image_classification_winml_named" => (
87             nn_witx_image_classification_winml_named,
88             IgnoreCheck::for_winml(),
89         ),
90         "nn_witx_image_classification_pytorch" => (
91             nn_witx_image_classification_pytorch,
92             IgnoreCheck::for_pytorch(),
93         ),
94         // WIT-based tests:
95         "nn_wit_image_classification_openvino" => (
96             nn_wit_image_classification_openvino,
97             IgnoreCheck::for_openvino(),
98         ),
99         "nn_wit_image_classification_openvino_named" => (
100             nn_wit_image_classification_openvino_named,
101             IgnoreCheck::for_openvino(),
102         ),
103         "nn_wit_image_classification_onnx" => {
104             (nn_wit_image_classification_onnx, IgnoreCheck::for_onnx())
105         }
106         "nn_wit_image_classification_winml_named" => (
107             nn_wit_image_classification_winml_named,
108             IgnoreCheck::for_winml(),
109         ),
110         "nn_wit_image_classification_pytorch" => (
111             nn_wit_image_classification_pytorch,
112             IgnoreCheck::for_pytorch(),
113         ),
114         _ => panic!("unknown test program: {name} (add to this `match`)"),
115     }
116 }
117 
nn_witx_image_classification_openvino() -> Result<()>118 fn nn_witx_image_classification_openvino() -> Result<()> {
119     check::openvino::is_installed()?;
120     check::openvino::are_artifacts_available()?;
121     let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
122     exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO, backend, false)
123 }
124 
nn_witx_image_classification_openvino_named() -> Result<()>125 fn nn_witx_image_classification_openvino_named() -> Result<()> {
126     check::openvino::is_installed()?;
127     check::openvino::are_artifacts_available()?;
128     let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
129     exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_OPENVINO_NAMED, backend, true)
130 }
131 
132 #[cfg(feature = "onnx")]
nn_witx_image_classification_onnx() -> Result<()>133 fn nn_witx_image_classification_onnx() -> Result<()> {
134     check::onnx::are_artifacts_available()?;
135     let backend = Backend::from(backend::onnx::OnnxBackend::default());
136     exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)
137 }
138 #[cfg(not(feature = "onnx"))]
nn_witx_image_classification_onnx() -> Result<()>139 fn nn_witx_image_classification_onnx() -> Result<()> {
140     wasmtime::bail!("this test requires the `onnx` feature")
141 }
142 
143 #[cfg(all(feature = "winml", target_os = "windows"))]
nn_witx_image_classification_winml_named() -> Result<()>144 fn nn_witx_image_classification_winml_named() -> Result<()> {
145     check::winml::is_available()?;
146     check::onnx::are_artifacts_available()?;
147     let backend = Backend::from(backend::winml::WinMLBackend::default());
148     exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_ONNX, backend, false)
149 }
150 #[cfg(not(all(feature = "winml", target_os = "windows")))]
nn_witx_image_classification_winml_named() -> Result<()>151 fn nn_witx_image_classification_winml_named() -> Result<()> {
152     wasmtime::bail!("this test requires the `winml` feature and only runs on windows")
153 }
154 
155 #[cfg(feature = "pytorch")]
nn_witx_image_classification_pytorch() -> Result<()>156 fn nn_witx_image_classification_pytorch() -> Result<()> {
157     check::pytorch::are_artifacts_available()?;
158     let backend = Backend::from(backend::pytorch::PytorchBackend::default());
159     exec::witx::run(NN_WITX_IMAGE_CLASSIFICATION_PYTORCH, backend, false)
160 }
161 #[cfg(not(feature = "pytorch"))]
nn_witx_image_classification_pytorch() -> Result<()>162 fn nn_witx_image_classification_pytorch() -> Result<()> {
163     wasmtime::bail!("this test requires the `pytorch` feature")
164 }
165 
nn_wit_image_classification_openvino() -> Result<()>166 fn nn_wit_image_classification_openvino() -> Result<()> {
167     check::openvino::is_installed()?;
168     check::openvino::are_artifacts_available()?;
169     let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
170     exec::wit::run(
171         NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_COMPONENT,
172         backend,
173         false,
174     )
175 }
176 
nn_wit_image_classification_openvino_named() -> Result<()>177 fn nn_wit_image_classification_openvino_named() -> Result<()> {
178     check::openvino::is_installed()?;
179     check::openvino::are_artifacts_available()?;
180     let backend = Backend::from(backend::openvino::OpenvinoBackend::default());
181     exec::wit::run(
182         NN_WIT_IMAGE_CLASSIFICATION_OPENVINO_NAMED_COMPONENT,
183         backend,
184         true,
185     )
186 }
187 
188 #[cfg(feature = "onnx")]
nn_wit_image_classification_onnx() -> Result<()>189 fn nn_wit_image_classification_onnx() -> Result<()> {
190     check::onnx::are_artifacts_available()?;
191     let backend = Backend::from(backend::onnx::OnnxBackend::default());
192     exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)
193 }
194 #[cfg(not(feature = "onnx"))]
nn_wit_image_classification_onnx() -> Result<()>195 fn nn_wit_image_classification_onnx() -> Result<()> {
196     wasmtime::bail!("this test requires the `onnx` feature")
197 }
198 
199 #[cfg(feature = "pytorch")]
nn_wit_image_classification_pytorch() -> Result<()>200 fn nn_wit_image_classification_pytorch() -> Result<()> {
201     check::pytorch::are_artifacts_available()?;
202     let backend = Backend::from(backend::pytorch::PytorchBackend::default());
203     exec::wit::run(
204         NN_WIT_IMAGE_CLASSIFICATION_PYTORCH_COMPONENT,
205         backend,
206         false,
207     )
208 }
209 #[cfg(not(feature = "pytorch"))]
nn_wit_image_classification_pytorch() -> Result<()>210 fn nn_wit_image_classification_pytorch() -> Result<()> {
211     wasmtime::bail!("this test requires the `pytorch` feature")
212 }
213 
214 #[cfg(all(feature = "winml", target_os = "windows"))]
nn_wit_image_classification_winml_named() -> Result<()>215 fn nn_wit_image_classification_winml_named() -> Result<()> {
216     check::winml::is_available()?;
217     check::onnx::are_artifacts_available()?;
218     let backend = Backend::from(backend::winml::WinMLBackend::default());
219     exec::wit::run(NN_WIT_IMAGE_CLASSIFICATION_ONNX_COMPONENT, backend, false)
220 }
221 #[cfg(not(all(feature = "winml", target_os = "windows")))]
nn_wit_image_classification_winml_named() -> Result<()>222 fn nn_wit_image_classification_winml_named() -> Result<()> {
223     wasmtime::bail!("this test requires the `winml` feature and only runs on windows")
224 }
225 
226 /// Helper for keeping track of what tests should do when pre-test checks fail.
227 #[derive(Clone)]
228 enum IgnoreCheck {
229     Run,
230     Ignore(Cow<'static, str>),
231     Fail(Cow<'static, str>),
232 }
233 
234 impl IgnoreCheck {
reason(&self) -> &str235     fn reason(&self) -> &str {
236         match self {
237             IgnoreCheck::Run => panic!("cannot get reason for `Run`"),
238             IgnoreCheck::Ignore(reason) => reason,
239             IgnoreCheck::Fail(reason) => reason,
240         }
241     }
242 
downgrade_failure(self) -> Self243     fn downgrade_failure(self) -> Self {
244         if let IgnoreCheck::Fail(reason) = self {
245             IgnoreCheck::Ignore(reason)
246         } else {
247             self
248         }
249     }
250 
is_ignore(&self) -> bool251     fn is_ignore(&self) -> bool {
252         matches!(self, IgnoreCheck::Ignore(_))
253     }
254 }
255 
256 /// Some pre-test checks for various backends.
257 impl IgnoreCheck {
for_openvino() -> IgnoreCheck258     fn for_openvino() -> IgnoreCheck {
259         use IgnoreCheck::*;
260         if !cfg!(target_arch = "x86_64") {
261             Fail("requires x86_64".into())
262         } else if !cfg!(target_os = "linux") && !cfg!(target_os = "windows") {
263             Fail("requires linux or windows or macos".into())
264         } else if let Err(e) = check::openvino::is_installed() {
265             Fail(e.to_string().into())
266         } else {
267             Run
268         }
269     }
270 
for_onnx() -> Self271     fn for_onnx() -> Self {
272         use IgnoreCheck::*;
273         #[cfg(feature = "onnx")]
274         if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {
275             Fail("requires x86_64 or aarch64".into())
276         } else if !cfg!(target_os = "linux")
277             && !cfg!(target_os = "windows")
278             && !cfg!(target_os = "macos")
279         {
280             Fail("requires linux, windows, or macos".into())
281         } else {
282             Run
283         }
284         #[cfg(not(feature = "onnx"))]
285         Ignore("requires the `onnx` feature".into())
286     }
287 
for_pytorch() -> Self288     fn for_pytorch() -> Self {
289         use IgnoreCheck::*;
290         #[cfg(feature = "pytorch")]
291         if !cfg!(target_arch = "x86_64") && !cfg!(target_arch = "aarch64") {
292             Fail("requires x86_64 or aarch64".into())
293         } else if !cfg!(target_os = "linux")
294             && !cfg!(target_os = "windows")
295             && !cfg!(target_os = "macos")
296         {
297             Fail("requires linux, windows, or macos".into())
298         } else {
299             Run
300         }
301         #[cfg(not(feature = "pytorch"))]
302         Ignore("requires the `pytorch` feature".into())
303     }
304 
for_winml() -> IgnoreCheck305     fn for_winml() -> IgnoreCheck {
306         use IgnoreCheck::*;
307         #[cfg(all(feature = "winml", target_os = "windows"))]
308         if !cfg!(target_arch = "x86_64") {
309             Fail("requires x86_64".into())
310         } else if !cfg!(target_os = "windows") {
311             Fail("requires windows".into())
312         } else if let Err(e) = check::winml::is_available() {
313             Fail(e.to_string().into())
314         } else {
315             Run
316         }
317         #[cfg(not(all(feature = "winml", target_os = "windows")))]
318         Ignore("requires the `winml` feature on windows".into())
319     }
320 }
321