twilight_gateway_queue/
in_memory.rs

1//! Memory based [`Queue`] implementation and supporting items.
2
3use super::{Queue, IDENTIFY_DELAY, LIMIT_PERIOD};
4use std::{collections::VecDeque, fmt::Debug, iter};
5use tokio::{
6    sync::{mpsc, oneshot},
7    task::yield_now,
8    time::{sleep_until, Duration, Instant},
9};
10
11/// Possible messages from the [`InMemoryQueue`] to the [`runner`].
12#[derive(Debug)]
13enum Message {
14    /// Request a permit.
15    Request {
16        /// For this shard.
17        shard: u32,
18        /// Indicate readiness through this sender.
19        tx: oneshot::Sender<()>,
20    },
21    /// Update the runner's settings.
22    Update(Settings),
23}
24
25/// [`runner`]'s settings.
26#[derive(Debug)]
27struct Settings {
28    /// The maximum number of concurrent permits to grant. `0` instantly grants
29    /// all permits.
30    max_concurrency: u16,
31    /// Remaining daily permits.
32    remaining: u32,
33    /// Time until the daily permits reset.
34    reset_after: Duration,
35    /// The number of permits to reset to.
36    total: u32,
37}
38
39/// [`InMemoryQueue`]'s background task runner.
40///
41/// Buckets requests such that only one timer is necessary.
42async fn runner(
43    mut rx: mpsc::UnboundedReceiver<Message>,
44    Settings {
45        max_concurrency,
46        mut remaining,
47        reset_after,
48        mut total,
49    }: Settings,
50) {
51    let (interval, reset_at) = {
52        let now = Instant::now();
53        (sleep_until(now), sleep_until(now + reset_after))
54    };
55    tokio::pin!(interval, reset_at);
56
57    let mut queues = iter::repeat_with(VecDeque::new)
58        .take(max_concurrency.into())
59        .collect::<Box<_>>();
60
61    #[allow(clippy::ignored_unit_patterns)]
62    loop {
63        tokio::select! {
64            biased;
65            _ = &mut reset_at, if remaining != total => {
66                remaining = total;
67            }
68            message = rx.recv() => {
69                match message {
70                    Some(Message::Request { shard, tx }) => {
71                        if queues.is_empty() {
72                            _ = tx.send(());
73                        } else {
74                            let key = shard as usize % queues.len();
75                            queues[key].push_back((shard, tx));
76                        }
77                    }
78                    Some(Message::Update(update)) => {
79                        let (max_concurrency, reset_after);
80                        Settings {
81                            max_concurrency,
82                            remaining,
83                            reset_after,
84                            total,
85                        } = update;
86
87                        if remaining != total {
88                            reset_at.as_mut().reset(Instant::now() + reset_after);
89                        }
90
91                        if max_concurrency as usize != queues.len() {
92                            let unbalanced = queues.into_vec().into_iter().flatten();
93                            queues = iter::repeat_with(VecDeque::new)
94                                .take(max_concurrency.into())
95                                .collect();
96                            for (shard, tx) in unbalanced {
97                                let key = (shard % u32::from(max_concurrency)) as usize;
98                                queues[key].push_back((shard, tx));
99                            }
100                        }
101                    }
102                    None => break,
103                }
104            }
105            _ = &mut interval, if queues.iter().any(|queue| !queue.is_empty()) => {
106                let now = Instant::now();
107                let span = tracing::info_span!("bucket", moment = ?now);
108
109                interval.as_mut().reset(now + IDENTIFY_DELAY);
110
111                if remaining == total {
112                    reset_at.as_mut().reset(now + LIMIT_PERIOD);
113                }
114
115                for (key, queue) in queues.iter_mut().enumerate() {
116                    if remaining == 0 {
117                        tracing::debug!(
118                            refill_delay = ?reset_at.deadline().saturating_duration_since(now),
119                            "exhausted available permits"
120                        );
121                        (&mut reset_at).await;
122                        remaining = total;
123
124                        break;
125                    }
126
127                    while let Some((shard, tx)) = queue.pop_front() {
128                        if tx.send(()).is_err() {
129                            continue;
130                        }
131
132                        tracing::debug!(parent: &span, key, shard);
133                        remaining -= 1;
134                        // Reschedule behind shard for ordering correctness.
135                        yield_now().await;
136
137                        break;
138                    }
139                }
140            }
141        }
142    }
143}
144
145/// Memory based [`Queue`] implementation backed by an efficient background task.
146///
147/// [`InMemoryQueue::update`] allows for dynamically changing the queue's
148/// settings.
149///
150/// Cloning the queue is cheap and just increments a reference counter.
151///
152/// **Note:** A `max_concurrency` of `0` processes all requests instantly,
153/// effectively disabling the queue.
154#[derive(Clone, Debug)]
155pub struct InMemoryQueue {
156    /// Sender to communicate with the background [task runner].
157    ///
158    /// [task runner]: runner
159    tx: mpsc::UnboundedSender<Message>,
160}
161
162impl InMemoryQueue {
163    /// Creates a new `InMemoryQueue` with custom settings.
164    ///
165    /// # Panics
166    ///
167    /// Panics if `total` < `remaining`.
168    pub fn new(max_concurrency: u16, remaining: u32, reset_after: Duration, total: u32) -> Self {
169        assert!(total >= remaining);
170        let (tx, rx) = mpsc::unbounded_channel();
171
172        tokio::spawn(runner(
173            rx,
174            Settings {
175                max_concurrency,
176                remaining,
177                reset_after,
178                total,
179            },
180        ));
181
182        Self { tx }
183    }
184
185    /// Update the queue with new info from the [Get Gateway Bot] endpoint.
186    ///
187    /// May be regularly called as the bot joins/leaves guilds.
188    ///
189    /// # Example
190    ///
191    /// ```no_run
192    /// # use twilight_gateway_queue::InMemoryQueue;
193    /// # let rt = tokio::runtime::Builder::new_current_thread()
194    /// #     .enable_time()
195    /// #     .build()
196    /// #     .unwrap();
197    /// use std::time::Duration;
198    /// use twilight_http::Client;
199    ///
200    /// # rt.block_on(async {
201    /// # let queue = InMemoryQueue::default();
202    /// # let token = String::new();
203    /// let client = Client::new(token);
204    /// let session = client
205    ///     .gateway()
206    ///     .authed()
207    ///     .await?
208    ///     .model()
209    ///     .await?
210    ///     .session_start_limit;
211    /// queue.update(
212    ///     session.max_concurrency,
213    ///     session.remaining,
214    ///     Duration::from_millis(session.reset_after),
215    ///     session.total,
216    /// );
217    /// # Ok::<(), Box<dyn std::error::Error>>(())
218    /// # });
219    /// ```
220    ///
221    /// # Panics
222    ///
223    /// Panics if `total` < `remaining`.
224    ///
225    /// [Get Gateway Bot]: https://discord.com/developers/docs/topics/gateway#get-gateway-bot
226    pub fn update(&self, max_concurrency: u16, remaining: u32, reset_after: Duration, total: u32) {
227        assert!(total >= remaining);
228
229        self.tx
230            .send(Message::Update(Settings {
231                max_concurrency,
232                remaining,
233                reset_after,
234                total,
235            }))
236            .expect("receiver dropped after sender");
237    }
238}
239
240impl Default for InMemoryQueue {
241    /// Creates a new `InMemoryQueue` with Discord's default settings.
242    ///
243    /// Currently these are:
244    ///
245    /// * `max_concurrency`: 1
246    /// * `remaining`: 1000
247    /// * `reset_after`: [`LIMIT_PERIOD`]
248    /// * `total`: 1000.
249    fn default() -> Self {
250        Self::new(1, 1000, LIMIT_PERIOD, 1000)
251    }
252}
253
254impl Queue for InMemoryQueue {
255    fn enqueue(&self, shard: u32) -> oneshot::Receiver<()> {
256        let (tx, rx) = oneshot::channel();
257
258        self.tx
259            .send(Message::Request { shard, tx })
260            .expect("receiver dropped after sender");
261
262        rx
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::InMemoryQueue;
269    use crate::Queue;
270    use static_assertions::assert_impl_all;
271    use std::fmt::Debug;
272
273    assert_impl_all!(InMemoryQueue: Clone, Debug, Default, Send, Sync, Queue);
274}