twilight_http_ratelimiting/
lib.rs1#![doc = include_str!("../README.md")]
2#![warn(
3 clippy::missing_const_for_fn,
4 clippy::missing_docs_in_private_items,
5 clippy::pedantic,
6 missing_docs,
7 unsafe_code
8)]
9#![allow(clippy::module_name_repetitions, clippy::must_use_candidate)]
10
11mod actor;
12
13#[cfg(not(target_os = "wasi"))]
14use std::{
15 future::Future,
16 hash::{Hash as _, Hasher},
17 pin::Pin,
18 task::{Context, Poll},
19 time::{Duration, Instant},
20};
21#[cfg(not(target_os = "wasi"))]
22use tokio::sync::{mpsc, oneshot};
23
24#[cfg(not(target_os = "wasi"))]
27pub const GLOBAL_LIMIT_PERIOD: Duration = Duration::from_secs(1);
28
29#[cfg(not(target_os = "wasi"))]
31const ACTOR_PANIC_MESSAGE: &str =
32 "actor task panicked: report its panic message to the maintainers";
33
34#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
38#[non_exhaustive]
39pub enum Method {
40 Delete,
42 Get,
44 Patch,
46 Post,
48 Put,
50}
51
52impl Method {
53 pub const fn name(self) -> &'static str {
55 match self {
56 Method::Delete => "DELETE",
57 Method::Get => "GET",
58 Method::Patch => "PATCH",
59 Method::Post => "POST",
60 Method::Put => "PUT",
61 }
62 }
63}
64
65#[derive(Clone, Debug, Eq, Hash, PartialEq)]
92#[cfg(not(target_os = "wasi"))]
93pub struct Endpoint {
94 pub method: Method,
96 pub path: String,
100}
101
102#[cfg(not(target_os = "wasi"))]
103impl Endpoint {
104 pub(crate) fn is_valid(&self) -> bool {
106 !self.path.as_bytes().starts_with(b"/") && !self.path.as_bytes().contains(&b'?')
107 }
108
109 pub(crate) fn is_interaction(&self) -> bool {
111 self.path.as_bytes().starts_with(b"webhooks")
112 || self.path.as_bytes().starts_with(b"interactions")
113 }
114
115 pub(crate) fn hash_resources(&self, state: &mut impl Hasher) {
125 let mut segments = self.path.as_bytes().split(|&s| s == b'/');
126 match segments.next().unwrap_or_default() {
127 b"channels" => {
128 if let Some(s) = segments.next() {
129 "channels".hash(state);
130 s.hash(state);
131 }
132 }
133 b"guilds" => {
134 if let Some(s) = segments.next() {
135 "guilds".hash(state);
136 s.hash(state);
137 }
138 }
139 b"webhooks" => {
140 if let Some(s) = segments.next() {
141 "webhooks".hash(state);
142 s.hash(state);
143 }
144 if let Some(s) = segments.next() {
145 s.hash(state);
146 }
147 }
148 _ => {}
149 }
150 }
151}
152
153#[derive(Clone, Debug, Eq, Hash, PartialEq)]
167#[cfg(not(target_os = "wasi"))]
168pub struct RateLimitHeaders {
169 pub bucket: Vec<u8>,
171 pub limit: u16,
173 pub remaining: u16,
175 pub reset_at: Instant,
177}
178
179#[cfg(not(target_os = "wasi"))]
180impl RateLimitHeaders {
181 pub const BUCKET: &'static str = "x-ratelimit-bucket";
183
184 pub const LIMIT: &'static str = "x-ratelimit-limit";
186
187 pub const REMAINING: &'static str = "x-ratelimit-remaining";
189
190 pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
192
193 pub const SCOPE: &'static str = "x-ratelimit-scope";
195
196 pub fn shared(bucket: Vec<u8>, retry_after: u16) -> Self {
199 Self {
200 bucket,
201 limit: 0,
202 remaining: 0,
203 reset_at: Instant::now() + Duration::from_secs(retry_after.into()),
204 }
205 }
206}
207
208#[derive(Debug)]
210#[must_use = "dropping the permit immediately cancels itself"]
211#[cfg(not(target_os = "wasi"))]
212pub struct Permit(oneshot::Sender<Option<RateLimitHeaders>>);
213
214#[cfg(not(target_os = "wasi"))]
215impl Permit {
216 #[allow(clippy::missing_panics_doc)]
221 pub fn complete(self, headers: Option<RateLimitHeaders>) {
222 assert!(self.0.send(headers).is_ok(), "{ACTOR_PANIC_MESSAGE}");
223 }
224}
225
226#[derive(Debug)]
228#[must_use = "futures do nothing unless you `.await` or poll them"]
229#[cfg(not(target_os = "wasi"))]
230pub struct PermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
231
232#[cfg(not(target_os = "wasi"))]
233impl Future for PermitFuture {
234 type Output = Permit;
235
236 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
237 #[allow(clippy::match_wild_err_arm)]
238 Pin::new(&mut self.0).poll(cx).map(|r| match r {
239 Ok(sender) => Permit(sender),
240 Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
241 })
242 }
243}
244
245#[derive(Debug)]
247#[must_use = "futures do nothing unless you `.await` or poll them"]
248#[cfg(not(target_os = "wasi"))]
249pub struct MaybePermitFuture(oneshot::Receiver<oneshot::Sender<Option<RateLimitHeaders>>>);
250
251#[cfg(not(target_os = "wasi"))]
252impl Future for MaybePermitFuture {
253 type Output = Option<Permit>;
254
255 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
256 Pin::new(&mut self.0).poll(cx).map(|r| r.ok().map(Permit))
257 }
258}
259
260#[non_exhaustive]
263#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
264#[cfg(not(target_os = "wasi"))]
265pub struct Bucket {
266 pub limit: u16,
268 pub remaining: u16,
270 pub reset_at: Instant,
272}
273
274#[cfg(not(target_os = "wasi"))]
276type Predicate = Box<dyn FnOnce(Option<Bucket>) -> bool + Send>;
277
278#[derive(Clone, Debug)]
286#[cfg(not(target_os = "wasi"))]
287pub struct RateLimiter {
288 tx: mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
290}
291
292#[cfg(not(target_os = "wasi"))]
293impl RateLimiter {
294 pub fn new(global_limit: u16) -> Self {
296 let (tx, rx) = mpsc::unbounded_channel();
297 tokio::spawn(actor::runner(global_limit, rx));
298
299 Self { tx }
300 }
301
302 #[allow(clippy::missing_panics_doc)]
306 pub fn acquire(&self, endpoint: Endpoint) -> PermitFuture {
307 let (notifier, rx) = oneshot::channel();
308 let message = actor::Message { endpoint, notifier };
309 assert!(
310 self.tx.send((message, None)).is_ok(),
311 "{ACTOR_PANIC_MESSAGE}"
312 );
313
314 PermitFuture(rx)
315 }
316
317 #[allow(clippy::missing_panics_doc)]
349 pub fn acquire_if<P>(&self, endpoint: Endpoint, predicate: P) -> MaybePermitFuture
350 where
351 P: FnOnce(Option<Bucket>) -> bool + Send + 'static,
352 {
353 fn acquire_if(
354 tx: &mpsc::UnboundedSender<(actor::Message, Option<Predicate>)>,
355 endpoint: Endpoint,
356 predicate: Predicate,
357 ) -> MaybePermitFuture {
358 let (notifier, rx) = oneshot::channel();
359 let message = actor::Message { endpoint, notifier };
360 assert!(
361 tx.send((message, Some(predicate))).is_ok(),
362 "{ACTOR_PANIC_MESSAGE}"
363 );
364
365 MaybePermitFuture(rx)
366 }
367
368 acquire_if(&self.tx, endpoint, Box::new(predicate))
369 }
370
371 #[allow(clippy::missing_panics_doc)]
375 pub async fn bucket(&self, endpoint: Endpoint) -> Option<Bucket> {
376 let (tx, rx) = oneshot::channel();
377 self.acquire_if(endpoint, |bucket| {
378 _ = tx.send(bucket);
380 false
381 })
382 .await;
383
384 #[allow(clippy::match_wild_err_arm)]
385 match rx.await {
386 Ok(bucket) => bucket,
387 Err(_) => panic!("{ACTOR_PANIC_MESSAGE}"),
388 }
389 }
390}
391
392#[cfg(not(target_os = "wasi"))]
393impl Default for RateLimiter {
394 fn default() -> Self {
398 Self::new(50)
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::{
405 Bucket, Endpoint, MaybePermitFuture, Method, Permit, PermitFuture, RateLimitHeaders,
406 RateLimiter,
407 };
408 use static_assertions::assert_impl_all;
409 use std::{
410 fmt::Debug,
411 future::Future,
412 hash::{DefaultHasher, Hash, Hasher as _},
413 time::{Duration, Instant},
414 };
415 use tokio::task;
416
417 assert_impl_all!(Bucket: Clone, Copy, Debug, Eq, Hash, PartialEq, Send, Sync);
418 assert_impl_all!(Endpoint: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
419 assert_impl_all!(MaybePermitFuture: Debug, Future<Output = Option<Permit>>);
420 assert_impl_all!(Method: Clone, Copy, Debug, Eq, PartialEq);
421 assert_impl_all!(Permit: Debug, Send, Sync);
422 assert_impl_all!(PermitFuture: Debug, Future<Output = Permit>);
423 assert_impl_all!(RateLimitHeaders: Clone, Debug, Eq, Hash, PartialEq, Send, Sync);
424 assert_impl_all!(RateLimiter: Clone, Debug, Default, Send, Sync);
425
426 const ENDPOINT: fn() -> Endpoint = || Endpoint {
427 method: Method::Get,
428 path: String::from("applications/@me"),
429 };
430
431 #[tokio::test]
432 async fn acquire_if() {
433 let rate_limiter = RateLimiter::default();
434
435 assert!(
436 rate_limiter
437 .acquire_if(ENDPOINT(), |_| false)
438 .await
439 .is_none()
440 );
441 assert!(
442 rate_limiter
443 .acquire_if(ENDPOINT(), |_| true)
444 .await
445 .is_some()
446 );
447 }
448
449 #[tokio::test]
450 async fn bucket() {
451 let rate_limiter = RateLimiter::default();
452
453 let limit = 2;
454 let remaining = 1;
455 let reset_at = Instant::now() + Duration::from_secs(1);
456 let headers = RateLimitHeaders {
457 bucket: vec![1, 2, 3],
458 limit,
459 remaining,
460 reset_at,
461 };
462
463 rate_limiter
464 .acquire(ENDPOINT())
465 .await
466 .complete(Some(headers));
467 task::yield_now().await;
468
469 let bucket = rate_limiter.bucket(ENDPOINT()).await.unwrap();
470 assert_eq!(bucket.limit, limit);
471 assert_eq!(bucket.remaining, remaining);
472 assert!(
473 bucket.reset_at.saturating_duration_since(reset_at) < Duration::from_millis(1)
474 && reset_at.saturating_duration_since(bucket.reset_at) < Duration::from_millis(1)
475 );
476 }
477
478 fn with_hasher(f: impl FnOnce(&mut DefaultHasher)) -> u64 {
479 let mut hasher = DefaultHasher::new();
480 f(&mut hasher);
481 hasher.finish()
482 }
483
484 #[test]
485 fn endpoint() {
486 let invalid = Endpoint {
487 method: Method::Get,
488 path: String::from("/guilds/745809834183753828/audit-logs?limit=10"),
489 };
490 let delete_webhook = Endpoint {
491 method: Method::Delete,
492 path: String::from("webhooks/1"),
493 };
494 let interaction_response = Endpoint {
495 method: Method::Post,
496 path: String::from("interactions/1/abc/callback"),
497 };
498
499 assert!(!invalid.is_valid());
500 assert!(delete_webhook.is_valid());
501 assert!(interaction_response.is_valid());
502
503 assert!(delete_webhook.is_interaction());
504 assert!(interaction_response.is_interaction());
505
506 assert_eq!(
507 with_hasher(|state| invalid.hash_resources(state)),
508 with_hasher(|_| {})
509 );
510 assert_eq!(
511 with_hasher(|state| delete_webhook.hash_resources(state)),
512 with_hasher(|state| {
513 "webhooks".hash(state);
514 b"1".hash(state);
515 })
516 );
517 assert_eq!(
518 with_hasher(|state| interaction_response.hash_resources(state)),
519 with_hasher(|_| {})
520 );
521 }
522
523 #[test]
524 fn method_conversions() {
525 assert_eq!("DELETE", Method::Delete.name());
526 assert_eq!("GET", Method::Get.name());
527 assert_eq!("PATCH", Method::Patch.name());
528 assert_eq!("POST", Method::Post.name());
529 assert_eq!("PUT", Method::Put.name());
530 }
531}