twilight_gateway/
ratelimiter.rs1#![allow(clippy::cast_possible_truncation)]
14
15use std::{
16 collections::VecDeque,
17 future::Future,
18 pin::Pin,
19 task::{Context, Poll, ready},
20};
21use tokio::time::{Duration, Instant, Sleep};
22
23const PERIOD: Duration = Duration::from_secs(60);
25
26const PERMITS: u8 = 120;
28
29#[derive(Debug)]
31pub struct CommandRatelimiter {
32 delay: Pin<Box<Sleep>>,
36 queue: VecDeque<u16>,
39}
40
41impl CommandRatelimiter {
42 pub(crate) fn new(heartbeat_interval: Duration) -> Self {
44 let capacity = usize::from(nonreserved_commands_per_reset(heartbeat_interval)) - 1;
45
46 let mut queue = VecDeque::with_capacity(capacity);
47 if queue.capacity() != capacity {
48 queue.resize(capacity, 0);
49 let vec = Vec::from(queue).into_boxed_slice().into_vec();
51 queue = VecDeque::from(vec);
53 queue.clear();
54 }
55
56 Self {
57 delay: Box::pin(tokio::time::sleep_until(Instant::now())),
58 queue,
59 }
60 }
61
62 pub fn available(&self) -> u8 {
64 let now = Instant::now();
65 let acquired = if now >= self.delay.deadline() {
66 self.next_acquired_position(now)
67 .map_or(0, |released_count| self.queue.len() - released_count)
68 } else {
69 self.queue.len() + 1
70 };
71
72 self.max() - acquired as u8
73 }
74
75 pub fn max(&self) -> u8 {
77 self.queue.capacity() as u8 + 1
78 }
79
80 pub fn next_available(&self) -> Duration {
82 self.delay
83 .deadline()
84 .saturating_duration_since(Instant::now())
85 }
86
87 pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> {
94 ready!(self.poll_available(cx));
95
96 let now = Instant::now();
97 if now >= self.delay.deadline() {
98 if let Some(new_deadline_idx) = self.next_acquired_position(now) {
99 self.rebase(new_deadline_idx);
100 } else {
101 self.queue.clear();
102 self.delay.as_mut().reset(now + PERIOD);
103
104 return Poll::Ready(());
105 }
106 }
107
108 let releases = (now + PERIOD) - self.delay.deadline();
109 debug_assert_ne!(self.queue.capacity(), self.queue.len());
110 self.queue.push_back(releases.as_millis() as u16);
111
112 if self.queue.len() == self.queue.capacity() {
113 tracing::debug!(duration = ?(self.delay.deadline() - now), "ratelimited");
114 }
115
116 Poll::Ready(())
117 }
118
119 pub(crate) fn poll_available(&mut self, cx: &mut Context<'_>) -> Poll<()> {
126 if self.queue.len() < self.queue.capacity() {
127 return Poll::Ready(());
128 }
129
130 self.delay.as_mut().poll(cx)
131 }
132
133 fn next_acquired_position(&self, now: Instant) -> Option<usize> {
137 self.queue
138 .iter()
139 .map(|&m| self.delay.deadline() + Duration::from_millis(m.into()))
140 .position(|deadline| deadline > now)
141 }
142
143 fn rebase(&mut self, new_deadline_idx: usize) {
145 let duration = Duration::from_millis(self.queue[new_deadline_idx].into());
146 let new_deadline = self.delay.deadline() + duration;
147
148 self.queue.drain(..=new_deadline_idx);
149
150 for timestamp in &mut self.queue {
151 let deadline = self.delay.deadline() + Duration::from_millis((*timestamp).into());
152 *timestamp = (deadline - new_deadline).as_millis() as u16;
153 }
154
155 self.delay.as_mut().reset(new_deadline);
156 }
157}
158
159fn nonreserved_commands_per_reset(heartbeat_interval: Duration) -> u8 {
168 const MAX_NONRESERVED_COMMANDS_PER_PERIOD: u8 = PERMITS - 10;
171
172 let heartbeats_per_reset = PERIOD.as_secs_f32() / heartbeat_interval.as_secs_f32();
174
175 #[allow(clippy::cast_sign_loss)]
177 let heartbeats_per_reset = heartbeats_per_reset.ceil() as u8;
178
179 let heartbeats_per_reset = heartbeats_per_reset.saturating_add(1);
181
182 let nonreserved_commands_per_reset = PERMITS.saturating_sub(heartbeats_per_reset);
184
185 nonreserved_commands_per_reset.max(MAX_NONRESERVED_COMMANDS_PER_PERIOD)
187}
188
189#[cfg(test)]
190mod tests {
191 #![allow(clippy::unchecked_time_subtraction)]
192
193 use super::{CommandRatelimiter, PERIOD, nonreserved_commands_per_reset};
194 use static_assertions::assert_impl_all;
195 use std::{fmt::Debug, future::poll_fn, time::Duration};
196 use tokio::time;
197
198 assert_impl_all!(CommandRatelimiter: Debug, Send, Sync);
199
200 #[test]
201 fn nonreserved_commands() {
202 assert_eq!(
203 118,
204 nonreserved_commands_per_reset(Duration::from_secs(u64::MAX))
205 );
206 assert_eq!(118, nonreserved_commands_per_reset(Duration::from_secs(60)));
207 assert_eq!(
208 117,
209 nonreserved_commands_per_reset(Duration::from_millis(42_500))
210 );
211 assert_eq!(117, nonreserved_commands_per_reset(Duration::from_secs(30)));
212 assert_eq!(
213 116,
214 nonreserved_commands_per_reset(Duration::from_millis(29_999))
215 );
216 assert_eq!(110, nonreserved_commands_per_reset(Duration::ZERO));
217 }
218
219 const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(60);
220
221 #[tokio::test(start_paused = true)]
222 async fn full_reset() {
223 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
224
225 assert_eq!(ratelimiter.available(), ratelimiter.max());
226 for _ in 0..ratelimiter.max() {
227 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
228 }
229 assert_eq!(ratelimiter.available(), 0);
230
231 time::advance(PERIOD - Duration::from_millis(100)).await;
233 assert_eq!(ratelimiter.available(), 0);
234
235 time::advance(Duration::from_millis(100)).await;
237 assert_eq!(ratelimiter.available(), ratelimiter.max());
238 }
239
240 #[tokio::test(start_paused = true)]
241 async fn half_reset() {
242 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
243
244 assert_eq!(ratelimiter.available(), ratelimiter.max());
245 for _ in 0..ratelimiter.max() / 2 {
246 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
247 }
248 assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
249
250 time::advance(PERIOD / 2).await;
251
252 assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
253 for _ in 0..ratelimiter.max() / 2 {
254 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
255 }
256 assert_eq!(ratelimiter.available(), 0);
257
258 time::advance(PERIOD / 2).await;
260 assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
261
262 time::advance(PERIOD / 2).await;
264 assert_eq!(ratelimiter.available(), ratelimiter.max());
265 }
266
267 #[tokio::test(start_paused = true)]
268 async fn constant_capacity() {
269 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
270 let max = ratelimiter.max();
271
272 for _ in 0..max {
273 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
274 }
275 assert_eq!(ratelimiter.available(), 0);
276
277 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
278 assert_eq!(max, ratelimiter.max());
279 }
280
281 #[tokio::test(start_paused = true)]
282 async fn rebase() {
283 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
284
285 for _ in 0..5 {
286 time::advance(Duration::from_millis(20)).await;
287 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
288 }
289 assert_eq!(ratelimiter.available(), ratelimiter.max() - 5);
290
291 time::advance(PERIOD - Duration::from_millis(80)).await;
292 assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
293
294 for _ in 0..4 {
295 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
296 time::advance(Duration::from_millis(20)).await;
297 assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
298 }
299 }
300}