twilight_gateway/
ratelimiter.rs

1//! Rate limit events sent to the Gateway.
2//!
3//! See <https://discord.com/developers/docs/topics/gateway#rate-limiting>
4//!
5//! # Algorithm
6//!
7//! [`CommandRatelimiter`] is implemented as a sliding window log. This is the
8//! only ratelimit algorithm that supports burst requests and guarantees that
9//! the (t - [`PERIOD`], t] window is never exceeded. See
10//! <https://hechao.li/posts/Rate-Limiter-Part1/> for an overview of it and
11//! other alternative algorithms.
12
13#![allow(clippy::cast_possible_truncation)]
14
15use std::{
16    collections::VecDeque,
17    future::Future,
18    pin::Pin,
19    task::{Context, Poll, ready},
20};
21use tokio::time::{Duration, Instant, Sleep};
22
23/// Duration until an acquired permit is released.
24const PERIOD: Duration = Duration::from_secs(60);
25
26/// Number of permits per [`PERIOD`].
27const PERMITS: u8 = 120;
28
29/// Ratelimiter for sending commands over the gateway to Discord.
30#[derive(Debug)]
31pub struct CommandRatelimiter {
32    /// Future that completes when the next permit is released.
33    ///
34    /// Counts as an acquired permit if pending.
35    delay: Pin<Box<Sleep>>,
36    /// Ordered queue of timestamps relative to [`Self::delay`] in milliseconds
37    /// when permits release.
38    queue: VecDeque<u16>,
39}
40
41impl CommandRatelimiter {
42    /// Create a new ratelimiter with some capacity reserved for heartbeating.
43    pub(crate) fn new(heartbeat_interval: Duration) -> Self {
44        let capacity = usize::from(nonreserved_commands_per_reset(heartbeat_interval)) - 1;
45
46        let mut queue = VecDeque::with_capacity(capacity);
47        if queue.capacity() != capacity {
48            queue.resize(capacity, 0);
49            // `into_boxed_slice().into_vec()` guarantees len == capacity.
50            let vec = Vec::from(queue).into_boxed_slice().into_vec();
51            // This is guaranteed to not allocate.
52            queue = VecDeque::from(vec);
53            queue.clear();
54        }
55
56        Self {
57            delay: Box::pin(tokio::time::sleep_until(Instant::now())),
58            queue,
59        }
60    }
61
62    /// Number of available permits.
63    pub fn available(&self) -> u8 {
64        let now = Instant::now();
65        let acquired = if now >= self.delay.deadline() {
66            self.next_acquired_position(now)
67                .map_or(0, |released_count| self.queue.len() - released_count)
68        } else {
69            self.queue.len() + 1
70        };
71
72        self.max() - acquired as u8
73    }
74
75    /// Maximum number of available permits.
76    pub fn max(&self) -> u8 {
77        self.queue.capacity() as u8 + 1
78    }
79
80    /// Duration until the next permit is available.
81    pub fn next_available(&self) -> Duration {
82        self.delay
83            .deadline()
84            .saturating_duration_since(Instant::now())
85    }
86
87    /// Attempts to acquire a permit.
88    ///
89    /// # Returns
90    ///
91    /// * `Poll::Pending` if no permit is available
92    /// * `Poll::Ready` if a permit is acquired.
93    pub(crate) fn poll_acquire(&mut self, cx: &mut Context<'_>) -> Poll<()> {
94        ready!(self.poll_available(cx));
95
96        let now = Instant::now();
97        if now >= self.delay.deadline() {
98            if let Some(new_deadline_idx) = self.next_acquired_position(now) {
99                self.rebase(new_deadline_idx);
100            } else {
101                self.queue.clear();
102                self.delay.as_mut().reset(now + PERIOD);
103
104                return Poll::Ready(());
105            }
106        }
107
108        let releases = (now + PERIOD) - self.delay.deadline();
109        debug_assert_ne!(self.queue.capacity(), self.queue.len());
110        self.queue.push_back(releases.as_millis() as u16);
111
112        if self.queue.len() == self.queue.capacity() {
113            tracing::debug!(duration = ?(self.delay.deadline() - now), "ratelimited");
114        }
115
116        Poll::Ready(())
117    }
118
119    /// Checks whether a permit is available.
120    ///
121    /// # Returns
122    ///
123    /// * `Poll::Pending` if no permit is available
124    /// * `Poll::Ready` if a permit is available.
125    pub(crate) fn poll_available(&mut self, cx: &mut Context<'_>) -> Poll<()> {
126        if self.queue.len() < self.queue.capacity() {
127            return Poll::Ready(());
128        }
129
130        self.delay.as_mut().poll(cx)
131    }
132
133    /// Searches for the first acquired timestamp, returning its index.
134    ///
135    /// If every timestamp is released, it returns `None`.
136    fn next_acquired_position(&self, now: Instant) -> Option<usize> {
137        self.queue
138            .iter()
139            .map(|&m| self.delay.deadline() + Duration::from_millis(m.into()))
140            .position(|deadline| deadline > now)
141    }
142
143    /// Resets to a new deadline and updates acquired permits' relative timestamp.
144    fn rebase(&mut self, new_deadline_idx: usize) {
145        let duration = Duration::from_millis(self.queue[new_deadline_idx].into());
146        let new_deadline = self.delay.deadline() + duration;
147
148        self.queue.drain(..=new_deadline_idx);
149
150        for timestamp in &mut self.queue {
151            let deadline = self.delay.deadline() + Duration::from_millis((*timestamp).into());
152            *timestamp = (deadline - new_deadline).as_millis() as u16;
153        }
154
155        self.delay.as_mut().reset(new_deadline);
156    }
157}
158
159/// Calculates the number of non reserved commands for heartbeating (which
160/// bypasses the ratelimiter) in a [`PERIOD`].
161///
162/// Reserves capacity for an additional gateway event to guard against Discord
163/// sending [`OpCode::Heartbeat`]s (which requires sending a heartbeat back
164/// immediately).
165///
166/// [`OpCode::Heartbeat`]: twilight_model::gateway::OpCode::Heartbeat
167fn nonreserved_commands_per_reset(heartbeat_interval: Duration) -> u8 {
168    /// Guard against faulty gateways specifying low heartbeat intervals by
169    /// maximally reserving this many heartbeats per [`PERIOD`].
170    const MAX_NONRESERVED_COMMANDS_PER_PERIOD: u8 = PERMITS - 10;
171
172    // Calculate the amount of heartbeats per heartbeat interval.
173    let heartbeats_per_reset = PERIOD.as_secs_f32() / heartbeat_interval.as_secs_f32();
174
175    // Round up to be on the safe side.
176    #[allow(clippy::cast_sign_loss)]
177    let heartbeats_per_reset = heartbeats_per_reset.ceil() as u8;
178
179    // Reserve an extra heartbeat just in case.
180    let heartbeats_per_reset = heartbeats_per_reset.saturating_add(1);
181
182    // Subtract the reserved heartbeats from the total available events.
183    let nonreserved_commands_per_reset = PERMITS.saturating_sub(heartbeats_per_reset);
184
185    // Take the larger value between this and the guard value.
186    nonreserved_commands_per_reset.max(MAX_NONRESERVED_COMMANDS_PER_PERIOD)
187}
188
189#[cfg(test)]
190mod tests {
191    #![allow(clippy::unchecked_time_subtraction)]
192
193    use super::{CommandRatelimiter, PERIOD, nonreserved_commands_per_reset};
194    use static_assertions::assert_impl_all;
195    use std::{fmt::Debug, future::poll_fn, time::Duration};
196    use tokio::time;
197
198    assert_impl_all!(CommandRatelimiter: Debug, Send, Sync);
199
200    #[test]
201    fn nonreserved_commands() {
202        assert_eq!(
203            118,
204            nonreserved_commands_per_reset(Duration::from_secs(u64::MAX))
205        );
206        assert_eq!(118, nonreserved_commands_per_reset(Duration::from_secs(60)));
207        assert_eq!(
208            117,
209            nonreserved_commands_per_reset(Duration::from_millis(42_500))
210        );
211        assert_eq!(117, nonreserved_commands_per_reset(Duration::from_secs(30)));
212        assert_eq!(
213            116,
214            nonreserved_commands_per_reset(Duration::from_millis(29_999))
215        );
216        assert_eq!(110, nonreserved_commands_per_reset(Duration::ZERO));
217    }
218
219    const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(60);
220
221    #[tokio::test(start_paused = true)]
222    async fn full_reset() {
223        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
224
225        assert_eq!(ratelimiter.available(), ratelimiter.max());
226        for _ in 0..ratelimiter.max() {
227            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
228        }
229        assert_eq!(ratelimiter.available(), 0);
230
231        // Should not refill until PERIOD has passed.
232        time::advance(PERIOD - Duration::from_millis(100)).await;
233        assert_eq!(ratelimiter.available(), 0);
234
235        // All should be refilled.
236        time::advance(Duration::from_millis(100)).await;
237        assert_eq!(ratelimiter.available(), ratelimiter.max());
238    }
239
240    #[tokio::test(start_paused = true)]
241    async fn half_reset() {
242        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
243
244        assert_eq!(ratelimiter.available(), ratelimiter.max());
245        for _ in 0..ratelimiter.max() / 2 {
246            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
247        }
248        assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
249
250        time::advance(PERIOD / 2).await;
251
252        assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
253        for _ in 0..ratelimiter.max() / 2 {
254            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
255        }
256        assert_eq!(ratelimiter.available(), 0);
257
258        // Half should be refilled.
259        time::advance(PERIOD / 2).await;
260        assert_eq!(ratelimiter.available(), ratelimiter.max() / 2);
261
262        // All should be refilled.
263        time::advance(PERIOD / 2).await;
264        assert_eq!(ratelimiter.available(), ratelimiter.max());
265    }
266
267    #[tokio::test(start_paused = true)]
268    async fn constant_capacity() {
269        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
270        let max = ratelimiter.max();
271
272        for _ in 0..max {
273            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
274        }
275        assert_eq!(ratelimiter.available(), 0);
276
277        poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
278        assert_eq!(max, ratelimiter.max());
279    }
280
281    #[tokio::test(start_paused = true)]
282    async fn rebase() {
283        let mut ratelimiter = CommandRatelimiter::new(HEARTBEAT_INTERVAL);
284
285        for _ in 0..5 {
286            time::advance(Duration::from_millis(20)).await;
287            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
288        }
289        assert_eq!(ratelimiter.available(), ratelimiter.max() - 5);
290
291        time::advance(PERIOD - Duration::from_millis(80)).await;
292        assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
293
294        for _ in 0..4 {
295            poll_fn(|cx| ratelimiter.poll_acquire(cx)).await;
296            time::advance(Duration::from_millis(20)).await;
297            assert_eq!(ratelimiter.available(), ratelimiter.max() - 4);
298        }
299    }
300}