twilight_gateway/
ratelimiter.rs1#![allow(clippy::cast_possible_truncation)]
14
15use std::{
16 collections::VecDeque,
17 future::Future,
18 pin::Pin,
19 task::{ready, Context, Poll},
20};
21use tokio::time::{Duration, Instant, Sleep};
22
23const PERIOD: Duration = Duration::from_secs(60);
25
26const PERMITS: u8 = 120;
28
29#[derive(Debug)]
31pub struct CommandRatelimiter {
32 delay: Pin<Box<Sleep>>,
36 queue: VecDeque<u16>,
39}
40
41impl CommandRatelimiter {
42 pub(crate) fn new(heartbeat_interval: Duration) -> Self {
44 let capacity = usize::from(nonreserved_commands_per_reset(heartbeat_interval)) - 1;
45
46 let mut queue = VecDeque::with_capacity(capacity);
47 if queue.capacity() != capacity {
48 queue.resize(capacity, 0);
49 let vec = Vec::from(queue).into_boxed_slice().into_vec();
51 queue = VecDeque::from(vec);
53 queue.clear();
54 }
55
56 Self {
57 delay: Box::pin(tokio::time::sleep_until(Instant::now())),
58 queue,
59 }
60 }
61
62 pub fn available(&self) -> u8 {
64 let now = Instant::now();
65 let acquired = if now >= self.delay.deadline() {
66 self.next_acquired_position(now)
67 .map_or(0, |released_count| self.queue.len() - released_count)
68 } else {
69 self.queue.len() + 1
70 };
71
72 self.max() - acquired as u8
73 }
74
75 pub fn max(&self) -> u8 {
77 self.queue.capacity() as u8 + 1
78 }
79
80 pub fn next_available(&self) -> Duration {
82 self.delay
83 .deadline()
84 .saturating_duration_since(Instant::now())
85 }
86
87 pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> {
94 ready!(self.poll_available(cx));
95
96 let now = Instant::now();
97 if now >= self.delay.deadline() {
98 if let Some(new_deadline_idx) = self.next_acquired_position(now) {
99 self.rebase(new_deadline_idx);
100 } else {
101 self.queue.clear();
102 self.delay.as_mut().reset(now + PERIOD);
103
104 return Poll::Ready(());
105 }
106 }
107
108 let releases = (now + PERIOD) - self.delay.deadline();
109 debug_assert_ne!(self.queue.capacity(), self.queue.len());
110 self.queue.push_back(releases.as_millis() as u16);
111
112 if self.queue.len() == self.queue.capacity() {
113 tracing::debug!(duration = ?(self.delay.deadline() - now), "ratelimited");
114 }
115
116 Poll::Ready(())
117 }
118
119 pub(crate) fn poll_available(&mut self, cx: &mut Context<'_>) -> Poll<()> {
126 if self.queue.len() < self.queue.capacity() {
127 return Poll::Ready(());
128 }
129
130 self.delay.as_mut().poll(cx)
131 }
132
133 fn next_acquired_position(&self, now: Instant) -> Option<usize> {
137 self.queue
138 .iter()
139 .map(|&m| self.delay.deadline() + Duration::from_millis(m.into()))
140 .position(|deadline| deadline > now)
141 }
142
143 fn rebase(&mut self, new_deadline_idx: usize) {
145 let duration = Duration::from_millis(self.queue[new_deadline_idx].into());
146 let new_deadline = self.delay.deadline() + duration;
147
148 self.queue.drain(..=new_deadline_idx);
149
150 for timestamp in &mut self.queue {
151 let deadline = self.delay.deadline() + Duration::from_millis((*timestamp).into());
152 *timestamp = (deadline - new_deadline).as_millis() as u16;
153 }
154
155 self.delay.as_mut().reset(new_deadline);
156 }
157}
158
159fn nonreserved_commands_per_reset(heartbeat_interval: Duration) -> u8 {
168 const MAX_NONRESERVED_COMMANDS_PER_PERIOD: u8 = PERMITS - 10;
171
172 let heartbeats_per_reset = PERIOD.as_secs_f32() / heartbeat_interval.as_secs_f32();
174
175 #[allow(clippy::cast_sign_loss)]
177 let heartbeats_per_reset = heartbeats_per_reset.ceil() as u8;
178
179 let heartbeats_per_reset = heartbeats_per_reset.saturating_add(1);
181
182 let nonreserved_commands_per_reset = PERMITS.saturating_sub(heartbeats_per_reset);
184
185 nonreserved_commands_per_reset.max(MAX_NONRESERVED_COMMANDS_PER_PERIOD)
187}
188
189#[cfg(test)]
190mod tests {
191 use super::{nonreserved_commands_per_reset, CommandRatelimiter, PERIOD};
192 use static_assertions::assert_impl_all;
193 use std::{fmt::Debug, future::poll_fn, time::Duration};
194 use tokio::time;
195
196 assert_impl_all!(CommandRatelimiter: Debug, Send, Sync);
197
198 #[test]
199 fn nonreserved_commands() {
200 assert_eq!(
201 118,
202 nonreserved_commands_per_reset(Duration::from_secs(u64::MAX))
203 );
204 assert_eq!(118, nonreserved_commands_per_reset(Duration::from_secs(60)));
205 assert_eq!(
206 117,
207 nonreserved_commands_per_reset(Duration::from_millis(42_500))
208 );
209 assert_eq!(117, nonreserved_commands_per_reset(Duration::from_secs(30)));
210 assert_eq!(
211 116,
212 nonreserved_commands_per_reset(Duration::from_millis(29_999))
213 );
214 assert_eq!(110, nonreserved_commands_per_reset(Duration::ZERO));
215 }
216
217 const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(60);
218
219 #[tokio::test(start_paused = true)]
220 async fn full_reset() {
221 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
222
223 assert_eq!(ratelimiter.available(), ratelimiter.max());
224 for _ in 0..ratelimiter.max() {
225 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
226 }
227 assert_eq!(ratelimiter.available(), 0);
228
229 time::advance(PERIOD - Duration::from_millis(100)).await;
231 assert_eq!(ratelimiter.available(), 0);
232
233 time::advance(Duration::from_millis(100)).await;
235 assert_eq!(ratelimiter.available(), ratelimiter.max());
236 }
237
238 #[tokio::test(start_paused = true)]
239 async fn half_reset() {
240 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
241
242 assert_eq!(ratelimiter.available(), ratelimiter.max());
243 for _ in 0..ratelimiter.max() / 2 {
244 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
245 }
246 assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
247
248 time::advance(PERIOD / 2).await;
249
250 assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
251 for _ in 0..ratelimiter.max() / 2 {
252 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
253 }
254 assert_eq!(ratelimiter.available(), 0);
255
256 time::advance(PERIOD / 2).await;
258 assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
259
260 time::advance(PERIOD / 2).await;
262 assert_eq!(ratelimiter.available(), ratelimiter.max());
263 }
264
265 #[tokio::test(start_paused = true)]
266 async fn constant_capacity() {
267 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
268 let max = ratelimiter.max();
269
270 for _ in 0..max {
271 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
272 }
273 assert_eq!(ratelimiter.available(), 0);
274
275 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
276 assert_eq!(max, ratelimiter.max());
277 }
278
279 #[tokio::test(start_paused = true)]
280 async fn rebase() {
281 let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
282
283 for _ in 0..5 {
284 time::advance(Duration::from_millis(20)).await;
285 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
286 }
287 assert_eq!(ratelimiter.available(), ratelimiter.max() - 5);
288
289 time::advance(PERIOD - Duration::from_millis(80)).await;
290 assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
291
292 for _ in 0..4 {
293 poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
294 time::advance(Duration::from_millis(20)).await;
295 assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
296 }
297 }
298}