1 use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder};
2 use crate::codec::EncodeBuf;
3 use crate::Status;
4 use prost::Message;
5 use std::marker::PhantomData;
6
7 /// A [`Codec`] that implements `application/grpc+proto` via the prost library..
8 #[derive(Debug, Clone)]
9 pub struct ProstCodec<T, U> {
10 _pd: PhantomData<(T, U)>,
11 }
12
13 impl<T, U> ProstCodec<T, U> {
14 /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control
15 /// how memory is allocated and grows per RPC.
new() -> Self16 pub fn new() -> Self {
17 Self { _pd: PhantomData }
18 }
19 }
20
21 impl<T, U> Default for ProstCodec<T, U> {
default() -> Self22 fn default() -> Self {
23 Self::new()
24 }
25 }
26
27 impl<T, U> ProstCodec<T, U>
28 where
29 T: Message + Send + 'static,
30 U: Message + Default + Send + 'static,
31 {
32 /// A tool for building custom codecs based on prost encoding and decoding.
33 /// See the codec_buffers example for one possible way to use this.
raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder34 pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder {
35 ProstEncoder {
36 _pd: PhantomData,
37 buffer_settings,
38 }
39 }
40
41 /// A tool for building custom codecs based on prost encoding and decoding.
42 /// See the codec_buffers example for one possible way to use this.
raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder43 pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder {
44 ProstDecoder {
45 _pd: PhantomData,
46 buffer_settings,
47 }
48 }
49 }
50
51 impl<T, U> Codec for ProstCodec<T, U>
52 where
53 T: Message + Send + 'static,
54 U: Message + Default + Send + 'static,
55 {
56 type Encode = T;
57 type Decode = U;
58
59 type Encoder = ProstEncoder<T>;
60 type Decoder = ProstDecoder<U>;
61
encoder(&mut self) -> Self::Encoder62 fn encoder(&mut self) -> Self::Encoder {
63 ProstEncoder {
64 _pd: PhantomData,
65 buffer_settings: BufferSettings::default(),
66 }
67 }
68
decoder(&mut self) -> Self::Decoder69 fn decoder(&mut self) -> Self::Decoder {
70 ProstDecoder {
71 _pd: PhantomData,
72 buffer_settings: BufferSettings::default(),
73 }
74 }
75 }
76
77 /// A [`Encoder`] that knows how to encode `T`.
78 #[derive(Debug, Clone, Default)]
79 pub struct ProstEncoder<T> {
80 _pd: PhantomData<T>,
81 buffer_settings: BufferSettings,
82 }
83
84 impl<T> ProstEncoder<T> {
85 /// Get a new encoder with explicit buffer settings
new(buffer_settings: BufferSettings) -> Self86 pub fn new(buffer_settings: BufferSettings) -> Self {
87 Self {
88 _pd: PhantomData,
89 buffer_settings,
90 }
91 }
92 }
93
94 impl<T: Message> Encoder for ProstEncoder<T> {
95 type Item = T;
96 type Error = Status;
97
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>98 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
99 item.encode(buf)
100 .expect("Message only errors if not enough space");
101
102 Ok(())
103 }
104
buffer_settings(&self) -> BufferSettings105 fn buffer_settings(&self) -> BufferSettings {
106 self.buffer_settings
107 }
108 }
109
110 /// A [`Decoder`] that knows how to decode `U`.
111 #[derive(Debug, Clone, Default)]
112 pub struct ProstDecoder<U> {
113 _pd: PhantomData<U>,
114 buffer_settings: BufferSettings,
115 }
116
117 impl<U> ProstDecoder<U> {
118 /// Get a new decoder with explicit buffer settings
new(buffer_settings: BufferSettings) -> Self119 pub fn new(buffer_settings: BufferSettings) -> Self {
120 Self {
121 _pd: PhantomData,
122 buffer_settings,
123 }
124 }
125 }
126
127 impl<U: Message + Default> Decoder for ProstDecoder<U> {
128 type Item = U;
129 type Error = Status;
130
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>131 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
132 let item = Message::decode(buf)
133 .map(Option::Some)
134 .map_err(from_decode_error)?;
135
136 Ok(item)
137 }
138
buffer_settings(&self) -> BufferSettings139 fn buffer_settings(&self) -> BufferSettings {
140 self.buffer_settings
141 }
142 }
143
from_decode_error(error: prost::DecodeError) -> crate::Status144 fn from_decode_error(error: prost::DecodeError) -> crate::Status {
145 // Map Protobuf parse errors to an INTERNAL status code, as per
146 // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
147 Status::internal(error.to_string())
148 }
149
150 #[cfg(test)]
151 mod tests {
152 use crate::codec::compression::SingleMessageCompressionOverride;
153 use crate::codec::{
154 DecodeBuf, Decoder, EncodeBody, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
155 };
156 use crate::Status;
157 use bytes::{Buf, BufMut, BytesMut};
158 use http_body::Body;
159 use http_body_util::BodyExt as _;
160 use std::pin::pin;
161
162 const LEN: usize = 10000;
163 // The maximum uncompressed size in bytes for a message. Set to 2MB.
164 const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
165
166 #[tokio::test]
decode()167 async fn decode() {
168 let decoder = MockDecoder::default();
169
170 let msg = vec![0u8; LEN];
171
172 let mut buf = BytesMut::new();
173
174 buf.reserve(msg.len() + HEADER_SIZE);
175 buf.put_u8(0);
176 buf.put_u32(msg.len() as u32);
177
178 buf.put(&msg[..]);
179
180 let body = body::MockBody::new(&buf[..], 10005, 0);
181
182 let mut stream = Streaming::new_request(decoder, body, None, None);
183
184 let mut i = 0usize;
185 while let Some(output_msg) = stream.message().await.unwrap() {
186 assert_eq!(output_msg.len(), msg.len());
187 i += 1;
188 }
189 assert_eq!(i, 1);
190 }
191
192 #[tokio::test]
decode_max_message_size_exceeded()193 async fn decode_max_message_size_exceeded() {
194 let decoder = MockDecoder::default();
195
196 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
197
198 let mut buf = BytesMut::new();
199
200 buf.reserve(msg.len() + HEADER_SIZE);
201 buf.put_u8(0);
202 buf.put_u32(msg.len() as u32);
203
204 buf.put(&msg[..]);
205
206 let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
207
208 let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
209
210 let actual = stream.message().await.unwrap_err();
211
212 let expected = Status::out_of_range(format!(
213 "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
214 msg.len(),
215 MAX_MESSAGE_SIZE
216 ));
217
218 assert_eq!(actual.code(), expected.code());
219 assert_eq!(actual.message(), expected.message());
220 }
221
222 #[tokio::test]
encode()223 async fn encode() {
224 let encoder = MockEncoder::default();
225
226 let msg = Vec::from(&[0u8; 1024][..]);
227
228 let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
229 let source = tokio_stream::iter(messages);
230
231 let mut body = pin!(EncodeBody::new_server(
232 encoder,
233 source,
234 None,
235 SingleMessageCompressionOverride::default(),
236 None,
237 ));
238
239 while let Some(r) = body.frame().await {
240 r.unwrap();
241 }
242 }
243
244 #[tokio::test]
encode_max_message_size_exceeded()245 async fn encode_max_message_size_exceeded() {
246 let encoder = MockEncoder::default();
247
248 let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
249
250 let messages = std::iter::once(Ok::<_, Status>(msg));
251 let source = tokio_stream::iter(messages);
252
253 let mut body = pin!(EncodeBody::new_server(
254 encoder,
255 source,
256 None,
257 SingleMessageCompressionOverride::default(),
258 Some(MAX_MESSAGE_SIZE),
259 ));
260
261 let frame = body
262 .frame()
263 .await
264 .expect("at least one frame")
265 .expect("no error polling frame");
266 assert_eq!(
267 frame
268 .into_trailers()
269 .expect("got trailers")
270 .get(Status::GRPC_STATUS)
271 .expect("grpc-status header"),
272 "11"
273 );
274 assert!(body.is_end_stream());
275 }
276
277 // skip on windows because CI stumbles over our 4GB allocation
278 #[cfg(not(target_family = "windows"))]
279 #[tokio::test]
encode_too_big()280 async fn encode_too_big() {
281 use crate::codec::EncodeBody;
282
283 let encoder = MockEncoder::default();
284
285 let msg = vec![0u8; u32::MAX as usize + 1];
286
287 let messages = std::iter::once(Ok::<_, Status>(msg));
288 let source = tokio_stream::iter(messages);
289
290 let mut body = pin!(EncodeBody::new_server(
291 encoder,
292 source,
293 None,
294 SingleMessageCompressionOverride::default(),
295 Some(usize::MAX),
296 ));
297
298 let frame = body
299 .frame()
300 .await
301 .expect("at least one frame")
302 .expect("no error polling frame");
303 assert_eq!(
304 frame
305 .into_trailers()
306 .expect("got trailers")
307 .get(Status::GRPC_STATUS)
308 .expect("grpc-status header"),
309 "8"
310 );
311 assert!(body.is_end_stream());
312 }
313
314 #[derive(Debug, Clone, Default)]
315 struct MockEncoder {}
316
317 impl Encoder for MockEncoder {
318 type Item = Vec<u8>;
319 type Error = Status;
320
encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error>321 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
322 buf.put(&item[..]);
323 Ok(())
324 }
325
buffer_settings(&self) -> crate::codec::BufferSettings326 fn buffer_settings(&self) -> crate::codec::BufferSettings {
327 Default::default()
328 }
329 }
330
331 #[derive(Debug, Clone, Default)]
332 struct MockDecoder {}
333
334 impl Decoder for MockDecoder {
335 type Item = Vec<u8>;
336 type Error = Status;
337
decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error>338 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
339 let out = Vec::from(buf.chunk());
340 buf.advance(LEN);
341 Ok(Some(out))
342 }
343
buffer_settings(&self) -> crate::codec::BufferSettings344 fn buffer_settings(&self) -> crate::codec::BufferSettings {
345 Default::default()
346 }
347 }
348
349 mod body {
350 use crate::Status;
351 use bytes::Bytes;
352 use http_body::{Body, Frame};
353 use std::{
354 pin::Pin,
355 task::{Context, Poll},
356 };
357
358 #[derive(Debug)]
359 pub(super) struct MockBody {
360 data: Bytes,
361
362 // the size of the partial message to send
363 partial_len: usize,
364
365 // the number of times we've sent
366 count: usize,
367 }
368
369 impl MockBody {
new(b: &[u8], partial_len: usize, count: usize) -> Self370 pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
371 MockBody {
372 data: Bytes::copy_from_slice(b),
373 partial_len,
374 count,
375 }
376 }
377 }
378
379 impl Body for MockBody {
380 type Data = Bytes;
381 type Error = Status;
382
poll_frame( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>>383 fn poll_frame(
384 mut self: Pin<&mut Self>,
385 cx: &mut Context<'_>,
386 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
387 // every other call to poll_data returns data
388 let should_send = self.count % 2 == 0;
389 let data_len = self.data.len();
390 let partial_len = self.partial_len;
391 let count = self.count;
392 if data_len > 0 {
393 let result = if should_send {
394 let response =
395 self.data
396 .split_to(if count == 0 { partial_len } else { data_len });
397 Poll::Ready(Some(Ok(Frame::data(response))))
398 } else {
399 cx.waker().wake_by_ref();
400 Poll::Pending
401 };
402 // make some fake progress
403 self.count += 1;
404 result
405 } else {
406 Poll::Ready(None)
407 }
408 }
409 }
410 }
411 }
412