xref: /tonic/tonic/src/codec/prost.rs (revision 517b7fc9)
1 use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder};
2 use crate::codec::EncodeBuf;
3 use crate::Status;
4 use prost::Message;
5 use std::marker::PhantomData;
6 
7 /// A [`Codec`] that implements `application/grpc+proto` via the prost library..
8 #[derive(Debug, Clone)]
9 pub struct ProstCodec<T, U> {
10     _pd: PhantomData<(T, U)>,
11 }
12 
13 impl<T, U> ProstCodec<T, U> {
14     /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control
15     /// how memory is allocated and grows per RPC.
new() -> Self16     pub fn new() -> Self {
17         Self { _pd: PhantomData }
18     }
19 }
20 
21 impl<T, U> Default for ProstCodec<T, U> {
default() -> Self22     fn default() -> Self {
23         Self::new()
24     }
25 }
26 
27 impl<T, U> ProstCodec<T, U>
28 where
29     T: Message + Send + 'static,
30     U: Message + Default + Send + 'static,
31 {
32     /// A tool for building custom codecs based on prost encoding and decoding.
33     /// See the codec_buffers example for one possible way to use this.
raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder34     pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder {
35         ProstEncoder {
36             _pd: PhantomData,
37             buffer_settings,
38         }
39     }
40 
41     /// A tool for building custom codecs based on prost encoding and decoding.
42     /// See the codec_buffers example for one possible way to use this.
raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder43     pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder {
44         ProstDecoder {
45             _pd: PhantomData,
46             buffer_settings,
47         }
48     }
49 }
50 
51 impl<T, U> Codec for ProstCodec<T, U>
52 where
53     T: Message + Send + 'static,
54     U: Message + Default + Send + 'static,
55 {
56     type Encode = T;
57     type Decode = U;
58 
59     type Encoder = ProstEncoder<T>;
60     type Decoder = ProstDecoder<U>;
61 
encoder(&mut self) -> Self::Encoder62     fn encoder(&mut self) -> Self::Encoder {
63         ProstEncoder {
64             _pd: PhantomData,
65             buffer_settings: BufferSettings::default(),
66         }
67     }
68 
decoder(&mut self) -> Self::Decoder69     fn decoder(&mut self) -> Self::Decoder {
70         ProstDecoder {
71             _pd: PhantomData,
72             buffer_settings: BufferSettings::default(),
73         }
74     }
75 }
76 
77 /// A [`Encoder`] that knows how to encode `T`.
78 #[derive(Debug, Clone, Default)]
79 pub struct ProstEncoder<T> {
80     _pd: PhantomData<T>,
81     buffer_settings: BufferSettings,
82 }
83 
84 impl<T> ProstEncoder<T> {
85     /// Get a new encoder with explicit buffer settings
new(buffer_settings: BufferSettings) -> Self86     pub fn new(buffer_settings: BufferSettings) -> Self {
87         Self {
88             _pd: PhantomData,
89             buffer_settings,
90         }
91     }
92 }
93 
94 impl<T: Message> Encoder for ProstEncoder<T> {
95     type Item = T;
96     type Error = Status;
97 
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>98     fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
99         item.encode(buf)
100             .expect("Message only errors if not enough space");
101 
102         Ok(())
103     }
104 
buffer_settings(&self) -> BufferSettings105     fn buffer_settings(&self) -> BufferSettings {
106         self.buffer_settings
107     }
108 }
109 
110 /// A [`Decoder`] that knows how to decode `U`.
111 #[derive(Debug, Clone, Default)]
112 pub struct ProstDecoder<U> {
113     _pd: PhantomData<U>,
114     buffer_settings: BufferSettings,
115 }
116 
117 impl<U> ProstDecoder<U> {
118     /// Get a new decoder with explicit buffer settings
new(buffer_settings: BufferSettings) -> Self119     pub fn new(buffer_settings: BufferSettings) -> Self {
120         Self {
121             _pd: PhantomData,
122             buffer_settings,
123         }
124     }
125 }
126 
127 impl<U: Message + Default> Decoder for ProstDecoder<U> {
128     type Item = U;
129     type Error = Status;
130 
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>131     fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
132         let item = Message::decode(buf)
133             .map(Option::Some)
134             .map_err(from_decode_error)?;
135 
136         Ok(item)
137     }
138 
buffer_settings(&self) -> BufferSettings139     fn buffer_settings(&self) -> BufferSettings {
140         self.buffer_settings
141     }
142 }
143 
from_decode_error(error: prost::DecodeError) -> crate::Status144 fn from_decode_error(error: prost::DecodeError) -> crate::Status {
145     // Map Protobuf parse errors to an INTERNAL status code, as per
146     // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
147     Status::internal(error.to_string())
148 }
149 
150 #[cfg(test)]
151 mod tests {
152     use crate::codec::compression::SingleMessageCompressionOverride;
153     use crate::codec::{
154         DecodeBuf, Decoder, EncodeBody, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
155     };
156     use crate::Status;
157     use bytes::{Buf, BufMut, BytesMut};
158     use http_body::Body;
159     use http_body_util::BodyExt as _;
160     use std::pin::pin;
161 
162     const LEN: usize = 10000;
163     // The maximum uncompressed size in bytes for a message. Set to 2MB.
164     const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
165 
166     #[tokio::test]
decode()167     async fn decode() {
168         let decoder = MockDecoder::default();
169 
170         let msg = vec![0u8; LEN];
171 
172         let mut buf = BytesMut::new();
173 
174         buf.reserve(msg.len() + HEADER_SIZE);
175         buf.put_u8(0);
176         buf.put_u32(msg.len() as u32);
177 
178         buf.put(&msg[..]);
179 
180         let body = body::MockBody::new(&buf[..], 10005, 0);
181 
182         let mut stream = Streaming::new_request(decoder, body, None, None);
183 
184         let mut i = 0usize;
185         while let Some(output_msg) = stream.message().await.unwrap() {
186             assert_eq!(output_msg.len(), msg.len());
187             i += 1;
188         }
189         assert_eq!(i, 1);
190     }
191 
192     #[tokio::test]
decode_max_message_size_exceeded()193     async fn decode_max_message_size_exceeded() {
194         let decoder = MockDecoder::default();
195 
196         let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
197 
198         let mut buf = BytesMut::new();
199 
200         buf.reserve(msg.len() + HEADER_SIZE);
201         buf.put_u8(0);
202         buf.put_u32(msg.len() as u32);
203 
204         buf.put(&msg[..]);
205 
206         let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
207 
208         let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
209 
210         let actual = stream.message().await.unwrap_err();
211 
212         let expected = Status::out_of_range(format!(
213             "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
214             msg.len(),
215             MAX_MESSAGE_SIZE
216         ));
217 
218         assert_eq!(actual.code(), expected.code());
219         assert_eq!(actual.message(), expected.message());
220     }
221 
222     #[tokio::test]
encode()223     async fn encode() {
224         let encoder = MockEncoder::default();
225 
226         let msg = Vec::from(&[0u8; 1024][..]);
227 
228         let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
229         let source = tokio_stream::iter(messages);
230 
231         let mut body = pin!(EncodeBody::new_server(
232             encoder,
233             source,
234             None,
235             SingleMessageCompressionOverride::default(),
236             None,
237         ));
238 
239         while let Some(r) = body.frame().await {
240             r.unwrap();
241         }
242     }
243 
244     #[tokio::test]
encode_max_message_size_exceeded()245     async fn encode_max_message_size_exceeded() {
246         let encoder = MockEncoder::default();
247 
248         let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
249 
250         let messages = std::iter::once(Ok::<_, Status>(msg));
251         let source = tokio_stream::iter(messages);
252 
253         let mut body = pin!(EncodeBody::new_server(
254             encoder,
255             source,
256             None,
257             SingleMessageCompressionOverride::default(),
258             Some(MAX_MESSAGE_SIZE),
259         ));
260 
261         let frame = body
262             .frame()
263             .await
264             .expect("at least one frame")
265             .expect("no error polling frame");
266         assert_eq!(
267             frame
268                 .into_trailers()
269                 .expect("got trailers")
270                 .get(Status::GRPC_STATUS)
271                 .expect("grpc-status header"),
272             "11"
273         );
274         assert!(body.is_end_stream());
275     }
276 
277     // skip on windows because CI stumbles over our 4GB allocation
278     #[cfg(not(target_family = "windows"))]
279     #[tokio::test]
encode_too_big()280     async fn encode_too_big() {
281         use crate::codec::EncodeBody;
282 
283         let encoder = MockEncoder::default();
284 
285         let msg = vec![0u8; u32::MAX as usize + 1];
286 
287         let messages = std::iter::once(Ok::<_, Status>(msg));
288         let source = tokio_stream::iter(messages);
289 
290         let mut body = pin!(EncodeBody::new_server(
291             encoder,
292             source,
293             None,
294             SingleMessageCompressionOverride::default(),
295             Some(usize::MAX),
296         ));
297 
298         let frame = body
299             .frame()
300             .await
301             .expect("at least one frame")
302             .expect("no error polling frame");
303         assert_eq!(
304             frame
305                 .into_trailers()
306                 .expect("got trailers")
307                 .get(Status::GRPC_STATUS)
308                 .expect("grpc-status header"),
309             "8"
310         );
311         assert!(body.is_end_stream());
312     }
313 
314     #[derive(Debug, Clone, Default)]
315     struct MockEncoder {}
316 
317     impl Encoder for MockEncoder {
318         type Item = Vec<u8>;
319         type Error = Status;
320 
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>321         fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
322             buf.put(&item[..]);
323             Ok(())
324         }
325 
buffer_settings(&self) -> crate::codec::BufferSettings326         fn buffer_settings(&self) -> crate::codec::BufferSettings {
327             Default::default()
328         }
329     }
330 
331     #[derive(Debug, Clone, Default)]
332     struct MockDecoder {}
333 
334     impl Decoder for MockDecoder {
335         type Item = Vec<u8>;
336         type Error = Status;
337 
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>338         fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
339             let out = Vec::from(buf.chunk());
340             buf.advance(LEN);
341             Ok(Some(out))
342         }
343 
buffer_settings(&self) -> crate::codec::BufferSettings344         fn buffer_settings(&self) -> crate::codec::BufferSettings {
345             Default::default()
346         }
347     }
348 
349     mod body {
350         use crate::Status;
351         use bytes::Bytes;
352         use http_body::{Body, Frame};
353         use std::{
354             pin::Pin,
355             task::{Context, Poll},
356         };
357 
358         #[derive(Debug)]
359         pub(super) struct MockBody {
360             data: Bytes,
361 
362             // the size of the partial message to send
363             partial_len: usize,
364 
365             // the number of times we've sent
366             count: usize,
367         }
368 
369         impl MockBody {
new(b: &[u8], partial_len: usize, count: usize) -> Self370             pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
371                 MockBody {
372                     data: Bytes::copy_from_slice(b),
373                     partial_len,
374                     count,
375                 }
376             }
377         }
378 
379         impl Body for MockBody {
380             type Data = Bytes;
381             type Error = Status;
382 
poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>383             fn poll_frame(
384                 mut self: Pin<&mut Self>,
385                 cx: &mut Context<'_>,
386             ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
387                 // every other call to poll_data returns data
388                 let should_send = self.count % 2 == 0;
389                 let data_len = self.data.len();
390                 let partial_len = self.partial_len;
391                 let count = self.count;
392                 if data_len > 0 {
393                     let result = if should_send {
394                         let response =
395                             self.data
396                                 .split_to(if count == 0 { partial_len } else { data_len });
397                         Poll::Ready(Some(Ok(Frame::data(response))))
398                     } else {
399                         cx.waker().wake_by_ref();
400                         Poll::Pending
401                     };
402                     // make some fake progress
403                     self.count += 1;
404                     result
405                 } else {
406                     Poll::Ready(None)
407                 }
408             }
409         }
410     }
411 }
412