twilight_gateway/
ratelimiter.rs

1//! Rate limit events sent to the Gateway.
2//!
3//! See <https://discord.com/developers/docs/topics/gateway#rate-limiting>
4//!
5//! # Algorithm
6//!
7//! [`CommandRatelimiter`] is implemented as a sliding window log. This is the
8//! only ratelimit algorithm that supports burst requests and guarantees that
9//! the (t - [`PERIOD`], t] window is never exceeded. See
10//! <https://hechao.li/posts/Rate-Limiter-Part1/> for an overview of it and
11//! other alternative algorithms.
12
13#![allow(clippy::cast_possible_truncation)]
14
15use std::{
16    collections::VecDeque,
17    future::Future,
18    pin::Pin,
19    task::{ready, Context, Poll},
20};
21use tokio::time::{Duration, Instant, Sleep};
22
23/// Duration until an acquired permit is released.
24const PERIOD: Duration = Duration::from_secs(60);
25
26/// Number of permits per [`PERIOD`].
27const PERMITS: u8 = 120;
28
29/// Ratelimiter for sending commands over the gateway to Discord.
30#[derive(Debug)]
31pub struct CommandRatelimiter {
32    /// Future that completes when the next permit is released.
33    ///
34    /// Counts as an acquired permit if pending.
35    delay: Pin<Box<Sleep>>,
36    /// Ordered queue of timestamps relative to [`Self::delay`] in milliseconds
37    /// when permits release.
38    queue: VecDeque<u16>,
39}
40
41impl CommandRatelimiter {
42    /// Create a new ratelimiter with some capacity reserved for heartbeating.
43    pub(crate) fn new(heartbeat_interval: Duration) -> Self {
44        let capacity = usize::from(nonreserved_commands_per_reset(heartbeat_interval)) - 1;
45
46        let mut queue = VecDeque::with_capacity(capacity);
47        if queue.capacity() != capacity {
48            queue.resize(capacity, 0);
49            // `into_boxed_slice().into_vec()` guarantees len == capacity.
50            let vec = Vec::from(queue).into_boxed_slice().into_vec();
51            // This is guaranteed to not allocate.
52            queue = VecDeque::from(vec);
53            queue.clear();
54        }
55
56        Self {
57            delay: Box::pin(tokio::time::sleep_until(Instant::now())),
58            queue,
59        }
60    }
61
62    /// Number of available permits.
63    pub fn available(&self) -> u8 {
64        let now = Instant::now();
65        let acquired = if now >= self.delay.deadline() {
66            self.next_acquired_position(now)
67                .map_or(0, |released_count| self.queue.len() - released_count)
68        } else {
69            self.queue.len() + 1
70        };
71
72        self.max() - acquired as u8
73    }
74
75    /// Maximum number of available permits.
76    pub fn max(&self) -> u8 {
77        self.queue.capacity() as u8 + 1
78    }
79
80    /// Duration until the next permit is available.
81    pub fn next_available(&self) -> Duration {
82        self.delay
83            .deadline()
84            .saturating_duration_since(Instant::now())
85    }
86
87    /// Attempts to acquire a permit.
88    ///
89    /// # Returns
90    ///
91    /// * `Poll::Pending` if no permit is available
92    /// * `Poll::Ready` if a permit is acquired.
93    pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> {
94        ready!(self.poll_available(cx));
95
96        let now = Instant::now();
97        if now >= self.delay.deadline() {
98            if let Some(new_deadline_idx) = self.next_acquired_position(now) {
99                self.rebase(new_deadline_idx);
100            } else {
101                self.queue.clear();
102                self.delay.as_mut().reset(now + PERIOD);
103
104                return Poll::Ready(());
105            }
106        }
107
108        let releases = (now + PERIOD) - self.delay.deadline();
109        debug_assert_ne!(self.queue.capacity(), self.queue.len());
110        self.queue.push_back(releases.as_millis() as u16);
111
112        if self.queue.len() == self.queue.capacity() {
113            tracing::debug!(duration = ?(self.delay.deadline() - now), "ratelimited");
114        }
115
116        Poll::Ready(())
117    }
118
119    /// Checks whether a permit is available.
120    ///
121    /// # Returns
122    ///
123    /// * `Poll::Pending` if no permit is available
124    /// * `Poll::Ready` if a permit is available.
125    pub(crate) fn poll_available(&mut self, cx: &mut Context<'_>) -> Poll<()> {
126        if self.queue.len() < self.queue.capacity() {
127            return Poll::Ready(());
128        }
129
130        self.delay.as_mut().poll(cx)
131    }
132
133    /// Searches for the first acquired timestamp, returning its index.
134    ///
135    /// If every timestamp is released, it returns `None`.
136    fn next_acquired_position(&self, now: Instant) -> Option<usize> {
137        self.queue
138            .iter()
139            .map(|&m| self.delay.deadline() + Duration::from_millis(m.into()))
140            .position(|deadline| deadline > now)
141    }
142
143    /// Resets to a new deadline and updates acquired permits' relative timestamp.
144    fn rebase(&mut self, new_deadline_idx: usize) {
145        let duration = Duration::from_millis(self.queue[new_deadline_idx].into());
146        let new_deadline = self.delay.deadline() + duration;
147
148        self.queue.drain(..=new_deadline_idx);
149
150        for timestamp in &mut self.queue {
151            let deadline = self.delay.deadline() + Duration::from_millis((*timestamp).into());
152            *timestamp = (deadline - new_deadline).as_millis() as u16;
153        }
154
155        self.delay.as_mut().reset(new_deadline);
156    }
157}
158
159/// Calculates the number of non reserved commands for heartbeating (which
160/// bypasses the ratelimiter) in a [`PERIOD`].
161///
162/// Reserves capacity for an additional gateway event to guard against Discord
163/// sending [`OpCode::Heartbeat`]s (which requires sending a heartbeat back
164/// immediately).
165///
166/// [`OpCode::Heartbeat`]: twilight_model::gateway::OpCode::Heartbeat
167fn nonreserved_commands_per_reset(heartbeat_interval: Duration) -> u8 {
168    /// Guard against faulty gateways specifying low heartbeat intervals by
169    /// maximally reserving this many heartbeats per [`PERIOD`].
170    const MAX_NONRESERVED_COMMANDS_PER_PERIOD: u8 = PERMITS - 10;
171
172    // Calculate the amount of heartbeats per heartbeat interval.
173    let heartbeats_per_reset = PERIOD.as_secs_f32() / heartbeat_interval.as_secs_f32();
174
175    // Round up to be on the safe side.
176    #[allow(clippy::cast_sign_loss)]
177    let heartbeats_per_reset = heartbeats_per_reset.ceil() as u8;
178
179    // Reserve an extra heartbeat just in case.
180    let heartbeats_per_reset = heartbeats_per_reset.saturating_add(1);
181
182    // Subtract the reserved heartbeats from the total available events.
183    let nonreserved_commands_per_reset = PERMITS.saturating_sub(heartbeats_per_reset);
184
185    // Take the larger value between this and the guard value.
186    nonreserved_commands_per_reset.max(MAX_NONRESERVED_COMMANDS_PER_PERIOD)
187}
188
189#[cfg(test)]
190mod tests {
191    use super::{nonreserved_commands_per_reset, CommandRatelimiter, PERIOD};
192    use static_assertions::assert_impl_all;
193    use std::{fmt::Debug, future::poll_fn, time::Duration};
194    use tokio::time;
195
196    assert_impl_all!(CommandRatelimiter: Debug, Send, Sync);
197
198    #[test]
199    fn nonreserved_commands() {
200        assert_eq!(
201            118,
202            nonreserved_commands_per_reset(Duration::from_secs(u64::MAX))
203        );
204        assert_eq!(118, nonreserved_commands_per_reset(Duration::from_secs(60)));
205        assert_eq!(
206            117,
207            nonreserved_commands_per_reset(Duration::from_millis(42_500))
208        );
209        assert_eq!(117, nonreserved_commands_per_reset(Duration::from_secs(30)));
210        assert_eq!(
211            116,
212            nonreserved_commands_per_reset(Duration::from_millis(29_999))
213        );
214        assert_eq!(110, nonreserved_commands_per_reset(Duration::ZERO));
215    }
216
217    const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(60);
218
219    #[tokio::test(start_paused = true)]
220    async fn full_reset() {
221        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
222
223        assert_eq!(ratelimiter.available(), ratelimiter.max());
224        for _ in 0..ratelimiter.max() {
225            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
226        }
227        assert_eq!(ratelimiter.available(), 0);
228
229        // Should not refill until PERIOD has passed.
230        time::advance(PERIOD - Duration::from_millis(100)).await;
231        assert_eq!(ratelimiter.available(), 0);
232
233        // All should be refilled.
234        time::advance(Duration::from_millis(100)).await;
235        assert_eq!(ratelimiter.available(), ratelimiter.max());
236    }
237
238    #[tokio::test(start_paused = true)]
239    async fn half_reset() {
240        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
241
242        assert_eq!(ratelimiter.available(), ratelimiter.max());
243        for _ in 0..ratelimiter.max() / 2 {
244            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
245        }
246        assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
247
248        time::advance(PERIOD / 2).await;
249
250        assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
251        for _ in 0..ratelimiter.max() / 2 {
252            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
253        }
254        assert_eq!(ratelimiter.available(), 0);
255
256        // Half should be refilled.
257        time::advance(PERIOD / 2).await;
258        assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
259
260        // All should be refilled.
261        time::advance(PERIOD / 2).await;
262        assert_eq!(ratelimiter.available(), ratelimiter.max());
263    }
264
265    #[tokio::test(start_paused = true)]
266    async fn constant_capacity() {
267        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
268        let max = ratelimiter.max();
269
270        for _ in 0..max {
271            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
272        }
273        assert_eq!(ratelimiter.available(), 0);
274
275        poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
276        assert_eq!(max, ratelimiter.max());
277    }
278
279    #[tokio::test(start_paused = true)]
280    async fn rebase() {
281        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
282
283        for _ in 0..5 {
284            time::advance(Duration::from_millis(20)).await;
285            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
286        }
287        assert_eq!(ratelimiter.available(), ratelimiter.max() - 5);
288
289        time::advance(PERIOD - Duration::from_millis(80)).await;
290        assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
291
292        for _ in 0..4 {
293            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
294            time::advance(Duration::from_millis(20)).await;
295            assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
296        }
297    }
298}