xref: /webrtc/util/src/replay_detector/mod.rs (revision ffe74184)
1 #[cfg(test)]
2 mod replay_detector_test;
3 
4 use super::fixed_big_int::*;
5 
6 // ReplayDetector is the interface of sequence replay detector.
7 pub trait ReplayDetector {
8     // Check returns true if given sequence number is not replayed.
9     // Call accept() to mark the packet is received properly.
check(&mut self, seq: u64) -> bool10     fn check(&mut self, seq: u64) -> bool;
accept(&mut self)11     fn accept(&mut self);
12 }
13 
14 pub struct SlidingWindowDetector {
15     accepted: bool,
16     seq: u64,
17     latest_seq: u64,
18     max_seq: u64,
19     window_size: usize,
20     mask: FixedBigInt,
21 }
22 
23 impl SlidingWindowDetector {
24     // New creates ReplayDetector.
25     // Created ReplayDetector doesn't allow wrapping.
26     // It can handle monotonically increasing sequence number up to
27     // full 64bit number. It is suitable for DTLS replay protection.
new(window_size: usize, max_seq: u64) -> Self28     pub fn new(window_size: usize, max_seq: u64) -> Self {
29         SlidingWindowDetector {
30             accepted: false,
31             seq: 0,
32             latest_seq: 0,
33             max_seq,
34             window_size,
35             mask: FixedBigInt::new(window_size),
36         }
37     }
38 }
39 
40 impl ReplayDetector for SlidingWindowDetector {
check(&mut self, seq: u64) -> bool41     fn check(&mut self, seq: u64) -> bool {
42         self.accepted = false;
43 
44         if seq > self.max_seq {
45             // Exceeded upper limit.
46             return false;
47         }
48 
49         if seq <= self.latest_seq {
50             if self.latest_seq >= self.window_size as u64 + seq {
51                 return false;
52             }
53             if self.mask.bit((self.latest_seq - seq) as usize) != 0 {
54                 // The sequence number is duplicated.
55                 return false;
56             }
57         }
58 
59         self.accepted = true;
60         self.seq = seq;
61         true
62     }
63 
accept(&mut self)64     fn accept(&mut self) {
65         if !self.accepted {
66             return;
67         }
68 
69         if self.seq > self.latest_seq {
70             // Update the head of the window.
71             self.mask.lsh((self.seq - self.latest_seq) as usize);
72             self.latest_seq = self.seq;
73         }
74         let diff = (self.latest_seq - self.seq) % self.max_seq;
75         self.mask.set_bit(diff as usize);
76     }
77 }
78 
79 pub struct WrappedSlidingWindowDetector {
80     accepted: bool,
81     seq: u64,
82     latest_seq: u64,
83     max_seq: u64,
84     window_size: usize,
85     mask: FixedBigInt,
86     init: bool,
87 }
88 
89 impl WrappedSlidingWindowDetector {
90     // WithWrap creates ReplayDetector allowing sequence wrapping.
91     // This is suitable for short bitwidth counter like SRTP and SRTCP.
new(window_size: usize, max_seq: u64) -> Self92     pub fn new(window_size: usize, max_seq: u64) -> Self {
93         WrappedSlidingWindowDetector {
94             accepted: false,
95             seq: 0,
96             latest_seq: 0,
97             max_seq,
98             window_size,
99             mask: FixedBigInt::new(window_size),
100             init: false,
101         }
102     }
103 }
104 
105 impl ReplayDetector for WrappedSlidingWindowDetector {
check(&mut self, seq: u64) -> bool106     fn check(&mut self, seq: u64) -> bool {
107         self.accepted = false;
108 
109         if seq > self.max_seq {
110             // Exceeded upper limit.
111             return false;
112         }
113         if !self.init {
114             if seq != 0 {
115                 self.latest_seq = seq - 1;
116             } else {
117                 self.latest_seq = self.max_seq;
118             }
119             self.init = true;
120         }
121 
122         let mut diff = self.latest_seq as i64 - seq as i64;
123         // Wrap the number.
124         if diff > self.max_seq as i64 / 2 {
125             diff -= (self.max_seq + 1) as i64;
126         } else if diff <= -(self.max_seq as i64 / 2) {
127             diff += (self.max_seq + 1) as i64;
128         }
129 
130         if diff >= self.window_size as i64 {
131             // Too old.
132             return false;
133         }
134         if diff >= 0 && self.mask.bit(diff as usize) != 0 {
135             // The sequence number is duplicated.
136             return false;
137         }
138 
139         self.accepted = true;
140         self.seq = seq;
141         true
142     }
143 
accept(&mut self)144     fn accept(&mut self) {
145         if !self.accepted {
146             return;
147         }
148 
149         let mut diff = self.latest_seq as i64 - self.seq as i64;
150         // Wrap the number.
151         if diff > self.max_seq as i64 / 2 {
152             diff -= (self.max_seq + 1) as i64;
153         } else if diff <= -(self.max_seq as i64 / 2) {
154             diff += (self.max_seq + 1) as i64;
155         }
156 
157         assert!(diff < self.window_size as i64);
158 
159         if diff < 0 {
160             // Update the head of the window.
161             self.mask.lsh((-diff) as usize);
162             self.latest_seq = self.seq;
163         }
164         self.mask
165             .set_bit((self.latest_seq as isize - self.seq as isize) as usize);
166     }
167 }
168 
169 #[derive(Default)]
170 pub struct NoOpReplayDetector;
171 
172 impl ReplayDetector for NoOpReplayDetector {
check(&mut self, _: u64) -> bool173     fn check(&mut self, _: u64) -> bool {
174         true
175     }
accept(&mut self)176     fn accept(&mut self) {}
177 }
178