twilight_http_ratelimiting/in_memory/
bucket.rs1use super::GlobalLockPair;
7use crate::{headers::RatelimitHeaders, request::Path, ticket::TicketNotifier};
8use std::{
9 collections::HashMap,
10 sync::{
11 atomic::{AtomicU64, Ordering},
12 Arc, Mutex,
13 },
14 time::{Duration, Instant},
15};
16use tokio::{
17 sync::{
18 mpsc::{self, UnboundedReceiver, UnboundedSender},
19 Mutex as AsyncMutex,
20 },
21 time::{sleep, timeout},
22};
23
24#[derive(Clone, Debug)]
26pub enum TimeRemaining {
27 Finished,
29 NotStarted,
31 Some(Duration),
33}
34
35#[derive(Debug)]
39pub struct Bucket {
40 pub limit: AtomicU64,
42 #[allow(dead_code)]
45 pub path: Path,
46 pub queue: BucketQueue,
48 pub remaining: AtomicU64,
50 pub reset_after: AtomicU64,
52 pub started_at: Mutex<Option<Instant>>,
54}
55
56impl Bucket {
57 pub fn new(path: Path) -> Self {
59 Self {
60 limit: AtomicU64::new(u64::MAX),
61 path,
62 queue: BucketQueue::default(),
63 remaining: AtomicU64::new(u64::MAX),
64 reset_after: AtomicU64::new(u64::MAX),
65 started_at: Mutex::new(None),
66 }
67 }
68
69 pub fn limit(&self) -> u64 {
71 self.limit.load(Ordering::Relaxed)
72 }
73
74 pub fn remaining(&self) -> u64 {
76 self.remaining.load(Ordering::Relaxed)
77 }
78
79 pub fn reset_after(&self) -> u64 {
83 self.reset_after.load(Ordering::Relaxed)
84 }
85
86 pub fn time_remaining(&self) -> TimeRemaining {
88 let reset_after = self.reset_after();
89 let maybe_started_at = *self.started_at.lock().expect("bucket poisoned");
90
91 let Some(started_at) = maybe_started_at else {
92 return TimeRemaining::NotStarted;
93 };
94
95 let elapsed = started_at.elapsed();
96
97 if elapsed > Duration::from_millis(reset_after) {
98 return TimeRemaining::Finished;
99 }
100
101 TimeRemaining::Some(Duration::from_millis(reset_after) - elapsed)
102 }
103
104 pub fn try_reset(&self) -> bool {
110 if self.started_at.lock().expect("bucket poisoned").is_none() {
111 return false;
112 }
113
114 if let TimeRemaining::Finished = self.time_remaining() {
115 self.remaining.store(self.limit(), Ordering::Relaxed);
116 *self.started_at.lock().expect("bucket poisoned") = None;
117
118 true
119 } else {
120 false
121 }
122 }
123
124 pub fn update(&self, ratelimits: Option<(u64, u64, u64)>) {
126 let bucket_limit = self.limit();
127
128 {
129 let mut started_at = self.started_at.lock().expect("bucket poisoned");
130
131 if started_at.is_none() {
132 started_at.replace(Instant::now());
133 }
134 }
135
136 if let Some((limit, remaining, reset_after)) = ratelimits {
137 if bucket_limit != limit && bucket_limit == u64::MAX {
138 self.reset_after.store(reset_after, Ordering::SeqCst);
139 self.limit.store(limit, Ordering::SeqCst);
140 }
141
142 self.remaining.store(remaining, Ordering::Relaxed);
143 } else {
144 self.remaining.fetch_sub(1, Ordering::Relaxed);
145 }
146 }
147}
148
149#[derive(Debug)]
151pub struct BucketQueue {
152 rx: AsyncMutex<UnboundedReceiver<TicketNotifier>>,
154 tx: UnboundedSender<TicketNotifier>,
156}
157
158impl BucketQueue {
159 pub fn push(&self, tx: TicketNotifier) {
161 let _sent = self.tx.send(tx);
162 }
163
164 pub async fn pop(&self, timeout_duration: Duration) -> Option<TicketNotifier> {
166 let mut rx = self.rx.lock().await;
167
168 timeout(timeout_duration, rx.recv()).await.ok().flatten()
169 }
170}
171
172impl Default for BucketQueue {
173 fn default() -> Self {
174 let (tx, rx) = mpsc::unbounded_channel();
175
176 Self {
177 rx: AsyncMutex::new(rx),
178 tx,
179 }
180 }
181}
182
183pub(super) struct BucketQueueTask {
187 bucket: Arc<Bucket>,
189 buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
191 global: Arc<GlobalLockPair>,
193 path: Path,
195}
196
197impl BucketQueueTask {
198 const WAIT: Duration = Duration::from_secs(10);
200
201 pub const fn new(
203 bucket: Arc<Bucket>,
204 buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
205 global: Arc<GlobalLockPair>,
206 path: Path,
207 ) -> Self {
208 Self {
209 bucket,
210 buckets,
211 global,
212 path,
213 }
214 }
215
216 #[tracing::instrument(name = "background queue task", skip(self), fields(path = ?self.path))]
219 pub async fn run(self) {
220 while let Some(queue_tx) = self.next().await {
221 if self.global.is_locked() {
222 drop(self.global.0.lock().await);
223 }
224
225 let Some(ticket_headers) = queue_tx.available() else {
226 continue;
227 };
228
229 tracing::debug!("starting to wait for response headers");
230
231 match timeout(Self::WAIT, ticket_headers).await {
232 Ok(Ok(Some(headers))) => self.handle_headers(&headers).await,
233 Ok(Ok(None)) => {
234 tracing::debug!("request aborted");
235 }
236 Ok(Err(_)) => {
237 tracing::debug!("ticket channel closed");
238 }
239 Err(_) => {
240 tracing::debug!("receiver timed out");
241 }
242 }
243 }
244
245 tracing::debug!("bucket appears finished, removing");
246
247 self.buckets
248 .lock()
249 .expect("ratelimit buckets poisoned")
250 .remove(&self.path);
251 }
252
253 async fn handle_headers(&self, headers: &RatelimitHeaders) {
255 let ratelimits = match headers {
256 RatelimitHeaders::Global(global) => {
257 self.lock_global(Duration::from_secs(global.retry_after()))
258 .await;
259
260 None
261 }
262 RatelimitHeaders::None => return,
263 RatelimitHeaders::Present(present) => {
264 Some((present.limit(), present.remaining(), present.reset_after()))
265 }
266 };
267
268 tracing::debug!(path=?self.path, "updating bucket");
269 self.bucket.update(ratelimits);
270 }
271
272 async fn lock_global(&self, wait: Duration) {
274 tracing::debug!(path=?self.path, "request got global ratelimited");
275 self.global.lock();
276 let lock = self.global.0.lock().await;
277 sleep(wait).await;
278 self.global.unlock();
279
280 drop(lock);
281 }
282
283 async fn next(&self) -> Option<TicketNotifier> {
285 tracing::debug!(path=?self.path, "starting to get next in queue");
286
287 self.wait_if_needed().await;
288
289 self.bucket.queue.pop(Self::WAIT).await
290 }
291
292 #[tracing::instrument(name = "waiting for bucket to refresh", skip(self), fields(path = ?self.path))]
294 async fn wait_if_needed(&self) {
295 let wait = {
296 if self.bucket.remaining() > 0 {
297 return;
298 }
299
300 tracing::debug!("0 tickets remaining, may have to wait");
301
302 match self.bucket.time_remaining() {
303 TimeRemaining::Finished => {
304 self.bucket.try_reset();
305
306 return;
307 }
308 TimeRemaining::NotStarted => return,
309 TimeRemaining::Some(dur) => dur,
310 }
311 };
312
313 tracing::debug!(
314 milliseconds=%wait.as_millis(),
315 "waiting for ratelimit to pass",
316 );
317
318 sleep(wait).await;
319
320 tracing::debug!("done waiting for ratelimit to pass");
321
322 self.bucket.try_reset();
323 }
324}