1 use image::{DynamicImage, RgbImage};
2 use ndarray::Array;
3 use std::{fs, time::Instant};
4
main()5 pub fn main() {
6 // Load model from a file.
7 let graph =
8 wasi_nn::GraphBuilder::new(wasi_nn::GraphEncoding::Onnx, wasi_nn::ExecutionTarget::CPU)
9 .build_from_files(["fixture/mobilenet.onnx"])
10 .unwrap();
11
12 let mut context = graph.init_execution_context().unwrap();
13 println!("Created an execution context.");
14
15 // Read image from file and convert it to tensor data.
16 let image_data = fs::read("fixture/kitten.png").unwrap();
17
18 // Preprocessing. Normalize data based on model requirements https://github.com/onnx/models/tree/main/validated/vision/classification/mobilenet#preprocessing
19 let tensor_data = preprocess(
20 image_data.as_slice(),
21 224,
22 224,
23 &[0.485, 0.456, 0.406],
24 &[0.229, 0.224, 0.225],
25 );
26 println!("Read input tensor, size in bytes: {}", tensor_data.len());
27
28 context
29 .set_input(0, wasi_nn::TensorType::F32, &[1, 3, 224, 224], &tensor_data)
30 .unwrap();
31
32 // Execute the inference.
33 let before_compute = Instant::now();
34 context.compute().unwrap();
35 println!(
36 "Executed graph inference, took {} ms.",
37 before_compute.elapsed().as_millis()
38 );
39
40 // Retrieve the output.
41 let mut output_buffer = vec![0f32; 1000];
42 context.get_output(0, &mut output_buffer[..]).unwrap();
43
44 // Postprocessing. Calculating the softmax probability scores.
45 let result = postprocess(output_buffer);
46
47 // Load labels for classification
48 let labels_file = fs::read("fixture/synset.txt").unwrap();
49 let labels_str = String::from_utf8(labels_file).unwrap();
50 let labels: Vec<String> = labels_str
51 .lines()
52 .map(|line| {
53 let words: Vec<&str> = line.split_whitespace().collect();
54 words[1..].join(" ")
55 })
56 .collect();
57
58 println!(
59 "Found results, sorted top 5: {:?}",
60 &sort_results(&result, &labels)[..5]
61 )
62 }
63
64 // Sort the buffer of probabilities. The graph places the match probability for each class at the
65 // index for that class (e.g. the probability of class 42 is placed at buffer[42]). Here we convert
66 // to a wrapping InferenceResult and sort the results.
sort_results(buffer: &[f32], labels: &Vec<String>) -> Vec<InferenceResult>67 fn sort_results(buffer: &[f32], labels: &Vec<String>) -> Vec<InferenceResult> {
68 let mut results: Vec<InferenceResult> = buffer
69 .iter()
70 .enumerate()
71 .map(|(c, p)| InferenceResult(labels[c].clone(), *p))
72 .collect();
73 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
74 results
75 }
76
77 // Resize image to height x width, and then converts the pixel precision to FP32, normalize with
78 // given mean and std. The resulting RGB pixel vector is then returned.
preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8>79 fn preprocess(image: &[u8], height: u32, width: u32, mean: &[f32], std: &[f32]) -> Vec<u8> {
80 let dyn_img: DynamicImage = image::load_from_memory(image).unwrap().resize_exact(
81 width,
82 height,
83 image::imageops::Triangle,
84 );
85 let rgb_img: RgbImage = dyn_img.to_rgb8();
86
87 // Get an array of the pixel values
88 let raw_u8_arr: &[u8] = &rgb_img.as_raw()[..];
89
90 // Create an array to hold the f32 value of those pixels
91 let bytes_required = raw_u8_arr.len() * 4;
92 let mut u8_f32_arr: Vec<u8> = vec![0; bytes_required];
93
94 // Read the number as a f32 and break it into u8 bytes
95 for i in 0..raw_u8_arr.len() {
96 let u8_f32: f32 = raw_u8_arr[i] as f32;
97 let rgb_iter = i % 3;
98
99 // Normalize the pixel
100 let norm_u8_f32: f32 = (u8_f32 / 255.0 - mean[rgb_iter]) / std[rgb_iter];
101
102 // Convert it to u8 bytes and write it with new shape
103 let u8_bytes = norm_u8_f32.to_ne_bytes();
104 for j in 0..4 {
105 u8_f32_arr[(raw_u8_arr.len() * 4 * rgb_iter / 3) + (i / 3) * 4 + j] = u8_bytes[j];
106 }
107 }
108
109 return u8_f32_arr;
110 }
111
postprocess(output_tensor: Vec<f32>) -> Vec<f32>112 fn postprocess(output_tensor: Vec<f32>) -> Vec<f32> {
113 // Post-Processing requirement: compute softmax to inferencing output
114 let output_shape = [1, 1000, 1, 1];
115 let exp_output = Array::from_shape_vec(output_shape, output_tensor)
116 .unwrap()
117 .mapv(|x| x.exp());
118 let sum_exp_output = exp_output.sum_axis(ndarray::Axis(1));
119 let softmax_output = exp_output / &sum_exp_output;
120 softmax_output.into_raw_vec()
121 }
122
bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32>123 pub fn bytes_to_f32_vec(data: Vec<u8>) -> Vec<f32> {
124 let chunks: Vec<&[u8]> = data.chunks(4).collect();
125 let v: Vec<f32> = chunks
126 .into_iter()
127 .map(|c| f32::from_ne_bytes(c.try_into().unwrap()))
128 .collect();
129
130 v.into_iter().collect()
131 }
132
133 // A wrapper for class ID and match probabilities.
134 #[derive(Debug, PartialEq)]
135 struct InferenceResult(String, f32);
136