xref: /wasmtime-44.0.1/crates/wizer/src/rewrite.rs (revision e6937050)
1 //! Final rewrite pass.
2 
3 use crate::{FuncRenames, SnapshotVal, Wizer, info::ModuleContext, snapshot::Snapshot};
4 use std::cell::Cell;
5 use std::convert::TryFrom;
6 use wasm_encoder::reencode::{Reencode, RoundtripReencoder};
7 use wasm_encoder::{ConstExpr, SectionId};
8 
9 impl Wizer {
10     /// Given the initialized snapshot, rewrite the Wasm so that it is already
11     /// initialized.
12     ///
rewrite( &self, module: &mut ModuleContext<'_>, snapshot: &Snapshot, renames: &FuncRenames, remove_wasi_initialize: bool, ) -> Vec<u8>13     pub(crate) fn rewrite(
14         &self,
15         module: &mut ModuleContext<'_>,
16         snapshot: &Snapshot,
17         renames: &FuncRenames,
18         remove_wasi_initialize: bool,
19     ) -> Vec<u8> {
20         log::debug!("Rewriting input Wasm to pre-initialized state");
21 
22         let mut encoder = wasm_encoder::Module::new();
23         let has_wasi_initialize = module.has_wasi_initialize();
24 
25         // Encode the initialized data segments from the snapshot rather
26         // than the original, uninitialized data segments.
27         let add_data_segments = |data_section: &mut wasm_encoder::DataSection| {
28             for seg in &snapshot.data_segments {
29                 let offset = if seg.is64 {
30                     ConstExpr::i64_const(seg.offset.cast_signed())
31                 } else {
32                     ConstExpr::i32_const(u32::try_from(seg.offset).unwrap().cast_signed())
33                 };
34                 data_section.active(seg.memory_index, &offset, seg.data.iter().copied());
35             }
36         };
37 
38         // There are multiple places were we potentially need to check whether
39         // we've added the data section already and if we haven't yet, then do
40         // so. For example, the original Wasm might not have a data section at
41         // all, and so we have to potentially add it at the end of iterating
42         // over the original sections. This closure encapsulates all that
43         // add-it-if-we-haven't-already logic in one place.
44         let added_data_section = Cell::new(false);
45 
46         let add_data_section = |encoder: &mut wasm_encoder::Module| {
47             if added_data_section.get() {
48                 return;
49             }
50             added_data_section.set(true);
51             let mut data_section = wasm_encoder::DataSection::new();
52             add_data_segments(&mut data_section);
53             encoder.section(&data_section);
54         };
55 
56         for section in module.raw_sections() {
57             match section {
58                 // Some tools expect the name custom section to come last, even
59                 // though custom sections are allowed in any order. Therefore,
60                 // make sure we've added our data section by now.
61                 s if is_name_section(s) => {
62                     add_data_section(&mut encoder);
63                     encoder.section(s);
64                 }
65 
66                 // For the memory section, we update the minimum size of each
67                 // defined memory to the snapshot's initialized size for that
68                 // memory.
69                 s if s.id == u8::from(SectionId::Memory) => {
70                     let mut memories = wasm_encoder::MemorySection::new();
71                     assert_eq!(module.defined_memories_len(), snapshot.memory_mins.len());
72                     for ((_, mem), new_min) in module
73                         .defined_memories()
74                         .zip(snapshot.memory_mins.iter().copied())
75                     {
76                         let mut mem = RoundtripReencoder.memory_type(mem).unwrap();
77                         mem.minimum = new_min;
78                         memories.memory(mem);
79                     }
80                     encoder.section(&memories);
81                 }
82 
83                 // Encode the initialized global values from the snapshot,
84                 // rather than the original values.
85                 s if s.id == u8::from(SectionId::Global) => {
86                     let original_globals = wasmparser::GlobalSectionReader::new(
87                         wasmparser::BinaryReader::new(s.data, 0),
88                     )
89                     .unwrap();
90                     let mut globals = wasm_encoder::GlobalSection::new();
91                     let mut snapshot = snapshot.globals.iter();
92                     for ((_, glob_ty, export_name), global) in
93                         module.defined_globals().zip(original_globals)
94                     {
95                         let global = global.unwrap();
96                         if export_name.is_some() {
97                             // This is a mutable global and it was present in
98                             // the snapshot, so translate the snapshot value to
99                             // a constant expression and insert it.
100                             assert!(glob_ty.mutable);
101                             let (_, val) = snapshot.next().unwrap();
102                             let init = match val {
103                                 SnapshotVal::I32(x) => ConstExpr::i32_const(*x),
104                                 SnapshotVal::I64(x) => ConstExpr::i64_const(*x),
105                                 SnapshotVal::F32(x) => {
106                                     ConstExpr::f32_const(wasm_encoder::Ieee32::new(*x))
107                                 }
108                                 SnapshotVal::F64(x) => {
109                                     ConstExpr::f64_const(wasm_encoder::Ieee64::new(*x))
110                                 }
111                                 SnapshotVal::V128(x) => ConstExpr::v128_const(x.cast_signed()),
112                             };
113                             let glob_ty = RoundtripReencoder.global_type(glob_ty).unwrap();
114                             globals.global(glob_ty, &init);
115                         } else {
116                             // This global isn't mutable so preserve its value
117                             // as-is.
118                             assert!(!glob_ty.mutable);
119                             RoundtripReencoder
120                                 .parse_global(&mut globals, global)
121                                 .unwrap();
122                         };
123                     }
124                     encoder.section(&globals);
125                 }
126 
127                 // Remove exports for the wizer initialization
128                 // function and WASI reactor _initialize function,
129                 // then perform any requested renames.
130                 s if s.id == u8::from(SectionId::Export) => {
131                     let mut exports = wasm_encoder::ExportSection::new();
132                     for export in module.exports() {
133                         if (export.name == self.get_init_func() && !self.get_keep_init_func())
134                             || (remove_wasi_initialize
135                                 && has_wasi_initialize
136                                 && export.name == "_initialize")
137                         {
138                             continue;
139                         }
140 
141                         if !renames.rename_src_to_dst.contains_key(export.name)
142                             && renames.rename_dsts.contains(export.name)
143                         {
144                             // A rename overwrites this export, and it is not
145                             // renamed to another export, so skip it.
146                             continue;
147                         }
148 
149                         let field = renames
150                             .rename_src_to_dst
151                             .get(export.name)
152                             .map_or(export.name, |f| f.as_str());
153 
154                         let kind = RoundtripReencoder.export_kind(export.kind).unwrap();
155                         exports.export(field, kind, export.index);
156                     }
157                     encoder.section(&exports);
158                 }
159 
160                 // Skip the `start` function -- it's already been run!
161                 s if s.id == u8::from(SectionId::Start) => {
162                     continue;
163                 }
164 
165                 // Add the data segments that are being added for the snapshot
166                 // to the data count section, if present.
167                 s if s.id == u8::from(SectionId::DataCount) => {
168                     let mut data = wasmparser::BinaryReader::new(s.data, 0);
169                     let prev = data.read_var_u32().unwrap();
170                     assert!(data.eof());
171                     encoder.section(&wasm_encoder::DataCountSection {
172                         count: prev + u32::try_from(snapshot.data_segments.len()).unwrap(),
173                     });
174                 }
175 
176                 s if s.id == u8::from(SectionId::Data) => {
177                     let mut section = wasm_encoder::DataSection::new();
178                     let data = wasmparser::BinaryReader::new(s.data, 0);
179                     for data in wasmparser::DataSectionReader::new(data).unwrap() {
180                         let data = data.unwrap();
181                         match data.kind {
182                             // Active data segments, by definition in wasm, are
183                             // truncated after instantiation. That means that
184                             // for the snapshot all active data segments, which
185                             // are already applied, are all turned into empty
186                             // passive segments instead.
187                             wasmparser::DataKind::Active { .. } => {
188                                 section.passive([]);
189                             }
190 
191                             // Passive segments are plumbed through as-is.
192                             wasmparser::DataKind::Passive => {
193                                 section.passive(data.data.iter().copied());
194                             }
195                         }
196                     }
197 
198                     // Append all the initializer data segments before adding
199                     // the section.
200                     add_data_segments(&mut section);
201                     encoder.section(&section);
202                     added_data_section.set(true);
203                 }
204 
205                 s => {
206                     encoder.section(s);
207                 }
208             }
209         }
210 
211         // Make sure that we've added our data section to the module.
212         add_data_section(&mut encoder);
213         encoder.finish()
214     }
215 }
216 
is_name_section(s: &wasm_encoder::RawSection) -> bool217 fn is_name_section(s: &wasm_encoder::RawSection) -> bool {
218     s.id == u8::from(SectionId::Custom) && {
219         let mut reader = wasmparser::BinaryReader::new(s.data, 0);
220         matches!(reader.read_string(), Ok("name"))
221     }
222 }
223