Skip to main content

twilight_http_ratelimiting/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(
3    clippy::missing_const_for_fn,
4    clippy::missing_docs_in_private_items,
5    clippy::pedantic,
6    missing_docs,
7    unsafe_code
8)]
9#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)]
10
11mod actor;
12
13#[cfg(not(target_os = "wasi"))]
14use std::{
15    future::Future,
16    hash::{Hash as _, Hasher},
17    pin::Pin,
18    task::{Context, Poll},
19    time::{Duration, Instant},
20};
21#[cfg(not(target_os = "wasi"))]
22use tokio::sync::{mpsc, oneshot};
23
24/// Duration from the first globally limited request until the remaining count
25/// resets to the global limit count.
26#[cfg(not(target_os = "wasi"))]
27pub const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1);
28
29/// User actionable description that the actor panicked.
30#[cfg(not(target_os = "wasi"))]
31const ACTOR_PANIC_MESSAGE: &str =
32    "actor task panicked: report its panic message to the maintainers";
33
34/// HTTP request [method].
35///
36/// [method]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Methods
37#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
38#[non_exhaustive]
39pub enum Method {
40    /// Delete a resource.
41    Delete,
42    /// Retrieve a resource.
43    Get,
44    /// Update a resource.
45    Patch,
46    /// Create a resource.
47    Post,
48    /// Replace a resource.
49    Put,
50}
51
52impl Method {
53    /// Name of the method.
54    pub const fn name(self) -> &'static str {
55        match self {
56            Method::Delete => "DELETE",
57            Method::Get => "GET",
58            Method::Patch => "PATCH",
59            Method::Post => "POST",
60            Method::Put => "PUT",
61        }
62    }
63}
64
65/// Rate limited endpoint.
66///
67/// The rate limiter dynamically supports new or unknown API paths, but is consequently unable to
68/// catch invalid arguments. Invalidly structured endpoints may be permitted at an improper time.
69///
70/// # Example
71///
72/// ```no_run
73/// # let rt = tokio::runtime::Builder::new_current_thread()
74/// #     .enable_time()
75/// #     .build()
76/// #     .unwrap();
77/// # rt.block_on(async {
78/// # let rate_limiter = twilight_http_ratelimiting::RateLimiter::default();
79/// use twilight_http_ratelimiting::{Endpoint, Method};
80///
81/// let url = "https://discord.com/api/v10/guilds/745809834183753828/audit-logs?limit=10";
82/// let endpoint = Endpoint {
83///     method: Method::Get,
84///     path: String::from("guilds/745809834183753828/audit-logs"),
85/// };
86/// let permit = rate_limiter.acquire(endpoint).await;
87/// let headers = unimplemented!("GET {url}");
88/// permit.complete(headers);
89/// # });
90/// ```
91#[derive(Clone, Debug, Eq, Hash, PartialEq)]
92#[cfg(not(target_os = "wasi"))]
93pub struct Endpoint {
94    /// Method of the endpoint.
95    pub method: Method,
96    /// API path of the endpoint.
97    ///
98    /// Should not start with a slash (`/`) or include query parameters (`?key=value`).
99    pub path: String,
100}
101
102#[cfg(not(target_os = "wasi"))]
103impl Endpoint {
104    /// Whether the endpoint is properly structured.
105    pub(crate) fn is_valid(&self) -> bool {
106        !self.path.as_bytes().starts_with(b"/") && !self.path.as_bytes().contains(&b'?')
107    }
108
109    /// Whether the endpoint is an interaction.
110    pub(crate) fn is_interaction(&self) -> bool {
111        self.path.as_bytes().starts_with(b"webhooks")
112            || self.path.as_bytes().starts_with(b"interactions")
113    }
114
115    /// Feeds the top-level resources of this endpoint into the given [`Hasher`].
116    ///
117    /// Top-level resources represent the bucket namespace in which they are unique.
118    ///
119    /// Top-level resources are currently:
120    /// - `channels/<channel_id>`
121    /// - `guilds/<guild_id>`
122    /// - `webhooks/<webhook_id>`
123    /// - `webhooks/<webhook_id>/<webhook_token>`
124    pub(crate) fn hash_resources(&self, state: &mut impl Hasher) {
125        let mut segments = self.path.as_bytes().split(|&s| s == b'/');
126        match segments.next().unwrap_or_default() {
127            b"channels" => {
128                if let Some(s) = segments.next() {
129                    "channels".hash(state);
130                    s.hash(state);
131                }
132            }
133            b"guilds" => {
134                if let Some(s) = segments.next() {
135                    "guilds".hash(state);
136                    s.hash(state);
137                }
138            }
139            b"webhooks" => {
140                if let Some(s) = segments.next() {
141                    "webhooks".hash(state);
142                    s.hash(state);
143                }
144                if let Some(s) = segments.next() {
145                    s.hash(state);
146                }
147            }
148            _ => {}
149        }
150    }
151}
152
153/// Parsed user response rate limit headers.
154///
155/// A `limit` of zero marks the [`Bucket`] as exhausted until `reset_at` elapses.
156///
157/// # Global limits
158///
159/// Please open an issue if the [`RateLimiter`] exceeded the global limit.
160///
161/// # Shared limits
162///
163/// You may preemptively exhaust the bucket until `Reset-After` by completing
164/// the [`Permit`] with [`RateLimitHeaders::shared`], but are not required to
165/// since these limits do not count towards the invalid request limit.
166#[derive(Clone, Debug, Eq, Hash, PartialEq)]
167#[cfg(not(target_os = "wasi"))]
168pub struct RateLimitHeaders {
169    /// Bucket identifier.
170    pub bucket: Vec<u8>,
171    /// Total number of requests until the bucket becomes exhausted.
172    pub limit: u16,
173    /// Number of remaining requests until the bucket becomes exhausted.
174    pub remaining: u16,
175    /// Time at which the bucket resets.
176    pub reset_at: Instant,
177}
178
179#[cfg(not(target_os = "wasi"))]
180impl RateLimitHeaders {
181    /// Lowercased name for the bucket header.
182    pub const BUCKET: &'static str = "x-ratelimit-bucket";
183
184    /// Lowercased name for the limit header.
185    pub const LIMIT: &'static str = "x-ratelimit-limit";
186
187    /// Lowercased name for the remaining header.
188    pub const REMAINING: &'static str = "x-ratelimit-remaining";
189
190    /// Lowercased name for the reset-after header.
191    pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
192
193    /// Lowercased name for the scope header.
194    pub const SCOPE: &'static str = "x-ratelimit-scope";
195
196    /// Emulates a shared resource limit as a user limit by setting `limit` and
197    /// `remaining` to zero.
198    pub fn shared(bucket: Vec<u8>, retry_after: u16) -> Self {
199        Self {
200            bucket,
201            limit: 0,
202            remaining: 0,
203            reset_at: Instant::now() + Duration::from_secs(retry_after.into()),
204        }
205    }
206}
207
208/// Permit to send a Discord HTTP API request to the acquired endpoint.
209#[derive(Debug)]
210#[must_use = "dropping the permit immediately cancels itself"]
211#[cfg(not(target_os = "wasi"))]
212pub struct Permit(oneshot::Sender<Option<RateLimitHeaders>>);
213
214#[cfg(not(target_os = "wasi"))]
215impl Permit {
216    /// Update the [`RateLimiter`] based on the response headers.
217    ///
218    /// Non-completed permits are regarded as cancelled, so only call this
219    /// on receiving a response.
220    #[allow(clippy::missing_panics_doc)]
221    pub fn complete(self, headers: Option<RateLimitHeaders>) {
222        assert!(self.0.send(headers).is_ok(), "{ACTOR_PANIC_MESSAGE}");
223    }
224}
225
226/// Future that completes when a permit is ready.
227#[derive(Debug)]
228#[must_use = "futures do nothing unless you `.await` or poll them"]
229#[cfg(not(target_os = "wasi"))]
230pub struct PermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
231
232#[cfg(not(target_os = "wasi"))]
233impl Future for PermitFuture {
234    type Output = Permit;
235
236    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237        #[allow(clippy::match_wild_err_arm)]
238        Pin::new(&mut self.0).poll(cx).map(|r| match r {
239            Ok(sender) => Permit(sender),
240            Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
241        })
242    }
243}
244
245/// Future that completes when a permit is ready or cancelled.
246#[derive(Debug)]
247#[must_use = "futures do nothing unless you `.await` or poll them"]
248#[cfg(not(target_os = "wasi"))]
249pub struct MaybePermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
250
251#[cfg(not(target_os = "wasi"))]
252impl Future for MaybePermitFuture {
253    type Output = Option<Permit>;
254
255    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256        Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit))
257    }
258}
259
260/// Rate limit information for one or more paths from previous
261/// [`RateLimitHeaders`].
262#[non_exhaustive]
263#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
264#[cfg(not(target_os = "wasi"))]
265pub struct Bucket {
266    /// Total number of permits until the bucket becomes exhausted.
267    pub limit: u16,
268    /// Number of remaining permits until the bucket becomes exhausted.
269    pub remaining: u16,
270    /// Time at which the bucket resets.
271    pub reset_at: Instant,
272}
273
274/// Actor run closure pre-enqueue for early [`MaybePermitFuture`] cancellation.
275#[cfg(not(target_os = "wasi"))]
276type Predicate = Box<dyn FnOnce(Option<Bucket>) -> bool + Send>;
277
278/// Discord HTTP client API rate limiter.
279///
280/// The [`RateLimiter`] runs an associated actor task to concurrently handle permit
281/// requests and responses.
282///
283/// Cloning a [`RateLimiter`] increments just the amount of senders for the actor.
284/// The actor completes when there are no senders and non-completed permits left.
285#[derive(Clone, Debug)]
286#[cfg(not(target_os = "wasi"))]
287pub struct RateLimiter {
288    /// Actor message sender.
289    tx: mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
290}
291
292#[cfg(not(target_os = "wasi"))]
293impl RateLimiter {
294    /// Create a new [`RateLimiter`] with a custom global limit.
295    pub fn new(global_limit: u16) -> Self {
296        let (tx, rx) = mpsc::unbounded_channel();
297        tokio::spawn(actor::runner(global_limit, rx));
298
299        Self { tx }
300    }
301
302    /// Await a single permit for this endpoint.
303    ///
304    /// Permits are queued per endpoint in the order they were requested.
305    #[allow(clippy::missing_panics_doc)]
306    pub fn acquire(&self, endpoint: Endpoint) -> PermitFuture {
307        let (notifier, rx) = oneshot::channel();
308        let message = actor::Message { endpoint, notifier };
309        assert!(
310            self.tx.send((message, None)).is_ok(),
311            "{ACTOR_PANIC_MESSAGE}"
312        );
313
314        PermitFuture(rx)
315    }
316
317    /// Await a single permit for this endpoint, but only if the predicate evaluates
318    /// to `true`.
319    ///
320    /// Permits are queued per endpoint in the order they were requested.
321    ///
322    /// Note that the predicate is asynchronously called in the actor task.
323    ///
324    /// # Example
325    ///
326    /// ```no_run
327    /// # let rt = tokio::runtime::Builder::new_current_thread()
328    /// #     .enable_time()
329    /// #     .build()
330    /// #     .unwrap();
331    /// # rt.block_on(async {
332    /// # let rate_limiter = twilight_http_ratelimiting::RateLimiter::default();
333    /// use twilight_http_ratelimiting::{Endpoint, Method};
334    ///
335    /// let endpoint = Endpoint {
336    ///     method: Method::Get,
337    ///     path: String::from("applications/@me"),
338    /// };
339    /// if let Some(permit) = rate_limiter
340    ///     .acquire_if(endpoint, |b| b.is_none_or(|b| b.remaining > 10))
341    ///     .await
342    /// {
343    ///     let headers = unimplemented!("GET /applications/@me");
344    ///     permit.complete(headers);
345    /// }
346    /// # });
347    /// ```
348    #[allow(clippy::missing_panics_doc)]
349    pub fn acquire_if<P>(&self, endpoint: Endpoint, predicate: P) -> MaybePermitFuture
350    where
351        P: FnOnce(Option<Bucket>) -> bool + Send + 'static,
352    {
353        fn acquire_if(
354            tx: &mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
355            endpoint: Endpoint,
356            predicate: Predicate,
357        ) -> MaybePermitFuture {
358            let (notifier, rx) = oneshot::channel();
359            let message = actor::Message { endpoint, notifier };
360            assert!(
361                tx.send((message, Some(predicate))).is_ok(),
362                "{ACTOR_PANIC_MESSAGE}"
363            );
364
365            MaybePermitFuture(rx)
366        }
367
368        acquire_if(&self.tx, endpoint, Box::new(predicate))
369    }
370
371    /// Retrieve the [`Bucket`] for this endpoint.
372    ///
373    /// The bucket is internally retrieved via [`acquire_if`][Self::acquire_if].
374    #[allow(clippy::missing_panics_doc)]
375    pub async fn bucket(&self, endpoint: Endpoint) -> Option<Bucket> {
376        let (tx, rx) = oneshot::channel();
377        self.acquire_if(endpoint, |bucket| {
378            // Ignore cancellation error.
379            _ = tx.send(bucket);
380            false
381        })
382        .await;
383
384        #[allow(clippy::match_wild_err_arm)]
385        match rx.await {
386            Ok(bucket) => bucket,
387            Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
388        }
389    }
390}
391
392#[cfg(not(target_os = "wasi"))]
393impl Default for RateLimiter {
394    /// Create a new [`RateLimiter`] with Discord's default global limit.
395    ///
396    /// Currently this is `50`.
397    fn default() -> Self {
398        Self::new(50)
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::{
405        Bucket, Endpoint, MaybePermitFuture, Method, Permit, PermitFuture, RateLimitHeaders,
406        RateLimiter,
407    };
408    use static_assertions::assert_impl_all;
409    use std::{
410        fmt::Debug,
411        future::Future,
412        hash::{DefaultHasher, Hash, Hasher as _},
413        time::{Duration, Instant},
414    };
415    use tokio::task;
416
417    assert_impl_all!(Bucket: Clone, Copy, Debug, Eq, Hash, PartialEq, Send, Sync);
418    assert_impl_all!(Endpoint: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
419    assert_impl_all!(MaybePermitFuture: Debug, Future<Output = Option<Permit>>);
420    assert_impl_all!(Method: Clone, Copy, Debug, Eq, PartialEq);
421    assert_impl_all!(Permit: Debug, Send, Sync);
422    assert_impl_all!(PermitFuture: Debug, Future<Output = Permit>);
423    assert_impl_all!(RateLimitHeaders: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
424    assert_impl_all!(RateLimiter: Clone, Debug, Default, Send, Sync);
425
426    const ENDPOINT: fn() -> Endpoint = || Endpoint {
427        method: Method::Get,
428        path: String::from("applications/@me"),
429    };
430
431    #[tokio::test]
432    async fn acquire_if() {
433        let rate_limiter = RateLimiter::default();
434
435        assert!(
436            rate_limiter
437                .acquire_if(ENDPOINT(), |_| false)
438                .await
439                .is_none()
440        );
441        assert!(
442            rate_limiter
443                .acquire_if(ENDPOINT(), |_| true)
444                .await
445                .is_some()
446        );
447    }
448
449    #[tokio::test]
450    async fn bucket() {
451        let rate_limiter = RateLimiter::default();
452
453        let limit = 2;
454        let remaining = 1;
455        let reset_at = Instant::now() + Duration::from_secs(1);
456        let headers = RateLimitHeaders {
457            bucket: vec![1, 2, 3],
458            limit,
459            remaining,
460            reset_at,
461        };
462
463        rate_limiter
464            .acquire(ENDPOINT())
465            .await
466            .complete(Some(headers));
467        task::yield_now().await;
468
469        let bucket = rate_limiter.bucket(ENDPOINT()).await.unwrap();
470        assert_eq!(bucket.limit, limit);
471        assert_eq!(bucket.remaining, remaining);
472        assert!(
473            bucket.reset_at.saturating_duration_since(reset_at) < Duration::from_millis(1)
474                && reset_at.saturating_duration_since(bucket.reset_at) < Duration::from_millis(1)
475        );
476    }
477
478    fn with_hasher(f: impl FnOnce(&mut DefaultHasher)) -> u64 {
479        let mut hasher = DefaultHasher::new();
480        f(&mut hasher);
481        hasher.finish()
482    }
483
484    #[test]
485    fn endpoint() {
486        let invalid = Endpoint {
487            method: Method::Get,
488            path: String::from("/guilds/745809834183753828/audit-logs?limit=10"),
489        };
490        let delete_webhook = Endpoint {
491            method: Method::Delete,
492            path: String::from("webhooks/1"),
493        };
494        let interaction_response = Endpoint {
495            method: Method::Post,
496            path: String::from("interactions/1/abc/callback"),
497        };
498
499        assert!(!invalid.is_valid());
500        assert!(delete_webhook.is_valid());
501        assert!(interaction_response.is_valid());
502
503        assert!(delete_webhook.is_interaction());
504        assert!(interaction_response.is_interaction());
505
506        assert_eq!(
507            with_hasher(|state| invalid.hash_resources(state)),
508            with_hasher(|_| {})
509        );
510        assert_eq!(
511            with_hasher(|state| delete_webhook.hash_resources(state)),
512            with_hasher(|state| {
513                "webhooks".hash(state);
514                b"1".hash(state);
515            })
516        );
517        assert_eq!(
518            with_hasher(|state| interaction_response.hash_resources(state)),
519            with_hasher(|_| {})
520        );
521    }
522
523    #[test]
524    fn method_conversions() {
525        assert_eq!("DELETE", Method::Delete.name());
526        assert_eq!("GET", Method::Get.name());
527        assert_eq!("PATCH", Method::Patch.name());
528        assert_eq!("POST", Method::Post.name());
529        assert_eq!("PUT", Method::Put.name());
530    }
531}