1 //! This module attempts to paper over the differences between the two
2 //! implementations of wasi-nn: the legacy WITX-based version (`mod witx`) and
3 //! the up-to-date WIT version (`mod wit`). Since the tests are mainly a simple
4 //! classifier, this exposes a high-level `classify` function to go along with
5 //! `load`, etc.
6 //!
7 //! This module exists solely for convenience--e.g., reduces test duplication.
8 //! In the future can be safely disposed of or altered as more tests are added.
9
10 /// Call `wasi-nn` functions from WebAssembly using the canonical ABI of the
11 /// component model via WIT-based tooling. Used by `bin/nn_wit_*.rs` tests.
12 pub mod wit {
13 use anyhow::{Result, anyhow};
14 use std::time::Instant;
15
16 // Generate the wasi-nn bindings based on the `*.wit` files.
17 wit_bindgen::generate!({
18 path: "../wasi-nn/wit",
19 world: "ml",
20 default_bindings_module: "test_programs::ml"
21 });
22 use self::wasi::nn::errors;
23 use self::wasi::nn::graph::{self, Graph};
24 pub use self::wasi::nn::graph::{ExecutionTarget, GraphEncoding}; // Used by tests.
25 use self::wasi::nn::tensor::{Tensor, TensorType};
26
27 /// Load a wasi-nn graph from a set of bytes.
load( bytes: &[Vec<u8>], encoding: GraphEncoding, target: ExecutionTarget, ) -> Result<Graph>28 pub fn load(
29 bytes: &[Vec<u8>],
30 encoding: GraphEncoding,
31 target: ExecutionTarget,
32 ) -> Result<Graph> {
33 graph::load(bytes, encoding, target).map_err(err_as_anyhow)
34 }
35
36 /// Load a wasi-nn graph by name.
load_by_name(name: &str) -> Result<Graph>37 pub fn load_by_name(name: &str) -> Result<Graph> {
38 graph::load_by_name(name).map_err(err_as_anyhow)
39 }
40
41 /// Run a wasi-nn inference using a simple classifier model (single input,
42 /// single output).
classify(graph: Graph, input: (&str, Vec<u8>)) -> Result<Vec<f32>>43 pub fn classify(graph: Graph, input: (&str, Vec<u8>)) -> Result<Vec<f32>> {
44 let context = graph.init_execution_context().map_err(err_as_anyhow)?;
45 println!("[nn] created wasi-nn execution context with ID: {context:?}");
46
47 // Many classifiers have a single input; currently, this test suite also
48 // uses tensors of the same shape, though this is not usually the case.
49 let tensor = Tensor::new(&vec![1, 3, 224, 224], TensorType::Fp32, &input.1);
50 println!("[nn] input tensor: {} bytes", input.1.len());
51
52 let before = Instant::now();
53 let input_tuple = (input.0.to_string(), tensor);
54 let output_tensors = context.compute(vec![input_tuple]).unwrap();
55 println!(
56 "[nn] executed graph inference in {} ms",
57 before.elapsed().as_millis()
58 );
59
60 // Many classifiers emit probabilities as floating point values; here we
61 // convert the raw bytes to `f32` knowing all models used here use that
62 // type.
63 let output = &output_tensors[0].1;
64 println!(
65 "[nn] retrieved output tensor: {} bytes",
66 output.data().len()
67 );
68 let output: Vec<f32> = output
69 .data()
70 .chunks(4)
71 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
72 .collect();
73 Ok(output)
74 }
75
err_as_anyhow(e: errors::Error) -> anyhow::Error76 fn err_as_anyhow(e: errors::Error) -> anyhow::Error {
77 anyhow!("error: {e:?}")
78 }
79 }
80
81 /// Call `wasi-nn` functions from WebAssembly using the legacy WITX-based
82 /// tooling. This older API has been deprecated for the newer WIT-based API but
83 /// retained for backwards compatibility testing--i.e., `bin/nn_witx_*.rs`
84 /// tests.
85 pub mod witx {
86 use anyhow::Result;
87 use std::time::Instant;
88 pub use wasi_nn::{ExecutionTarget, GraphEncoding};
89 use wasi_nn::{Graph, GraphBuilder, TensorType};
90
91 /// Load a wasi-nn graph from a set of bytes.
load( bytes: &[&[u8]], encoding: GraphEncoding, target: ExecutionTarget, ) -> Result<Graph>92 pub fn load(
93 bytes: &[&[u8]],
94 encoding: GraphEncoding,
95 target: ExecutionTarget,
96 ) -> Result<Graph> {
97 Ok(GraphBuilder::new(encoding, target).build_from_bytes(bytes)?)
98 }
99
100 /// Load a wasi-nn graph by name.
load_by_name( name: &str, encoding: GraphEncoding, target: ExecutionTarget, ) -> Result<Graph>101 pub fn load_by_name(
102 name: &str,
103 encoding: GraphEncoding,
104 target: ExecutionTarget,
105 ) -> Result<Graph> {
106 Ok(GraphBuilder::new(encoding, target).build_from_cache(name)?)
107 }
108
109 /// Run a wasi-nn inference using a simple classifier model (single input,
110 /// single output).
classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>>111 pub fn classify(graph: Graph, tensor: Vec<u8>) -> Result<Vec<f32>> {
112 let mut context = graph.init_execution_context()?;
113 println!("[nn] created wasi-nn execution context with ID: {context}");
114
115 // Many classifiers have a single input; currently, this test suite also
116 // uses tensors of the same shape, though this is not usually the case.
117 context.set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor)?;
118 println!("[nn] set input tensor: {} bytes", tensor.len());
119
120 let before = Instant::now();
121 context.compute()?;
122 println!(
123 "[nn] executed graph inference in {} ms",
124 before.elapsed().as_millis()
125 );
126
127 // Many classifiers emit probabilities as floating point values; here we
128 // convert the raw bytes to `f32` knowing all models used here use that
129 // type.
130 let mut output_buffer = vec![0u8; 1001 * std::mem::size_of::<f32>()];
131 let num_bytes = context.get_output(0, &mut output_buffer)?;
132 println!("[nn] retrieved output tensor: {num_bytes} bytes");
133 let output: Vec<f32> = output_buffer[..num_bytes]
134 .chunks(4)
135 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
136 .collect();
137 Ok(output)
138 }
139 }
140
141 /// Sort some classification probabilities.
142 ///
143 /// Many classification models output a buffer of probabilities for each class,
144 /// placing the match probability for each class at the index for that class
145 /// (the probability of class `N` is stored at `probabilities[N]`).
sort_results(probabilities: &[f32]) -> Vec<InferenceResult>146 pub fn sort_results(probabilities: &[f32]) -> Vec<InferenceResult> {
147 let mut results: Vec<InferenceResult> = probabilities
148 .iter()
149 .enumerate()
150 .map(|(c, p)| InferenceResult(c, *p))
151 .collect();
152 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
153 results
154 }
155
156 // A wrapper for class ID and match probabilities.
157 #[derive(Debug, PartialEq)]
158 pub struct InferenceResult(usize, f32);
159 impl InferenceResult {
class_id(&self) -> usize160 pub fn class_id(&self) -> usize {
161 self.0
162 }
probability(&self) -> f32163 pub fn probability(&self) -> f32 {
164 self.1
165 }
166 }
167