1 use crate::component::ComponentContext;
2 use crate::component::info::RawSection;
3 use crate::component::snapshot::ComponentSnapshot;
4 use crate::{FuncRenames, Wizer};
5 use wasm_encoder::reencode::{Reencode, ReencodeComponent};
6 
7 impl Wizer {
8     /// Helper method which is the equivalent of [`Wizer::rewrite`], but for
9     /// components.
10     ///
11     /// This effectively plumbs through all non-module sections as-is and
12     /// updates module sections with whatever [`Wizer::rewrite`] returns.
rewrite_component( &self, component: &mut ComponentContext<'_>, snapshot: &ComponentSnapshot, ) -> Vec<u8>13     pub(crate) fn rewrite_component(
14         &self,
15         component: &mut ComponentContext<'_>,
16         snapshot: &ComponentSnapshot,
17     ) -> Vec<u8> {
18         let mut encoder = wasm_encoder::Component::new();
19         let mut reencoder = Reencoder {
20             funcs: 0,
21             removed_func: None,
22             wizer: self,
23         };
24 
25         let mut module_index = 0;
26         for section in component.sections.iter_mut() {
27             match section {
28                 RawSection::Module(module) => {
29                     let snapshot = snapshot
30                         .modules
31                         .iter()
32                         .find(|(i, _)| *i == module_index)
33                         .map(|(_, s)| s);
34                     module_index += 1;
35                     match snapshot {
36                         // This module's snapshot is used for [`Wizer::rewrite`]
37                         // and the results of that are spliced into the
38                         // component.
39                         Some(snapshot) => {
40                             let rewritten_wasm =
41                                 self.rewrite(module, snapshot, &FuncRenames::default(), false);
42                             encoder.section(&wasm_encoder::RawSection {
43                                 id: wasm_encoder::ComponentSectionId::CoreModule as u8,
44                                 data: &rewritten_wasm,
45                             });
46                         }
47 
48                         // This module wasn't instantiated and has no snapshot,
49                         // plumb it through as-is.
50                         None => {
51                             let mut module_encoder = wasm_encoder::Module::new();
52                             for section in module.raw_sections() {
53                                 module_encoder.section(section);
54                             }
55                             encoder.section(&wasm_encoder::ModuleSection(&module_encoder));
56                         }
57                     }
58                 }
59                 RawSection::Raw(s) => {
60                     reencoder.raw_section(&mut encoder, s);
61                 }
62             }
63         }
64 
65         encoder.finish()
66     }
67 }
68 
69 struct Reencoder<'a> {
70     /// Number of defined functions encountered so far.
71     funcs: u32,
72     /// Index of the start function that's being removed, used to renumber all
73     /// other functions.
74     removed_func: Option<u32>,
75     /// Wizer configuration.
76     wizer: &'a Wizer,
77 }
78 
79 impl Reencoder<'_> {
raw_section( &mut self, encoder: &mut wasm_encoder::Component, section: &wasm_encoder::RawSection, )80     fn raw_section(
81         &mut self,
82         encoder: &mut wasm_encoder::Component,
83         section: &wasm_encoder::RawSection,
84     ) {
85         match section.id {
86             // These can't define component functions so the sections are
87             // plumbed as-is.
88             id if id == wasm_encoder::ComponentSectionId::CoreCustom as u8
89                 || id == wasm_encoder::ComponentSectionId::CoreInstance as u8
90                 || id == wasm_encoder::ComponentSectionId::CoreType as u8
91                 || id == wasm_encoder::ComponentSectionId::Component as u8
92                 || id == wasm_encoder::ComponentSectionId::Type as u8 =>
93             {
94                 encoder.section(section);
95             }
96 
97             id if id == wasm_encoder::ComponentSectionId::CoreModule as u8 => {
98                 panic!("should happen in caller");
99             }
100             id if id == wasm_encoder::ComponentSectionId::Start as u8 => {
101                 // Component start sections aren't supported yet anyway
102                 todo!()
103             }
104 
105             // These sections all might affect or refer to component function
106             // indices so they're reencoded here, optionally updating function
107             // indices in case the index is higher than the one that we're
108             // removing.
109             id if id == wasm_encoder::ComponentSectionId::Instance as u8 => {
110                 self.rewrite(
111                     encoder,
112                     section.data,
113                     Self::parse_component_instance_section,
114                 );
115             }
116             id if id == wasm_encoder::ComponentSectionId::Alias as u8 => {
117                 self.rewrite(encoder, section.data, Self::parse_component_alias_section);
118             }
119             id if id == wasm_encoder::ComponentSectionId::CanonicalFunction as u8 => {
120                 self.rewrite(
121                     encoder,
122                     section.data,
123                     Self::parse_component_canonical_section,
124                 );
125             }
126             id if id == wasm_encoder::ComponentSectionId::Import as u8 => {
127                 self.rewrite(encoder, section.data, Self::parse_component_import_section);
128             }
129             id if id == wasm_encoder::ComponentSectionId::Export as u8 => {
130                 self.rewrite(encoder, section.data, Self::parse_component_export_section);
131             }
132             other => panic!("unexpected component section id: {other}"),
133         }
134     }
135 
rewrite<'a, T, S>( &mut self, encoder: &mut wasm_encoder::Component, data: &'a [u8], f: fn(&mut Self, dst: &mut S, wasmparser::SectionLimited<'a, T>) -> Result<(), Error>, ) where T: wasmparser::FromReader<'a>, S: Default + wasm_encoder::ComponentSection,136     fn rewrite<'a, T, S>(
137         &mut self,
138         encoder: &mut wasm_encoder::Component,
139         data: &'a [u8],
140         f: fn(&mut Self, dst: &mut S, wasmparser::SectionLimited<'a, T>) -> Result<(), Error>,
141     ) where
142         T: wasmparser::FromReader<'a>,
143         S: Default + wasm_encoder::ComponentSection,
144     {
145         let mut section = S::default();
146         f(
147             self,
148             &mut section,
149             wasmparser::SectionLimited::new(wasmparser::BinaryReader::new(data, 0)).unwrap(),
150         )
151         .unwrap();
152         encoder.section(&section);
153     }
154 }
155 
156 impl Reencode for Reencoder<'_> {
157     type Error = std::convert::Infallible;
158 }
159 type Error = wasm_encoder::reencode::Error<std::convert::Infallible>;
160 
161 impl ReencodeComponent for Reencoder<'_> {
component_func_index(&mut self, original_index: u32) -> u32162     fn component_func_index(&mut self, original_index: u32) -> u32 {
163         match self.removed_func {
164             None => original_index,
165             Some(removed) => {
166                 if original_index < removed {
167                     original_index
168                 } else if original_index == removed {
169                     panic!("referenced removed function")
170                 } else {
171                     original_index - 1
172                 }
173             }
174         }
175     }
176 
parse_component_alias_section( &mut self, aliases: &mut wasm_encoder::ComponentAliasSection, section: wasmparser::ComponentAliasSectionReader<'_>, ) -> Result<(), Error>177     fn parse_component_alias_section(
178         &mut self,
179         aliases: &mut wasm_encoder::ComponentAliasSection,
180         section: wasmparser::ComponentAliasSectionReader<'_>,
181     ) -> Result<(), Error> {
182         for alias in section.clone() {
183             let alias = alias?;
184             if let wasmparser::ComponentAlias::InstanceExport {
185                 kind: wasmparser::ComponentExternalKind::Func,
186                 ..
187             } = alias
188             {
189                 self.funcs += 1;
190             }
191         }
192 
193         wasm_encoder::reencode::component_utils::parse_component_alias_section(
194             self, aliases, section,
195         )
196     }
197 
parse_component_canonical_section( &mut self, canonicals: &mut wasm_encoder::CanonicalFunctionSection, section: wasmparser::ComponentCanonicalSectionReader<'_>, ) -> Result<(), Error>198     fn parse_component_canonical_section(
199         &mut self,
200         canonicals: &mut wasm_encoder::CanonicalFunctionSection,
201         section: wasmparser::ComponentCanonicalSectionReader<'_>,
202     ) -> Result<(), Error> {
203         for canonical in section.clone() {
204             let canonical = canonical?;
205             if let wasmparser::CanonicalFunction::Lift { .. } = canonical {
206                 self.funcs += 1;
207             }
208         }
209 
210         wasm_encoder::reencode::component_utils::parse_component_canonical_section(
211             self, canonicals, section,
212         )
213     }
214 
parse_component_import_section( &mut self, imports: &mut wasm_encoder::ComponentImportSection, section: wasmparser::ComponentImportSectionReader<'_>, ) -> Result<(), Error>215     fn parse_component_import_section(
216         &mut self,
217         imports: &mut wasm_encoder::ComponentImportSection,
218         section: wasmparser::ComponentImportSectionReader<'_>,
219     ) -> Result<(), Error> {
220         for import in section.clone() {
221             let import = import?;
222             if let wasmparser::ComponentExternalKind::Func = import.ty.kind() {
223                 self.funcs += 1;
224             }
225         }
226 
227         wasm_encoder::reencode::component_utils::parse_component_import_section(
228             self, imports, section,
229         )
230     }
231 
parse_component_export_section( &mut self, exports: &mut wasm_encoder::ComponentExportSection, section: wasmparser::ComponentExportSectionReader<'_>, ) -> Result<(), Error>232     fn parse_component_export_section(
233         &mut self,
234         exports: &mut wasm_encoder::ComponentExportSection,
235         section: wasmparser::ComponentExportSectionReader<'_>,
236     ) -> Result<(), Error> {
237         for export in section {
238             let export = export?;
239             if !self.wizer.get_keep_init_func() && export.name.0 == self.wizer.get_init_func() {
240                 self.removed_func = Some(self.funcs);
241             } else {
242                 if export.kind == wasmparser::ComponentExternalKind::Func {
243                     self.funcs += 1;
244                 }
245                 self.parse_component_export(exports, export)?;
246             }
247         }
248         Ok(())
249     }
250 }
251