1 use super::compression::{
2 compress, CompressionEncoding, CompressionSettings, SingleMessageCompressionOverride,
3 };
4 use super::{BufferSettings, EncodeBuf, Encoder, DEFAULT_MAX_SEND_MESSAGE_SIZE, HEADER_SIZE};
5 use crate::Status;
6 use bytes::{BufMut, Bytes, BytesMut};
7 use http::HeaderMap;
8 use http_body::{Body, Frame};
9 use pin_project::pin_project;
10 use std::{
11 pin::Pin,
12 task::{ready, Context, Poll},
13 };
14 use tokio_stream::{adapters::Fuse, Stream, StreamExt};
15
16 /// Combinator for efficient encoding of messages into reasonably sized buffers.
17 /// EncodedBytes encodes ready messages from its delegate stream into a BytesMut,
18 /// splitting off and yielding a buffer when either:
19 /// * The delegate stream polls as not ready, or
20 /// * The encoded buffer surpasses YIELD_THRESHOLD.
21 #[pin_project(project = EncodedBytesProj)]
22 #[derive(Debug)]
23 struct EncodedBytes<T, U> {
24 #[pin]
25 source: Fuse<U>,
26 encoder: T,
27 compression_encoding: Option<CompressionEncoding>,
28 max_message_size: Option<usize>,
29 buf: BytesMut,
30 uncompression_buf: BytesMut,
31 error: Option<Status>,
32 }
33
34 impl<T: Encoder, U: Stream> EncodedBytes<T, U> {
new( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> Self35 fn new(
36 encoder: T,
37 source: U,
38 compression_encoding: Option<CompressionEncoding>,
39 compression_override: SingleMessageCompressionOverride,
40 max_message_size: Option<usize>,
41 ) -> Self {
42 let buffer_settings = encoder.buffer_settings();
43 let buf = BytesMut::with_capacity(buffer_settings.buffer_size);
44
45 let compression_encoding =
46 if compression_override == SingleMessageCompressionOverride::Disable {
47 None
48 } else {
49 compression_encoding
50 };
51
52 let uncompression_buf = if compression_encoding.is_some() {
53 BytesMut::with_capacity(buffer_settings.buffer_size)
54 } else {
55 BytesMut::new()
56 };
57
58 Self {
59 source: source.fuse(),
60 encoder,
61 compression_encoding,
62 max_message_size,
63 buf,
64 uncompression_buf,
65 error: None,
66 }
67 }
68 }
69
70 impl<T, U> Stream for EncodedBytes<T, U>
71 where
72 T: Encoder<Error = Status>,
73 U: Stream<Item = Result<T::Item, Status>>,
74 {
75 type Item = Result<Bytes, Status>;
76
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>77 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
78 let EncodedBytesProj {
79 mut source,
80 encoder,
81 compression_encoding,
82 max_message_size,
83 buf,
84 uncompression_buf,
85 error,
86 } = self.project();
87 let buffer_settings = encoder.buffer_settings();
88
89 if let Some(status) = error.take() {
90 return Poll::Ready(Some(Err(status)));
91 }
92
93 loop {
94 match source.as_mut().poll_next(cx) {
95 Poll::Pending if buf.is_empty() => {
96 return Poll::Pending;
97 }
98 Poll::Ready(None) if buf.is_empty() => {
99 return Poll::Ready(None);
100 }
101 Poll::Pending | Poll::Ready(None) => {
102 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
103 }
104 Poll::Ready(Some(Ok(item))) => {
105 if let Err(status) = encode_item(
106 encoder,
107 buf,
108 uncompression_buf,
109 *compression_encoding,
110 *max_message_size,
111 buffer_settings,
112 item,
113 ) {
114 return Poll::Ready(Some(Err(status)));
115 }
116
117 if buf.len() >= buffer_settings.yield_threshold {
118 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
119 }
120 }
121 Poll::Ready(Some(Err(status))) => {
122 if buf.is_empty() {
123 return Poll::Ready(Some(Err(status)));
124 }
125 *error = Some(status);
126 return Poll::Ready(Some(Ok(buf.split_to(buf.len()).freeze())));
127 }
128 }
129 }
130 }
131 }
132
encode_item<T>( encoder: &mut T, buf: &mut BytesMut, uncompression_buf: &mut BytesMut, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buffer_settings: BufferSettings, item: T::Item, ) -> Result<(), Status> where T: Encoder<Error = Status>,133 fn encode_item<T>(
134 encoder: &mut T,
135 buf: &mut BytesMut,
136 uncompression_buf: &mut BytesMut,
137 compression_encoding: Option<CompressionEncoding>,
138 max_message_size: Option<usize>,
139 buffer_settings: BufferSettings,
140 item: T::Item,
141 ) -> Result<(), Status>
142 where
143 T: Encoder<Error = Status>,
144 {
145 let offset = buf.len();
146
147 buf.reserve(HEADER_SIZE);
148 unsafe {
149 buf.advance_mut(HEADER_SIZE);
150 }
151
152 if let Some(encoding) = compression_encoding {
153 uncompression_buf.clear();
154
155 encoder
156 .encode(item, &mut EncodeBuf::new(uncompression_buf))
157 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
158
159 let uncompressed_len = uncompression_buf.len();
160
161 compress(
162 CompressionSettings {
163 encoding,
164 buffer_growth_interval: buffer_settings.buffer_size,
165 },
166 uncompression_buf,
167 buf,
168 uncompressed_len,
169 )
170 .map_err(|err| Status::internal(format!("Error compressing: {}", err)))?;
171 } else {
172 encoder
173 .encode(item, &mut EncodeBuf::new(buf))
174 .map_err(|err| Status::internal(format!("Error encoding: {}", err)))?;
175 }
176
177 // now that we know length, we can write the header
178 finish_encoding(compression_encoding, max_message_size, &mut buf[offset..])
179 }
180
finish_encoding( compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, buf: &mut [u8], ) -> Result<(), Status>181 fn finish_encoding(
182 compression_encoding: Option<CompressionEncoding>,
183 max_message_size: Option<usize>,
184 buf: &mut [u8],
185 ) -> Result<(), Status> {
186 let len = buf.len() - HEADER_SIZE;
187 let limit = max_message_size.unwrap_or(DEFAULT_MAX_SEND_MESSAGE_SIZE);
188 if len > limit {
189 return Err(Status::out_of_range(format!(
190 "Error, encoded message length too large: found {} bytes, the limit is: {} bytes",
191 len, limit
192 )));
193 }
194
195 if len > u32::MAX as usize {
196 return Err(Status::resource_exhausted(format!(
197 "Cannot return body with more than 4GB of data but got {len} bytes"
198 )));
199 }
200 {
201 let mut buf = &mut buf[..HEADER_SIZE];
202 buf.put_u8(compression_encoding.is_some() as u8);
203 buf.put_u32(len as u32);
204 }
205
206 Ok(())
207 }
208
209 #[derive(Debug)]
210 enum Role {
211 Client,
212 Server,
213 }
214
215 /// A specialized implementation of [Body] for encoding [Result<Bytes, Status>].
216 #[pin_project]
217 #[derive(Debug)]
218 pub struct EncodeBody<T, U> {
219 #[pin]
220 inner: EncodedBytes<T, U>,
221 state: EncodeState,
222 }
223
224 #[derive(Debug)]
225 struct EncodeState {
226 error: Option<Status>,
227 role: Role,
228 is_end_stream: bool,
229 }
230
231 impl<T: Encoder, U: Stream> EncodeBody<T, U> {
232 /// Turns a stream of grpc messages into [EncodeBody] which is used by grpc clients for
233 /// turning the messages into http frames for sending over the network.
new_client( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, max_message_size: Option<usize>, ) -> Self234 pub fn new_client(
235 encoder: T,
236 source: U,
237 compression_encoding: Option<CompressionEncoding>,
238 max_message_size: Option<usize>,
239 ) -> Self {
240 Self {
241 inner: EncodedBytes::new(
242 encoder,
243 source,
244 compression_encoding,
245 SingleMessageCompressionOverride::default(),
246 max_message_size,
247 ),
248 state: EncodeState {
249 error: None,
250 role: Role::Client,
251 is_end_stream: false,
252 },
253 }
254 }
255
256 /// Turns a stream of grpc results (message or error status) into [EncodeBody] which is used by grpc
257 /// servers for turning the messages into http frames for sending over the network.
new_server( encoder: T, source: U, compression_encoding: Option<CompressionEncoding>, compression_override: SingleMessageCompressionOverride, max_message_size: Option<usize>, ) -> Self258 pub fn new_server(
259 encoder: T,
260 source: U,
261 compression_encoding: Option<CompressionEncoding>,
262 compression_override: SingleMessageCompressionOverride,
263 max_message_size: Option<usize>,
264 ) -> Self {
265 Self {
266 inner: EncodedBytes::new(
267 encoder,
268 source,
269 compression_encoding,
270 compression_override,
271 max_message_size,
272 ),
273 state: EncodeState {
274 error: None,
275 role: Role::Server,
276 is_end_stream: false,
277 },
278 }
279 }
280 }
281
282 impl EncodeState {
trailers(&mut self) -> Option<Result<HeaderMap, Status>>283 fn trailers(&mut self) -> Option<Result<HeaderMap, Status>> {
284 match self.role {
285 Role::Client => None,
286 Role::Server => {
287 if self.is_end_stream {
288 return None;
289 }
290
291 self.is_end_stream = true;
292 let status = if let Some(status) = self.error.take() {
293 status
294 } else {
295 Status::ok("")
296 };
297 Some(status.to_header_map())
298 }
299 }
300 }
301 }
302
303 impl<T, U> Body for EncodeBody<T, U>
304 where
305 T: Encoder<Error = Status>,
306 U: Stream<Item = Result<T::Item, Status>>,
307 {
308 type Data = Bytes;
309 type Error = Status;
310
is_end_stream(&self) -> bool311 fn is_end_stream(&self) -> bool {
312 self.state.is_end_stream
313 }
314
poll_frame( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>315 fn poll_frame(
316 self: Pin<&mut Self>,
317 cx: &mut Context<'_>,
318 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
319 let self_proj = self.project();
320 match ready!(self_proj.inner.poll_next(cx)) {
321 Some(Ok(d)) => Some(Ok(Frame::data(d))).into(),
322 Some(Err(status)) => match self_proj.state.role {
323 Role::Client => Some(Err(status)).into(),
324 Role::Server => {
325 self_proj.state.is_end_stream = true;
326 Some(Ok(Frame::trailers(status.to_header_map()?))).into()
327 }
328 },
329 None => self_proj
330 .state
331 .trailers()
332 .map(|t| t.map(Frame::trailers))
333 .into(),
334 }
335 }
336 }
337