twilight_http_ratelimiting/in_memory/
mod.rs

1//! In-memory based default [`Ratelimiter`] implementation used in `twilight-http`.
2
3mod bucket;
4
5use self::bucket::{Bucket, BucketQueueTask};
6use super::{
7    ticket::{self, TicketNotifier},
8    Bucket as InfoBucket, Ratelimiter,
9};
10use crate::{
11    request::Path, GetBucketFuture, GetTicketFuture, HasBucketFuture, IsGloballyLockedFuture,
12};
13use std::{
14    collections::hash_map::{Entry, HashMap},
15    future,
16    sync::{
17        atomic::{AtomicBool, Ordering},
18        Arc, Mutex,
19    },
20    time::Duration,
21};
22use tokio::sync::Mutex as AsyncMutex;
23
24/// Global lock. We use a pair to avoid actually locking the mutex every check.
25/// This allows futures to only wait on the global lock when a global ratelimit
26/// is in place by, in turn, waiting for a guard, and then each immediately
27/// dropping it.
28#[derive(Debug, Default)]
29struct GlobalLockPair(AsyncMutex<()>, AtomicBool);
30
31impl GlobalLockPair {
32    /// Set the global ratelimit as exhausted.
33    pub fn lock(&self) {
34        self.1.store(true, Ordering::Release);
35    }
36
37    /// Set the global ratelimit as no longer exhausted.
38    pub fn unlock(&self) {
39        self.1.store(false, Ordering::Release);
40    }
41
42    /// Whether the global ratelimit is exhausted.
43    pub fn is_locked(&self) -> bool {
44        self.1.load(Ordering::Relaxed)
45    }
46}
47
48/// Default ratelimiter implementation used in twilight that
49/// stores ratelimit information in an in-memory mapping.
50///
51/// This will meet most users' needs for simple ratelimiting,
52/// but for multi-processed bots, consider either implementing
53/// your own [`Ratelimiter`] that uses a shared storage backend
54/// or use the [HTTP proxy].
55///
56/// [HTTP proxy]: https://twilight.rs/chapter_2_multi-serviced_approach.html#http-proxy-ratelimiting
57#[derive(Clone, Debug, Default)]
58pub struct InMemoryRatelimiter {
59    /// Mapping of [`Path`]s to their associated [`Bucket`]s.
60    buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
61    /// Global ratelimit data.
62    global: Arc<GlobalLockPair>,
63}
64
65impl InMemoryRatelimiter {
66    /// Create a new in-memory ratelimiter.
67    ///
68    /// This is used by HTTP client to queue requests in order to avoid
69    /// hitting the API's ratelimits.
70    #[must_use]
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    /// Enqueue the [`TicketNotifier`] to the [`Path`]'s [`Bucket`].
76    ///
77    /// Returns the new [`Bucket`] if none existed.
78    fn entry(&self, path: Path, tx: TicketNotifier) -> Option<Arc<Bucket>> {
79        let mut buckets = self.buckets.lock().expect("buckets poisoned");
80
81        match buckets.entry(path.clone()) {
82            Entry::Occupied(bucket) => {
83                tracing::debug!("got existing bucket: {path:?}");
84
85                bucket.get().queue.push(tx);
86
87                tracing::debug!("added request into bucket queue: {path:?}");
88
89                None
90            }
91            Entry::Vacant(entry) => {
92                tracing::debug!("making new bucket for path: {path:?}");
93
94                let bucket = Bucket::new(path);
95                bucket.queue.push(tx);
96
97                let bucket = Arc::new(bucket);
98                entry.insert(Arc::clone(&bucket));
99
100                Some(bucket)
101            }
102        }
103    }
104}
105
106impl Ratelimiter for InMemoryRatelimiter {
107    fn bucket(&self, path: &Path) -> GetBucketFuture {
108        self.buckets
109            .lock()
110            .expect("buckets poisoned")
111            .get(path)
112            .map_or_else(
113                || Box::pin(future::ready(Ok(None))),
114                |bucket| {
115                    let started_at = bucket.started_at.lock().expect("bucket poisoned");
116                    let reset_after = Duration::from_millis(bucket.reset_after());
117
118                    Box::pin(future::ready(Ok(Some(InfoBucket::new(
119                        bucket.limit(),
120                        bucket.remaining(),
121                        reset_after,
122                        *started_at,
123                    )))))
124                },
125            )
126    }
127
128    fn is_globally_locked(&self) -> IsGloballyLockedFuture {
129        Box::pin(future::ready(Ok(self.global.is_locked())))
130    }
131
132    fn has(&self, path: &Path) -> HasBucketFuture {
133        let has = self
134            .buckets
135            .lock()
136            .expect("buckets poisoned")
137            .contains_key(path);
138
139        Box::pin(future::ready(Ok(has)))
140    }
141
142    fn ticket(&self, path: Path) -> GetTicketFuture {
143        tracing::debug!("getting bucket for path: {path:?}");
144
145        let (tx, rx) = ticket::channel();
146
147        if let Some(bucket) = self.entry(path.clone(), tx) {
148            tokio::spawn(
149                BucketQueueTask::new(
150                    bucket,
151                    Arc::clone(&self.buckets),
152                    Arc::clone(&self.global),
153                    path,
154                )
155                .run(),
156            );
157        }
158
159        Box::pin(future::ready(Ok(rx)))
160    }
161}