1 use std::convert::TryInto;
2 use std::fs;
3 
main()4 pub fn main() {
5     let xml = fs::read_to_string("fixture/model.xml").unwrap();
6     println!("Read graph XML, first 50 characters: {}", &xml[..50]);
7 
8     let weights = fs::read("fixture/model.bin").unwrap();
9     println!("Read graph weights, size in bytes: {}", weights.len());
10 
11     let graph = unsafe {
12         wasi_nn::load(
13             &[&xml.into_bytes(), &weights],
14             wasi_nn::GRAPH_ENCODING_OPENVINO,
15             wasi_nn::EXECUTION_TARGET_CPU,
16         )
17         .unwrap()
18     };
19     println!("Loaded graph into wasi-nn with ID: {graph}");
20 
21     let context = unsafe { wasi_nn::init_execution_context(graph).unwrap() };
22     println!("Created wasi-nn execution context with ID: {context}");
23 
24     // Load a tensor that precisely matches the graph input tensor (see
25     // `fixture/frozen_inference_graph.xml`).
26     let tensor_data = fs::read("fixture/tensor.bgr").unwrap();
27     println!("Read input tensor, size in bytes: {}", tensor_data.len());
28     let tensor = wasi_nn::Tensor {
29         dimensions: &[1, 3, 224, 224],
30         r#type: wasi_nn::TENSOR_TYPE_F32,
31         data: &tensor_data,
32     };
33     unsafe {
34         wasi_nn::set_input(context, 0, tensor).unwrap();
35     }
36 
37     // Execute the inference.
38     unsafe {
39         wasi_nn::compute(context).unwrap();
40     }
41     println!("Executed graph inference");
42 
43     // Retrieve the output.
44     let mut output_buffer = vec![0f32; 1001];
45     unsafe {
46         wasi_nn::get_output(
47             context,
48             0,
49             &mut output_buffer[..] as *mut [f32] as *mut u8,
50             (output_buffer.len() * 4).try_into().unwrap(),
51         )
52         .unwrap();
53     }
54     println!(
55         "Found results, sorted top 5: {:?}",
56         &sort_results(&output_buffer)[..5]
57     )
58 }
59 
60 // Sort the buffer of probabilities. The graph places the match probability for each class at the
61 // index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
62 // to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
63 // indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
64 // (e.g. 763 = "revolver" vs 762 = "restaurant")
sort_results(buffer: &[f32]) -> Vec<InferenceResult>65 fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
66     let mut results: Vec<InferenceResult> = buffer
67         .iter()
68         .skip(1)
69         .enumerate()
70         .map(|(c, p)| InferenceResult(c, *p))
71         .collect();
72     results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
73     results
74 }
75 
76 // A wrapper for class ID and match probabilities.
77 #[derive(Debug, PartialEq)]
78 struct InferenceResult(usize, f32);
79