xref: /tonic/tonic/src/codec/compression.rs (revision 79a06cc8)
1 use crate::{metadata::MetadataValue, Status};
2 use bytes::{Buf, BufMut, BytesMut};
3 #[cfg(feature = "gzip")]
4 use flate2::read::{GzDecoder, GzEncoder};
5 #[cfg(feature = "deflate")]
6 use flate2::read::{ZlibDecoder, ZlibEncoder};
7 use std::fmt;
8 #[cfg(feature = "zstd")]
9 use zstd::stream::read::{Decoder, Encoder};
10 
11 pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12 pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13 
14 /// Struct used to configure which encodings are enabled on a server or channel.
15 ///
16 /// Represents an ordered list of compression encodings that are enabled.
17 #[derive(Debug, Default, Clone, Copy)]
18 pub struct EnabledCompressionEncodings {
19     inner: [Option<CompressionEncoding>; 3],
20 }
21 
22 impl EnabledCompressionEncodings {
23     /// Enable a [`CompressionEncoding`].
24     ///
25     /// Adds the new encoding to the end of the encoding list.
enable(&mut self, encoding: CompressionEncoding)26     pub fn enable(&mut self, encoding: CompressionEncoding) {
27         for e in self.inner.iter_mut() {
28             match e {
29                 Some(e) if *e == encoding => return,
30                 None => {
31                     *e = Some(encoding);
32                     return;
33                 }
34                 _ => continue,
35             }
36         }
37     }
38 
39     /// Remove the last [`CompressionEncoding`].
pop(&mut self) -> Option<CompressionEncoding>40     pub fn pop(&mut self) -> Option<CompressionEncoding> {
41         self.inner
42             .iter_mut()
43             .rev()
44             .find(|entry| entry.is_some())?
45             .take()
46     }
47 
into_accept_encoding_header_value(self) -> Option<http::HeaderValue>48     pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49         let mut value = BytesMut::new();
50         for encoding in self.inner.into_iter().flatten() {
51             value.put_slice(encoding.as_str().as_bytes());
52             value.put_u8(b',');
53         }
54 
55         if value.is_empty() {
56             return None;
57         }
58 
59         value.put_slice(b"identity");
60         Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61     }
62 
63     /// Check if a [`CompressionEncoding`] is enabled.
is_enabled(&self, encoding: CompressionEncoding) -> bool64     pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65         self.inner.contains(&Some(encoding))
66     }
67 
68     /// Check if any [`CompressionEncoding`]s are enabled.
is_empty(&self) -> bool69     pub fn is_empty(&self) -> bool {
70         self.inner.iter().all(|e| e.is_none())
71     }
72 }
73 
74 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
75 pub(crate) struct CompressionSettings {
76     pub(crate) encoding: CompressionEncoding,
77     /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
78     /// The default buffer growth interval is 8 kilobytes.
79     pub(crate) buffer_growth_interval: usize,
80 }
81 
82 /// The compression encodings Tonic supports.
83 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
84 #[non_exhaustive]
85 pub enum CompressionEncoding {
86     #[allow(missing_docs)]
87     #[cfg(feature = "gzip")]
88     Gzip,
89     #[allow(missing_docs)]
90     #[cfg(feature = "deflate")]
91     Deflate,
92     #[allow(missing_docs)]
93     #[cfg(feature = "zstd")]
94     Zstd,
95 }
96 
97 impl CompressionEncoding {
98     pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99         #[cfg(feature = "gzip")]
100         CompressionEncoding::Gzip,
101         #[cfg(feature = "deflate")]
102         CompressionEncoding::Deflate,
103         #[cfg(feature = "zstd")]
104         CompressionEncoding::Zstd,
105     ];
106 
107     /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
from_accept_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Option<Self>108     pub(crate) fn from_accept_encoding_header(
109         map: &http::HeaderMap,
110         enabled_encodings: EnabledCompressionEncodings,
111     ) -> Option<Self> {
112         if enabled_encodings.is_empty() {
113             return None;
114         }
115 
116         let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117         let header_value_str = header_value.to_str().ok()?;
118 
119         split_by_comma(header_value_str).find_map(|value| match value {
120             #[cfg(feature = "gzip")]
121             "gzip" => Some(CompressionEncoding::Gzip),
122             #[cfg(feature = "deflate")]
123             "deflate" => Some(CompressionEncoding::Deflate),
124             #[cfg(feature = "zstd")]
125             "zstd" => Some(CompressionEncoding::Zstd),
126             _ => None,
127         })
128     }
129 
130     /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
from_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Result<Option<Self>, Status>131     pub(crate) fn from_encoding_header(
132         map: &http::HeaderMap,
133         enabled_encodings: EnabledCompressionEncodings,
134     ) -> Result<Option<Self>, Status> {
135         let Some(header_value) = map.get(ENCODING_HEADER) else {
136             return Ok(None);
137         };
138 
139         match header_value.as_bytes() {
140             #[cfg(feature = "gzip")]
141             b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142                 Ok(Some(CompressionEncoding::Gzip))
143             }
144             #[cfg(feature = "deflate")]
145             b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146                 Ok(Some(CompressionEncoding::Deflate))
147             }
148             #[cfg(feature = "zstd")]
149             b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150                 Ok(Some(CompressionEncoding::Zstd))
151             }
152             b"identity" => Ok(None),
153             other => {
154                 // NOTE: Workaround for lifetime limitation. Resolved at Rust 1.79.
155                 // https://blog.rust-lang.org/2024/06/13/Rust-1.79.0.html#extending-automatic-temporary-lifetime-extension
156                 let other_debug_string;
157 
158                 let mut status = Status::unimplemented(format!(
159                     "Content is compressed with `{}` which isn't supported",
160                     match std::str::from_utf8(other) {
161                         Ok(s) => s,
162                         Err(_) => {
163                             other_debug_string = format!("{other:?}");
164                             &other_debug_string
165                         }
166                     }
167                 ));
168 
169                 let header_value = enabled_encodings
170                     .into_accept_encoding_header_value()
171                     .map(MetadataValue::unchecked_from_header_value)
172                     .unwrap_or_else(|| MetadataValue::from_static("identity"));
173                 status
174                     .metadata_mut()
175                     .insert(ACCEPT_ENCODING_HEADER, header_value);
176 
177                 Err(status)
178             }
179         }
180     }
181 
as_str(self) -> &'static str182     pub(crate) fn as_str(self) -> &'static str {
183         match self {
184             #[cfg(feature = "gzip")]
185             CompressionEncoding::Gzip => "gzip",
186             #[cfg(feature = "deflate")]
187             CompressionEncoding::Deflate => "deflate",
188             #[cfg(feature = "zstd")]
189             CompressionEncoding::Zstd => "zstd",
190         }
191     }
192 
193     #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
into_header_value(self) -> http::HeaderValue194     pub(crate) fn into_header_value(self) -> http::HeaderValue {
195         http::HeaderValue::from_static(self.as_str())
196     }
197 }
198 
199 impl fmt::Display for CompressionEncoding {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result200     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201         f.write_str(self.as_str())
202     }
203 }
204 
split_by_comma(s: &str) -> impl Iterator<Item = &str>205 fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
206     s.split(',').map(|s| s.trim())
207 }
208 
209 /// Compress `len` bytes from `decompressed_buf` into `out_buf`.
210 /// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
211 #[allow(unused_variables, unreachable_code)]
compress( settings: CompressionSettings, decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>212 pub(crate) fn compress(
213     settings: CompressionSettings,
214     decompressed_buf: &mut BytesMut,
215     out_buf: &mut BytesMut,
216     len: usize,
217 ) -> Result<(), std::io::Error> {
218     let buffer_growth_interval = settings.buffer_growth_interval;
219     let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
220     out_buf.reserve(capacity);
221 
222     #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
223     let mut out_writer = out_buf.writer();
224 
225     match settings.encoding {
226         #[cfg(feature = "gzip")]
227         CompressionEncoding::Gzip => {
228             let mut gzip_encoder = GzEncoder::new(
229                 &decompressed_buf[0..len],
230                 // FIXME: support customizing the compression level
231                 flate2::Compression::new(6),
232             );
233             std::io::copy(&mut gzip_encoder, &mut out_writer)?;
234         }
235         #[cfg(feature = "deflate")]
236         CompressionEncoding::Deflate => {
237             let mut deflate_encoder = ZlibEncoder::new(
238                 &decompressed_buf[0..len],
239                 // FIXME: support customizing the compression level
240                 flate2::Compression::new(6),
241             );
242             std::io::copy(&mut deflate_encoder, &mut out_writer)?;
243         }
244         #[cfg(feature = "zstd")]
245         CompressionEncoding::Zstd => {
246             let mut zstd_encoder = Encoder::new(
247                 &decompressed_buf[0..len],
248                 // FIXME: support customizing the compression level
249                 zstd::DEFAULT_COMPRESSION_LEVEL,
250             )?;
251             std::io::copy(&mut zstd_encoder, &mut out_writer)?;
252         }
253     }
254 
255     decompressed_buf.advance(len);
256 
257     Ok(())
258 }
259 
260 /// Decompress `len` bytes from `compressed_buf` into `out_buf`.
261 #[allow(unused_variables, unreachable_code)]
decompress( settings: CompressionSettings, compressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>262 pub(crate) fn decompress(
263     settings: CompressionSettings,
264     compressed_buf: &mut BytesMut,
265     out_buf: &mut BytesMut,
266     len: usize,
267 ) -> Result<(), std::io::Error> {
268     let buffer_growth_interval = settings.buffer_growth_interval;
269     let estimate_decompressed_len = len * 2;
270     let capacity =
271         ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
272     out_buf.reserve(capacity);
273 
274     #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
275     let mut out_writer = out_buf.writer();
276 
277     match settings.encoding {
278         #[cfg(feature = "gzip")]
279         CompressionEncoding::Gzip => {
280             let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
281             std::io::copy(&mut gzip_decoder, &mut out_writer)?;
282         }
283         #[cfg(feature = "deflate")]
284         CompressionEncoding::Deflate => {
285             let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
286             std::io::copy(&mut deflate_decoder, &mut out_writer)?;
287         }
288         #[cfg(feature = "zstd")]
289         CompressionEncoding::Zstd => {
290             let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
291             std::io::copy(&mut zstd_decoder, &mut out_writer)?;
292         }
293     }
294 
295     compressed_buf.advance(len);
296 
297     Ok(())
298 }
299 
300 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
301 pub enum SingleMessageCompressionOverride {
302     /// Inherit whatever compression is already configured. If the stream is compressed this
303     /// message will also be configured.
304     ///
305     /// This is the default.
306     #[default]
307     Inherit,
308     /// Don't compress this message, even if compression is enabled on the stream.
309     Disable,
310 }
311 
312 #[cfg(test)]
313 mod tests {
314     #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
315     use http::HeaderValue;
316 
317     use super::*;
318 
319     #[test]
convert_none_into_header_value()320     fn convert_none_into_header_value() {
321         let encodings = EnabledCompressionEncodings::default();
322 
323         assert!(encodings.into_accept_encoding_header_value().is_none());
324     }
325 
326     #[test]
327     #[cfg(feature = "gzip")]
convert_gzip_into_header_value()328     fn convert_gzip_into_header_value() {
329         const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
330 
331         let encodings = EnabledCompressionEncodings {
332             inner: [Some(CompressionEncoding::Gzip), None, None],
333         };
334 
335         assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
336 
337         let encodings = EnabledCompressionEncodings {
338             inner: [None, None, Some(CompressionEncoding::Gzip)],
339         };
340 
341         assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
342     }
343 
344     #[test]
345     #[cfg(feature = "zstd")]
convert_zstd_into_header_value()346     fn convert_zstd_into_header_value() {
347         const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
348 
349         let encodings = EnabledCompressionEncodings {
350             inner: [Some(CompressionEncoding::Zstd), None, None],
351         };
352 
353         assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
354 
355         let encodings = EnabledCompressionEncodings {
356             inner: [None, None, Some(CompressionEncoding::Zstd)],
357         };
358 
359         assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
360     }
361 
362     #[test]
363     #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
convert_compression_encodings_into_header_value()364     fn convert_compression_encodings_into_header_value() {
365         let encodings = EnabledCompressionEncodings {
366             inner: [
367                 Some(CompressionEncoding::Gzip),
368                 Some(CompressionEncoding::Deflate),
369                 Some(CompressionEncoding::Zstd),
370             ],
371         };
372 
373         assert_eq!(
374             encodings.into_accept_encoding_header_value().unwrap(),
375             HeaderValue::from_static("gzip,deflate,zstd,identity"),
376         );
377 
378         let encodings = EnabledCompressionEncodings {
379             inner: [
380                 Some(CompressionEncoding::Zstd),
381                 Some(CompressionEncoding::Deflate),
382                 Some(CompressionEncoding::Gzip),
383             ],
384         };
385 
386         assert_eq!(
387             encodings.into_accept_encoding_header_value().unwrap(),
388             HeaderValue::from_static("zstd,deflate,gzip,identity"),
389         );
390     }
391 }
392