Skip to main content

twilight_gateway/
shard.rs

1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(any(feature = "zlib", feature = "zstd"))]
10use crate::compression::Decompressor;
11use crate::{
12    API_VERSION, Command, Config, Message, ShardId,
13    channel::{MessageChannel, MessageSender},
14    error::{ReceiveMessageError, ReceiveMessageErrorType},
15    json,
16    latency::Latency,
17    queue::{InMemoryQueue, Queue},
18    ratelimiter::CommandRatelimiter,
19    session::Session,
20};
21use futures_core::Stream;
22use futures_sink::Sink;
23use serde::{Deserialize, Serialize, de::DeserializeOwned};
24use std::{
25    env::consts::OS,
26    error::Error,
27    fmt,
28    future::Future,
29    io,
30    pin::Pin,
31    str,
32    sync::Arc,
33    task::{Context, Poll, ready},
34};
35use tokio::{
36    net::TcpStream,
37    sync::oneshot,
38    time::{self, Duration, Instant, Interval, MissedTickBehavior},
39};
40use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
41use twilight_model::gateway::{
42    CloseCode, CloseFrame, Intents, OpCode,
43    event::GatewayEventDeserializer,
44    payload::{
45        incoming::Hello,
46        outgoing::{
47            Heartbeat, Identify, Resume,
48            identify::{IdentifyInfo, IdentifyProperties},
49        },
50    },
51};
52
53/// URL of the Discord gateway.
54const GATEWAY_URL: &str = "wss://gateway.discord.gg";
55
56/// Query argument depending on enabled compression features.
57const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
58    "&compress=zstd-stream"
59} else if cfg!(feature = "zlib") {
60    "&compress=zlib-stream"
61} else {
62    ""
63};
64
65/// Timeout for connecting to the gateway.
66const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
67
68/// [`tokio_websockets`] library Websocket connection.
69type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
70
71/// Dynamically dispatched [`Error`].
72type GenericError = Box<dyn Error + Send + Sync>;
73
74/// Wrapper struct around an `async fn` with a `Debug` implementation.
75struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, GenericError>> + Send>>);
76
77impl fmt::Debug for ConnectionFuture {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_tuple("ConnectionFuture")
80            .field(&"<async fn>")
81            .finish()
82    }
83}
84
85/// Close initiator of a websocket connection.
86#[derive(Clone, Debug)]
87enum CloseInitiator {
88    /// Gateway initiated the close.
89    ///
90    /// Contains an optional close code.
91    Gateway(Option<u16>),
92    /// Shard initiated the close.
93    ///
94    /// Contains a close code.
95    Shard(CloseFrame<'static>),
96    /// Transport error initiated the close.
97    Transport,
98}
99
100/// Current state of a [Shard].
101#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102pub enum ShardState {
103    /// Shard is connected to the gateway with an active session.
104    Active,
105    /// Shard is disconnected from the gateway but may reconnect in the future.
106    ///
107    /// The websocket connection may still be open.
108    Disconnected {
109        /// Number of reconnection attempts that have been made.
110        reconnect_attempts: u8,
111    },
112    /// Shard has fatally closed.
113    ///
114    /// Possible reasons may be due to [failed authentication],
115    /// [invalid intents], or other reasons. Refer to the documentation for
116    /// [`CloseCode`] for possible reasons.
117    ///
118    /// [failed authentication]: CloseCode::AuthenticationFailed
119    /// [invalid intents]: CloseCode::InvalidIntents
120    FatallyClosed,
121    /// Shard is waiting to establish or resume a session.
122    Identifying,
123    /// Shard is replaying missed dispatch events.
124    ///
125    /// The shard is considered identified whilst resuming.
126    Resuming,
127}
128
129impl ShardState {
130    /// Determine the connection status from the close code.
131    ///
132    /// Defers to [`CloseCode::can_reconnect`] to determine whether the
133    /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
134    /// the close code is unknown.
135    fn from_close_code(close_code: Option<u16>) -> Self {
136        match close_code.map(CloseCode::try_from) {
137            Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
138            _ => Self::Disconnected {
139                reconnect_attempts: 0,
140            },
141        }
142    }
143
144    /// Whether the shard has disconnected but may reconnect in the future.
145    const fn is_disconnected(self) -> bool {
146        matches!(self, Self::Disconnected { .. })
147    }
148
149    /// Whether the shard is identified with an active session.
150    ///
151    /// `true` if the status is [`Active`] or [`Resuming`].
152    ///
153    /// [`Active`]: Self::Active
154    /// [`Resuming`]: Self::Resuming
155    pub const fn is_identified(self) -> bool {
156        matches!(self, Self::Active | Self::Resuming)
157    }
158}
159
160/// Gateway event with only minimal required data.
161#[derive(Deserialize)]
162struct MinimalEvent<T> {
163    /// Attached data of the gateway event.
164    #[serde(rename = "d")]
165    data: T,
166}
167
168/// Minimal [`Ready`] for light deserialization.
169///
170/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
171#[derive(Deserialize)]
172struct MinimalReady {
173    /// Used for resuming connections.
174    resume_gateway_url: Box<str>,
175    /// ID of the new identified session.
176    session_id: String,
177}
178
179/// Pending outgoing message indicator.
180#[derive(Debug)]
181struct Pending {
182    /// The pending message, if not already sent.
183    gateway_event: Option<Message>,
184    /// Whether the pending gateway event is a heartbeat.
185    is_heartbeat: bool,
186}
187
188impl Pending {
189    /// Constructor for a pending gateway event.
190    fn event<T: Serialize>(event: T, is_heartbeat: bool) -> Option<Self> {
191        Some(Self {
192            gateway_event: Some(Message::Text(
193                json::to_string(&event).expect("json serialization is infallible"),
194            )),
195            is_heartbeat,
196        })
197    }
198}
199
200/// Gateway API client responsible for up to 2500 guilds.
201///
202/// Shards are responsible for maintaining the gateway connection by processing
203/// events relevant to the operation of shards---such as requests from the
204/// gateway to re-connect or invalidate a session---and then to pass them on to
205/// the user.
206///
207/// Shards start out disconnected, but will on the first successful call to
208/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
209/// be repeatedly called in order for the shard to maintain its connection and
210/// update its internal state.
211///
212/// Shards go through an [identify queue][`queue`] that rate limits concurrent
213/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
214/// invalidates the shard's session and it is therefore **very important** to
215/// reuse the same queue for all shards.
216///
217/// # Sharding
218///
219/// A shard may not be connected to more than 2500 guilds, so large bots must
220/// split themselves across multiple shards. See the
221/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
222/// more info.
223///
224/// # Examples
225///
226/// Create and start a shard and print new and deleted messages:
227///
228/// ```no_run
229/// use std::env;
230/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
231///
232/// # #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box<dyn std::error::Error>> {
233/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
234/// // token. Of course, this value may be passed into the program however is
235/// // preferred.
236/// let token = env::var("DISCORD_TOKEN")?;
237/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
238///
239/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
240///
241/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
242///     let Ok(event) = item else {
243///         tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
244///
245///         continue;
246///     };
247///
248///     match event {
249///         Event::MessageCreate(message) => {
250///             println!("message received with content: {}", message.content);
251///         }
252///         Event::MessageDelete(message) => {
253///             println!("message with ID {} deleted", message.id);
254///         }
255///         _ => {}
256///     }
257/// }
258/// # Ok(()) }
259/// ```
260///
261/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
262/// [gateway commands]: Shard::command
263/// [`poll_next`]: Shard::poll_next
264/// [`queue`]: crate::queue
265#[derive(Debug)]
266pub struct Shard<Q = InMemoryQueue> {
267    /// User provided configuration.
268    ///
269    /// Configurations are provided or created in shard initializing via
270    /// [`Shard::new`] or [`Shard::with_config`].
271    config: Config<Q>,
272    /// Future to establish a WebSocket connection with the Gateway.
273    connection_future: Option<ConnectionFuture>,
274    /// Websocket connection, which may be connected to Discord's gateway.
275    ///
276    /// The connection should only be dropped after it has returned `Ok(None)`
277    /// to comply with the WebSocket protocol.
278    connection: Option<Connection>,
279    /// Event decompressor.
280    #[cfg(any(feature = "zlib", feature = "zstd"))]
281    decompressor: Decompressor,
282    /// Interval of how often the gateway would like the shard to send
283    /// heartbeats.
284    ///
285    /// The interval is received in the [`GatewayEvent::Hello`] event when
286    /// first opening a new [connection].
287    ///
288    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
289    /// [connection]: Self::connection
290    heartbeat_interval: Option<Interval>,
291    /// Whether an event has been received in the current heartbeat interval.
292    heartbeat_interval_event: bool,
293    /// ID of the shard.
294    id: ShardId,
295    /// Identify queue receiver.
296    identify_rx: Option<oneshot::Receiver<()>>,
297    /// Potentially pending outgoing message.
298    pending: Option<Pending>,
299    /// Recent heartbeat latency statistics.
300    ///
301    /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
302    /// may have changed, invalidating previous latency statistic.
303    ///
304    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
305    latency: Latency,
306    /// Command ratelimiter, if it was enabled via
307    /// [`Config::ratelimit_messages`].
308    ratelimiter: Option<CommandRatelimiter>,
309    /// Used for resuming connections.
310    resume_url: Option<Box<str>>,
311    /// Active session of the shard.
312    ///
313    /// The shard may not have an active session if it hasn't yet identified and
314    /// received a `READY` dispatch event response.
315    session: Option<Session>,
316    /// Current state of the shard.
317    state: ShardState,
318    /// Messages from the user to be relayed and sent over the Websocket
319    /// connection.
320    user_channel: MessageChannel,
321}
322
323impl Shard {
324    /// Create a new shard with the default configuration.
325    pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
326        Self::with_config(id, Config::new(token, intents))
327    }
328}
329
330impl<Q> Shard<Q> {
331    /// Create a new shard with the provided configuration.
332    pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
333        let session = config.take_session();
334        let mut resume_url = config.take_resume_url();
335        //ensure resume_url is only used if we have a session to resume
336        if session.is_none() {
337            resume_url = None;
338        }
339
340        Self {
341            config,
342            connection_future: None,
343            connection: None,
344            #[cfg(any(feature = "zlib", feature = "zstd"))]
345            decompressor: Decompressor::new(),
346            heartbeat_interval: None,
347            heartbeat_interval_event: false,
348            id: shard_id,
349            identify_rx: None,
350            pending: None,
351            latency: Latency::new(),
352            ratelimiter: None,
353            resume_url,
354            session,
355            state: ShardState::Disconnected {
356                reconnect_attempts: 0,
357            },
358            user_channel: MessageChannel::new(),
359        }
360    }
361
362    /// Immutable reference to the configuration used to instantiate this shard.
363    pub const fn config(&self) -> &Config<Q> {
364        &self.config
365    }
366
367    /// ID of the shard.
368    pub const fn id(&self) -> ShardId {
369        self.id
370    }
371
372    /// State of the shard.
373    pub const fn state(&self) -> ShardState {
374        self.state
375    }
376
377    /// Shard latency statistics, including average latency and recent heartbeat
378    /// latency times.
379    ///
380    /// Reset when reconnecting to the gateway.
381    pub const fn latency(&self) -> &Latency {
382        &self.latency
383    }
384
385    /// Statistics about the number of available commands and when the command
386    /// ratelimiter will refresh.
387    ///
388    /// This won't be present if ratelimiting was disabled via
389    /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
390    ///
391    /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
392    pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
393        self.ratelimiter.as_ref()
394    }
395
396    /// Immutable reference to the gateways current resume URL.
397    ///
398    /// A resume URL might not be present if the shard had its session
399    /// invalidated and has not yet reconnected.
400    pub fn resume_url(&self) -> Option<&str> {
401        self.resume_url.as_deref()
402    }
403
404    /// Immutable reference to the active gateway session.
405    ///
406    /// An active session may not be present if the shard had its session
407    /// invalidated and has not yet reconnected.
408    pub const fn session(&self) -> Option<&Session> {
409        self.session.as_ref()
410    }
411
412    /// Queue a command to be sent to the gateway.
413    ///
414    /// Serializes the command and then calls [`send`].
415    ///
416    /// [`send`]: Self::send
417    #[allow(clippy::missing_panics_doc)]
418    pub fn command(&self, command: &impl Command) {
419        self.send(json::to_string(command).expect("serialization cannot fail"));
420    }
421
422    /// Queue a JSON encoded gateway event to be sent to the gateway.
423    #[allow(clippy::missing_panics_doc)]
424    pub fn send(&self, json: String) {
425        self.user_channel
426            .command_tx
427            .send(json)
428            .expect("channel open");
429    }
430
431    /// Queue a websocket close frame.
432    ///
433    /// Invalidates the session and shows the application's bot as offline if
434    /// the close frame code is `1000` or `1001`. Otherwise Discord will
435    /// continue showing the bot as online until its presence times out.
436    ///
437    /// To read all remaining messages, continue calling [`poll_next`] until it
438    /// returns [`Message::Close`].
439    ///
440    /// # Example
441    ///
442    /// Close the shard and process remaining messages:
443    ///
444    /// ```no_run
445    /// # use twilight_gateway::{Intents, Shard, ShardId};
446    /// # #[tokio::main(flavor = "current_thread")] async fn main() {
447    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
448    /// use tokio_stream::StreamExt;
449    /// use twilight_gateway::{CloseFrame, Message, error::ReceiveMessageErrorType};
450    ///
451    /// shard.close(CloseFrame::NORMAL);
452    ///
453    /// while let Some(item) = shard.next().await {
454    ///     match item {
455    ///         Ok(Message::Close(_)) => break,
456    ///         Ok(Message::Text(_)) => unimplemented!(),
457    ///         Err(source) => unimplemented!(),
458    ///     }
459    /// }
460    /// # }
461    /// ```
462    ///
463    /// [`poll_next`]: Shard::poll_next
464    pub fn close(&self, close_frame: CloseFrame<'static>) {
465        _ = self.user_channel.close_tx.try_send(close_frame);
466    }
467
468    /// Retrieve a channel to send messages over the shard to the gateway.
469    ///
470    /// This is primarily useful for sending to other tasks and threads where
471    /// the shard won't be available.
472    ///
473    /// # Example
474    ///
475    /// Queue a command in another process:
476    ///
477    /// ```no_run
478    /// # use twilight_gateway::{Intents, Shard, ShardId};
479    /// # #[tokio::main(flavor = "current_thread")] async fn main() {
480    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
481    /// use tokio_stream::StreamExt;
482    ///
483    /// while let Some(item) = shard.next().await {
484    ///     match item {
485    ///         Ok(message) => {
486    ///             let sender = shard.sender();
487    ///             tokio::spawn(async move {
488    ///                 let command = unimplemented!();
489    ///                 sender.send(command);
490    ///             });
491    ///         }
492    ///         Err(source) => unimplemented!(),
493    ///     }
494    /// }
495    /// # }
496    /// ```
497    pub fn sender(&self) -> MessageSender {
498        self.user_channel.sender()
499    }
500
501    /// Update internal state from gateway disconnect.
502    fn disconnect(&mut self, initiator: CloseInitiator) {
503        // May not send any additional WebSocket messages.
504        self.heartbeat_interval = None;
505        self.ratelimiter = None;
506        // Abort identify.
507        self.identify_rx = None;
508        self.state = match initiator {
509            CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
510            _ => ShardState::Disconnected {
511                reconnect_attempts: 0,
512            },
513        };
514        if let CloseInitiator::Shard(frame) = initiator {
515            // Not resuming, drop session and resume URL.
516            // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
517            if matches!(frame.code, 1000 | 1001) {
518                self.resume_url = None;
519                self.session = None;
520            }
521            self.pending = Some(Pending {
522                gateway_event: Some(Message::Close(Some(frame))),
523                is_heartbeat: false,
524            });
525        }
526    }
527
528    /// Parse a JSON message into an event with minimal data for [processing].
529    ///
530    /// # Errors
531    ///
532    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
533    /// event isn't a recognized structure, which may be the case for new or
534    /// undocumented events.
535    ///
536    /// [processing]: Self::process
537    fn parse_event<T: DeserializeOwned>(
538        json: &str,
539    ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
540        json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
541            kind: ReceiveMessageErrorType::Deserializing {
542                event: json.to_owned(),
543            },
544            source: Some(Box::new(source)),
545        })
546    }
547
548    /// Attempts to connect to the gateway.
549    ///
550    /// # Returns
551    ///
552    /// * `Poll::Pending` if connection is in progress
553    /// * `Poll::Ready(Ok)` if connected
554    /// * `Poll::Ready(Err)` if connecting to the gateway failed.
555    fn poll_connect(
556        &mut self,
557        cx: &mut Context<'_>,
558        attempt: u8,
559    ) -> Poll<Result<(), ReceiveMessageError>> {
560        let fut = self.connection_future.get_or_insert_with(|| {
561            let base_url = self
562                .resume_url
563                .as_deref()
564                .or_else(|| self.config.proxy_url())
565                .unwrap_or(GATEWAY_URL);
566            let base_url_len = base_url.len();
567            let uri = format!("{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}");
568
569            let tls = Arc::clone(&self.config.tls);
570            ConnectionFuture(Box::pin(async move {
571                if attempt != 0 {
572                    let secs = 2u8.saturating_pow(u32::from(attempt) - 1);
573                    time::sleep(Duration::from_secs(secs.into())).await;
574                }
575                tracing::debug!(url = &uri[..base_url_len], "connecting");
576
577                let builder = ClientBuilder::new()
578                    .uri(&uri)
579                    .expect("valid URL")
580                    .limits(Limits::unlimited())
581                    .connector(&tls);
582                Ok(time::timeout(CONNECT_TIMEOUT, builder.connect()).await??.0)
583            }))
584        });
585
586        let res = ready!(Pin::new(&mut fut.0).poll(cx));
587        self.connection_future = None;
588        match res {
589            Ok(connection) => {
590                self.connection = Some(connection);
591                self.state = ShardState::Identifying;
592                #[cfg(any(feature = "zlib", feature = "zstd"))]
593                self.decompressor.reset();
594            }
595            Err(source) => {
596                self.resume_url = None;
597                self.state = ShardState::Disconnected {
598                    reconnect_attempts: attempt.saturating_add(1),
599                };
600
601                return Poll::Ready(Err(ReceiveMessageError {
602                    kind: ReceiveMessageErrorType::Reconnect,
603                    source: Some(source),
604                }));
605            }
606        }
607
608        Poll::Ready(Ok(()))
609    }
610}
611
612impl<Q: Queue> Shard<Q> {
613    /// Attempts to send due commands to the gateway.
614    ///
615    /// # Returns
616    ///
617    /// * `Poll::Pending` if sending is in progress
618    /// * `Poll::Ready(Ok)` if no more scheduled commands remain
619    /// * `Poll::Ready(Err)` if sending a command failed.
620    fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
621        loop {
622            if let Some(pending) = self.pending.as_mut() {
623                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
624
625                let is_ratelimited = pending.gateway_event.as_ref().is_some_and(Message::is_text)
626                    && !pending.is_heartbeat;
627                if is_ratelimited && let Some(ratelimiter) = &mut self.ratelimiter {
628                    ready!(ratelimiter.poll_acquire(cx));
629                }
630
631                if let Some(msg) = pending.gateway_event.take().map(Message::into_websocket) {
632                    Pin::new(self.connection.as_mut().unwrap()).start_send(msg)?;
633                }
634
635                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
636
637                if pending.is_heartbeat {
638                    self.latency.record_sent();
639                }
640                self.pending = None;
641            }
642
643            if !self.state.is_disconnected()
644                && let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx)
645            {
646                let frame = frame.expect("shard owns channel");
647
648                tracing::debug!("sending close frame from user channel");
649                self.disconnect(CloseInitiator::Shard(frame));
650
651                continue;
652            }
653
654            if let Some(heartbeater) = &mut self.heartbeat_interval
655                && heartbeater.poll_tick(cx).is_ready()
656            {
657                // Discord never responded after the last heartbeat, connection
658                // is failed or "zombied", see
659                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
660                // Note that unlike documented *any* event is okay; it does not
661                // have to be a heartbeat ACK.
662                if self.latency.sent().is_some() && !self.heartbeat_interval_event {
663                    tracing::info!("connection is failed or \"zombied\"");
664
665                    return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
666                }
667
668                tracing::debug!("sending heartbeat");
669                self.pending =
670                    Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
671                self.heartbeat_interval_event = false;
672
673                continue;
674            }
675
676            let not_ratelimited = self
677                .ratelimiter
678                .as_mut()
679                .is_none_or(|ratelimiter| ratelimiter.poll_available(cx).is_ready());
680
681            if not_ratelimited
682                && let Some(rx) = &mut self.identify_rx
683                && let Poll::Ready(canceled) = Pin::new(rx).poll(cx).map(|r| r.is_err())
684            {
685                if canceled {
686                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
687                    continue;
688                }
689
690                tracing::debug!("sending identify");
691
692                self.pending = Pending::event(
693                    Identify::new(IdentifyInfo {
694                        compress: false,
695                        intents: self.config.intents(),
696                        large_threshold: self.config.large_threshold(),
697                        presence: self.config.presence().cloned(),
698                        properties: self
699                            .config
700                            .identify_properties()
701                            .cloned()
702                            .unwrap_or_else(default_identify_properties),
703                        shard: Some(self.id),
704                        token: self.config.token().to_owned(),
705                    }),
706                    false,
707                );
708                self.identify_rx = None;
709
710                continue;
711            }
712
713            if not_ratelimited
714                && self.state.is_identified()
715                && let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx)
716            {
717                let command = command.expect("shard owns channel");
718
719                tracing::debug!("sending command from user channel");
720                self.pending = Some(Pending {
721                    gateway_event: Some(Message::Text(command)),
722                    is_heartbeat: false,
723                });
724
725                continue;
726            }
727
728            return Poll::Ready(Ok(()));
729        }
730    }
731
732    /// Updates the shard's internal state from a gateway event by recording
733    /// and/or responding to certain Discord events.
734    ///
735    /// # Errors
736    ///
737    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
738    /// gateway event isn't a recognized structure.
739    #[allow(clippy::too_many_lines)]
740    fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
741        let (raw_opcode, maybe_sequence, maybe_event_type) =
742            GatewayEventDeserializer::from_json(event)
743                .ok_or_else(|| ReceiveMessageError {
744                    kind: ReceiveMessageErrorType::Deserializing {
745                        event: event.to_owned(),
746                    },
747                    source: Some("missing opcode".into()),
748                })?
749                .into_parts();
750
751        if self.latency.sent().is_some() {
752            self.heartbeat_interval_event = true;
753        }
754
755        match OpCode::from(raw_opcode) {
756            Some(OpCode::Dispatch) => {
757                let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
758                    kind: ReceiveMessageErrorType::Deserializing {
759                        event: event.to_owned(),
760                    },
761                    source: Some("missing dispatch event type".into()),
762                })?;
763                let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
764                    kind: ReceiveMessageErrorType::Deserializing {
765                        event: event.to_owned(),
766                    },
767                    source: Some("missing sequence".into()),
768                })?;
769                tracing::debug!(%event_type, %sequence, "received dispatch");
770
771                match event_type.as_ref() {
772                    "READY" => {
773                        let event = Self::parse_event::<MinimalReady>(event)?;
774
775                        self.resume_url = Some(event.data.resume_gateway_url);
776                        self.session = Some(Session::new(sequence, event.data.session_id));
777                        self.state = ShardState::Active;
778                    }
779                    "RESUMED" => self.state = ShardState::Active,
780                    _ => {}
781                }
782
783                if let Some(session) = self.session.as_mut() {
784                    session.set_sequence(sequence);
785                }
786            }
787            Some(OpCode::Heartbeat) => {
788                tracing::debug!("received heartbeat");
789                self.pending =
790                    Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
791            }
792            Some(OpCode::HeartbeatAck) => {
793                let requested = self.latency.received().is_none() && self.latency.sent().is_some();
794                if requested {
795                    tracing::debug!("received heartbeat ack");
796                    self.latency.record_received();
797                } else {
798                    tracing::info!("received unrequested heartbeat ack");
799                }
800            }
801            Some(OpCode::Hello) => {
802                let event = Self::parse_event::<Hello>(event)?;
803                let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
804                // First heartbeat should have some jitter, see
805                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
806                let jitter = heartbeat_interval.mul_f64(fastrand::f64());
807                tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
808
809                if self.config().ratelimit_messages() {
810                    self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
811                }
812
813                let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
814                interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
815                self.heartbeat_interval = Some(interval);
816
817                // Reset `Latency` since the shard might have connected to a new
818                // remote which invalidates the recorded latencies.
819                self.latency = Latency::new();
820
821                if let Some(session) = &self.session {
822                    self.pending = Pending::event(
823                        Resume::new(session.sequence(), session.id(), self.config.token()),
824                        false,
825                    );
826                    self.state = ShardState::Resuming;
827                } else {
828                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
829                }
830            }
831            Some(OpCode::InvalidSession) => {
832                let resumable = Self::parse_event(event)?.data;
833                tracing::debug!(resumable, "received invalid session");
834                if resumable {
835                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
836                } else {
837                    self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
838                }
839            }
840            Some(OpCode::Reconnect) => {
841                tracing::debug!("received reconnect");
842                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
843            }
844            _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
845        }
846
847        Ok(())
848    }
849}
850
851impl<Q: Queue + Unpin> Stream for Shard<Q> {
852    type Item = Result<Message, ReceiveMessageError>;
853
854    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
855        let message = loop {
856            match self.state {
857                ShardState::FatallyClosed => {
858                    _ = ready!(
859                        Pin::new(
860                            self.connection
861                                .as_mut()
862                                .expect("poll_next called after Poll::Ready(None)")
863                        )
864                        .poll_close(cx)
865                    );
866                    self.connection = None;
867                    return Poll::Ready(None);
868                }
869                ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
870                    ready!(self.poll_connect(cx, reconnect_attempts))?;
871                }
872                _ => {}
873            }
874
875            if ready!(self.poll_send(cx)).is_err() {
876                self.disconnect(CloseInitiator::Transport);
877                self.connection = None;
878
879                return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
880            }
881
882            match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
883                Some(Ok(message)) => {
884                    #[cfg(any(feature = "zlib", feature = "zstd"))]
885                    if message.is_binary() {
886                        match self.decompressor.decompress(message.as_payload()) {
887                            #[cfg(feature = "zstd")]
888                            Ok(message) => break Message::Text(message),
889                            #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
890                            Ok(Some(message)) => break Message::Text(message),
891                            #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
892                            Ok(None) => continue,
893                            Err(source) => {
894                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
895                                return Poll::Ready(Some(Err(
896                                    ReceiveMessageError::from_compression(source),
897                                )));
898                            }
899                        }
900                    }
901                    if let Some(message) = Message::from_websocket_msg(&message) {
902                        break message;
903                    }
904                }
905                Some(Err(_)) if self.state.is_disconnected() => {}
906                Some(Err(_)) => {
907                    self.disconnect(CloseInitiator::Transport);
908                    return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
909                }
910                None => {
911                    _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
912                    tracing::debug!("gateway WebSocket connection closed");
913                    // Unclean closure.
914                    if !self.state.is_disconnected() {
915                        self.disconnect(CloseInitiator::Transport);
916                    }
917                    self.connection = None;
918                }
919            }
920        };
921
922        match &message {
923            Message::Close(frame) => {
924                // tokio-websockets automatically replies to the close message.
925                tracing::debug!(?frame, "received WebSocket close message");
926                // Don't run `disconnect` if we initiated the close.
927                if !self.state.is_disconnected() {
928                    self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
929                }
930            }
931            Message::Text(event) => {
932                self.process(event)?;
933            }
934        }
935
936        Poll::Ready(Some(Ok(message)))
937    }
938}
939
940/// Default identify properties to use when the user hasn't customized it in
941/// [`Config::identify_properties`].
942///
943/// [`Config::identify_properties`]: Config::identify_properties
944fn default_identify_properties() -> IdentifyProperties {
945    IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
946}
947
948#[cfg(test)]
949mod tests {
950    use super::Shard;
951    use static_assertions::{assert_impl_all, assert_not_impl_any};
952    use std::fmt::Debug;
953
954    assert_impl_all!(Shard: Debug, Send);
955    assert_not_impl_any!(Shard: Sync);
956}