1use super::{Response, StatusCode};
2use crate::{
3 api_error::ApiError,
4 error::{Error, ErrorType},
5};
6use http::StatusCode as HyperStatusCode;
7use hyper_util::client::legacy::ResponseFuture as HyperResponseFuture;
8use std::{
9 future::Future,
10 marker::PhantomData,
11 mem,
12 pin::Pin,
13 sync::{
14 atomic::{AtomicBool, Ordering},
15 Arc,
16 },
17 task::{Context, Poll},
18 time::Duration,
19};
20use tokio::time::{self, Timeout};
21use twilight_http_ratelimiting::{ticket::TicketSender, RatelimitHeaders, WaitForTicketFuture};
22
23type Output<T> = Result<Response<T>, Error>;
24
25enum InnerPoll<T> {
26 Advance(ResponseFutureStage),
27 Pending(ResponseFutureStage),
28 Ready(Output<T>),
29}
30
31struct Chunking {
32 future: Pin<Box<dyn Future<Output = Result<Vec<u8>, Error>> + Send + Sync + 'static>>,
33 status: HyperStatusCode,
34}
35
36impl Chunking {
37 fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
38 let bytes = match Pin::new(&mut self.future).poll(cx) {
39 Poll::Ready(Ok(bytes)) => bytes,
40 Poll::Ready(Err(source)) => return InnerPoll::Ready(Err(source)),
41 Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::Chunking(self)),
42 };
43
44 let error = match crate::json::from_bytes::<ApiError>(&bytes) {
45 Ok(error) => error,
46 Err(source) => {
47 return InnerPoll::Ready(Err(Error {
48 kind: ErrorType::Parsing { body: bytes },
49 source: Some(Box::new(source)),
50 }));
51 }
52 };
53
54 InnerPoll::Ready(Err(Error {
55 kind: ErrorType::Response {
56 body: bytes,
57 error,
58 status: StatusCode::new(self.status.as_u16()),
59 },
60 source: None,
61 }))
62 }
63}
64
65struct Failed {
66 source: Error,
67}
68
69impl Failed {
70 fn poll<T>(self, _: &mut Context<'_>) -> InnerPoll<T> {
71 InnerPoll::Ready(Err(self.source))
72 }
73}
74
75struct InFlight {
76 future: Pin<Box<Timeout<HyperResponseFuture>>>,
77 invalid_token: Option<Arc<AtomicBool>>,
78 tx: Option<TicketSender>,
79}
80
81impl InFlight {
82 fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
83 let resp = match Pin::new(&mut self.future).poll(cx) {
84 Poll::Ready(Ok(Ok(resp))) => resp,
85 Poll::Ready(Ok(Err(source))) => {
86 return InnerPoll::Ready(Err(Error {
87 kind: ErrorType::RequestError,
88 source: Some(Box::new(source)),
89 }))
90 }
91 Poll::Ready(Err(source)) => {
92 return InnerPoll::Ready(Err(Error {
93 kind: ErrorType::RequestTimedOut,
94 source: Some(Box::new(source)),
95 }))
96 }
97 Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::InFlight(self)),
98 };
99
100 if resp.status() == HyperStatusCode::UNAUTHORIZED {
104 if let Some(invalid_token) = self.invalid_token {
105 invalid_token.store(true, Ordering::Relaxed);
106 }
107 }
108
109 if let Some(tx) = self.tx {
110 let headers = resp
111 .headers()
112 .iter()
113 .map(|(key, value)| (key.as_str(), value.as_bytes()));
114
115 match RatelimitHeaders::from_pairs(headers) {
116 Ok(v) => {
117 let _res = tx.headers(Some(v));
118 }
119 Err(source) => {
120 tracing::warn!("header parsing failed: {source:?}; {resp:?}");
121
122 let _res = tx.headers(None);
123 }
124 }
125 }
126
127 let status = resp.status();
128
129 if status.is_success() {
130 #[cfg(feature = "decompression")]
131 let mut resp = resp;
132 #[cfg(feature = "decompression")]
134 resp.headers_mut().remove(http::header::CONTENT_LENGTH);
135
136 return InnerPoll::Ready(Ok(Response::new(resp)));
137 }
138
139 match status {
140 HyperStatusCode::TOO_MANY_REQUESTS => {
141 tracing::warn!("429 response: {resp:?}");
142 }
143 HyperStatusCode::SERVICE_UNAVAILABLE => {
144 return InnerPoll::Ready(Err(Error {
145 kind: ErrorType::ServiceUnavailable { response: resp },
146 source: None,
147 }));
148 }
149 _ => {}
150 }
151
152 let fut = async {
153 Response::<()>::new(resp)
154 .bytes()
155 .await
156 .map_err(|source| Error {
157 kind: ErrorType::ChunkingResponse,
158 source: Some(Box::new(source)),
159 })
160 };
161
162 InnerPoll::Advance(ResponseFutureStage::Chunking(Chunking {
163 future: Box::pin(fut),
164 status,
165 }))
166 }
167}
168
169struct RatelimitQueue {
170 invalid_token: Option<Arc<AtomicBool>>,
171 response_future: HyperResponseFuture,
172 timeout: Duration,
173 pre_flight_check: Option<Box<dyn FnOnce() -> bool + Send + 'static>>,
174 wait_for_sender: WaitForTicketFuture,
175}
176
177impl RatelimitQueue {
178 fn poll<T>(mut self, cx: &mut Context<'_>) -> InnerPoll<T> {
179 let tx = match Pin::new(&mut self.wait_for_sender).poll(cx) {
180 Poll::Ready(Ok(tx)) => tx,
181 Poll::Ready(Err(source)) => {
182 return InnerPoll::Ready(Err(Error {
183 kind: ErrorType::RatelimiterTicket,
184 source: Some(source),
185 }))
186 }
187 Poll::Pending => return InnerPoll::Pending(ResponseFutureStage::RatelimitQueue(self)),
188 };
189
190 if let Some(pre_flight_check) = self.pre_flight_check {
191 if !pre_flight_check() {
192 return InnerPoll::Ready(Err(Error {
193 kind: ErrorType::RequestCanceled,
194 source: None,
195 }));
196 }
197 }
198
199 InnerPoll::Advance(ResponseFutureStage::InFlight(InFlight {
200 future: Box::pin(time::timeout(self.timeout, self.response_future)),
201 invalid_token: self.invalid_token,
202 tx: Some(tx),
203 }))
204 }
205}
206
207enum ResponseFutureStage {
208 Chunking(Chunking),
209 Completed,
210 Failed(Failed),
211 InFlight(InFlight),
212 RatelimitQueue(RatelimitQueue),
213}
214
215#[must_use = "futures do nothing unless you `.await` or poll them"]
257pub struct ResponseFuture<T> {
258 phantom: PhantomData<T>,
259 stage: ResponseFutureStage,
260}
261
262impl<T> ResponseFuture<T> {
263 pub(crate) const fn new(
264 future: Pin<Box<Timeout<HyperResponseFuture>>>,
265 invalid_token: Option<Arc<AtomicBool>>,
266 ) -> Self {
267 Self {
268 phantom: PhantomData,
269 stage: ResponseFutureStage::InFlight(InFlight {
270 future,
271 invalid_token,
272 tx: None,
273 }),
274 }
275 }
276
277 pub fn set_pre_flight(
331 &mut self,
332 pre_flight: Box<dyn FnOnce() -> bool + Send + 'static>,
333 ) -> bool {
334 if let ResponseFutureStage::RatelimitQueue(queue) = &mut self.stage {
335 queue.pre_flight_check = Some(pre_flight);
336
337 true
338 } else {
339 false
340 }
341 }
342
343 pub(crate) const fn error(source: Error) -> Self {
344 Self {
345 phantom: PhantomData,
346 stage: ResponseFutureStage::Failed(Failed { source }),
347 }
348 }
349
350 pub(crate) fn ratelimit(
351 invalid_token: Option<Arc<AtomicBool>>,
352 response_future: HyperResponseFuture,
353 timeout: Duration,
354 wait_for_sender: WaitForTicketFuture,
355 ) -> Self {
356 Self {
357 phantom: PhantomData,
358 stage: ResponseFutureStage::RatelimitQueue(RatelimitQueue {
359 invalid_token,
360 response_future,
361 timeout,
362 pre_flight_check: None,
363 wait_for_sender,
364 }),
365 }
366 }
367}
368
369impl<T: Unpin> Future for ResponseFuture<T> {
370 type Output = Output<T>;
371
372 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
373 loop {
374 let stage = mem::replace(&mut self.stage, ResponseFutureStage::Completed);
375
376 let result = match stage {
377 ResponseFutureStage::Chunking(chunking) => chunking.poll(cx),
378 ResponseFutureStage::Completed => panic!("future already completed"),
379 ResponseFutureStage::Failed(failed) => failed.poll(cx),
380 ResponseFutureStage::InFlight(in_flight) => in_flight.poll(cx),
381 ResponseFutureStage::RatelimitQueue(queue) => queue.poll(cx),
382 };
383
384 match result {
385 InnerPoll::Advance(stage) => {
386 self.stage = stage;
387 }
388 InnerPoll::Pending(stage) => {
389 self.stage = stage;
390
391 return Poll::Pending;
392 }
393 InnerPoll::Ready(output) => {
394 self.stage = ResponseFutureStage::Completed;
395
396 return Poll::Ready(output);
397 }
398 }
399 }
400 }
401}