1 //! This module attempts to paper over the differences between the two
2 //! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and
3 //! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple
4 //! classifier, this exposes a high-level `classify` function to go along with
5 //! `load`, etc.
6 //!
7 //! This module exists solely for convenience--e.g., reduces test duplication.
8 //! In the future can be safely disposed of or altered as more tests are added.
9 
10 /// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the
11 /// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests.
12 pub mod wit {
13     use anyhow::{Result, anyhow};
14     use std::time::Instant;
15 
16     // Generate the wasi-nn bindings based on the `*.wit` files.
17     wit_bindgen::generate!({
18         path: "../wasi-nn/wit",
19         world: "ml",
20         default_bindings_module: "test_programs::ml"
21     });
22     use self::wasi::nn::errors;
23     use self::wasi::nn::graph::{self, Graph};
24     pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests.
25     use self::wasi::nn::tensor::{Tensor, TensorType};
26 
27     /// Load a wasi-nn graph from a set of bytes.
load( bytes: &[Vec<u8>], encoding: GraphEncoding, target: ExecutionTarget, ) -> Result<Graph>28     pub fn load(
29         bytes: &[Vec<u8>],
30         encoding: GraphEncoding,
31         target: ExecutionTarget,
32     ) -> Result<Graph> {
33         graph::load(bytes, encoding, target).map_err(err_as_anyhow)
34     }
35 
36     /// Load a wasi-nn graph by name.
load_by_name(name: &str) -> Result<Graph>37     pub fn load_by_name(name: &str) -> Result<Graph> {
38         graph::load_by_name(name).map_err(err_as_anyhow)
39     }
40 
41     /// Run a wasi-nn inference using a simple classifier model (single input,
42     /// single output).
classify(graph: Graph, input: (&str, Vec<u8>)) -> Result<Vec<f32>>43     pub fn classify(graph: Graph, input: (&str, Vec<u8>)) -> Result<Vec<f32>> {
44         let context = graph.init_execution_context().map_err(err_as_anyhow)?;
45         println!("[nn] created wasi-nn execution context with ID: {context:?}");
46 
47         // Many classifiers have a single input; currently, this test suite also
48         // uses tensors of the same shape, though this is not usually the case.
49         let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1);
50         println!("[nn] input tensor: {} bytes", input.1.len());
51 
52         let before = Instant::now();
53         let input_tuple = (input.0.to_string(), tensor);
54         let output_tensors = context.compute(vec![input_tuple]).unwrap();
55         println!(
56             "[nn] executed graph inference in {} ms",
57             before.elapsed().as_millis()
58         );
59 
60         // Many classifiers emit probabilities as floating point values; here we
61         // convert the raw bytes to `f32` knowing all models used here use that
62         // type.
63         let output = &output_tensors[0].1;
64         println!(
65             "[nn] retrieved output tensor: {} bytes",
66             output.data().len()
67         );
68         let output: Vec<f32> = output
69             .data()
70             .chunks(4)
71             .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
72             .collect();
73         Ok(output)
74     }
75 
err_as_anyhow(e: errors::Error) -> anyhow::Error76     fn err_as_anyhow(e: errors::Error) -> anyhow::Error {
77         anyhow!("error: {e:?}")
78     }
79 }
80 
81 /// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based
82 /// tooling. This older API has been deprecated for the newer WIT-based API but
83 /// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs`
84 /// tests.
85 pub mod witx {
86     use anyhow::Result;
87     use std::time::Instant;
88     pub use wasi_nn::{ExecutionTarget, GraphEncoding};
89     use wasi_nn::{Graph, GraphBuilder, TensorType};
90 
91     /// Load a wasi-nn graph from a set of bytes.
load( bytes: &[&[u8]], encoding: GraphEncoding, target: ExecutionTarget, ) -> Result<Graph>92     pub fn load(
93         bytes: &[&[u8]],
94         encoding: GraphEncoding,
95         target: ExecutionTarget,
96     ) -> Result<Graph> {
97         Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?)
98     }
99 
100     /// Load a wasi-nn graph by name.
load_by_name( name: &str, encoding: GraphEncoding, target: ExecutionTarget, ) -> Result<Graph>101     pub fn load_by_name(
102         name: &str,
103         encoding: GraphEncoding,
104         target: ExecutionTarget,
105     ) -> Result<Graph> {
106         Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?)
107     }
108 
109     /// Run a wasi-nn inference using a simple classifier model (single input,
110     /// single output).
classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>>111     pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
112         let mut context = graph.init_execution_context()?;
113         println!("[nn] created wasi-nn execution context with ID: {context}");
114 
115         // Many classifiers have a single input; currently, this test suite also
116         // uses tensors of the same shape, though this is not usually the case.
117         context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?;
118         println!("[nn] set input tensor: {} bytes", tensor.len());
119 
120         let before = Instant::now();
121         context.compute()?;
122         println!(
123             "[nn] executed graph inference in {} ms",
124             before.elapsed().as_millis()
125         );
126 
127         // Many classifiers emit probabilities as floating point values; here we
128         // convert the raw bytes to `f32` knowing all models used here use that
129         // type.
130         let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()];
131         let num_bytes = context.get_output(0, &mut output_buffer)?;
132         println!("[nn] retrieved output tensor: {num_bytes} bytes");
133         let output: Vec<f32> = output_buffer[..num_bytes]
134             .chunks(4)
135             .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
136             .collect();
137         Ok(output)
138     }
139 }
140 
141 /// Sort some classification probabilities.
142 ///
143 /// Many classification models output a buffer of probabilities for each class,
144 /// placing the match probability for each class at the index for that class
145 /// (the probability of class `N` is stored at `probabilities[N]`).
sort_results(probabilities: &[f32]) -> Vec<InferenceResult>146 pub fn sort_results(probabilities: &[f32]) -> Vec<InferenceResult> {
147     let mut results: Vec<InferenceResult> = probabilities
148         .iter()
149         .enumerate()
150         .map(|(c, p)| InferenceResult(c, *p))
151         .collect();
152     results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
153     results
154 }
155 
156 // A wrapper for class ID and match probabilities.
157 #[derive(Debug, PartialEq)]
158 pub struct InferenceResult(usize, f32);
159 impl InferenceResult {
class_id(&self) -> usize160     pub fn class_id(&self) -> usize {
161         self.0
162     }
probability(&self) -> f32163     pub fn probability(&self) -> f32 {
164         self.1
165     }
166 }
167