twilight_http_ratelimiting/in_memory/
mod.rs1mod bucket;
4
5use self::bucket::{Bucket, BucketQueueTask};
6use super::{
7 ticket::{self, TicketNotifier},
8 Bucket as InfoBucket, Ratelimiter,
9};
10use crate::{
11 request::Path, GetBucketFuture, GetTicketFuture, HasBucketFuture, IsGloballyLockedFuture,
12};
13use std::{
14 collections::hash_map::{Entry, HashMap},
15 future,
16 sync::{
17 atomic::{AtomicBool, Ordering},
18 Arc, Mutex,
19 },
20 time::Duration,
21};
22use tokio::sync::Mutex as AsyncMutex;
23
24#[derive(Debug, Default)]
29struct GlobalLockPair(AsyncMutex<()>, AtomicBool);
30
31impl GlobalLockPair {
32 pub fn lock(&self) {
34 self.1.store(true, Ordering::Release);
35 }
36
37 pub fn unlock(&self) {
39 self.1.store(false, Ordering::Release);
40 }
41
42 pub fn is_locked(&self) -> bool {
44 self.1.load(Ordering::Relaxed)
45 }
46}
47
48#[derive(Clone, Debug, Default)]
58pub struct InMemoryRatelimiter {
59 buckets: Arc<Mutex<HashMap<Path, Arc<Bucket>>>>,
61 global: Arc<GlobalLockPair>,
63}
64
65impl InMemoryRatelimiter {
66 #[must_use]
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 fn entry(&self, path: Path, tx: TicketNotifier) -> Option<Arc<Bucket>> {
79 let mut buckets = self.buckets.lock().expect("buckets poisoned");
80
81 match buckets.entry(path.clone()) {
82 Entry::Occupied(bucket) => {
83 tracing::debug!("got existing bucket: {path:?}");
84
85 bucket.get().queue.push(tx);
86
87 tracing::debug!("added request into bucket queue: {path:?}");
88
89 None
90 }
91 Entry::Vacant(entry) => {
92 tracing::debug!("making new bucket for path: {path:?}");
93
94 let bucket = Bucket::new(path);
95 bucket.queue.push(tx);
96
97 let bucket = Arc::new(bucket);
98 entry.insert(Arc::clone(&bucket));
99
100 Some(bucket)
101 }
102 }
103 }
104}
105
106impl Ratelimiter for InMemoryRatelimiter {
107 fn bucket(&self, path: &Path) -> GetBucketFuture {
108 self.buckets
109 .lock()
110 .expect("buckets poisoned")
111 .get(path)
112 .map_or_else(
113 || Box::pin(future::ready(Ok(None))),
114 |bucket| {
115 let started_at = bucket.started_at.lock().expect("bucket poisoned");
116 let reset_after = Duration::from_millis(bucket.reset_after());
117
118 Box::pin(future::ready(Ok(Some(InfoBucket::new(
119 bucket.limit(),
120 bucket.remaining(),
121 reset_after,
122 *started_at,
123 )))))
124 },
125 )
126 }
127
128 fn is_globally_locked(&self) -> IsGloballyLockedFuture {
129 Box::pin(future::ready(Ok(self.global.is_locked())))
130 }
131
132 fn has(&self, path: &Path) -> HasBucketFuture {
133 let has = self
134 .buckets
135 .lock()
136 .expect("buckets poisoned")
137 .contains_key(path);
138
139 Box::pin(future::ready(Ok(has)))
140 }
141
142 fn ticket(&self, path: Path) -> GetTicketFuture {
143 tracing::debug!("getting bucket for path: {path:?}");
144
145 let (tx, rx) = ticket::channel();
146
147 if let Some(bucket) = self.entry(path.clone(), tx) {
148 tokio::spawn(
149 BucketQueueTask::new(
150 bucket,
151 Arc::clone(&self.buckets),
152 Arc::clone(&self.global),
153 path,
154 )
155 .run(),
156 );
157 }
158
159 Box::pin(future::ready(Ok(rx)))
160 }
161}