1 use crate::{metadata::MetadataValue, Status};
2 use bytes::{Buf, BufMut, BytesMut};
3 #[cfg(feature = "gzip")]
4 use flate2::read::{GzDecoder, GzEncoder};
5 #[cfg(feature = "deflate")]
6 use flate2::read::{ZlibDecoder, ZlibEncoder};
7 use std::fmt;
8 #[cfg(feature = "zstd")]
9 use zstd::stream::read::{Decoder, Encoder};
10
11 pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12 pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13
14 /// Struct used to configure which encodings are enabled on a server or channel.
15 ///
16 /// Represents an ordered list of compression encodings that are enabled.
17 #[derive(Debug, Default, Clone, Copy)]
18 pub struct EnabledCompressionEncodings {
19 inner: [Option<CompressionEncoding>; 3],
20 }
21
22 impl EnabledCompressionEncodings {
23 /// Enable a [`CompressionEncoding`].
24 ///
25 /// Adds the new encoding to the end of the encoding list.
enable(&mut self, encoding: CompressionEncoding)26 pub fn enable(&mut self, encoding: CompressionEncoding) {
27 for e in self.inner.iter_mut() {
28 match e {
29 Some(e) if *e == encoding => return,
30 None => {
31 *e = Some(encoding);
32 return;
33 }
34 _ => continue,
35 }
36 }
37 }
38
39 /// Remove the last [`CompressionEncoding`].
pop(&mut self) -> Option<CompressionEncoding>40 pub fn pop(&mut self) -> Option<CompressionEncoding> {
41 self.inner
42 .iter_mut()
43 .rev()
44 .find(|entry| entry.is_some())?
45 .take()
46 }
47
into_accept_encoding_header_value(self) -> Option<http::HeaderValue>48 pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49 let mut value = BytesMut::new();
50 for encoding in self.inner.into_iter().flatten() {
51 value.put_slice(encoding.as_str().as_bytes());
52 value.put_u8(b',');
53 }
54
55 if value.is_empty() {
56 return None;
57 }
58
59 value.put_slice(b"identity");
60 Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61 }
62
63 /// Check if a [`CompressionEncoding`] is enabled.
is_enabled(&self, encoding: CompressionEncoding) -> bool64 pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65 self.inner.contains(&Some(encoding))
66 }
67
68 /// Check if any [`CompressionEncoding`]s are enabled.
is_empty(&self) -> bool69 pub fn is_empty(&self) -> bool {
70 self.inner.iter().all(|e| e.is_none())
71 }
72 }
73
74 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
75 pub(crate) struct CompressionSettings {
76 pub(crate) encoding: CompressionEncoding,
77 /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
78 /// The default buffer growth interval is 8 kilobytes.
79 pub(crate) buffer_growth_interval: usize,
80 }
81
82 /// The compression encodings Tonic supports.
83 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
84 #[non_exhaustive]
85 pub enum CompressionEncoding {
86 #[allow(missing_docs)]
87 #[cfg(feature = "gzip")]
88 Gzip,
89 #[allow(missing_docs)]
90 #[cfg(feature = "deflate")]
91 Deflate,
92 #[allow(missing_docs)]
93 #[cfg(feature = "zstd")]
94 Zstd,
95 }
96
97 impl CompressionEncoding {
98 pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99 #[cfg(feature = "gzip")]
100 CompressionEncoding::Gzip,
101 #[cfg(feature = "deflate")]
102 CompressionEncoding::Deflate,
103 #[cfg(feature = "zstd")]
104 CompressionEncoding::Zstd,
105 ];
106
107 /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
from_accept_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Option<Self>108 pub(crate) fn from_accept_encoding_header(
109 map: &http::HeaderMap,
110 enabled_encodings: EnabledCompressionEncodings,
111 ) -> Option<Self> {
112 if enabled_encodings.is_empty() {
113 return None;
114 }
115
116 let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117 let header_value_str = header_value.to_str().ok()?;
118
119 split_by_comma(header_value_str).find_map(|value| match value {
120 #[cfg(feature = "gzip")]
121 "gzip" => Some(CompressionEncoding::Gzip),
122 #[cfg(feature = "deflate")]
123 "deflate" => Some(CompressionEncoding::Deflate),
124 #[cfg(feature = "zstd")]
125 "zstd" => Some(CompressionEncoding::Zstd),
126 _ => None,
127 })
128 }
129
130 /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
from_encoding_header( map: &http::HeaderMap, enabled_encodings: EnabledCompressionEncodings, ) -> Result<Option<Self>, Status>131 pub(crate) fn from_encoding_header(
132 map: &http::HeaderMap,
133 enabled_encodings: EnabledCompressionEncodings,
134 ) -> Result<Option<Self>, Status> {
135 let Some(header_value) = map.get(ENCODING_HEADER) else {
136 return Ok(None);
137 };
138
139 match header_value.as_bytes() {
140 #[cfg(feature = "gzip")]
141 b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142 Ok(Some(CompressionEncoding::Gzip))
143 }
144 #[cfg(feature = "deflate")]
145 b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146 Ok(Some(CompressionEncoding::Deflate))
147 }
148 #[cfg(feature = "zstd")]
149 b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150 Ok(Some(CompressionEncoding::Zstd))
151 }
152 b"identity" => Ok(None),
153 other => {
154 // NOTE: Workaround for lifetime limitation. Resolved at Rust 1.79.
155 // https://blog.rust-lang.org/2024/06/13/Rust-1.79.0.html#extending-automatic-temporary-lifetime-extension
156 let other_debug_string;
157
158 let mut status = Status::unimplemented(format!(
159 "Content is compressed with `{}` which isn't supported",
160 match std::str::from_utf8(other) {
161 Ok(s) => s,
162 Err(_) => {
163 other_debug_string = format!("{other:?}");
164 &other_debug_string
165 }
166 }
167 ));
168
169 let header_value = enabled_encodings
170 .into_accept_encoding_header_value()
171 .map(MetadataValue::unchecked_from_header_value)
172 .unwrap_or_else(|| MetadataValue::from_static("identity"));
173 status
174 .metadata_mut()
175 .insert(ACCEPT_ENCODING_HEADER, header_value);
176
177 Err(status)
178 }
179 }
180 }
181
as_str(self) -> &'static str182 pub(crate) fn as_str(self) -> &'static str {
183 match self {
184 #[cfg(feature = "gzip")]
185 CompressionEncoding::Gzip => "gzip",
186 #[cfg(feature = "deflate")]
187 CompressionEncoding::Deflate => "deflate",
188 #[cfg(feature = "zstd")]
189 CompressionEncoding::Zstd => "zstd",
190 }
191 }
192
193 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
into_header_value(self) -> http::HeaderValue194 pub(crate) fn into_header_value(self) -> http::HeaderValue {
195 http::HeaderValue::from_static(self.as_str())
196 }
197 }
198
199 impl fmt::Display for CompressionEncoding {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result200 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201 f.write_str(self.as_str())
202 }
203 }
204
split_by_comma(s: &str) -> impl Iterator<Item = &str>205 fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
206 s.split(',').map(|s| s.trim())
207 }
208
209 /// Compress `len` bytes from `decompressed_buf` into `out_buf`.
210 /// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
211 #[allow(unused_variables, unreachable_code)]
compress( settings: CompressionSettings, decompressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>212 pub(crate) fn compress(
213 settings: CompressionSettings,
214 decompressed_buf: &mut BytesMut,
215 out_buf: &mut BytesMut,
216 len: usize,
217 ) -> Result<(), std::io::Error> {
218 let buffer_growth_interval = settings.buffer_growth_interval;
219 let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
220 out_buf.reserve(capacity);
221
222 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
223 let mut out_writer = out_buf.writer();
224
225 match settings.encoding {
226 #[cfg(feature = "gzip")]
227 CompressionEncoding::Gzip => {
228 let mut gzip_encoder = GzEncoder::new(
229 &decompressed_buf[0..len],
230 // FIXME: support customizing the compression level
231 flate2::Compression::new(6),
232 );
233 std::io::copy(&mut gzip_encoder, &mut out_writer)?;
234 }
235 #[cfg(feature = "deflate")]
236 CompressionEncoding::Deflate => {
237 let mut deflate_encoder = ZlibEncoder::new(
238 &decompressed_buf[0..len],
239 // FIXME: support customizing the compression level
240 flate2::Compression::new(6),
241 );
242 std::io::copy(&mut deflate_encoder, &mut out_writer)?;
243 }
244 #[cfg(feature = "zstd")]
245 CompressionEncoding::Zstd => {
246 let mut zstd_encoder = Encoder::new(
247 &decompressed_buf[0..len],
248 // FIXME: support customizing the compression level
249 zstd::DEFAULT_COMPRESSION_LEVEL,
250 )?;
251 std::io::copy(&mut zstd_encoder, &mut out_writer)?;
252 }
253 }
254
255 decompressed_buf.advance(len);
256
257 Ok(())
258 }
259
260 /// Decompress `len` bytes from `compressed_buf` into `out_buf`.
261 #[allow(unused_variables, unreachable_code)]
decompress( settings: CompressionSettings, compressed_buf: &mut BytesMut, out_buf: &mut BytesMut, len: usize, ) -> Result<(), std::io::Error>262 pub(crate) fn decompress(
263 settings: CompressionSettings,
264 compressed_buf: &mut BytesMut,
265 out_buf: &mut BytesMut,
266 len: usize,
267 ) -> Result<(), std::io::Error> {
268 let buffer_growth_interval = settings.buffer_growth_interval;
269 let estimate_decompressed_len = len * 2;
270 let capacity =
271 ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
272 out_buf.reserve(capacity);
273
274 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
275 let mut out_writer = out_buf.writer();
276
277 match settings.encoding {
278 #[cfg(feature = "gzip")]
279 CompressionEncoding::Gzip => {
280 let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
281 std::io::copy(&mut gzip_decoder, &mut out_writer)?;
282 }
283 #[cfg(feature = "deflate")]
284 CompressionEncoding::Deflate => {
285 let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
286 std::io::copy(&mut deflate_decoder, &mut out_writer)?;
287 }
288 #[cfg(feature = "zstd")]
289 CompressionEncoding::Zstd => {
290 let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
291 std::io::copy(&mut zstd_decoder, &mut out_writer)?;
292 }
293 }
294
295 compressed_buf.advance(len);
296
297 Ok(())
298 }
299
300 #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
301 pub enum SingleMessageCompressionOverride {
302 /// Inherit whatever compression is already configured. If the stream is compressed this
303 /// message will also be configured.
304 ///
305 /// This is the default.
306 #[default]
307 Inherit,
308 /// Don't compress this message, even if compression is enabled on the stream.
309 Disable,
310 }
311
312 #[cfg(test)]
313 mod tests {
314 #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
315 use http::HeaderValue;
316
317 use super::*;
318
319 #[test]
convert_none_into_header_value()320 fn convert_none_into_header_value() {
321 let encodings = EnabledCompressionEncodings::default();
322
323 assert!(encodings.into_accept_encoding_header_value().is_none());
324 }
325
326 #[test]
327 #[cfg(feature = "gzip")]
convert_gzip_into_header_value()328 fn convert_gzip_into_header_value() {
329 const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
330
331 let encodings = EnabledCompressionEncodings {
332 inner: [Some(CompressionEncoding::Gzip), None, None],
333 };
334
335 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
336
337 let encodings = EnabledCompressionEncodings {
338 inner: [None, None, Some(CompressionEncoding::Gzip)],
339 };
340
341 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
342 }
343
344 #[test]
345 #[cfg(feature = "zstd")]
convert_zstd_into_header_value()346 fn convert_zstd_into_header_value() {
347 const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
348
349 let encodings = EnabledCompressionEncodings {
350 inner: [Some(CompressionEncoding::Zstd), None, None],
351 };
352
353 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
354
355 let encodings = EnabledCompressionEncodings {
356 inner: [None, None, Some(CompressionEncoding::Zstd)],
357 };
358
359 assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
360 }
361
362 #[test]
363 #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
convert_compression_encodings_into_header_value()364 fn convert_compression_encodings_into_header_value() {
365 let encodings = EnabledCompressionEncodings {
366 inner: [
367 Some(CompressionEncoding::Gzip),
368 Some(CompressionEncoding::Deflate),
369 Some(CompressionEncoding::Zstd),
370 ],
371 };
372
373 assert_eq!(
374 encodings.into_accept_encoding_header_value().unwrap(),
375 HeaderValue::from_static("gzip,deflate,zstd,identity"),
376 );
377
378 let encodings = EnabledCompressionEncodings {
379 inner: [
380 Some(CompressionEncoding::Zstd),
381 Some(CompressionEncoding::Deflate),
382 Some(CompressionEncoding::Gzip),
383 ],
384 };
385
386 assert_eq!(
387 encodings.into_accept_encoding_header_value().unwrap(),
388 HeaderValue::from_static("zstd,deflate,gzip,identity"),
389 );
390 }
391 }
392