twilight_http/response/future.rs
1use super::{BytesFuture, Response};
2use crate::{
3 api_error::ApiError,
4 client::connector::Connector,
5 error::{Error, ErrorType},
6};
7use http::{HeaderMap, HeaderValue, Request, StatusCode, header};
8use http_body_util::Full;
9use hyper::body::Bytes;
10use hyper_util::client::legacy::{Client as HyperClient, ResponseFuture as HyperResponseFuture};
11use std::{
12 future::{Future, Ready, ready},
13 marker::PhantomData,
14 pin::Pin,
15 sync::{
16 Arc,
17 atomic::{AtomicBool, Ordering},
18 },
19 task::{Context, Poll, ready},
20 time::{Duration, Instant},
21};
22use tokio::time::{self, Timeout};
23use twilight_http_ratelimiting::{Endpoint, Permit, PermitFuture, RateLimitHeaders, RateLimiter};
24
25/// Parse ratelimit headers from a map of headers.
26///
27/// # Errors
28///
29/// Errors if a required header is missing or if a header value is of an
30/// invalid type.
31fn parse_ratelimit_headers(
32 headers: &HeaderMap,
33) -> Result<Option<RateLimitHeaders>, Box<dyn std::error::Error>> {
34 match headers
35 .get(RateLimitHeaders::SCOPE)
36 .map(HeaderValue::as_bytes)
37 {
38 Some(b"global") => {
39 tracing::info!("globally rate limited");
40
41 Ok(None)
42 }
43 Some(b"shared") => {
44 let bucket = headers
45 .get(RateLimitHeaders::BUCKET)
46 .ok_or("missing bucket header")?
47 .as_bytes()
48 .to_vec();
49 let retry_after = headers
50 .get(header::RETRY_AFTER)
51 .ok_or("missing retry-after header")?
52 .to_str()?
53 .parse()?;
54
55 Ok(Some(RateLimitHeaders::shared(bucket, retry_after)))
56 }
57 Some(b"user") => {
58 let bucket = headers
59 .get(RateLimitHeaders::BUCKET)
60 .ok_or("missing bucket header")?
61 .as_bytes()
62 .to_vec();
63 let limit = headers
64 .get(RateLimitHeaders::LIMIT)
65 .ok_or("missing limit header")?
66 .to_str()?
67 .parse()?;
68 let remaining = headers
69 .get(RateLimitHeaders::REMAINING)
70 .ok_or("missing remaining header")?
71 .to_str()?
72 .parse()?;
73 let reset_after = headers
74 .get(RateLimitHeaders::RESET_AFTER)
75 .ok_or("missing reset-after header")?
76 .to_str()?
77 .parse()?;
78
79 Ok(Some(RateLimitHeaders {
80 bucket,
81 limit,
82 remaining,
83 reset_at: Instant::now() + Duration::from_secs_f32(reset_after),
84 }))
85 }
86 _ => Ok(None),
87 }
88}
89
90/// Sub-futures of [`ResponseFuture`].
91enum ResponseStageFuture {
92 /// Future that completes with an error response body.
93 Error {
94 /// Inner response body future.
95 fut: BytesFuture,
96 /// Erroneous response status code.
97 status: StatusCode,
98 },
99 /// Future that completes when a rate limit permit is ready.
100 RateLimitPermit(PermitFuture),
101 /// Future that completes with a response or timeout.
102 Response {
103 /// Inner timed response future.
104 fut: Pin<Box<Timeout<HyperResponseFuture>>>,
105 /// Optional rate limit permit.
106 permit: Option<Permit>,
107 },
108}
109
110/// [`PermitFuture`] generator.
111struct PermitFutureGenerator {
112 /// Rate limiter to acquire permits from.
113 rate_limiter: RateLimiter,
114 /// Rate limiter endpoint to acquire permits for.
115 endpoint: Endpoint,
116}
117
118impl PermitFutureGenerator {
119 /// Generates a permit future.
120 fn generate(&self) -> PermitFuture {
121 tracing::debug!("awaiting permit");
122 self.rate_limiter.acquire(self.endpoint.clone())
123 }
124}
125
126/// [`Timeout<HyperResponseFuture>`] generator.
127struct TimedResponseFutureGenerator {
128 /// HTTP client to send requests from.
129 client: HyperClient<Connector, Full<Bytes>>,
130 /// HTTP request to send.
131 request: Request<Full<Bytes>>,
132 /// Duration after which the request times out.
133 timeout: Duration,
134}
135
136impl TimedResponseFutureGenerator {
137 /// Generates a timeout response future.
138 fn generate(&self) -> Pin<Box<Timeout<HyperResponseFuture>>> {
139 tracing::debug!("awaiting response");
140 Box::pin(time::timeout(
141 self.timeout,
142 self.client.request(self.request.clone()),
143 ))
144 }
145}
146
147/// Future that completes when a [`Response`] is received.
148///
149/// # Rate limits
150///
151/// Requests that exceed a rate limit are automatically and immediately retried
152/// until they succeed or fail with another error. If configured without a
153/// [`RateLimiter`], care must be taken that an external service intercepts and
154/// delays these retry requests.
155///
156/// # Canceling a response future pre-flight
157///
158/// Response futures can be canceled pre-flight via
159/// [`ResponseFuture::set_pre_flight`]. This allows you to cancel requests that
160/// are no longer necessary once they have been cleared by the ratelimit queue,
161/// which may be necessary in scenarios where requests are being spammed. Refer
162/// to its documentation for more information.
163///
164/// # Errors
165///
166/// Returns an [`ErrorType::Parsing`] error type if the request failed and the
167/// error in the response body could not be deserialized.
168///
169/// Returns an [`ErrorType::RequestCanceled`] error type if the request was
170/// canceled by the user.
171///
172/// Returns an [`ErrorType::RequestError`] error type if creating the request
173/// failed.
174///
175/// Returns an [`ErrorType::RequestTimedOut`] error type if the request timed
176/// out. The timeout value is configured via [`ClientBuilder::timeout`].
177///
178/// Returns an [`ErrorType::Response`] error type if the request failed.
179///
180/// [`ClientBuilder::timeout`]: crate::client::ClientBuilder::timeout
181/// [`ErrorType::Json`]: crate::error::ErrorType::Json
182/// [`ErrorType::Parsing`]: crate::error::ErrorType::Parsing
183/// [`ErrorType::RequestCanceled`]: crate::error::ErrorType::RequestCanceled
184/// [`ErrorType::RequestError`]: crate::error::ErrorType::RequestError
185/// [`ErrorType::RequestTimedOut`]: crate::error::ErrorType::RequestTimedOut
186/// [`ErrorType::Response`]: crate::error::ErrorType::Response
187/// [`Response`]: super::Response
188#[must_use = "futures do nothing unless you `.await` or poll them"]
189pub struct ResponseFuture<T>(Result<Inner<T>, Ready<Error>>);
190
191impl<T> ResponseFuture<T> {
192 pub(crate) fn new(
193 client: HyperClient<Connector, Full<Bytes>>,
194 invalid_token: Option<Arc<AtomicBool>>,
195 request: Request<Full<Bytes>>,
196 span: tracing::Span,
197 timeout: Duration,
198 rate_limiter: Option<RateLimiter>,
199 endpoint: Endpoint,
200 ) -> Self {
201 let entered = span.entered();
202
203 let permit_generator = rate_limiter.map(|rate_limiter| PermitFutureGenerator {
204 rate_limiter,
205 endpoint,
206 });
207 let response_generator = TimedResponseFutureGenerator {
208 client,
209 request,
210 timeout,
211 };
212 let stage = permit_generator.as_ref().map_or_else(
213 || ResponseStageFuture::Response {
214 fut: response_generator.generate(),
215 permit: None,
216 },
217 |generator| ResponseStageFuture::RateLimitPermit(generator.generate()),
218 );
219 Self(Ok(Inner {
220 invalid_token,
221 permit_generator,
222 phantom: PhantomData,
223 pre_flight_check: None,
224 response_generator,
225 span: entered.exit(),
226 stage,
227 }))
228 }
229
230 /// Set a function to call after clearing the ratelimiter but prior to
231 /// sending the request to determine if the request is still valid.
232 ///
233 /// This function will be a no-op if the request has failed, has already
234 /// passed the ratelimiter, or if there is no ratelimiter configured.
235 ///
236 /// Returns whether the pre flight function was set.
237 ///
238 /// # Examples
239 ///
240 /// Delete a message, but immediately before sending the request check if
241 /// the request should still be sent:
242 ///
243 /// ```no_run
244 /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
245 /// use std::{
246 /// collections::HashSet,
247 /// env,
248 /// future::IntoFuture,
249 /// sync::{Arc, Mutex},
250 /// };
251 /// use twilight_http::{Client, error::ErrorType};
252 /// use twilight_model::id::Id;
253 ///
254 /// let channel_id = Id::new(1);
255 /// let message_id = Id::new(2);
256 ///
257 /// let channels_ignored = {
258 /// let mut map = HashSet::new();
259 /// map.insert(channel_id);
260 ///
261 /// Arc::new(Mutex::new(map))
262 /// };
263 ///
264 /// let client = Client::new(env::var("DISCORD_TOKEN")?);
265 /// let mut req = client.delete_message(channel_id, message_id).into_future();
266 ///
267 /// let channels_ignored_clone = channels_ignored.clone();
268 /// req.set_pre_flight(move || {
269 /// // imagine you have some logic here to external state that checks
270 /// // whether the request should still be performed
271 /// let channels_ignored = channels_ignored_clone.lock().expect("channels poisoned");
272 ///
273 /// !channels_ignored.contains(&channel_id)
274 /// });
275 ///
276 /// // the pre-flight check will cancel the request
277 /// assert!(matches!(
278 /// req.await.unwrap_err().kind(),
279 /// ErrorType::RequestCanceled,
280 /// ));
281 /// # Ok(()) }
282 /// ```
283 pub fn set_pre_flight<P>(&mut self, predicate: P) -> bool
284 where
285 P: Fn() -> bool + Send + 'static,
286 {
287 if let Ok(inner) = &mut self.0
288 && inner.permit_generator.is_some()
289 && inner.pre_flight_check.is_none()
290 {
291 inner.pre_flight_check = Some(Box::new(predicate));
292
293 true
294 } else {
295 false
296 }
297 }
298
299 /// Creates a future that is immediately ready with an error.
300 pub(crate) fn error(source: Error) -> Self {
301 Self(Err(ready(source)))
302 }
303}
304
305impl<T: Unpin> Future for ResponseFuture<T> {
306 type Output = Result<Response<T>, Error>;
307
308 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309 let inner = match &mut self.0 {
310 Ok(inner) => inner,
311 Err(err) => return Pin::new(err).poll(cx).map(Err),
312 };
313
314 let _entered = inner.span.enter();
315
316 loop {
317 match &mut inner.stage {
318 ResponseStageFuture::Error { fut, status } => {
319 let body = ready!(Pin::new(fut).poll(cx)).map_err(|source| Error {
320 kind: ErrorType::RequestError,
321 source: Some(Box::new(source)),
322 })?;
323
324 return Poll::Ready(Err(match crate::json::from_bytes::<ApiError>(&body) {
325 Ok(error) => Error {
326 kind: ErrorType::Response {
327 body,
328 error,
329 status: super::StatusCode::new(status.as_u16()),
330 },
331 source: None,
332 },
333 Err(source) => Error {
334 kind: ErrorType::Parsing { body },
335 source: Some(Box::new(source)),
336 },
337 }));
338 }
339 ResponseStageFuture::RateLimitPermit(fut) => {
340 let permit = ready!(Pin::new(fut).poll(cx));
341 if inner
342 .pre_flight_check
343 .as_ref()
344 .is_some_and(|check| !check())
345 {
346 return Poll::Ready(Err(Error {
347 kind: ErrorType::RequestCanceled,
348 source: None,
349 }));
350 }
351
352 inner.stage = ResponseStageFuture::Response {
353 fut: inner.response_generator.generate(),
354 permit: Some(permit),
355 };
356 }
357 ResponseStageFuture::Response { fut, permit } => {
358 let response = ready!(Pin::new(fut).poll(cx))
359 .map_err(|source| Error {
360 kind: ErrorType::RequestTimedOut,
361 source: Some(Box::new(source)),
362 })?
363 .map_err(|source| Error {
364 kind: ErrorType::RequestError,
365 source: Some(Box::new(source)),
366 })?;
367
368 if response.status() == StatusCode::UNAUTHORIZED
369 && let Some(invalid) = &inner.invalid_token
370 {
371 invalid.store(true, Ordering::Relaxed);
372 }
373
374 if let Some(permit) = permit.take() {
375 match parse_ratelimit_headers(response.headers()) {
376 Ok(v) => permit.complete(v),
377 Err(source) => {
378 tracing::warn!("header parsing failed: {source}; {response:?}");
379
380 permit.complete(None);
381 }
382 }
383 }
384
385 if response.status().is_success() {
386 #[cfg(feature = "decompression")]
387 let mut response = response;
388 // Inaccurate since end-users can only access the decompressed body.
389 #[cfg(feature = "decompression")]
390 response.headers_mut().remove(header::CONTENT_LENGTH);
391
392 return Poll::Ready(Ok(Response::new(response)));
393 } else if response.status() == StatusCode::TOO_MANY_REQUESTS {
394 tracing::info!("rate limited; retrying");
395 inner.stage = match &inner.permit_generator {
396 Some(generator) => {
397 ResponseStageFuture::RateLimitPermit(generator.generate())
398 }
399 None => ResponseStageFuture::Response {
400 fut: inner.response_generator.generate(),
401 permit: None,
402 },
403 };
404 } else {
405 inner.stage = ResponseStageFuture::Error {
406 status: response.status(),
407 fut: Response::<()>::new(response).bytes(),
408 };
409 }
410 }
411 }
412 }
413 }
414}
415
416/// Internal response future fields.
417struct Inner<T> {
418 /// Whether the client's token is invalidated.
419 invalid_token: Option<Arc<AtomicBool>>,
420 /// Optional [`PermitFuture`] generator, if registered.
421 permit_generator: Option<PermitFutureGenerator>,
422 phantom: PhantomData<T>,
423 /// Predicate to check after completing [`ResponseStageFuture::RateLimitPermit`].
424 pre_flight_check: Option<Box<dyn Fn() -> bool + Send + 'static>>,
425 /// [`Timeout<HyperResponseFuture>`] generator.
426 response_generator: TimedResponseFutureGenerator,
427 /// This future's span.
428 span: tracing::Span,
429 /// This future's current stage.
430 stage: ResponseStageFuture,
431}