twilight_gateway/
shard.rs

1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13    any(feature = "zlib-stock", feature = "zlib-simd"),
14    not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18    channel::{MessageChannel, MessageSender},
19    error::{ReceiveMessageError, ReceiveMessageErrorType},
20    json,
21    latency::Latency,
22    queue::{InMemoryQueue, Queue},
23    ratelimiter::CommandRatelimiter,
24    session::Session,
25    Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30use std::{
31    env::consts::OS,
32    error::Error,
33    fmt,
34    future::Future,
35    io,
36    pin::Pin,
37    str,
38    task::{ready, Context, Poll},
39};
40use tokio::{
41    net::TcpStream,
42    sync::oneshot,
43    time::{self, error::Elapsed, timeout, Duration, Instant, Interval, MissedTickBehavior},
44};
45use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
46use twilight_model::gateway::{
47    event::GatewayEventDeserializer,
48    payload::{
49        incoming::Hello,
50        outgoing::{
51            identify::{IdentifyInfo, IdentifyProperties},
52            Heartbeat, Identify, Resume,
53        },
54    },
55    CloseCode, CloseFrame, Intents, OpCode,
56};
57
58/// URL of the Discord gateway.
59const GATEWAY_URL: &str = "wss://gateway.discord.gg";
60
61/// Query argument depending on enabled compression features.
62const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
63    "&compress=zstd-stream"
64} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
65    "&compress=zlib-stream"
66} else {
67    ""
68};
69
70/// Timeout for connecting to the gateway.
71const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
72
73/// [`tokio_websockets`] library Websocket connection.
74type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
75
76/// Wrapper enum around [`WebsocketError`] with a timeout case.
77enum ConnectionError {
78    /// Connection attempt timed out.
79    Timeout(Elapsed),
80    /// Error from the websocket library, [`tokio_websockets`].
81    Websocket(WebsocketError),
82}
83
84impl ConnectionError {
85    /// Returns the boxed wrapped error.
86    fn into_boxed_error(self) -> Box<dyn Error + Send + Sync> {
87        match self {
88            Self::Websocket(e) => Box::new(e),
89            Self::Timeout(e) => Box::new(e),
90        }
91    }
92}
93
94impl From<WebsocketError> for ConnectionError {
95    fn from(value: WebsocketError) -> Self {
96        Self::Websocket(value)
97    }
98}
99
100impl From<Elapsed> for ConnectionError {
101    fn from(value: Elapsed) -> Self {
102        Self::Timeout(value)
103    }
104}
105
106/// Wrapper struct around an `async fn` with a `Debug` implementation.
107struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, ConnectionError>> + Send>>);
108
109impl fmt::Debug for ConnectionFuture {
110    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111        f.debug_tuple("ConnectionFuture")
112            .field(&"<async fn>")
113            .finish()
114    }
115}
116
117/// Close initiator of a websocket connection.
118#[derive(Clone, Debug)]
119enum CloseInitiator {
120    /// Gateway initiated the close.
121    ///
122    /// Contains an optional close code.
123    Gateway(Option<u16>),
124    /// Shard initiated the close.
125    ///
126    /// Contains a close code.
127    Shard(CloseFrame<'static>),
128    /// Transport error initiated the close.
129    Transport,
130}
131
132/// Current state of a [Shard].
133#[derive(Clone, Copy, Debug, Eq, PartialEq)]
134pub enum ShardState {
135    /// Shard is connected to the gateway with an active session.
136    Active,
137    /// Shard is disconnected from the gateway but may reconnect in the future.
138    ///
139    /// The websocket connection may still be open.
140    Disconnected {
141        /// Number of reconnection attempts that have been made.
142        reconnect_attempts: u8,
143    },
144    /// Shard has fatally closed.
145    ///
146    /// Possible reasons may be due to [failed authentication],
147    /// [invalid intents], or other reasons. Refer to the documentation for
148    /// [`CloseCode`] for possible reasons.
149    ///
150    /// [failed authentication]: CloseCode::AuthenticationFailed
151    /// [invalid intents]: CloseCode::InvalidIntents
152    FatallyClosed,
153    /// Shard is waiting to establish or resume a session.
154    Identifying,
155    /// Shard is replaying missed dispatch events.
156    ///
157    /// The shard is considered identified whilst resuming.
158    Resuming,
159}
160
161impl ShardState {
162    /// Determine the connection status from the close code.
163    ///
164    /// Defers to [`CloseCode::can_reconnect`] to determine whether the
165    /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
166    /// the close code is unknown.
167    fn from_close_code(close_code: Option<u16>) -> Self {
168        match close_code.map(CloseCode::try_from) {
169            Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
170            _ => Self::Disconnected {
171                reconnect_attempts: 0,
172            },
173        }
174    }
175
176    /// Whether the shard has disconnected but may reconnect in the future.
177    const fn is_disconnected(self) -> bool {
178        matches!(self, Self::Disconnected { .. })
179    }
180
181    /// Whether the shard is identified with an active session.
182    ///
183    /// `true` if the status is [`Active`] or [`Resuming`].
184    ///
185    /// [`Active`]: Self::Active
186    /// [`Resuming`]: Self::Resuming
187    pub const fn is_identified(self) -> bool {
188        matches!(self, Self::Active | Self::Resuming)
189    }
190}
191
192/// Gateway event with only minimal required data.
193#[derive(Deserialize)]
194struct MinimalEvent<T> {
195    /// Attached data of the gateway event.
196    #[serde(rename = "d")]
197    data: T,
198}
199
200/// Minimal [`Ready`] for light deserialization.
201///
202/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
203#[derive(Deserialize)]
204struct MinimalReady {
205    /// Used for resuming connections.
206    resume_gateway_url: Box<str>,
207    /// ID of the new identified session.
208    session_id: String,
209}
210
211/// Pending outgoing message indicator.
212#[derive(Debug)]
213struct Pending {
214    /// The pending message, if not already sent.
215    gateway_event: Option<Message>,
216    /// Whether the pending gateway event is a heartbeat.
217    is_heartbeat: bool,
218}
219
220impl Pending {
221    /// Constructor for a pending gateway event.
222    const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
223        Some(Self {
224            gateway_event: Some(Message::Text(json)),
225            is_heartbeat,
226        })
227    }
228}
229
230/// Gateway API client responsible for up to 2500 guilds.
231///
232/// Shards are responsible for maintaining the gateway connection by processing
233/// events relevant to the operation of shards---such as requests from the
234/// gateway to re-connect or invalidate a session---and then to pass them on to
235/// the user.
236///
237/// Shards start out disconnected, but will on the first successful call to
238/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
239/// be repeatedly called in order for the shard to maintain its connection and
240/// update its internal state.
241///
242/// Shards go through an [identify queue][`queue`] that rate limits concurrent
243/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
244/// invalidates the shard's session and it is therefore **very important** to
245/// reuse the same queue for all shards.
246///
247/// # Sharding
248///
249/// A shard may not be connected to more than 2500 guilds, so large bots must
250/// split themselves across multiple shards. See the
251/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
252/// more info.
253///
254/// # Examples
255///
256/// Create and start a shard and print new and deleted messages:
257///
258/// ```no_run
259/// use std::env;
260/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
261///
262/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
263/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
264/// // token. Of course, this value may be passed into the program however is
265/// // preferred.
266/// let token = env::var("DISCORD_TOKEN")?;
267/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
268///
269/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
270///
271/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
272///     let Ok(event) = item else {
273///         tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
274///
275///         continue;
276///     };
277///
278///     match event {
279///         Event::MessageCreate(message) => {
280///             println!("message received with content: {}", message.content);
281///         }
282///         Event::MessageDelete(message) => {
283///             println!("message with ID {} deleted", message.id);
284///         }
285///         _ => {}
286///     }
287/// }
288/// # Ok(()) }
289/// ```
290///
291/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
292/// [gateway commands]: Shard::command
293/// [`poll_next`]: Shard::poll_next
294/// [`queue`]: crate::queue
295#[derive(Debug)]
296pub struct Shard<Q = InMemoryQueue> {
297    /// User provided configuration.
298    ///
299    /// Configurations are provided or created in shard initializing via
300    /// [`Shard::new`] or [`Shard::with_config`].
301    config: Config<Q>,
302    /// Future to establish a WebSocket connection with the Gateway.
303    connection_future: Option<ConnectionFuture>,
304    /// Websocket connection, which may be connected to Discord's gateway.
305    ///
306    /// The connection should only be dropped after it has returned `Ok(None)`
307    /// to comply with the WebSocket protocol.
308    connection: Option<Connection>,
309    /// Zstd decompressor.
310    #[cfg(feature = "zstd")]
311    decompressor: Decompressor,
312    /// Interval of how often the gateway would like the shard to send
313    /// heartbeats.
314    ///
315    /// The interval is received in the [`GatewayEvent::Hello`] event when
316    /// first opening a new [connection].
317    ///
318    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
319    /// [connection]: Self::connection
320    heartbeat_interval: Option<Interval>,
321    /// Whether an event has been received in the current heartbeat interval.
322    heartbeat_interval_event: bool,
323    /// ID of the shard.
324    id: ShardId,
325    /// Identify queue receiver.
326    identify_rx: Option<oneshot::Receiver<()>>,
327    /// Zlib decompressor.
328    #[allow(deprecated)]
329    #[cfg(all(
330        any(feature = "zlib-stock", feature = "zlib-simd"),
331        not(feature = "zstd")
332    ))]
333    inflater: Inflater,
334    /// Potentially pending outgoing message.
335    pending: Option<Pending>,
336    /// Recent heartbeat latency statistics.
337    ///
338    /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
339    /// may have changed, invalidating previous latency statistic.
340    ///
341    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
342    latency: Latency,
343    /// Command ratelimiter, if it was enabled via
344    /// [`Config::ratelimit_messages`].
345    ratelimiter: Option<CommandRatelimiter>,
346    /// Used for resuming connections.
347    resume_url: Option<Box<str>>,
348    /// Active session of the shard.
349    ///
350    /// The shard may not have an active session if it hasn't yet identified and
351    /// received a `READY` dispatch event response.
352    session: Option<Session>,
353    /// Current state of the shard.
354    state: ShardState,
355    /// Messages from the user to be relayed and sent over the Websocket
356    /// connection.
357    user_channel: MessageChannel,
358}
359
360impl Shard {
361    /// Create a new shard with the default configuration.
362    pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
363        Self::with_config(id, Config::new(token, intents))
364    }
365}
366
367impl<Q> Shard<Q> {
368    /// Create a new shard with the provided configuration.
369    pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
370        let session = config.take_session();
371        let mut resume_url = config.take_resume_url();
372        //ensure resume_url is only used if we have a session to resume
373        if session.is_none() {
374            resume_url = None;
375        }
376
377        Self {
378            config,
379            connection_future: None,
380            connection: None,
381            #[cfg(feature = "zstd")]
382            decompressor: Decompressor::new(),
383            heartbeat_interval: None,
384            heartbeat_interval_event: false,
385            id: shard_id,
386            identify_rx: None,
387            #[allow(deprecated)]
388            #[cfg(all(
389                any(feature = "zlib-stock", feature = "zlib-simd"),
390                not(feature = "zstd")
391            ))]
392            inflater: Inflater::new(),
393            pending: None,
394            latency: Latency::new(),
395            ratelimiter: None,
396            resume_url,
397            session,
398            state: ShardState::Disconnected {
399                reconnect_attempts: 0,
400            },
401            user_channel: MessageChannel::new(),
402        }
403    }
404
405    /// Immutable reference to the configuration used to instantiate this shard.
406    pub const fn config(&self) -> &Config<Q> {
407        &self.config
408    }
409
410    /// ID of the shard.
411    pub const fn id(&self) -> ShardId {
412        self.id
413    }
414
415    /// Zlib decompressor statistics.
416    ///
417    /// Reset when reconnecting to the gateway.
418    #[allow(deprecated)]
419    #[cfg(all(
420        any(feature = "zlib-stock", feature = "zlib-simd"),
421        not(feature = "zstd")
422    ))]
423    #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
424    pub const fn inflater(&self) -> &Inflater {
425        &self.inflater
426    }
427
428    /// State of the shard.
429    pub const fn state(&self) -> ShardState {
430        self.state
431    }
432
433    /// Shard latency statistics, including average latency and recent heartbeat
434    /// latency times.
435    ///
436    /// Reset when reconnecting to the gateway.
437    pub const fn latency(&self) -> &Latency {
438        &self.latency
439    }
440
441    /// Statistics about the number of available commands and when the command
442    /// ratelimiter will refresh.
443    ///
444    /// This won't be present if ratelimiting was disabled via
445    /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
446    ///
447    /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
448    pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
449        self.ratelimiter.as_ref()
450    }
451
452    /// Immutable reference to the gateways current resume URL.
453    ///
454    /// A resume URL might not be present if the shard had its session
455    /// invalidated and has not yet reconnected.
456    pub fn resume_url(&self) -> Option<&str> {
457        self.resume_url.as_deref()
458    }
459
460    /// Immutable reference to the active gateway session.
461    ///
462    /// An active session may not be present if the shard had its session
463    /// invalidated and has not yet reconnected.
464    pub const fn session(&self) -> Option<&Session> {
465        self.session.as_ref()
466    }
467
468    /// Queue a command to be sent to the gateway.
469    ///
470    /// Serializes the command and then calls [`send`].
471    ///
472    /// [`send`]: Self::send
473    #[allow(clippy::missing_panics_doc)]
474    pub fn command(&self, command: &impl Command) {
475        self.send(json::to_string(command).expect("serialization cannot fail"));
476    }
477
478    /// Queue a JSON encoded gateway event to be sent to the gateway.
479    #[allow(clippy::missing_panics_doc)]
480    pub fn send(&self, json: String) {
481        self.user_channel
482            .command_tx
483            .send(json)
484            .expect("channel open");
485    }
486
487    /// Queue a websocket close frame.
488    ///
489    /// Invalidates the session and shows the application's bot as offline if
490    /// the close frame code is `1000` or `1001`. Otherwise Discord will
491    /// continue showing the bot as online until its presence times out.
492    ///
493    /// To read all remaining messages, continue calling [`poll_next`] until it
494    /// returns [`Message::Close`].
495    ///
496    /// # Example
497    ///
498    /// Close the shard and process remaining messages:
499    ///
500    /// ```no_run
501    /// # use twilight_gateway::{Intents, Shard, ShardId};
502    /// # #[tokio::main] async fn main() {
503    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
504    /// use tokio_stream::StreamExt;
505    /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
506    ///
507    /// shard.close(CloseFrame::NORMAL);
508    ///
509    /// while let Some(item) = shard.next().await {
510    ///     match item {
511    ///         Ok(Message::Close(_)) => break,
512    ///         Ok(Message::Text(_)) => unimplemented!(),
513    ///         Err(source) => unimplemented!(),
514    ///     }
515    /// }
516    /// # }
517    /// ```
518    ///
519    /// [`poll_next`]: Shard::poll_next
520    pub fn close(&self, close_frame: CloseFrame<'static>) {
521        _ = self.user_channel.close_tx.try_send(close_frame);
522    }
523
524    /// Retrieve a channel to send messages over the shard to the gateway.
525    ///
526    /// This is primarily useful for sending to other tasks and threads where
527    /// the shard won't be available.
528    ///
529    /// # Example
530    ///
531    /// Queue a command in another process:
532    ///
533    /// ```no_run
534    /// # use twilight_gateway::{Intents, Shard, ShardId};
535    /// # #[tokio::main] async fn main() {
536    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
537    /// use tokio_stream::StreamExt;
538    ///
539    /// while let Some(item) = shard.next().await {
540    ///     match item {
541    ///         Ok(message) => {
542    ///             let sender = shard.sender();
543    ///             tokio::spawn(async move {
544    ///                 let command = unimplemented!();
545    ///                 sender.send(command);
546    ///             });
547    ///         }
548    ///         Err(source) => unimplemented!(),
549    ///     }
550    /// }
551    /// # }
552    /// ```
553    pub fn sender(&self) -> MessageSender {
554        self.user_channel.sender()
555    }
556
557    /// Update internal state from gateway disconnect.
558    fn disconnect(&mut self, initiator: CloseInitiator) {
559        // May not send any additional WebSocket messages.
560        self.heartbeat_interval = None;
561        self.ratelimiter = None;
562        // Abort identify.
563        self.identify_rx = None;
564        self.state = match initiator {
565            CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
566            _ => ShardState::Disconnected {
567                reconnect_attempts: 0,
568            },
569        };
570        if let CloseInitiator::Shard(frame) = initiator {
571            // Not resuming, drop session and resume URL.
572            // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
573            if matches!(frame.code, 1000 | 1001) {
574                self.resume_url = None;
575                self.session = None;
576            }
577            self.pending = Some(Pending {
578                gateway_event: Some(Message::Close(Some(frame))),
579                is_heartbeat: false,
580            });
581        }
582    }
583
584    /// Parse a JSON message into an event with minimal data for [processing].
585    ///
586    /// # Errors
587    ///
588    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
589    /// event isn't a recognized structure, which may be the case for new or
590    /// undocumented events.
591    ///
592    /// [processing]: Self::process
593    fn parse_event<T: DeserializeOwned>(
594        json: &str,
595    ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
596        json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
597            kind: ReceiveMessageErrorType::Deserializing {
598                event: json.to_owned(),
599            },
600            source: Some(Box::new(source)),
601        })
602    }
603}
604
605impl<Q: Queue> Shard<Q> {
606    /// Attempts to send due commands to the gateway.
607    ///
608    /// # Returns
609    ///
610    /// * `Poll::Pending` if sending is in progress
611    /// * `Poll::Ready(Ok)` if no more scheduled commands remain
612    /// * `Poll::Ready(Err)` if sending a command failed.
613    fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
614        loop {
615            if let Some(pending) = self.pending.as_mut() {
616                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
617
618                if let Some(message) = &pending.gateway_event {
619                    if let Some(ratelimiter) = self.ratelimiter.as_mut() {
620                        if message.is_text() && !pending.is_heartbeat {
621                            ready!(ratelimiter.poll_acquire(cx));
622                        }
623                    }
624
625                    let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
626                    Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
627                }
628
629                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
630
631                if pending.is_heartbeat {
632                    self.latency.record_sent();
633                }
634                self.pending = None;
635            }
636
637            if !self.state.is_disconnected() {
638                if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
639                    let frame = frame.expect("shard owns channel");
640
641                    tracing::debug!("sending close frame from user channel");
642                    self.disconnect(CloseInitiator::Shard(frame));
643
644                    continue;
645                }
646            }
647
648            if self
649                .heartbeat_interval
650                .as_mut()
651                .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
652            {
653                // Discord never responded after the last heartbeat, connection
654                // is failed or "zombied", see
655                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
656                // Note that unlike documented *any* event is okay; it does not
657                // have to be a heartbeat ACK.
658                if self.latency.sent().is_some() && !self.heartbeat_interval_event {
659                    tracing::info!("connection is failed or \"zombied\"");
660
661                    return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
662                }
663
664                tracing::debug!("sending heartbeat");
665                self.pending = Pending::text(
666                    json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
667                        .expect("serialization cannot fail"),
668                    true,
669                );
670                self.heartbeat_interval_event = false;
671
672                continue;
673            }
674
675            let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
676                ratelimiter.poll_available(cx).is_ready()
677            });
678
679            if not_ratelimited {
680                if let Some(Poll::Ready(canceled)) = self
681                    .identify_rx
682                    .as_mut()
683                    .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
684                {
685                    if canceled {
686                        self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
687                        continue;
688                    }
689
690                    tracing::debug!("sending identify");
691
692                    self.pending = Pending::text(
693                        json::to_string(&Identify::new(IdentifyInfo {
694                            compress: false,
695                            intents: self.config.intents(),
696                            large_threshold: self.config.large_threshold(),
697                            presence: self.config.presence().cloned(),
698                            properties: self
699                                .config
700                                .identify_properties()
701                                .cloned()
702                                .unwrap_or_else(default_identify_properties),
703                            shard: Some(self.id),
704                            token: self.config.token().to_owned(),
705                        }))
706                        .expect("serialization cannot fail"),
707                        false,
708                    );
709                    self.identify_rx = None;
710
711                    continue;
712                }
713            }
714
715            if not_ratelimited && self.state.is_identified() {
716                if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
717                    let command = command.expect("shard owns channel");
718
719                    tracing::debug!("sending command from user channel");
720                    self.pending = Some(Pending {
721                        gateway_event: Some(Message::Text(command)),
722                        is_heartbeat: false,
723                    });
724
725                    continue;
726                }
727            }
728
729            return Poll::Ready(Ok(()));
730        }
731    }
732
733    /// Updates the shard's internal state from a gateway event by recording
734    /// and/or responding to certain Discord events.
735    ///
736    /// # Errors
737    ///
738    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
739    /// gateway event isn't a recognized structure.
740    #[allow(clippy::too_many_lines)]
741    fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
742        let (raw_opcode, maybe_sequence, maybe_event_type) =
743            GatewayEventDeserializer::from_json(event)
744                .ok_or_else(|| ReceiveMessageError {
745                    kind: ReceiveMessageErrorType::Deserializing {
746                        event: event.to_owned(),
747                    },
748                    source: Some("missing opcode".into()),
749                })?
750                .into_parts();
751
752        if self.latency.sent().is_some() {
753            self.heartbeat_interval_event = true;
754        }
755
756        match OpCode::from(raw_opcode) {
757            Some(OpCode::Dispatch) => {
758                let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
759                    kind: ReceiveMessageErrorType::Deserializing {
760                        event: event.to_owned(),
761                    },
762                    source: Some("missing dispatch event type".into()),
763                })?;
764                let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
765                    kind: ReceiveMessageErrorType::Deserializing {
766                        event: event.to_owned(),
767                    },
768                    source: Some("missing sequence".into()),
769                })?;
770                tracing::debug!(%event_type, %sequence, "received dispatch");
771
772                match event_type.as_ref() {
773                    "READY" => {
774                        let event = Self::parse_event::<MinimalReady>(event)?;
775
776                        self.resume_url = Some(event.data.resume_gateway_url);
777                        self.session = Some(Session::new(sequence, event.data.session_id));
778                        self.state = ShardState::Active;
779                    }
780                    "RESUMED" => self.state = ShardState::Active,
781                    _ => {}
782                }
783
784                if let Some(session) = self.session.as_mut() {
785                    session.set_sequence(sequence);
786                }
787            }
788            Some(OpCode::Heartbeat) => {
789                tracing::debug!("received heartbeat");
790                self.pending = Pending::text(
791                    json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
792                        .expect("serialization cannot fail"),
793                    true,
794                );
795            }
796            Some(OpCode::HeartbeatAck) => {
797                let requested = self.latency.received().is_none() && self.latency.sent().is_some();
798                if requested {
799                    tracing::debug!("received heartbeat ack");
800                    self.latency.record_received();
801                } else {
802                    tracing::info!("received unrequested heartbeat ack");
803                }
804            }
805            Some(OpCode::Hello) => {
806                let event = Self::parse_event::<Hello>(event)?;
807                let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
808                // First heartbeat should have some jitter, see
809                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
810                let jitter = heartbeat_interval.mul_f64(fastrand::f64());
811                tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
812
813                if self.config().ratelimit_messages() {
814                    self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
815                }
816
817                let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
818                interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
819                self.heartbeat_interval = Some(interval);
820
821                // Reset `Latency` since the shard might have connected to a new
822                // remote which invalidates the recorded latencies.
823                self.latency = Latency::new();
824
825                if let Some(session) = &self.session {
826                    self.pending = Pending::text(
827                        json::to_string(&Resume::new(
828                            session.sequence(),
829                            session.id(),
830                            self.config.token(),
831                        ))
832                        .expect("serialization cannot fail"),
833                        false,
834                    );
835                    self.state = ShardState::Resuming;
836                } else {
837                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
838                }
839            }
840            Some(OpCode::InvalidSession) => {
841                let resumable = Self::parse_event(event)?.data;
842                tracing::debug!(resumable, "received invalid session");
843                if resumable {
844                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
845                } else {
846                    self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
847                }
848            }
849            Some(OpCode::Reconnect) => {
850                tracing::debug!("received reconnect");
851                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
852            }
853            _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
854        }
855
856        Ok(())
857    }
858}
859
860impl<Q: Queue + Unpin> Stream for Shard<Q> {
861    type Item = Result<Message, ReceiveMessageError>;
862
863    #[allow(clippy::too_many_lines)]
864    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
865        let message = loop {
866            match self.state {
867                ShardState::FatallyClosed => {
868                    _ = ready!(Pin::new(
869                        self.connection
870                            .as_mut()
871                            .expect("poll_next called after Poll::Ready(None)")
872                    )
873                    .poll_close(cx));
874                    self.connection = None;
875                    return Poll::Ready(None);
876                }
877                ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
878                    if self.connection_future.is_none() {
879                        let base_url = self
880                            .resume_url
881                            .as_deref()
882                            .or_else(|| self.config.proxy_url())
883                            .unwrap_or(GATEWAY_URL);
884                        let uri = format!(
885                            "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
886                        );
887
888                        tracing::debug!(url = base_url, "connecting to gateway");
889
890                        let tls = self.config.tls.clone();
891                        self.connection_future = Some(ConnectionFuture(Box::pin(async move {
892                            let secs = 2u8.saturating_pow(reconnect_attempts.into());
893                            time::sleep(Duration::from_secs(secs.into())).await;
894
895                            Ok(timeout(
896                                CONNECT_TIMEOUT,
897                                ClientBuilder::new()
898                                    .uri(&uri)
899                                    .expect("URL should be valid")
900                                    .limits(Limits::unlimited())
901                                    .connector(&tls)
902                                    .connect(),
903                            )
904                            .await??
905                            .0)
906                        })));
907                    }
908
909                    let res =
910                        ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
911                    self.connection_future = None;
912                    match res {
913                        Ok(connection) => {
914                            self.connection = Some(connection);
915                            self.state = ShardState::Identifying;
916                            #[cfg(feature = "zstd")]
917                            self.decompressor.reset();
918                            #[allow(deprecated)]
919                            #[cfg(all(
920                                not(feature = "zstd"),
921                                any(feature = "zlib-stock", feature = "zlib-simd")
922                            ))]
923                            self.inflater.reset();
924                        }
925                        Err(source) => {
926                            self.resume_url = None;
927                            self.state = ShardState::Disconnected {
928                                reconnect_attempts: reconnect_attempts + 1,
929                            };
930
931                            return Poll::Ready(Some(Err(ReceiveMessageError {
932                                kind: ReceiveMessageErrorType::Reconnect,
933                                source: Some(source.into_boxed_error()),
934                            })));
935                        }
936                    }
937                }
938                _ => {}
939            }
940
941            if ready!(self.poll_send(cx)).is_err() {
942                self.disconnect(CloseInitiator::Transport);
943                self.connection = None;
944
945                return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
946            }
947
948            match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
949                Some(Ok(message)) => {
950                    #[cfg(feature = "zstd")]
951                    if message.is_binary() {
952                        match self.decompressor.decompress(message.as_payload()) {
953                            Ok(message) => break Message::Text(message),
954                            Err(source) => {
955                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
956                                return Poll::Ready(Some(Err(
957                                    ReceiveMessageError::from_compression(source),
958                                )));
959                            }
960                        }
961                    }
962                    #[cfg(all(
963                        not(feature = "zstd"),
964                        any(feature = "zlib-stock", feature = "zlib-simd")
965                    ))]
966                    if message.is_binary() {
967                        match self.inflater.inflate(message.as_payload()) {
968                            Ok(Some(message)) => break Message::Text(message),
969                            Ok(None) => continue,
970                            Err(source) => {
971                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
972                                return Poll::Ready(Some(Err(
973                                    ReceiveMessageError::from_compression(source),
974                                )));
975                            }
976                        }
977                    }
978                    if let Some(message) = Message::from_websocket_msg(&message) {
979                        break message;
980                    }
981                }
982                Some(Err(_)) if self.state.is_disconnected() => {}
983                Some(Err(_)) => {
984                    self.disconnect(CloseInitiator::Transport);
985                    return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
986                }
987                None => {
988                    _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
989                    tracing::debug!("gateway WebSocket connection closed");
990                    // Unclean closure.
991                    if !self.state.is_disconnected() {
992                        self.disconnect(CloseInitiator::Transport);
993                    }
994                    self.connection = None;
995                }
996            }
997        };
998
999        match &message {
1000            Message::Close(frame) => {
1001                // tokio-websockets automatically replies to the close message.
1002                tracing::debug!(?frame, "received WebSocket close message");
1003                // Don't run `disconnect` if we initiated the close.
1004                if !self.state.is_disconnected() {
1005                    self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
1006                }
1007            }
1008            Message::Text(event) => {
1009                self.process(event)?;
1010            }
1011        }
1012
1013        Poll::Ready(Some(Ok(message)))
1014    }
1015}
1016
1017/// Default identify properties to use when the user hasn't customized it in
1018/// [`Config::identify_properties`].
1019///
1020/// [`Config::identify_properties`]: Config::identify_properties
1021fn default_identify_properties() -> IdentifyProperties {
1022    IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027    use super::Shard;
1028    use static_assertions::{assert_impl_all, assert_not_impl_any};
1029    use std::fmt::Debug;
1030
1031    assert_impl_all!(Shard: Debug, Send);
1032    assert_not_impl_any!(Shard: Sync);
1033}