1 use std::fs;
2 use wasi_nn::*;
3 
main()4 pub fn main() {
5     let graph = GraphBuilder::new(GraphEncoding::Openvino, ExecutionTarget::CPU)
6         .build_from_cache("mobilenet")
7         .unwrap();
8     println!("Loaded a graph: {:?}", graph);
9 
10     let mut context = graph.init_execution_context().unwrap();
11     println!("Created an execution context: {:?}", context);
12 
13     // Load a tensor that precisely matches the graph input tensor (see
14     // `fixture/frozen_inference_graph.xml`).
15     let tensor_data = fs::read("fixture/tensor.bgr").unwrap();
16     println!("Read input tensor, size in bytes: {}", tensor_data.len());
17     context
18         .set_input(0, TensorType::F32, &[1, 3, 224, 224], &tensor_data)
19         .unwrap();
20 
21     // Execute the inference.
22     context.compute().unwrap();
23     println!("Executed graph inference");
24 
25     // Retrieve the output.
26     let mut output_buffer = vec![0f32; 1001];
27     context.get_output(0, &mut output_buffer[..]).unwrap();
28 
29     println!(
30         "Found results, sorted top 5: {:?}",
31         &sort_results(&output_buffer)[..5]
32     )
33 }
34 
35 // Sort the buffer of probabilities. The graph places the match probability for each class at the
36 // index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
37 // to a wrapping InferenceResult and sort the results. It is unclear why the MobileNet output
38 // indices are "off by one" but the `.skip(1)` below seems necessary to get results that make sense
39 // (e.g. 763 = "revolver" vs 762 = "restaurant")
sort_results(buffer: &[f32]) -> Vec<InferenceResult>40 fn sort_results(buffer: &[f32]) -> Vec<InferenceResult> {
41     let mut results: Vec<InferenceResult> = buffer
42         .iter()
43         .skip(1)
44         .enumerate()
45         .map(|(c, p)| InferenceResult(c, *p))
46         .collect();
47     results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
48     results
49 }
50 
51 // A wrapper for class ID and match probabilities.
52 #[derive(Debug, PartialEq)]
53 struct InferenceResult(usize, f32);
54