1 use image::{DynamicImage, RgbImage};
2 use ndarray::Array;
3 use std::{fs, time::Instant};
4 
main()5 pub fn main() {
6     // Load model from a file.
7     let graph =
8         wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU)
9             .build_from_files(["fixture/mobilenet.onnx"])
10             .unwrap();
11 
12     let mut context = graph.init_execution_context().unwrap();
13     println!("Created an execution context.");
14 
15     // Read image from file and convert it to tensor data.
16     let image_data = fs::read("fixture/kitten.png").unwrap();
17 
18     // Preprocessing. Normalize data based on model requirements https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet#preprocessing
19     let tensor_data = preprocess(
20         image_data.as_slice(),
21         224,
22         224,
23         &[0.485, 0.456, 0.406],
24         &[0.229, 0.224, 0.225],
25     );
26     println!("Read input tensor, size in bytes: {}", tensor_data.len());
27 
28     context
29         .set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &tensor_data)
30         .unwrap();
31 
32     // Execute the inference.
33     let before_compute = Instant::now();
34     context.compute().unwrap();
35     println!(
36         "Executed graph inference, took {} ms.",
37         before_compute.elapsed().as_millis()
38     );
39 
40     // Retrieve the output.
41     let mut output_buffer = vec![0f32; 1000];
42     context.get_output(0, &mut output_buffer[..]).unwrap();
43 
44     // Postprocessing. Calculating the softmax probability scores.
45     let result = postprocess(output_buffer);
46 
47     // Load labels for classification
48     let labels_file = fs::read("fixture/synset.txt").unwrap();
49     let labels_str = String::from_utf8(labels_file).unwrap();
50     let labels: Vec<String> = labels_str
51         .lines()
52         .map(|line| {
53             let words: Vec<&str> = line.split_whitespace().collect();
54             words[1..].join(" ")
55         })
56         .collect();
57 
58     println!(
59         "Found results, sorted top 5: {:?}",
60         &sort_results(&result, &labels)[..5]
61     )
62 }
63 
64 // Sort the buffer of probabilities. The graph places the match probability for each class at the
65 // index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
66 // to a wrapping InferenceResult and sort the results.
sort_results(buffer: &[f32], labels: &Vec<String>) -> Vec<InferenceResult>67 fn sort_results(buffer: &[f32], labels: &Vec<String>) -> Vec<InferenceResult> {
68     let mut results: Vec<InferenceResult> = buffer
69         .iter()
70         .enumerate()
71         .map(|(c, p)| InferenceResult(labels[c].clone(), *p))
72         .collect();
73     results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
74     results
75 }
76 
77 // Resize image to height x width, and then converts the pixel precision to FP32, normalize with
78 // given mean and std. The resulting RGB pixel vector is then returned.
preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8>79 fn preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8> {
80     let dyn_img: DynamicImage = image::load_from_memory(image).unwrap().resize_exact(
81         width,
82         height,
83         image::imageops::Triangle,
84     );
85     let rgb_img: RgbImage = dyn_img.to_rgb8();
86 
87     // Get an array of the pixel values
88     let raw_u8_arr: &[u8] = &rgb_img.as_raw()[..];
89 
90     // Create an array to hold the f32 value of those pixels
91     let bytes_required = raw_u8_arr.len() * 4;
92     let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];
93 
94     // Read the number as a f32 and break it into u8 bytes
95     for i in 0..raw_u8_arr.len() {
96         let u8_f32: f32 = raw_u8_arr[i] as f32;
97         let rgb_iter = i % 3;
98 
99         // Normalize the pixel
100         let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];
101 
102         // Convert it to u8 bytes and write it with new shape
103         let u8_bytes = norm_u8_f32.to_ne_bytes();
104         for j in 0..4 {
105             u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];
106         }
107     }
108 
109     return u8_f32_arr;
110 }
111 
postprocess(output_tensor: Vec<f32>) -> Vec<f32>112 fn postprocess(output_tensor: Vec<f32>) -> Vec<f32> {
113     // Post-Processing requirement: compute softmax to inferencing output
114     let output_shape = [1, 1000, 1, 1];
115     let exp_output = Array::from_shape_vec(output_shape, output_tensor)
116         .unwrap()
117         .mapv(|x| x.exp());
118     let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));
119     let softmax_output = exp_output / &sum_exp_output;
120     softmax_output.into_raw_vec()
121 }
122 
bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32>123 pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
124     let chunks: Vec<&[u8]> = data.chunks(4).collect();
125     let v: Vec<f32> = chunks
126         .into_iter()
127         .map(|c| f32::from_ne_bytes(c.try_into().unwrap()))
128         .collect();
129 
130     v.into_iter().collect()
131 }
132 
133 // A wrapper for class ID and match probabilities.
134 #[derive(Debug, PartialEq)]
135 struct InferenceResult(String, f32);
136