Skip to main content

twilight_http/response/
future.rs

1use super::{BytesFuture, Response};
2use crate::{
3    api_error::ApiError,
4    client::connector::Connector,
5    error::{Error, ErrorType},
6};
7use http::{HeaderMap, HeaderValue, Request, StatusCode, header};
8use http_body_util::Full;
9use hyper::body::Bytes;
10use hyper_util::client::legacy::{Client as HyperClient, ResponseFuture as HyperResponseFuture};
11use std::{
12    future::{Future, Ready, ready},
13    marker::PhantomData,
14    pin::Pin,
15    sync::{
16        Arc,
17        atomic::{AtomicBool, Ordering},
18    },
19    task::{Context, Poll, ready},
20    time::{Duration, Instant},
21};
22use tokio::time::{self, Timeout};
23use twilight_http_ratelimiting::{Endpoint, Permit, PermitFuture, RateLimitHeaders, RateLimiter};
24
25/// Parse ratelimit headers from a map of headers.
26///
27/// # Errors
28///
29/// Errors if a required header is missing or if a header value is of an
30/// invalid type.
31fn parse_ratelimit_headers(
32    headers: &HeaderMap,
33) -> Result<Option<RateLimitHeaders>, Box<dyn std::error::Error>> {
34    match headers
35        .get(RateLimitHeaders::SCOPE)
36        .map(HeaderValue::as_bytes)
37    {
38        Some(b"global") => {
39            tracing::info!("globally rate limited");
40
41            Ok(None)
42        }
43        Some(b"shared") => {
44            let bucket = headers
45                .get(RateLimitHeaders::BUCKET)
46                .ok_or("missing bucket header")?
47                .as_bytes()
48                .to_vec();
49            let retry_after = headers
50                .get(header::RETRY_AFTER)
51                .ok_or("missing retry-after header")?
52                .to_str()?
53                .parse()?;
54
55            Ok(Some(RateLimitHeaders::shared(bucket, retry_after)))
56        }
57        Some(b"user") => {
58            let bucket = headers
59                .get(RateLimitHeaders::BUCKET)
60                .ok_or("missing bucket header")?
61                .as_bytes()
62                .to_vec();
63            let limit = headers
64                .get(RateLimitHeaders::LIMIT)
65                .ok_or("missing limit header")?
66                .to_str()?
67                .parse()?;
68            let remaining = headers
69                .get(RateLimitHeaders::REMAINING)
70                .ok_or("missing remaining header")?
71                .to_str()?
72                .parse()?;
73            let reset_after = headers
74                .get(RateLimitHeaders::RESET_AFTER)
75                .ok_or("missing reset-after header")?
76                .to_str()?
77                .parse()?;
78
79            Ok(Some(RateLimitHeaders {
80                bucket,
81                limit,
82                remaining,
83                reset_at: Instant::now() + Duration::from_secs_f32(reset_after),
84            }))
85        }
86        _ => Ok(None),
87    }
88}
89
90/// Sub-futures of [`ResponseFuture`].
91enum ResponseStageFuture {
92    /// Future that completes with an error response body.
93    Error {
94        /// Inner response body future.
95        fut: BytesFuture,
96        /// Erroneous response status code.
97        status: StatusCode,
98    },
99    /// Future that completes when a rate limit permit is ready.
100    RateLimitPermit(PermitFuture),
101    /// Future that completes with a response or timeout.
102    Response {
103        /// Inner timed response future.
104        fut: Pin<Box<Timeout<HyperResponseFuture>>>,
105        /// Optional rate limit permit.
106        permit: Option<Permit>,
107    },
108}
109
110/// [`PermitFuture`] generator.
111struct PermitFutureGenerator {
112    /// Rate limiter to acquire permits from.
113    rate_limiter: RateLimiter,
114    /// Rate limiter endpoint to acquire permits for.
115    endpoint: Endpoint,
116}
117
118impl PermitFutureGenerator {
119    /// Generates a permit future.
120    fn generate(&self) -> PermitFuture {
121        tracing::debug!("awaiting permit");
122        self.rate_limiter.acquire(self.endpoint.clone())
123    }
124}
125
126/// [`Timeout<HyperResponseFuture>`] generator.
127struct TimedResponseFutureGenerator {
128    /// HTTP client to send requests from.
129    client: HyperClient<Connector, Full<Bytes>>,
130    /// HTTP request to send.
131    request: Request<Full<Bytes>>,
132    /// Duration after which the request times out.
133    timeout: Duration,
134}
135
136impl TimedResponseFutureGenerator {
137    /// Generates a timeout response future.
138    fn generate(&self) -> Pin<Box<Timeout<HyperResponseFuture>>> {
139        tracing::debug!("awaiting response");
140        Box::pin(time::timeout(
141            self.timeout,
142            self.client.request(self.request.clone()),
143        ))
144    }
145}
146
147/// Future that completes when a [`Response`] is received.
148///
149/// # Rate limits
150///
151/// Requests that exceed a rate limit are automatically and immediately retried
152/// until they succeed or fail with another error. If configured without a
153/// [`RateLimiter`], care must be taken that an external service intercepts and
154/// delays these retry requests.
155///
156/// # Canceling a response future pre-flight
157///
158/// Response futures can be canceled pre-flight via
159/// [`ResponseFuture::set_pre_flight`]. This allows you to cancel requests that
160/// are no longer necessary once they have been cleared by the ratelimit queue,
161/// which may be necessary in scenarios where requests are being spammed. Refer
162/// to its documentation for more information.
163///
164/// # Errors
165///
166/// Returns an [`ErrorType::Parsing`] error type if the request failed and the
167/// error in the response body could not be deserialized.
168///
169/// Returns an [`ErrorType::RequestCanceled`] error type if the request was
170/// canceled by the user.
171///
172/// Returns an [`ErrorType::RequestError`] error type if creating the request
173/// failed.
174///
175/// Returns an [`ErrorType::RequestTimedOut`] error type if the request timed
176/// out. The timeout value is configured via [`ClientBuilder::timeout`].
177///
178/// Returns an [`ErrorType::Response`] error type if the request failed.
179///
180/// [`ClientBuilder::timeout`]: crate::client::ClientBuilder::timeout
181/// [`ErrorType::Json`]: crate::error::ErrorType::Json
182/// [`ErrorType::Parsing`]: crate::error::ErrorType::Parsing
183/// [`ErrorType::RequestCanceled`]: crate::error::ErrorType::RequestCanceled
184/// [`ErrorType::RequestError`]: crate::error::ErrorType::RequestError
185/// [`ErrorType::RequestTimedOut`]: crate::error::ErrorType::RequestTimedOut
186/// [`ErrorType::Response`]: crate::error::ErrorType::Response
187/// [`Response`]: super::Response
188#[must_use = "futures do nothing unless you `.await` or poll them"]
189pub struct ResponseFuture<T>(Result<Inner<T>, Ready<Error>>);
190
191impl<T> ResponseFuture<T> {
192    pub(crate) fn new(
193        client: HyperClient<Connector, Full<Bytes>>,
194        invalid_token: Option<Arc<AtomicBool>>,
195        request: Request<Full<Bytes>>,
196        span: tracing::Span,
197        timeout: Duration,
198        rate_limiter: Option<RateLimiter>,
199        endpoint: Endpoint,
200    ) -> Self {
201        let entered = span.entered();
202
203        let permit_generator = rate_limiter.map(|rate_limiter| PermitFutureGenerator {
204            rate_limiter,
205            endpoint,
206        });
207        let response_generator = TimedResponseFutureGenerator {
208            client,
209            request,
210            timeout,
211        };
212        let stage = permit_generator.as_ref().map_or_else(
213            || ResponseStageFuture::Response {
214                fut: response_generator.generate(),
215                permit: None,
216            },
217            |generator| ResponseStageFuture::RateLimitPermit(generator.generate()),
218        );
219        Self(Ok(Inner {
220            invalid_token,
221            permit_generator,
222            phantom: PhantomData,
223            pre_flight_check: None,
224            response_generator,
225            span: entered.exit(),
226            stage,
227        }))
228    }
229
230    /// Set a function to call after clearing the ratelimiter but prior to
231    /// sending the request to determine if the request is still valid.
232    ///
233    /// This function will be a no-op if the request has failed, has already
234    /// passed the ratelimiter, or if there is no ratelimiter configured.
235    ///
236    /// Returns whether the pre flight function was set.
237    ///
238    /// # Examples
239    ///
240    /// Delete a message, but immediately before sending the request check if
241    /// the request should still be sent:
242    ///
243    /// ```no_run
244    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
245    /// use std::{
246    ///     collections::HashSet,
247    ///     env,
248    ///     future::IntoFuture,
249    ///     sync::{Arc, Mutex},
250    /// };
251    /// use twilight_http::{Client, error::ErrorType};
252    /// use twilight_model::id::Id;
253    ///
254    /// let channel_id = Id::new(1);
255    /// let message_id = Id::new(2);
256    ///
257    /// let channels_ignored = {
258    ///     let mut map = HashSet::new();
259    ///     map.insert(channel_id);
260    ///
261    ///     Arc::new(Mutex::new(map))
262    /// };
263    ///
264    /// let client = Client::new(env::var("DISCORD_TOKEN")?);
265    /// let mut req = client.delete_message(channel_id, message_id).into_future();
266    ///
267    /// let channels_ignored_clone = channels_ignored.clone();
268    /// req.set_pre_flight(move || {
269    ///     // imagine you have some logic here to external state that checks
270    ///     // whether the request should still be performed
271    ///     let channels_ignored = channels_ignored_clone.lock().expect("channels poisoned");
272    ///
273    ///     !channels_ignored.contains(&channel_id)
274    /// });
275    ///
276    /// // the pre-flight check will cancel the request
277    /// assert!(matches!(
278    ///     req.await.unwrap_err().kind(),
279    ///     ErrorType::RequestCanceled,
280    /// ));
281    /// # Ok(()) }
282    /// ```
283    pub fn set_pre_flight<P>(&mut self, predicate: P) -> bool
284    where
285        P: Fn() -> bool + Send + 'static,
286    {
287        if let Ok(inner) = &mut self.0
288            && inner.permit_generator.is_some()
289            && inner.pre_flight_check.is_none()
290        {
291            inner.pre_flight_check = Some(Box::new(predicate));
292
293            true
294        } else {
295            false
296        }
297    }
298
299    /// Creates a future that is immediately ready with an error.
300    pub(crate) fn error(source: Error) -> Self {
301        Self(Err(ready(source)))
302    }
303}
304
305impl<T: Unpin> Future for ResponseFuture<T> {
306    type Output = Result<Response<T>, Error>;
307
308    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309        let inner = match &mut self.0 {
310            Ok(inner) => inner,
311            Err(err) => return Pin::new(err).poll(cx).map(Err),
312        };
313
314        let _entered = inner.span.enter();
315
316        loop {
317            match &mut inner.stage {
318                ResponseStageFuture::Error { fut, status } => {
319                    let body = ready!(Pin::new(fut).poll(cx)).map_err(|source| Error {
320                        kind: ErrorType::RequestError,
321                        source: Some(Box::new(source)),
322                    })?;
323
324                    return Poll::Ready(Err(match crate::json::from_bytes::<ApiError>(&body) {
325                        Ok(error) => Error {
326                            kind: ErrorType::Response {
327                                body,
328                                error,
329                                status: super::StatusCode::new(status.as_u16()),
330                            },
331                            source: None,
332                        },
333                        Err(source) => Error {
334                            kind: ErrorType::Parsing { body },
335                            source: Some(Box::new(source)),
336                        },
337                    }));
338                }
339                ResponseStageFuture::RateLimitPermit(fut) => {
340                    let permit = ready!(Pin::new(fut).poll(cx));
341                    if inner
342                        .pre_flight_check
343                        .as_ref()
344                        .is_some_and(|check| !check())
345                    {
346                        return Poll::Ready(Err(Error {
347                            kind: ErrorType::RequestCanceled,
348                            source: None,
349                        }));
350                    }
351
352                    inner.stage = ResponseStageFuture::Response {
353                        fut: inner.response_generator.generate(),
354                        permit: Some(permit),
355                    };
356                }
357                ResponseStageFuture::Response { fut, permit } => {
358                    let response = ready!(Pin::new(fut).poll(cx))
359                        .map_err(|source| Error {
360                            kind: ErrorType::RequestTimedOut,
361                            source: Some(Box::new(source)),
362                        })?
363                        .map_err(|source| Error {
364                            kind: ErrorType::RequestError,
365                            source: Some(Box::new(source)),
366                        })?;
367
368                    if response.status() == StatusCode::UNAUTHORIZED
369                        && let Some(invalid) = &inner.invalid_token
370                    {
371                        invalid.store(true, Ordering::Relaxed);
372                    }
373
374                    if let Some(permit) = permit.take() {
375                        match parse_ratelimit_headers(response.headers()) {
376                            Ok(v) => permit.complete(v),
377                            Err(source) => {
378                                tracing::warn!("header parsing failed: {source}; {response:?}");
379
380                                permit.complete(None);
381                            }
382                        }
383                    }
384
385                    if response.status().is_success() {
386                        #[cfg(feature = "decompression")]
387                        let mut response = response;
388                        // Inaccurate since end-users can only access the decompressed body.
389                        #[cfg(feature = "decompression")]
390                        response.headers_mut().remove(header::CONTENT_LENGTH);
391
392                        return Poll::Ready(Ok(Response::new(response)));
393                    } else if response.status() == StatusCode::TOO_MANY_REQUESTS {
394                        tracing::info!("rate limited; retrying");
395                        inner.stage = match &inner.permit_generator {
396                            Some(generator) => {
397                                ResponseStageFuture::RateLimitPermit(generator.generate())
398                            }
399                            None => ResponseStageFuture::Response {
400                                fut: inner.response_generator.generate(),
401                                permit: None,
402                            },
403                        };
404                    } else {
405                        inner.stage = ResponseStageFuture::Error {
406                            status: response.status(),
407                            fut: Response::<()>::new(response).bytes(),
408                        };
409                    }
410                }
411            }
412        }
413    }
414}
415
416/// Internal response future fields.
417struct Inner<T> {
418    /// Whether the client's token is invalidated.
419    invalid_token: Option<Arc<AtomicBool>>,
420    /// Optional [`PermitFuture`] generator, if registered.
421    permit_generator: Option<PermitFutureGenerator>,
422    phantom: PhantomData<T>,
423    /// Predicate to check after completing [`ResponseStageFuture::RateLimitPermit`].
424    pre_flight_check: Option<Box<dyn Fn() -> bool + Send + 'static>>,
425    /// [`Timeout<HyperResponseFuture>`] generator.
426    response_generator: TimedResponseFutureGenerator,
427    /// This future's span.
428    span: tracing::Span,
429    /// This future's current stage.
430    stage: ResponseStageFuture,
431}