xref: /tonic/tonic/src/codec/encode.rs (revision 3c900ebd)
1 use super::compression::{
2     compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride,
3 };
4 use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
5 use crate::Status;
6 use bytes::{BufMut, Bytes, BytesMut};
7 use http::HeaderMap;
8 use http_body::{Body, Frame};
9 use pin_project::pin_project;
10 use std::{
11     pin::Pin,
12     task::{ready, Context, Poll},
13 };
14 use tokio_stream::{adapters::Fuse, Stream, StreamExt};
15 
16 /// Combinator for efficient encoding of messages into reasonably sized buffers.
17 /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
18 /// splitting off and yielding a buffer when either:
19 ///  * The delegate stream polls as not ready, or
20 ///  * The encoded buffer surpasses YIELD_THRESHOLD.
21 #[pin_project(project = EncodedBytesProj)]
22 #[derive(Debug)]
23 struct EncodedBytes<T, U> {
24     #[pin]
25     source: Fuse<U>,
26     encoder: T,
27     compression_encoding: Option<CompressionEncoding>,
28     max_message_size: Option<usize>,
29     buf: BytesMut,
30     uncompression_buf: BytesMut,
31     error: Option<Status>,
32 }
33 
34 impl<T: Encoder, U: Stream> EncodedBytes<T, U> {
new( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> Self35     fn new(
36         encoder: T,
37         source: U,
38         compression_encoding: Option<CompressionEncoding>,
39         compression_override: SingleMessageCompressionOverride,
40         max_message_size: Option<usize>,
41     ) -> Self {
42         let buffer_settings = encoder.buffer_settings();
43         let buf = BytesMut::with_capacity(buffer_settings.buffer_size);
44 
45         let compression_encoding =
46             if compression_override == SingleMessageCompressionOverride::Disable {
47                 None
48             } else {
49                 compression_encoding
50             };
51 
52         let uncompression_buf = if compression_encoding.is_some() {
53             BytesMut::with_capacity(buffer_settings.buffer_size)
54         } else {
55             BytesMut::new()
56         };
57 
58         Self {
59             source: source.fuse(),
60             encoder,
61             compression_encoding,
62             max_message_size,
63             buf,
64             uncompression_buf,
65             error: None,
66         }
67     }
68 }
69 
70 impl<T, U> Stream for EncodedBytes<T, U>
71 where
72     T: Encoder<Error = Status>,
73     U: Stream<Item = Result<T::Item, Status>>,
74 {
75     type Item = Result<Bytes, Status>;
76 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>77     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78         let EncodedBytesProj {
79             mut source,
80             encoder,
81             compression_encoding,
82             max_message_size,
83             buf,
84             uncompression_buf,
85             error,
86         } = self.project();
87         let buffer_settings = encoder.buffer_settings();
88 
89         if let Some(status) = error.take() {
90             return Poll::Ready(Some(Err(status)));
91         }
92 
93         loop {
94             match source.as_mut().poll_next(cx) {
95                 Poll::Pending if buf.is_empty() => {
96                     return Poll::Pending;
97                 }
98                 Poll::Ready(None) if buf.is_empty() => {
99                     return Poll::Ready(None);
100                 }
101                 Poll::Pending | Poll::Ready(None) => {
102                     return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
103                 }
104                 Poll::Ready(Some(Ok(item))) => {
105                     if let Err(status) = encode_item(
106                         encoder,
107                         buf,
108                         uncompression_buf,
109                         *compression_encoding,
110                         *max_message_size,
111                         buffer_settings,
112                         item,
113                     ) {
114                         return Poll::Ready(Some(Err(status)));
115                     }
116 
117                     if buf.len() >= buffer_settings.yield_threshold {
118                         return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
119                     }
120                 }
121                 Poll::Ready(Some(Err(status))) => {
122                     if buf.is_empty() {
123                         return Poll::Ready(Some(Err(status)));
124                     }
125                     *error = Some(status);
126                     return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
127                 }
128             }
129         }
130     }
131 }
132 
encode_item<T>( encoder: &mut T, buf: &mut BytesMut, uncompression_buf: &mut BytesMut, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buffer_settings: BufferSettings, item: T::Item, ) -> Result<(), Status> where T: Encoder<Error = Status>,133 fn encode_item<T>(
134     encoder: &mut T,
135     buf: &mut BytesMut,
136     uncompression_buf: &mut BytesMut,
137     compression_encoding: Option<CompressionEncoding>,
138     max_message_size: Option<usize>,
139     buffer_settings: BufferSettings,
140     item: T::Item,
141 ) -> Result<(), Status>
142 where
143     T: Encoder<Error = Status>,
144 {
145     let offset = buf.len();
146 
147     buf.reserve(HEADER_SIZE);
148     unsafe {
149         buf.advance_mut(HEADER_SIZE);
150     }
151 
152     if let Some(encoding) = compression_encoding {
153         uncompression_buf.clear();
154 
155         encoder
156             .encode(item, &mut EncodeBuf::new(uncompression_buf))
157             .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
158 
159         let uncompressed_len = uncompression_buf.len();
160 
161         compress(
162             CompressionSettings {
163                 encoding,
164                 buffer_growth_interval: buffer_settings.buffer_size,
165             },
166             uncompression_buf,
167             buf,
168             uncompressed_len,
169         )
170         .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
171     } else {
172         encoder
173             .encode(item, &mut EncodeBuf::new(buf))
174             .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
175     }
176 
177     // now that we know length, we can write the header
178     finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
179 }
180 
finish_encoding( compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buf: &mut [u8], ) -> Result<(), Status>181 fn finish_encoding(
182     compression_encoding: Option<CompressionEncoding>,
183     max_message_size: Option<usize>,
184     buf: &mut [u8],
185 ) -> Result<(), Status> {
186     let len = buf.len() - HEADER_SIZE;
187     let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
188     if len > limit {
189         return Err(Status::out_of_range(format!(
190             "Error, encoded message length too large: found {} bytes, the limit is: {} bytes",
191             len, limit
192         )));
193     }
194 
195     if len > u32::MAX as usize {
196         return Err(Status::resource_exhausted(format!(
197             "Cannot return body with more than 4GB of data but got {len} bytes"
198         )));
199     }
200     {
201         let mut buf = &mut buf[..HEADER_SIZE];
202         buf.put_u8(compression_encoding.is_some() as u8);
203         buf.put_u32(len as u32);
204     }
205 
206     Ok(())
207 }
208 
209 #[derive(Debug)]
210 enum Role {
211     Client,
212     Server,
213 }
214 
215 /// A specialized implementation of [Body] for encoding [Result<Bytes, Status>].
216 #[pin_project]
217 #[derive(Debug)]
218 pub struct EncodeBody<T, U> {
219     #[pin]
220     inner: EncodedBytes<T, U>,
221     state: EncodeState,
222 }
223 
224 #[derive(Debug)]
225 struct EncodeState {
226     error: Option<Status>,
227     role: Role,
228     is_end_stream: bool,
229 }
230 
231 impl<T: Encoder, U: Stream> EncodeBody<T, U> {
232     /// Turns a stream of grpc messages into [EncodeBody] which is used by grpc clients for
233     /// turning the messages into http frames for sending over the network.
new_client( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self234     pub fn new_client(
235         encoder: T,
236         source: U,
237         compression_encoding: Option<CompressionEncoding>,
238         max_message_size: Option<usize>,
239     ) -> Self {
240         Self {
241             inner: EncodedBytes::new(
242                 encoder,
243                 source,
244                 compression_encoding,
245                 SingleMessageCompressionOverride::default(),
246                 max_message_size,
247             ),
248             state: EncodeState {
249                 error: None,
250                 role: Role::Client,
251                 is_end_stream: false,
252             },
253         }
254     }
255 
256     /// Turns a stream of grpc results (message or error status) into [EncodeBody] which is used by grpc
257     /// servers for turning the messages into http frames for sending over the network.
new_server( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> Self258     pub fn new_server(
259         encoder: T,
260         source: U,
261         compression_encoding: Option<CompressionEncoding>,
262         compression_override: SingleMessageCompressionOverride,
263         max_message_size: Option<usize>,
264     ) -> Self {
265         Self {
266             inner: EncodedBytes::new(
267                 encoder,
268                 source,
269                 compression_encoding,
270                 compression_override,
271                 max_message_size,
272             ),
273             state: EncodeState {
274                 error: None,
275                 role: Role::Server,
276                 is_end_stream: false,
277             },
278         }
279     }
280 }
281 
282 impl EncodeState {
trailers(&mut self) -> Option<Result<HeaderMap, Status>>283     fn trailers(&mut self) -> Option<Result<HeaderMap, Status>> {
284         match self.role {
285             Role::Client => None,
286             Role::Server => {
287                 if self.is_end_stream {
288                     return None;
289                 }
290 
291                 self.is_end_stream = true;
292                 let status = if let Some(status) = self.error.take() {
293                     status
294                 } else {
295                     Status::ok("")
296                 };
297                 Some(status.to_header_map())
298             }
299         }
300     }
301 }
302 
303 impl<T, U> Body for EncodeBody<T, U>
304 where
305     T: Encoder<Error = Status>,
306     U: Stream<Item = Result<T::Item, Status>>,
307 {
308     type Data = Bytes;
309     type Error = Status;
310 
is_end_stream(&self) -> bool311     fn is_end_stream(&self) -> bool {
312         self.state.is_end_stream
313     }
314 
poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>315     fn poll_frame(
316         self: Pin<&mut Self>,
317         cx: &mut Context<'_>,
318     ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
319         let self_proj = self.project();
320         match ready!(self_proj.inner.poll_next(cx)) {
321             Some(Ok(d)) => Some(Ok(Frame::data(d))).into(),
322             Some(Err(status)) => match self_proj.state.role {
323                 Role::Client => Some(Err(status)).into(),
324                 Role::Server => {
325                     self_proj.state.is_end_stream = true;
326                     Some(Ok(Frame::trailers(status.to_header_map()?))).into()
327                 }
328             },
329             None => self_proj
330                 .state
331                 .trailers()
332                 .map(|t| t.map(Frame::trailers))
333                 .into(),
334         }
335     }
336 }
337