twilight_http_ratelimiting/in_memory/
bucket.rs

1//! [`Bucket`] management used by the [`super::InMemoryRatelimiter`] internally.
2//! Each bucket has an associated [`BucketQueue`] to queue an API request, which is
3//! consumed by the [`BucketQueueTask`] that manages the ratelimit for the bucket
4//! and respects the global ratelimit.
5
6use super::GlobalLockPair;
7use crate::{headers::RatelimitHeaders, request::Path, ticket::TicketNotifier};
8use std::{
9    collections::HashMap,
10    sync::{
11        atomic::{AtomicU64, Ordering},
12        Arc, Mutex,
13    },
14    time::{Duration, Instant},
15};
16use tokio::{
17    sync::{
18        mpsc::{self, UnboundedReceiver, UnboundedSender},
19        Mutex as AsyncMutex,
20    },
21    time::{sleep, timeout},
22};
23
24/// Time remaining until a bucket will reset.
25#[derive(Clone, Debug)]
26pub enum TimeRemaining {
27    /// Bucket has already reset.
28    Finished,
29    /// Bucket's ratelimit refresh countdown has not started yet.
30    NotStarted,
31    /// Amount of time until the bucket resets.
32    Some(Duration),
33}
34
35/// Ratelimit information for a bucket used in the [`super::InMemoryRatelimiter`].
36///
37/// A generic version not specific to this ratelimiter is [`crate::Bucket`].
38#[derive(Debug)]
39pub struct Bucket {
40    /// Total number of tickets allotted in a cycle.
41    pub limit: AtomicU64,
42    /// Path this ratelimit applies to.
43    // This is dead code, but it is useful for debugging.
44    #[allow(dead_code)]
45    pub path: Path,
46    /// Queue associated with this bucket.
47    pub queue: BucketQueue,
48    /// Number of tickets remaining.
49    pub remaining: AtomicU64,
50    /// Duration after the [`Self::started_at`] time the bucket will refresh.
51    pub reset_after: AtomicU64,
52    /// When the bucket's ratelimit refresh countdown started.
53    pub started_at: Mutex<Option<Instant>>,
54}
55
56impl Bucket {
57    /// Create a new bucket for the specified [`Path`].
58    pub fn new(path: Path) -> Self {
59        Self {
60            limit: AtomicU64::new(u64::MAX),
61            path,
62            queue: BucketQueue::default(),
63            remaining: AtomicU64::new(u64::MAX),
64            reset_after: AtomicU64::new(u64::MAX),
65            started_at: Mutex::new(None),
66        }
67    }
68
69    /// Total number of tickets allotted in a cycle.
70    pub fn limit(&self) -> u64 {
71        self.limit.load(Ordering::Relaxed)
72    }
73
74    /// Number of tickets remaining.
75    pub fn remaining(&self) -> u64 {
76        self.remaining.load(Ordering::Relaxed)
77    }
78
79    /// Duration after the [`started_at`] time the bucket will refresh.
80    ///
81    /// [`started_at`]: Self::started_at
82    pub fn reset_after(&self) -> u64 {
83        self.reset_after.load(Ordering::Relaxed)
84    }
85
86    /// Time remaining until this bucket will reset.
87    pub fn time_remaining(&self) -> TimeRemaining {
88        let reset_after = self.reset_after();
89        let maybe_started_at = *self.started_at.lock().expect("bucket poisoned");
90
91        let Some(started_at) = maybe_started_at else {
92            return TimeRemaining::NotStarted;
93        };
94
95        let elapsed = started_at.elapsed();
96
97        if elapsed > Duration::from_millis(reset_after) {
98            return TimeRemaining::Finished;
99        }
100
101        TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed)
102    }
103
104    /// Try to reset this bucket's [`started_at`] value if it has finished.
105    ///
106    /// Returns whether resetting was possible.
107    ///
108    /// [`started_at`]: Self::started_at
109    pub fn try_reset(&self) -> bool {
110        if self.started_at.lock().expect("bucket poisoned").is_none() {
111            return false;
112        }
113
114        if let TimeRemaining::Finished = self.time_remaining() {
115            self.remaining.store(self.limit(), Ordering::Relaxed);
116            *self.started_at.lock().expect("bucket poisoned") = None;
117
118            true
119        } else {
120            false
121        }
122    }
123
124    /// Update this bucket's ratelimit data after a request has been made.
125    pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) {
126        let bucket_limit = self.limit();
127
128        {
129            let mut started_at = self.started_at.lock().expect("bucket poisoned");
130
131            if started_at.is_none() {
132                started_at.replace(Instant::now());
133            }
134        }
135
136        if let Some((limit, remaining, reset_after)) = ratelimits {
137            if bucket_limit != limit && bucket_limit == u64::MAX {
138                self.reset_after.store(reset_after, Ordering::SeqCst);
139                self.limit.store(limit, Ordering::SeqCst);
140            }
141
142            self.remaining.store(remaining, Ordering::Relaxed);
143        } else {
144            self.remaining.fetch_sub(1, Ordering::Relaxed);
145        }
146    }
147}
148
149/// Queue of ratelimit requests for a bucket.
150#[derive(Debug)]
151pub struct BucketQueue {
152    /// Receiver for the ratelimit requests.
153    rx: AsyncMutex<UnboundedReceiver<TicketNotifier>>,
154    /// Sender for the ratelimit requests.
155    tx: UnboundedSender<TicketNotifier>,
156}
157
158impl BucketQueue {
159    /// Add a new ratelimit request to the queue.
160    pub fn push(&self, tx: TicketNotifier) {
161        let _sent = self.tx.send(tx);
162    }
163
164    /// Receive the first incoming ratelimit request.
165    pub async fn pop(&self, timeout_duration: Duration) -> Option<TicketNotifier> {
166        let mut rx = self.rx.lock().await;
167
168        timeout(timeout_duration, rx.recv()).await.ok().flatten()
169    }
170}
171
172impl Default for BucketQueue {
173    fn default() -> Self {
174        let (tx, rx) = mpsc::unbounded_channel();
175
176        Self {
177            rx: AsyncMutex::new(rx),
178            tx,
179        }
180    }
181}
182
183/// A background task that handles ratelimit requests to a [`Bucket`]
184/// and processes them in order, keeping track of both the global and
185/// the [`Path`]-specific ratelimits.
186pub(super) struct BucketQueueTask {
187    /// The [`Bucket`] managed by this task.
188    bucket: Arc<Bucket>,
189    /// All buckets managed by the associated [`super::InMemoryRatelimiter`].
190    buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
191    /// Global ratelimit data.
192    global: Arc<GlobalLockPair>,
193    /// The [`Path`] this [`Bucket`] belongs to.
194    path: Path,
195}
196
197impl BucketQueueTask {
198    /// Timeout to wait for response headers after initiating a request.
199    const WAIT: Duration = Duration::from_secs(10);
200
201    /// Create a new task to manage the ratelimit for a [`Bucket`].
202    pub const fn new(
203        bucket: Arc<Bucket>,
204        buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
205        global: Arc<GlobalLockPair>,
206        path: Path,
207    ) -> Self {
208        Self {
209            bucket,
210            buckets,
211            global,
212            path,
213        }
214    }
215
216    /// Process incoming ratelimit requests to this bucket and update the state
217    /// based on received [`RatelimitHeaders`].
218    #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))]
219    pub async fn run(self) {
220        while let Some(queue_tx) = self.next().await {
221            if self.global.is_locked() {
222                drop(self.global.0.lock().await);
223            }
224
225            let Some(ticket_headers) = queue_tx.available() else {
226                continue;
227            };
228
229            tracing::debug!("starting to wait for response headers");
230
231            match timeout(Self::WAIT, ticket_headers).await {
232                Ok(Ok(Some(headers))) => self.handle_headers(&headers).await,
233                Ok(Ok(None)) => {
234                    tracing::debug!("request aborted");
235                }
236                Ok(Err(_)) => {
237                    tracing::debug!("ticket channel closed");
238                }
239                Err(_) => {
240                    tracing::debug!("receiver timed out");
241                }
242            }
243        }
244
245        tracing::debug!("bucket appears finished, removing");
246
247        self.buckets
248            .lock()
249            .expect("ratelimit buckets poisoned")
250            .remove(&self.path);
251    }
252
253    /// Update the bucket's ratelimit state.
254    async fn handle_headers(&self, headers: &RatelimitHeaders) {
255        let ratelimits = match headers {
256            RatelimitHeaders::Global(global) => {
257                self.lock_global(Duration::from_secs(global.retry_after()))
258                    .await;
259
260                None
261            }
262            RatelimitHeaders::None => return,
263            RatelimitHeaders::Present(present) => {
264                Some((present.limit(), present.remaining(), present.reset_after()))
265            }
266        };
267
268        tracing::debug!(path=?self.path, "updating bucket");
269        self.bucket.update(ratelimits);
270    }
271
272    /// Lock the global ratelimit for a specified duration.
273    async fn lock_global(&self, wait: Duration) {
274        tracing::debug!(path=?self.path, "request got global ratelimited");
275        self.global.lock();
276        let lock = self.global.0.lock().await;
277        sleep(wait).await;
278        self.global.unlock();
279
280        drop(lock);
281    }
282
283    /// Get the next [`TicketNotifier`] in the queue.
284    async fn next(&self) -> Option<TicketNotifier> {
285        tracing::debug!(path=?self.path, "starting to get next in queue");
286
287        self.wait_if_needed().await;
288
289        self.bucket.queue.pop(Self::WAIT).await
290    }
291
292    /// Wait for this bucket to refresh if it isn't ready yet.
293    #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))]
294    async fn wait_if_needed(&self) {
295        let wait = {
296            if self.bucket.remaining() > 0 {
297                return;
298            }
299
300            tracing::debug!("0 tickets remaining, may have to wait");
301
302            match self.bucket.time_remaining() {
303                TimeRemaining::Finished => {
304                    self.bucket.try_reset();
305
306                    return;
307                }
308                TimeRemaining::NotStarted => return,
309                TimeRemaining::Some(dur) => dur,
310            }
311        };
312
313        tracing::debug!(
314            milliseconds=%wait.as_millis(),
315            "waiting for ratelimit to pass",
316        );
317
318        sleep(wait).await;
319
320        tracing::debug!("done waiting for ratelimit to pass");
321
322        self.bucket.try_reset();
323    }
324}