twilight_gateway/
shard.rs

1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(any(feature = "zlib", feature = "zstd"))]
10use crate::compression::Decompressor;
11use crate::{
12    API_VERSION, Command, Config, Message, ShardId,
13    channel::{MessageChannel, MessageSender},
14    error::{ReceiveMessageError, ReceiveMessageErrorType},
15    json,
16    latency::Latency,
17    queue::{InMemoryQueue, Queue},
18    ratelimiter::CommandRatelimiter,
19    session::Session,
20};
21use futures_core::Stream;
22use futures_sink::Sink;
23use serde::{Deserialize, Serialize, de::DeserializeOwned};
24use std::{
25    env::consts::OS,
26    error::Error,
27    fmt,
28    future::Future,
29    io,
30    pin::Pin,
31    str,
32    sync::Arc,
33    task::{Context, Poll, ready},
34};
35use tokio::{
36    net::TcpStream,
37    sync::oneshot,
38    time::{self, Duration, Instant, Interval, MissedTickBehavior},
39};
40use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
41use twilight_model::gateway::{
42    CloseCode, CloseFrame, Intents, OpCode,
43    event::GatewayEventDeserializer,
44    payload::{
45        incoming::Hello,
46        outgoing::{
47            Heartbeat, Identify, Resume,
48            identify::{IdentifyInfo, IdentifyProperties},
49        },
50    },
51};
52
53/// URL of the Discord gateway.
54const GATEWAY_URL: &str = "wss://gateway.discord.gg";
55
56/// Query argument depending on enabled compression features.
57const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
58    "&compress=zstd-stream"
59} else if cfg!(feature = "zlib") {
60    "&compress=zlib-stream"
61} else {
62    ""
63};
64
65/// Timeout for connecting to the gateway.
66const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
67
68/// [`tokio_websockets`] library Websocket connection.
69type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
70
71/// Dynamically dispatched [`Error`].
72type GenericError = Box<dyn Error + Send + Sync>;
73
74/// Wrapper struct around an `async fn` with a `Debug` implementation.
75struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, GenericError>> + Send>>);
76
77impl fmt::Debug for ConnectionFuture {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_tuple("ConnectionFuture")
80            .field(&"<async fn>")
81            .finish()
82    }
83}
84
85/// Close initiator of a websocket connection.
86#[derive(Clone, Debug)]
87enum CloseInitiator {
88    /// Gateway initiated the close.
89    ///
90    /// Contains an optional close code.
91    Gateway(Option<u16>),
92    /// Shard initiated the close.
93    ///
94    /// Contains a close code.
95    Shard(CloseFrame<'static>),
96    /// Transport error initiated the close.
97    Transport,
98}
99
100/// Current state of a [Shard].
101#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102pub enum ShardState {
103    /// Shard is connected to the gateway with an active session.
104    Active,
105    /// Shard is disconnected from the gateway but may reconnect in the future.
106    ///
107    /// The websocket connection may still be open.
108    Disconnected {
109        /// Number of reconnection attempts that have been made.
110        reconnect_attempts: u8,
111    },
112    /// Shard has fatally closed.
113    ///
114    /// Possible reasons may be due to [failed authentication],
115    /// [invalid intents], or other reasons. Refer to the documentation for
116    /// [`CloseCode`] for possible reasons.
117    ///
118    /// [failed authentication]: CloseCode::AuthenticationFailed
119    /// [invalid intents]: CloseCode::InvalidIntents
120    FatallyClosed,
121    /// Shard is waiting to establish or resume a session.
122    Identifying,
123    /// Shard is replaying missed dispatch events.
124    ///
125    /// The shard is considered identified whilst resuming.
126    Resuming,
127}
128
129impl ShardState {
130    /// Determine the connection status from the close code.
131    ///
132    /// Defers to [`CloseCode::can_reconnect`] to determine whether the
133    /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
134    /// the close code is unknown.
135    fn from_close_code(close_code: Option<u16>) -> Self {
136        match close_code.map(CloseCode::try_from) {
137            Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
138            _ => Self::Disconnected {
139                reconnect_attempts: 0,
140            },
141        }
142    }
143
144    /// Whether the shard has disconnected but may reconnect in the future.
145    const fn is_disconnected(self) -> bool {
146        matches!(self, Self::Disconnected { .. })
147    }
148
149    /// Whether the shard is identified with an active session.
150    ///
151    /// `true` if the status is [`Active`] or [`Resuming`].
152    ///
153    /// [`Active`]: Self::Active
154    /// [`Resuming`]: Self::Resuming
155    pub const fn is_identified(self) -> bool {
156        matches!(self, Self::Active | Self::Resuming)
157    }
158}
159
160/// Gateway event with only minimal required data.
161#[derive(Deserialize)]
162struct MinimalEvent<T> {
163    /// Attached data of the gateway event.
164    #[serde(rename = "d")]
165    data: T,
166}
167
168/// Minimal [`Ready`] for light deserialization.
169///
170/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
171#[derive(Deserialize)]
172struct MinimalReady {
173    /// Used for resuming connections.
174    resume_gateway_url: Box<str>,
175    /// ID of the new identified session.
176    session_id: String,
177}
178
179/// Pending outgoing message indicator.
180#[derive(Debug)]
181struct Pending {
182    /// The pending message, if not already sent.
183    gateway_event: Option<Message>,
184    /// Whether the pending gateway event is a heartbeat.
185    is_heartbeat: bool,
186}
187
188impl Pending {
189    /// Constructor for a pending gateway event.
190    fn event<T: Serialize>(event: T, is_heartbeat: bool) -> Option<Self> {
191        Some(Self {
192            gateway_event: Some(Message::Text(
193                json::to_string(&event).expect("json serialization is infallible"),
194            )),
195            is_heartbeat,
196        })
197    }
198}
199
200/// Gateway API client responsible for up to 2500 guilds.
201///
202/// Shards are responsible for maintaining the gateway connection by processing
203/// events relevant to the operation of shards---such as requests from the
204/// gateway to re-connect or invalidate a session---and then to pass them on to
205/// the user.
206///
207/// Shards start out disconnected, but will on the first successful call to
208/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
209/// be repeatedly called in order for the shard to maintain its connection and
210/// update its internal state.
211///
212/// Shards go through an [identify queue][`queue`] that rate limits concurrent
213/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
214/// invalidates the shard's session and it is therefore **very important** to
215/// reuse the same queue for all shards.
216///
217/// # Sharding
218///
219/// A shard may not be connected to more than 2500 guilds, so large bots must
220/// split themselves across multiple shards. See the
221/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
222/// more info.
223///
224/// # Examples
225///
226/// Create and start a shard and print new and deleted messages:
227///
228/// ```no_run
229/// use std::env;
230/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
231///
232/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
233/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
234/// // token. Of course, this value may be passed into the program however is
235/// // preferred.
236/// let token = env::var("DISCORD_TOKEN")?;
237/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
238///
239/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
240///
241/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
242///     let Ok(event) = item else {
243///         tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
244///
245///         continue;
246///     };
247///
248///     match event {
249///         Event::MessageCreate(message) => {
250///             println!("message received with content: {}", message.content);
251///         }
252///         Event::MessageDelete(message) => {
253///             println!("message with ID {} deleted", message.id);
254///         }
255///         _ => {}
256///     }
257/// }
258/// # Ok(()) }
259/// ```
260///
261/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
262/// [gateway commands]: Shard::command
263/// [`poll_next`]: Shard::poll_next
264/// [`queue`]: crate::queue
265#[derive(Debug)]
266pub struct Shard<Q = InMemoryQueue> {
267    /// User provided configuration.
268    ///
269    /// Configurations are provided or created in shard initializing via
270    /// [`Shard::new`] or [`Shard::with_config`].
271    config: Config<Q>,
272    /// Future to establish a WebSocket connection with the Gateway.
273    connection_future: Option<ConnectionFuture>,
274    /// Websocket connection, which may be connected to Discord's gateway.
275    ///
276    /// The connection should only be dropped after it has returned `Ok(None)`
277    /// to comply with the WebSocket protocol.
278    connection: Option<Connection>,
279    /// Event decompressor.
280    #[cfg(any(feature = "zlib", feature = "zstd"))]
281    decompressor: Decompressor,
282    /// Interval of how often the gateway would like the shard to send
283    /// heartbeats.
284    ///
285    /// The interval is received in the [`GatewayEvent::Hello`] event when
286    /// first opening a new [connection].
287    ///
288    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
289    /// [connection]: Self::connection
290    heartbeat_interval: Option<Interval>,
291    /// Whether an event has been received in the current heartbeat interval.
292    heartbeat_interval_event: bool,
293    /// ID of the shard.
294    id: ShardId,
295    /// Identify queue receiver.
296    identify_rx: Option<oneshot::Receiver<()>>,
297    /// Potentially pending outgoing message.
298    pending: Option<Pending>,
299    /// Recent heartbeat latency statistics.
300    ///
301    /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
302    /// may have changed, invalidating previous latency statistic.
303    ///
304    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
305    latency: Latency,
306    /// Command ratelimiter, if it was enabled via
307    /// [`Config::ratelimit_messages`].
308    ratelimiter: Option<CommandRatelimiter>,
309    /// Used for resuming connections.
310    resume_url: Option<Box<str>>,
311    /// Active session of the shard.
312    ///
313    /// The shard may not have an active session if it hasn't yet identified and
314    /// received a `READY` dispatch event response.
315    session: Option<Session>,
316    /// Current state of the shard.
317    state: ShardState,
318    /// Messages from the user to be relayed and sent over the Websocket
319    /// connection.
320    user_channel: MessageChannel,
321}
322
323impl Shard {
324    /// Create a new shard with the default configuration.
325    pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
326        Self::with_config(id, Config::new(token, intents))
327    }
328}
329
330impl<Q> Shard<Q> {
331    /// Create a new shard with the provided configuration.
332    pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
333        let session = config.take_session();
334        let mut resume_url = config.take_resume_url();
335        //ensure resume_url is only used if we have a session to resume
336        if session.is_none() {
337            resume_url = None;
338        }
339
340        Self {
341            config,
342            connection_future: None,
343            connection: None,
344            #[cfg(any(feature = "zlib", feature = "zstd"))]
345            decompressor: Decompressor::new(),
346            heartbeat_interval: None,
347            heartbeat_interval_event: false,
348            id: shard_id,
349            identify_rx: None,
350            pending: None,
351            latency: Latency::new(),
352            ratelimiter: None,
353            resume_url,
354            session,
355            state: ShardState::Disconnected {
356                reconnect_attempts: 0,
357            },
358            user_channel: MessageChannel::new(),
359        }
360    }
361
362    /// Immutable reference to the configuration used to instantiate this shard.
363    pub const fn config(&self) -> &Config<Q> {
364        &self.config
365    }
366
367    /// ID of the shard.
368    pub const fn id(&self) -> ShardId {
369        self.id
370    }
371
372    /// State of the shard.
373    pub const fn state(&self) -> ShardState {
374        self.state
375    }
376
377    /// Shard latency statistics, including average latency and recent heartbeat
378    /// latency times.
379    ///
380    /// Reset when reconnecting to the gateway.
381    pub const fn latency(&self) -> &Latency {
382        &self.latency
383    }
384
385    /// Statistics about the number of available commands and when the command
386    /// ratelimiter will refresh.
387    ///
388    /// This won't be present if ratelimiting was disabled via
389    /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
390    ///
391    /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
392    pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
393        self.ratelimiter.as_ref()
394    }
395
396    /// Immutable reference to the gateways current resume URL.
397    ///
398    /// A resume URL might not be present if the shard had its session
399    /// invalidated and has not yet reconnected.
400    pub fn resume_url(&self) -> Option<&str> {
401        self.resume_url.as_deref()
402    }
403
404    /// Immutable reference to the active gateway session.
405    ///
406    /// An active session may not be present if the shard had its session
407    /// invalidated and has not yet reconnected.
408    pub const fn session(&self) -> Option<&Session> {
409        self.session.as_ref()
410    }
411
412    /// Queue a command to be sent to the gateway.
413    ///
414    /// Serializes the command and then calls [`send`].
415    ///
416    /// [`send`]: Self::send
417    #[allow(clippy::missing_panics_doc)]
418    pub fn command(&self, command: &impl Command) {
419        self.send(json::to_string(command).expect("serialization cannot fail"));
420    }
421
422    /// Queue a JSON encoded gateway event to be sent to the gateway.
423    #[allow(clippy::missing_panics_doc)]
424    pub fn send(&self, json: String) {
425        self.user_channel
426            .command_tx
427            .send(json)
428            .expect("channel open");
429    }
430
431    /// Queue a websocket close frame.
432    ///
433    /// Invalidates the session and shows the application's bot as offline if
434    /// the close frame code is `1000` or `1001`. Otherwise Discord will
435    /// continue showing the bot as online until its presence times out.
436    ///
437    /// To read all remaining messages, continue calling [`poll_next`] until it
438    /// returns [`Message::Close`].
439    ///
440    /// # Example
441    ///
442    /// Close the shard and process remaining messages:
443    ///
444    /// ```no_run
445    /// # use twilight_gateway::{Intents, Shard, ShardId};
446    /// # #[tokio::main] async fn main() {
447    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
448    /// use tokio_stream::StreamExt;
449    /// use twilight_gateway::{CloseFrame, Message, error::ReceiveMessageErrorType};
450    ///
451    /// shard.close(CloseFrame::NORMAL);
452    ///
453    /// while let Some(item) = shard.next().await {
454    ///     match item {
455    ///         Ok(Message::Close(_)) => break,
456    ///         Ok(Message::Text(_)) => unimplemented!(),
457    ///         Err(source) => unimplemented!(),
458    ///     }
459    /// }
460    /// # }
461    /// ```
462    ///
463    /// [`poll_next`]: Shard::poll_next
464    pub fn close(&self, close_frame: CloseFrame<'static>) {
465        _ = self.user_channel.close_tx.try_send(close_frame);
466    }
467
468    /// Retrieve a channel to send messages over the shard to the gateway.
469    ///
470    /// This is primarily useful for sending to other tasks and threads where
471    /// the shard won't be available.
472    ///
473    /// # Example
474    ///
475    /// Queue a command in another process:
476    ///
477    /// ```no_run
478    /// # use twilight_gateway::{Intents, Shard, ShardId};
479    /// # #[tokio::main] async fn main() {
480    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
481    /// use tokio_stream::StreamExt;
482    ///
483    /// while let Some(item) = shard.next().await {
484    ///     match item {
485    ///         Ok(message) => {
486    ///             let sender = shard.sender();
487    ///             tokio::spawn(async move {
488    ///                 let command = unimplemented!();
489    ///                 sender.send(command);
490    ///             });
491    ///         }
492    ///         Err(source) => unimplemented!(),
493    ///     }
494    /// }
495    /// # }
496    /// ```
497    pub fn sender(&self) -> MessageSender {
498        self.user_channel.sender()
499    }
500
501    /// Update internal state from gateway disconnect.
502    fn disconnect(&mut self, initiator: CloseInitiator) {
503        // May not send any additional WebSocket messages.
504        self.heartbeat_interval = None;
505        self.ratelimiter = None;
506        // Abort identify.
507        self.identify_rx = None;
508        self.state = match initiator {
509            CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
510            _ => ShardState::Disconnected {
511                reconnect_attempts: 0,
512            },
513        };
514        if let CloseInitiator::Shard(frame) = initiator {
515            // Not resuming, drop session and resume URL.
516            // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
517            if matches!(frame.code, 1000 | 1001) {
518                self.resume_url = None;
519                self.session = None;
520            }
521            self.pending = Some(Pending {
522                gateway_event: Some(Message::Close(Some(frame))),
523                is_heartbeat: false,
524            });
525        }
526    }
527
528    /// Parse a JSON message into an event with minimal data for [processing].
529    ///
530    /// # Errors
531    ///
532    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
533    /// event isn't a recognized structure, which may be the case for new or
534    /// undocumented events.
535    ///
536    /// [processing]: Self::process
537    fn parse_event<T: DeserializeOwned>(
538        json: &str,
539    ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
540        json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
541            kind: ReceiveMessageErrorType::Deserializing {
542                event: json.to_owned(),
543            },
544            source: Some(Box::new(source)),
545        })
546    }
547
548    /// Attempts to connect to the gateway.
549    ///
550    /// # Returns
551    ///
552    /// * `Poll::Pending` if connection is in progress
553    /// * `Poll::Ready(Ok)` if connected
554    /// * `Poll::Ready(Err)` if connecting to the gateway failed.
555    fn poll_connect(
556        &mut self,
557        cx: &mut Context<'_>,
558        attempt: u8,
559    ) -> Poll<Result<(), ReceiveMessageError>> {
560        let fut = self.connection_future.get_or_insert_with(|| {
561            let base_url = self
562                .resume_url
563                .as_deref()
564                .or_else(|| self.config.proxy_url())
565                .unwrap_or(GATEWAY_URL);
566            let base_url_len = base_url.len();
567            let uri = format!("{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}");
568
569            let tls = Arc::clone(&self.config.tls);
570            ConnectionFuture(Box::pin(async move {
571                let delay = Duration::from_secs(2u8.saturating_pow(attempt.into()).into());
572                time::sleep(delay).await;
573                tracing::debug!(url = &uri[..base_url_len], "connecting");
574
575                let builder = ClientBuilder::new()
576                    .uri(&uri)
577                    .expect("valid URL")
578                    .limits(Limits::unlimited())
579                    .connector(&tls);
580                Ok(time::timeout(CONNECT_TIMEOUT, builder.connect()).await??.0)
581            }))
582        });
583
584        let res = ready!(Pin::new(&mut fut.0).poll(cx));
585        self.connection_future = None;
586        match res {
587            Ok(connection) => {
588                self.connection = Some(connection);
589                self.state = ShardState::Identifying;
590                #[cfg(any(feature = "zlib", feature = "zstd"))]
591                self.decompressor.reset();
592            }
593            Err(source) => {
594                self.resume_url = None;
595                self.state = ShardState::Disconnected {
596                    reconnect_attempts: attempt + 1,
597                };
598
599                return Poll::Ready(Err(ReceiveMessageError {
600                    kind: ReceiveMessageErrorType::Reconnect,
601                    source: Some(source),
602                }));
603            }
604        }
605
606        Poll::Ready(Ok(()))
607    }
608}
609
610impl<Q: Queue> Shard<Q> {
611    /// Attempts to send due commands to the gateway.
612    ///
613    /// # Returns
614    ///
615    /// * `Poll::Pending` if sending is in progress
616    /// * `Poll::Ready(Ok)` if no more scheduled commands remain
617    /// * `Poll::Ready(Err)` if sending a command failed.
618    fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
619        loop {
620            if let Some(pending) = self.pending.as_mut() {
621                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
622
623                let is_ratelimited = pending.gateway_event.as_ref().is_some_and(Message::is_text)
624                    && !pending.is_heartbeat;
625                if is_ratelimited && let Some(ratelimiter) = &mut self.ratelimiter {
626                    ready!(ratelimiter.poll_acquire(cx));
627                }
628
629                if let Some(msg) = pending.gateway_event.take().map(Message::into_websocket) {
630                    Pin::new(self.connection.as_mut().unwrap()).start_send(msg)?;
631                }
632
633                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
634
635                if pending.is_heartbeat {
636                    self.latency.record_sent();
637                }
638                self.pending = None;
639            }
640
641            if !self.state.is_disconnected()
642                && let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx)
643            {
644                let frame = frame.expect("shard owns channel");
645
646                tracing::debug!("sending close frame from user channel");
647                self.disconnect(CloseInitiator::Shard(frame));
648
649                continue;
650            }
651
652            if let Some(heartbeater) = &mut self.heartbeat_interval
653                && heartbeater.poll_tick(cx).is_ready()
654            {
655                // Discord never responded after the last heartbeat, connection
656                // is failed or "zombied", see
657                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
658                // Note that unlike documented *any* event is okay; it does not
659                // have to be a heartbeat ACK.
660                if self.latency.sent().is_some() && !self.heartbeat_interval_event {
661                    tracing::info!("connection is failed or \"zombied\"");
662
663                    return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
664                }
665
666                tracing::debug!("sending heartbeat");
667                self.pending =
668                    Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
669                self.heartbeat_interval_event = false;
670
671                continue;
672            }
673
674            let not_ratelimited = self
675                .ratelimiter
676                .as_mut()
677                .is_none_or(|ratelimiter| ratelimiter.poll_available(cx).is_ready());
678
679            if not_ratelimited
680                && let Some(rx) = &mut self.identify_rx
681                && let Poll::Ready(canceled) = Pin::new(rx).poll(cx).map(|r| r.is_err())
682            {
683                if canceled {
684                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
685                    continue;
686                }
687
688                tracing::debug!("sending identify");
689
690                self.pending = Pending::event(
691                    Identify::new(IdentifyInfo {
692                        compress: false,
693                        intents: self.config.intents(),
694                        large_threshold: self.config.large_threshold(),
695                        presence: self.config.presence().cloned(),
696                        properties: self
697                            .config
698                            .identify_properties()
699                            .cloned()
700                            .unwrap_or_else(default_identify_properties),
701                        shard: Some(self.id),
702                        token: self.config.token().to_owned(),
703                    }),
704                    false,
705                );
706                self.identify_rx = None;
707
708                continue;
709            }
710
711            if not_ratelimited
712                && self.state.is_identified()
713                && let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx)
714            {
715                let command = command.expect("shard owns channel");
716
717                tracing::debug!("sending command from user channel");
718                self.pending = Some(Pending {
719                    gateway_event: Some(Message::Text(command)),
720                    is_heartbeat: false,
721                });
722
723                continue;
724            }
725
726            return Poll::Ready(Ok(()));
727        }
728    }
729
730    /// Updates the shard's internal state from a gateway event by recording
731    /// and/or responding to certain Discord events.
732    ///
733    /// # Errors
734    ///
735    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
736    /// gateway event isn't a recognized structure.
737    #[allow(clippy::too_many_lines)]
738    fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
739        let (raw_opcode, maybe_sequence, maybe_event_type) =
740            GatewayEventDeserializer::from_json(event)
741                .ok_or_else(|| ReceiveMessageError {
742                    kind: ReceiveMessageErrorType::Deserializing {
743                        event: event.to_owned(),
744                    },
745                    source: Some("missing opcode".into()),
746                })?
747                .into_parts();
748
749        if self.latency.sent().is_some() {
750            self.heartbeat_interval_event = true;
751        }
752
753        match OpCode::from(raw_opcode) {
754            Some(OpCode::Dispatch) => {
755                let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
756                    kind: ReceiveMessageErrorType::Deserializing {
757                        event: event.to_owned(),
758                    },
759                    source: Some("missing dispatch event type".into()),
760                })?;
761                let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
762                    kind: ReceiveMessageErrorType::Deserializing {
763                        event: event.to_owned(),
764                    },
765                    source: Some("missing sequence".into()),
766                })?;
767                tracing::debug!(%event_type, %sequence, "received dispatch");
768
769                match event_type.as_ref() {
770                    "READY" => {
771                        let event = Self::parse_event::<MinimalReady>(event)?;
772
773                        self.resume_url = Some(event.data.resume_gateway_url);
774                        self.session = Some(Session::new(sequence, event.data.session_id));
775                        self.state = ShardState::Active;
776                    }
777                    "RESUMED" => self.state = ShardState::Active,
778                    _ => {}
779                }
780
781                if let Some(session) = self.session.as_mut() {
782                    session.set_sequence(sequence);
783                }
784            }
785            Some(OpCode::Heartbeat) => {
786                tracing::debug!("received heartbeat");
787                self.pending =
788                    Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
789            }
790            Some(OpCode::HeartbeatAck) => {
791                let requested = self.latency.received().is_none() && self.latency.sent().is_some();
792                if requested {
793                    tracing::debug!("received heartbeat ack");
794                    self.latency.record_received();
795                } else {
796                    tracing::info!("received unrequested heartbeat ack");
797                }
798            }
799            Some(OpCode::Hello) => {
800                let event = Self::parse_event::<Hello>(event)?;
801                let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
802                // First heartbeat should have some jitter, see
803                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
804                let jitter = heartbeat_interval.mul_f64(fastrand::f64());
805                tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
806
807                if self.config().ratelimit_messages() {
808                    self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
809                }
810
811                let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
812                interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
813                self.heartbeat_interval = Some(interval);
814
815                // Reset `Latency` since the shard might have connected to a new
816                // remote which invalidates the recorded latencies.
817                self.latency = Latency::new();
818
819                if let Some(session) = &self.session {
820                    self.pending = Pending::event(
821                        Resume::new(session.sequence(), session.id(), self.config.token()),
822                        false,
823                    );
824                    self.state = ShardState::Resuming;
825                } else {
826                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
827                }
828            }
829            Some(OpCode::InvalidSession) => {
830                let resumable = Self::parse_event(event)?.data;
831                tracing::debug!(resumable, "received invalid session");
832                if resumable {
833                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
834                } else {
835                    self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
836                }
837            }
838            Some(OpCode::Reconnect) => {
839                tracing::debug!("received reconnect");
840                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
841            }
842            _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
843        }
844
845        Ok(())
846    }
847}
848
849impl<Q: Queue + Unpin> Stream for Shard<Q> {
850    type Item = Result<Message, ReceiveMessageError>;
851
852    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
853        let message = loop {
854            match self.state {
855                ShardState::FatallyClosed => {
856                    _ = ready!(
857                        Pin::new(
858                            self.connection
859                                .as_mut()
860                                .expect("poll_next called after Poll::Ready(None)")
861                        )
862                        .poll_close(cx)
863                    );
864                    self.connection = None;
865                    return Poll::Ready(None);
866                }
867                ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
868                    ready!(self.poll_connect(cx, reconnect_attempts))?;
869                }
870                _ => {}
871            }
872
873            if ready!(self.poll_send(cx)).is_err() {
874                self.disconnect(CloseInitiator::Transport);
875                self.connection = None;
876
877                return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
878            }
879
880            match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
881                Some(Ok(message)) => {
882                    #[cfg(any(feature = "zlib", feature = "zstd"))]
883                    if message.is_binary() {
884                        match self.decompressor.decompress(message.as_payload()) {
885                            #[cfg(feature = "zstd")]
886                            Ok(message) => break Message::Text(message),
887                            #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
888                            Ok(Some(message)) => break Message::Text(message),
889                            #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
890                            Ok(None) => continue,
891                            Err(source) => {
892                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
893                                return Poll::Ready(Some(Err(
894                                    ReceiveMessageError::from_compression(source),
895                                )));
896                            }
897                        }
898                    }
899                    if let Some(message) = Message::from_websocket_msg(&message) {
900                        break message;
901                    }
902                }
903                Some(Err(_)) if self.state.is_disconnected() => {}
904                Some(Err(_)) => {
905                    self.disconnect(CloseInitiator::Transport);
906                    return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
907                }
908                None => {
909                    _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
910                    tracing::debug!("gateway WebSocket connection closed");
911                    // Unclean closure.
912                    if !self.state.is_disconnected() {
913                        self.disconnect(CloseInitiator::Transport);
914                    }
915                    self.connection = None;
916                }
917            }
918        };
919
920        match &message {
921            Message::Close(frame) => {
922                // tokio-websockets automatically replies to the close message.
923                tracing::debug!(?frame, "received WebSocket close message");
924                // Don't run `disconnect` if we initiated the close.
925                if !self.state.is_disconnected() {
926                    self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
927                }
928            }
929            Message::Text(event) => {
930                self.process(event)?;
931            }
932        }
933
934        Poll::Ready(Some(Ok(message)))
935    }
936}
937
938/// Default identify properties to use when the user hasn't customized it in
939/// [`Config::identify_properties`].
940///
941/// [`Config::identify_properties`]: Config::identify_properties
942fn default_identify_properties() -> IdentifyProperties {
943    IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
944}
945
946#[cfg(test)]
947mod tests {
948    use super::Shard;
949    use static_assertions::{assert_impl_all, assert_not_impl_any};
950    use std::fmt::Debug;
951
952    assert_impl_all!(Shard: Debug, Send);
953    assert_not_impl_any!(Shard: Sync);
954}