twilight_http/response/
future.rs

1use super::{Response, StatusCode};
2use crate::{
3    api_error::ApiError,
4    error::{Error, ErrorType},
5};
6use http::StatusCode as HyperStatusCode;
7use hyper_util::client::legacy::ResponseFuture as HyperResponseFuture;
8use std::{
9    future::Future,
10    marker::PhantomData,
11    mem,
12    pin::Pin,
13    sync::{
14        atomic::{AtomicBool, Ordering},
15        Arc,
16    },
17    task::{Context, Poll},
18    time::Duration,
19};
20use tokio::time::{self, Timeout};
21use twilight_http_ratelimiting::{ticket::TicketSender, RatelimitHeaders, WaitForTicketFuture};
22
23type Output<T> = Result<Response<T>, Error>;
24
25enum InnerPoll<T> {
26    Advance(ResponseFutureStage),
27    Pending(ResponseFutureStage),
28    Ready(Output<T>),
29}
30
31struct Chunking {
32    future: Pin<Box<dyn Future<Output = Result<Vec<u8>, Error>> + Send + Sync + 'static>>,
33    status: HyperStatusCode,
34}
35
36impl Chunking {
37    fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
38        let bytes = match Pin::new(&mut self.future).poll(cx) {
39            Poll::Ready(Ok(bytes)) => bytes,
40            Poll::Ready(Err(source)) => return InnerPoll::Ready(Err(source)),
41            Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::Chunking(self)),
42        };
43
44        let error = match crate::json::from_bytes::<ApiError>(&bytes) {
45            Ok(error) => error,
46            Err(source) => {
47                return InnerPoll::Ready(Err(Error {
48                    kind: ErrorType::Parsing { body: bytes },
49                    source: Some(Box::new(source)),
50                }));
51            }
52        };
53
54        InnerPoll::Ready(Err(Error {
55            kind: ErrorType::Response {
56                body: bytes,
57                error,
58                status: StatusCode::new(self.status.as_u16()),
59            },
60            source: None,
61        }))
62    }
63}
64
65struct Failed {
66    source: Error,
67}
68
69impl Failed {
70    fn poll<T>(self, _: &mut Context<'_>) -> InnerPoll<T> {
71        InnerPoll::Ready(Err(self.source))
72    }
73}
74
75struct InFlight {
76    future: Pin<Box<Timeout<HyperResponseFuture>>>,
77    invalid_token: Option<Arc<AtomicBool>>,
78    tx: Option<TicketSender>,
79}
80
81impl InFlight {
82    fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
83        let resp = match Pin::new(&mut self.future).poll(cx) {
84            Poll::Ready(Ok(Ok(resp))) => resp,
85            Poll::Ready(Ok(Err(source))) => {
86                return InnerPoll::Ready(Err(Error {
87                    kind: ErrorType::RequestError,
88                    source: Some(Box::new(source)),
89                }))
90            }
91            Poll::Ready(Err(source)) => {
92                return InnerPoll::Ready(Err(Error {
93                    kind: ErrorType::RequestTimedOut,
94                    source: Some(Box::new(source)),
95                }))
96            }
97            Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::InFlight(self)),
98        };
99
100        // If the API sent back an Unauthorized response, then the client's
101        // configured token is permanently invalid and future requests must be
102        // ignored to avoid API bans.
103        if resp.status() == HyperStatusCode::UNAUTHORIZED {
104            if let Some(invalid_token) = self.invalid_token {
105                invalid_token.store(true, Ordering::Relaxed);
106            }
107        }
108
109        if let Some(tx) = self.tx {
110            let headers = resp
111                .headers()
112                .iter()
113                .map(|(key, value)| (key.as_str(), value.as_bytes()));
114
115            match RatelimitHeaders::from_pairs(headers) {
116                Ok(v) => {
117                    let _res = tx.headers(Some(v));
118                }
119                Err(source) => {
120                    tracing::warn!("header parsing failed: {source:?}; {resp:?}");
121
122                    let _res = tx.headers(None);
123                }
124            }
125        }
126
127        let status = resp.status();
128
129        if status.is_success() {
130            #[cfg(feature = "decompression")]
131            let mut resp = resp;
132            // Inaccurate since end-users can only access the decompressed body.
133            #[cfg(feature = "decompression")]
134            resp.headers_mut().remove(http::header::CONTENT_LENGTH);
135
136            return InnerPoll::Ready(Ok(Response::new(resp)));
137        }
138
139        match status {
140            HyperStatusCode::TOO_MANY_REQUESTS => {
141                tracing::warn!("429 response: {resp:?}");
142            }
143            HyperStatusCode::SERVICE_UNAVAILABLE => {
144                return InnerPoll::Ready(Err(Error {
145                    kind: ErrorType::ServiceUnavailable { response: resp },
146                    source: None,
147                }));
148            }
149            _ => {}
150        }
151
152        let fut = async {
153            Response::<()>::new(resp)
154                .bytes()
155                .await
156                .map_err(|source| Error {
157                    kind: ErrorType::ChunkingResponse,
158                    source: Some(Box::new(source)),
159                })
160        };
161
162        InnerPoll::Advance(ResponseFutureStage::Chunking(Chunking {
163            future: Box::pin(fut),
164            status,
165        }))
166    }
167}
168
169struct RatelimitQueue {
170    invalid_token: Option<Arc<AtomicBool>>,
171    response_future: HyperResponseFuture,
172    timeout: Duration,
173    pre_flight_check: Option<Box<dyn FnOnce() -> bool + Send + 'static>>,
174    wait_for_sender: WaitForTicketFuture,
175}
176
177impl RatelimitQueue {
178    fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
179        let tx = match Pin::new(&mut self.wait_for_sender).poll(cx) {
180            Poll::Ready(Ok(tx)) => tx,
181            Poll::Ready(Err(source)) => {
182                return InnerPoll::Ready(Err(Error {
183                    kind: ErrorType::RatelimiterTicket,
184                    source: Some(source),
185                }))
186            }
187            Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::RatelimitQueue(self)),
188        };
189
190        if let Some(pre_flight_check) = self.pre_flight_check {
191            if !pre_flight_check() {
192                return InnerPoll::Ready(Err(Error {
193                    kind: ErrorType::RequestCanceled,
194                    source: None,
195                }));
196            }
197        }
198
199        InnerPoll::Advance(ResponseFutureStage::InFlight(InFlight {
200            future: Box::pin(time::timeout(self.timeout, self.response_future)),
201            invalid_token: self.invalid_token,
202            tx: Some(tx),
203        }))
204    }
205}
206
207enum ResponseFutureStage {
208    Chunking(Chunking),
209    Completed,
210    Failed(Failed),
211    InFlight(InFlight),
212    RatelimitQueue(RatelimitQueue),
213}
214
215/// Future that will resolve to a [`Response`].
216///
217/// # Canceling a response future pre-flight
218///
219/// Response futures can be canceled pre-flight via
220/// [`ResponseFuture::set_pre_flight`]. This allows you to cancel requests that
221/// are no longer necessary once they have been cleared by the ratelimit queue,
222/// which may be necessary in scenarios where requests are being spammed. Refer
223/// to its documentation for more information.
224///
225/// # Errors
226///
227/// Returns an [`ErrorType::Json`] error type if serializing the response body
228/// of the request failed.
229///
230/// Returns an [`ErrorType::Parsing`] error type if the request failed and the
231/// error in the response body could not be deserialized.
232///
233/// Returns an [`ErrorType::RequestCanceled`] error type if the request was
234/// canceled by the user.
235///
236/// Returns an [`ErrorType::RequestError`] error type if creating the request
237/// failed.
238///
239/// Returns an [`ErrorType::RequestTimedOut`] error type if the request timed
240/// out. The timeout value is configured via [`ClientBuilder::timeout`].
241///
242/// Returns an [`ErrorType::Response`] error type if the request failed.
243///
244/// Returns an [`ErrorType::ServiceUnavailable`] error type if the Discord API
245/// is unavailable.
246///
247/// [`ClientBuilder::timeout`]: crate::client::ClientBuilder::timeout
248/// [`ErrorType::Json`]: crate::error::ErrorType::Json
249/// [`ErrorType::Parsing`]: crate::error::ErrorType::Parsing
250/// [`ErrorType::RequestCanceled`]: crate::error::ErrorType::RequestCanceled
251/// [`ErrorType::RequestError`]: crate::error::ErrorType::RequestError
252/// [`ErrorType::RequestTimedOut`]: crate::error::ErrorType::RequestTimedOut
253/// [`ErrorType::Response`]: crate::error::ErrorType::Response
254/// [`ErrorType::ServiceUnavailable`]: crate::error::ErrorType::ServiceUnavailable
255/// [`Response`]: super::Response
256#[must_use = "futures do nothing unless you `.await` or poll them"]
257pub struct ResponseFuture<T> {
258    phantom: PhantomData<T>,
259    stage: ResponseFutureStage,
260}
261
262impl<T> ResponseFuture<T> {
263    pub(crate) const fn new(
264        future: Pin<Box<Timeout<HyperResponseFuture>>>,
265        invalid_token: Option<Arc<AtomicBool>>,
266    ) -> Self {
267        Self {
268            phantom: PhantomData,
269            stage: ResponseFutureStage::InFlight(InFlight {
270                future,
271                invalid_token,
272                tx: None,
273            }),
274        }
275    }
276
277    /// Set a function to call after clearing the ratelimiter but prior to
278    /// sending the request to determine if the request is still valid.
279    ///
280    /// This function will be a no-op if the request has failed, has already
281    /// passed the ratelimiter, or if there is no ratelimiter configured.
282    ///
283    /// Returns whether the pre flight function was set.
284    ///
285    /// # Examples
286    ///
287    /// Delete a message, but immediately before sending the request check if
288    /// the request should still be sent:
289    ///
290    /// ```no_run
291    /// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
292    /// use std::{
293    ///     collections::HashSet,
294    ///     env,
295    ///     future::IntoFuture,
296    ///     sync::{Arc, Mutex},
297    /// };
298    /// use twilight_http::{error::ErrorType, Client};
299    /// use twilight_model::id::Id;
300    ///
301    /// let channel_id = Id::new(1);
302    /// let message_id = Id::new(2);
303    ///
304    /// let channels_ignored = {
305    ///     let mut map = HashSet::new();
306    ///     map.insert(channel_id);
307    ///
308    ///     Arc::new(Mutex::new(map))
309    /// };
310    ///
311    /// let client = Client::new(env::var("DISCORD_TOKEN")?);
312    /// let mut req = client.delete_message(channel_id, message_id).into_future();
313    ///
314    /// let channels_ignored_clone = channels_ignored.clone();
315    /// req.set_pre_flight(Box::new(move || {
316    ///     // imagine you have some logic here to external state that checks
317    ///     // whether the request should still be performed
318    ///     let channels_ignored = channels_ignored_clone.lock().expect("channels poisoned");
319    ///
320    ///     !channels_ignored.contains(&channel_id)
321    /// }));
322    ///
323    /// // the pre-flight check will cancel the request
324    /// assert!(matches!(
325    ///     req.await.unwrap_err().kind(),
326    ///     ErrorType::RequestCanceled,
327    /// ));
328    /// # Ok(()) }
329    /// ```
330    pub fn set_pre_flight(
331        &mut self,
332        pre_flight: Box<dyn FnOnce() -> bool + Send + 'static>,
333    ) -> bool {
334        if let ResponseFutureStage::RatelimitQueue(queue) = &mut self.stage {
335            queue.pre_flight_check = Some(pre_flight);
336
337            true
338        } else {
339            false
340        }
341    }
342
343    pub(crate) const fn error(source: Error) -> Self {
344        Self {
345            phantom: PhantomData,
346            stage: ResponseFutureStage::Failed(Failed { source }),
347        }
348    }
349
350    pub(crate) fn ratelimit(
351        invalid_token: Option<Arc<AtomicBool>>,
352        response_future: HyperResponseFuture,
353        timeout: Duration,
354        wait_for_sender: WaitForTicketFuture,
355    ) -> Self {
356        Self {
357            phantom: PhantomData,
358            stage: ResponseFutureStage::RatelimitQueue(RatelimitQueue {
359                invalid_token,
360                response_future,
361                timeout,
362                pre_flight_check: None,
363                wait_for_sender,
364            }),
365        }
366    }
367}
368
369impl<T: Unpin> Future for ResponseFuture<T> {
370    type Output = Output<T>;
371
372    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
373        loop {
374            let stage = mem::replace(&mut self.stage, ResponseFutureStage::Completed);
375
376            let result = match stage {
377                ResponseFutureStage::Chunking(chunking) => chunking.poll(cx),
378                ResponseFutureStage::Completed => panic!("future already completed"),
379                ResponseFutureStage::Failed(failed) => failed.poll(cx),
380                ResponseFutureStage::InFlight(in_flight) => in_flight.poll(cx),
381                ResponseFutureStage::RatelimitQueue(queue) => queue.poll(cx),
382            };
383
384            match result {
385                InnerPoll::Advance(stage) => {
386                    self.stage = stage;
387                }
388                InnerPoll::Pending(stage) => {
389                    self.stage = stage;
390
391                    return Poll::Pending;
392                }
393                InnerPoll::Ready(output) => {
394                    self.stage = ResponseFutureStage::Completed;
395
396                    return Poll::Ready(output);
397                }
398            }
399        }
400    }
401}