twilight_http_ratelimiting/
lib.rs1#![doc = include_str!("../README.md")]
2#![warn(
3 clippy::missing_const_for_fn,
4 clippy::missing_docs_in_private_items,
5 clippy::pedantic,
6 missing_docs,
7 unsafe_code
8)]
9#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)]
10
11mod actor;
12
13use std::{
14 future::Future,
15 hash::{Hash as _, Hasher},
16 pin::Pin,
17 task::{Context, Poll},
18 time::{Duration, Instant},
19};
20use tokio::sync::{mpsc, oneshot};
21
22pub const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1);
25
26const ACTOR_PANIC_MESSAGE: &str =
28 "actor task panicked: report its panic message to the maintainers";
29
30#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
34#[non_exhaustive]
35pub enum Method {
36 Delete,
38 Get,
40 Patch,
42 Post,
44 Put,
46}
47
48impl Method {
49 pub const fn name(self) -> &'static str {
51 match self {
52 Method::Delete => "DELETE",
53 Method::Get => "GET",
54 Method::Patch => "PATCH",
55 Method::Post => "POST",
56 Method::Put => "PUT",
57 }
58 }
59}
60
61#[derive(Clone, Debug, Eq, Hash, PartialEq)]
88pub struct Endpoint {
89 pub method: Method,
91 pub path: String,
95}
96
97impl Endpoint {
98 pub(crate) fn is_valid(&self) -> bool {
100 !self.path.as_bytes().starts_with(b"/") && !self.path.as_bytes().contains(&b'?')
101 }
102
103 pub(crate) fn is_interaction(&self) -> bool {
105 self.path.as_bytes().starts_with(b"webhooks")
106 || self.path.as_bytes().starts_with(b"interactions")
107 }
108
109 pub(crate) fn hash_resources(&self, state: &mut impl Hasher) {
119 let mut segments = self.path.as_bytes().split(|&s| s == b'/');
120 match segments.next().unwrap_or_default() {
121 b"channels" => {
122 if let Some(s) = segments.next() {
123 "channels".hash(state);
124 s.hash(state);
125 }
126 }
127 b"guilds" => {
128 if let Some(s) = segments.next() {
129 "guilds".hash(state);
130 s.hash(state);
131 }
132 }
133 b"webhooks" => {
134 if let Some(s) = segments.next() {
135 "webhooks".hash(state);
136 s.hash(state);
137 }
138 if let Some(s) = segments.next() {
139 s.hash(state);
140 }
141 }
142 _ => {}
143 }
144 }
145}
146
147#[derive(Clone, Debug, Eq, Hash, PartialEq)]
161pub struct RateLimitHeaders {
162 pub bucket: Vec<u8>,
164 pub limit: u16,
166 pub remaining: u16,
168 pub reset_at: Instant,
170}
171
172impl RateLimitHeaders {
173 pub const BUCKET: &'static str = "x-ratelimit-bucket";
175
176 pub const LIMIT: &'static str = "x-ratelimit-limit";
178
179 pub const REMAINING: &'static str = "x-ratelimit-remaining";
181
182 pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
184
185 pub const SCOPE: &'static str = "x-ratelimit-scope";
187
188 pub fn shared(bucket: Vec<u8>, retry_after: u16) -> Self {
191 Self {
192 bucket,
193 limit: 0,
194 remaining: 0,
195 reset_at: Instant::now() + Duration::from_secs(retry_after.into()),
196 }
197 }
198}
199
200#[derive(Debug)]
202#[must_use = "dropping the permit immediately cancels itself"]
203pub struct Permit(oneshot::Sender<Option<RateLimitHeaders>>);
204
205impl Permit {
206 #[allow(clippy::missing_panics_doc)]
211 pub fn complete(self, headers: Option<RateLimitHeaders>) {
212 assert!(self.0.send(headers).is_ok(), "{ACTOR_PANIC_MESSAGE}");
213 }
214}
215
216#[derive(Debug)]
218#[must_use = "futures do nothing unless you `.await` or poll them"]
219pub struct PermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
220
221impl Future for PermitFuture {
222 type Output = Permit;
223
224 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
225 #[allow(clippy::match_wild_err_arm)]
226 Pin::new(&mut self.0).poll(cx).map(|r| match r {
227 Ok(sender) => Permit(sender),
228 Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
229 })
230 }
231}
232
233#[derive(Debug)]
235#[must_use = "futures do nothing unless you `.await` or poll them"]
236pub struct MaybePermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
237
238impl Future for MaybePermitFuture {
239 type Output = Option<Permit>;
240
241 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
242 Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit))
243 }
244}
245
246#[non_exhaustive]
249#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
250pub struct Bucket {
251 pub limit: u16,
253 pub remaining: u16,
255 pub reset_at: Instant,
257}
258
259type Predicate = Box<dyn FnOnce(Option<Bucket>) -> bool + Send>;
261
262#[derive(Clone, Debug)]
270pub struct RateLimiter {
271 tx: mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
273}
274
275impl RateLimiter {
276 pub fn new(global_limit: u16) -> Self {
278 let (tx, rx) = mpsc::unbounded_channel();
279 tokio::spawn(actor::runner(global_limit, rx));
280
281 Self { tx }
282 }
283
284 #[allow(clippy::missing_panics_doc)]
288 pub fn acquire(&self, endpoint: Endpoint) -> PermitFuture {
289 let (notifier, rx) = oneshot::channel();
290 let message = actor::Message { endpoint, notifier };
291 assert!(
292 self.tx.send((message, None)).is_ok(),
293 "{ACTOR_PANIC_MESSAGE}"
294 );
295
296 PermitFuture(rx)
297 }
298
299 #[allow(clippy::missing_panics_doc)]
331 pub fn acquire_if<P>(&self, endpoint: Endpoint, predicate: P) -> MaybePermitFuture
332 where
333 P: FnOnce(Option<Bucket>) -> bool + Send + 'static,
334 {
335 fn acquire_if(
336 tx: &mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
337 endpoint: Endpoint,
338 predicate: Predicate,
339 ) -> MaybePermitFuture {
340 let (notifier, rx) = oneshot::channel();
341 let message = actor::Message { endpoint, notifier };
342 assert!(
343 tx.send((message, Some(predicate))).is_ok(),
344 "{ACTOR_PANIC_MESSAGE}"
345 );
346
347 MaybePermitFuture(rx)
348 }
349
350 acquire_if(&self.tx, endpoint, Box::new(predicate))
351 }
352
353 #[allow(clippy::missing_panics_doc)]
357 pub async fn bucket(&self, endpoint: Endpoint) -> Option<Bucket> {
358 let (tx, rx) = oneshot::channel();
359 self.acquire_if(endpoint, |bucket| {
360 _ = tx.send(bucket);
362 false
363 })
364 .await;
365
366 #[allow(clippy::match_wild_err_arm)]
367 match rx.await {
368 Ok(bucket) => bucket,
369 Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
370 }
371 }
372}
373
374impl Default for RateLimiter {
375 fn default() -> Self {
379 Self::new(50)
380 }
381}
382
383#[cfg(test)]
384mod tests {
385 use super::{
386 Bucket, Endpoint, MaybePermitFuture, Method, Permit, PermitFuture, RateLimitHeaders,
387 RateLimiter,
388 };
389 use static_assertions::assert_impl_all;
390 use std::{
391 fmt::Debug,
392 future::Future,
393 hash::{DefaultHasher, Hash, Hasher as _},
394 time::{Duration, Instant},
395 };
396 use tokio::task;
397
398 assert_impl_all!(Bucket: Clone, Copy, Debug, Eq, Hash, PartialEq, Send, Sync);
399 assert_impl_all!(Endpoint: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
400 assert_impl_all!(MaybePermitFuture: Debug, Future<Output = Option<Permit>>);
401 assert_impl_all!(Method: Clone, Copy, Debug, Eq, PartialEq);
402 assert_impl_all!(Permit: Debug, Send, Sync);
403 assert_impl_all!(PermitFuture: Debug, Future<Output = Permit>);
404 assert_impl_all!(RateLimitHeaders: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
405 assert_impl_all!(RateLimiter: Clone, Debug, Default, Send, Sync);
406
407 const ENDPOINT: fn() -> Endpoint = || Endpoint {
408 method: Method::Get,
409 path: String::from("applications/@me"),
410 };
411
412 #[tokio::test]
413 async fn acquire_if() {
414 let rate_limiter = RateLimiter::default();
415
416 assert!(
417 rate_limiter
418 .acquire_if(ENDPOINT(), |_| false)
419 .await
420 .is_none()
421 );
422 assert!(
423 rate_limiter
424 .acquire_if(ENDPOINT(), |_| true)
425 .await
426 .is_some()
427 );
428 }
429
430 #[tokio::test]
431 async fn bucket() {
432 let rate_limiter = RateLimiter::default();
433
434 let limit = 2;
435 let remaining = 1;
436 let reset_at = Instant::now() + Duration::from_secs(1);
437 let headers = RateLimitHeaders {
438 bucket: vec![1, 2, 3],
439 limit,
440 remaining,
441 reset_at,
442 };
443
444 rate_limiter
445 .acquire(ENDPOINT())
446 .await
447 .complete(Some(headers));
448 task::yield_now().await;
449
450 let bucket = rate_limiter.bucket(ENDPOINT()).await.unwrap();
451 assert_eq!(bucket.limit, limit);
452 assert_eq!(bucket.remaining, remaining);
453 assert!(
454 bucket.reset_at.saturating_duration_since(reset_at) < Duration::from_millis(1)
455 && reset_at.saturating_duration_since(bucket.reset_at) < Duration::from_millis(1)
456 );
457 }
458
459 fn with_hasher(f: impl FnOnce(&mut DefaultHasher)) -> u64 {
460 let mut hasher = DefaultHasher::new();
461 f(&mut hasher);
462 hasher.finish()
463 }
464
465 #[test]
466 fn endpoint() {
467 let invalid = Endpoint {
468 method: Method::Get,
469 path: String::from("/guilds/745809834183753828/audit-logs?limit=10"),
470 };
471 let delete_webhook = Endpoint {
472 method: Method::Delete,
473 path: String::from("webhooks/1"),
474 };
475 let interaction_response = Endpoint {
476 method: Method::Post,
477 path: String::from("interactions/1/abc/callback"),
478 };
479
480 assert!(!invalid.is_valid());
481 assert!(delete_webhook.is_valid());
482 assert!(interaction_response.is_valid());
483
484 assert!(delete_webhook.is_interaction());
485 assert!(interaction_response.is_interaction());
486
487 assert_eq!(
488 with_hasher(|state| invalid.hash_resources(state)),
489 with_hasher(|_| {})
490 );
491 assert_eq!(
492 with_hasher(|state| delete_webhook.hash_resources(state)),
493 with_hasher(|state| {
494 "webhooks".hash(state);
495 b"1".hash(state);
496 })
497 );
498 assert_eq!(
499 with_hasher(|state| interaction_response.hash_resources(state)),
500 with_hasher(|_| {})
501 );
502 }
503
504 #[test]
505 fn method_conversions() {
506 assert_eq!("DELETE", Method::Delete.name());
507 assert_eq!("GET", Method::Get.name());
508 assert_eq!("PATCH", Method::Patch.name());
509 assert_eq!("POST", Method::Post.name());
510 assert_eq!("PUT", Method::Put.name());
511 }
512}