twilight_gateway/
shard.rs

1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13    any(feature = "zlib-stock", feature = "zlib-simd"),
14    not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18    channel::{MessageChannel, MessageSender},
19    error::{ReceiveMessageError, ReceiveMessageErrorType},
20    json,
21    latency::Latency,
22    queue::{InMemoryQueue, Queue},
23    ratelimiter::CommandRatelimiter,
24    session::Session,
25    Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30#[cfg(any(
31    feature = "native-tls",
32    feature = "rustls-native-roots",
33    feature = "rustls-platform-verifier",
34    feature = "rustls-webpki-roots"
35))]
36use std::io::ErrorKind as IoErrorKind;
37use std::{
38    env::consts::OS,
39    fmt,
40    future::Future,
41    pin::Pin,
42    str,
43    task::{ready, Context, Poll},
44};
45use tokio::{
46    net::TcpStream,
47    sync::oneshot,
48    time::{self, Duration, Instant, Interval, MissedTickBehavior},
49};
50use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
51use twilight_model::gateway::{
52    event::GatewayEventDeserializer,
53    payload::{
54        incoming::Hello,
55        outgoing::{
56            identify::{IdentifyInfo, IdentifyProperties},
57            Heartbeat, Identify, Resume,
58        },
59    },
60    CloseCode, CloseFrame, Intents, OpCode,
61};
62
63/// URL of the Discord gateway.
64const GATEWAY_URL: &str = "wss://gateway.discord.gg";
65
66/// Query argument depending on enabled compression features.
67const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
68    "&compress=zstd-stream"
69} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
70    "&compress=zlib-stream"
71} else {
72    ""
73};
74
75/// [`tokio_websockets`] library Websocket connection.
76type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
77
78/// Wrapper struct around an `async fn` with a `Debug` implementation.
79struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
80
81impl fmt::Debug for ConnectionFuture {
82    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83        f.debug_tuple("ConnectionFuture")
84            .field(&"<async fn>")
85            .finish()
86    }
87}
88
89/// Close initiator of a websocket connection.
90#[derive(Clone, Debug)]
91enum CloseInitiator {
92    /// Gateway initiated the close.
93    ///
94    /// Contains an optional close code.
95    Gateway(Option<u16>),
96    /// Shard initiated the close.
97    ///
98    /// Contains a close code.
99    Shard(CloseFrame<'static>),
100    /// Transport error initiated the close.
101    Transport,
102}
103
104/// Current state of a [Shard].
105#[derive(Clone, Copy, Debug, Eq, PartialEq)]
106pub enum ShardState {
107    /// Shard is connected to the gateway with an active session.
108    Active,
109    /// Shard is disconnected from the gateway but may reconnect in the future.
110    ///
111    /// The websocket connection may still be open.
112    Disconnected {
113        /// Number of reconnection attempts that have been made.
114        reconnect_attempts: u8,
115    },
116    /// Shard has fatally closed.
117    ///
118    /// Possible reasons may be due to [failed authentication],
119    /// [invalid intents], or other reasons. Refer to the documentation for
120    /// [`CloseCode`] for possible reasons.
121    ///
122    /// [failed authentication]: CloseCode::AuthenticationFailed
123    /// [invalid intents]: CloseCode::InvalidIntents
124    FatallyClosed,
125    /// Shard is waiting to establish or resume a session.
126    Identifying,
127    /// Shard is replaying missed dispatch events.
128    ///
129    /// The shard is considered identified whilst resuming.
130    Resuming,
131}
132
133impl ShardState {
134    /// Determine the connection status from the close code.
135    ///
136    /// Defers to [`CloseCode::can_reconnect`] to determine whether the
137    /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
138    /// the close code is unknown.
139    fn from_close_code(close_code: Option<u16>) -> Self {
140        match close_code.map(CloseCode::try_from) {
141            Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
142            _ => Self::Disconnected {
143                reconnect_attempts: 0,
144            },
145        }
146    }
147
148    /// Whether the shard has disconnected but may reconnect in the future.
149    const fn is_disconnected(self) -> bool {
150        matches!(self, Self::Disconnected { .. })
151    }
152
153    /// Whether the shard is identified with an active session.
154    ///
155    /// `true` if the status is [`Active`] or [`Resuming`].
156    ///
157    /// [`Active`]: Self::Active
158    /// [`Resuming`]: Self::Resuming
159    pub const fn is_identified(self) -> bool {
160        matches!(self, Self::Active | Self::Resuming)
161    }
162}
163
164/// Gateway event with only minimal required data.
165#[derive(Deserialize)]
166struct MinimalEvent<T> {
167    /// Attached data of the gateway event.
168    #[serde(rename = "d")]
169    data: T,
170}
171
172/// Minimal [`Ready`] for light deserialization.
173///
174/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
175#[derive(Deserialize)]
176struct MinimalReady {
177    /// Used for resuming connections.
178    resume_gateway_url: Box<str>,
179    /// ID of the new identified session.
180    session_id: String,
181}
182
183/// Pending outgoing message indicator.
184#[derive(Debug)]
185struct Pending {
186    /// The pending message, if not already sent.
187    gateway_event: Option<Message>,
188    /// Whether the pending gateway event is a heartbeat.
189    is_heartbeat: bool,
190}
191
192impl Pending {
193    /// Constructor for a pending gateway event.
194    const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
195        Some(Self {
196            gateway_event: Some(Message::Text(json)),
197            is_heartbeat,
198        })
199    }
200}
201
202/// Gateway API client responsible for up to 2500 guilds.
203///
204/// Shards are responsible for maintaining the gateway connection by processing
205/// events relevant to the operation of shards---such as requests from the
206/// gateway to re-connect or invalidate a session---and then to pass them on to
207/// the user.
208///
209/// Shards start out disconnected, but will on the first successful call to
210/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
211/// be repeatedly called in order for the shard to maintain its connection and
212/// update its internal state.
213///
214/// Shards go through an [identify queue][`queue`] that rate limits concurrent
215/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
216/// invalidates the shard's session and it is therefore **very important** to
217/// reuse the same queue for all shards.
218///
219/// # Sharding
220///
221/// A shard may not be connected to more than 2500 guilds, so large bots must
222/// split themselves across multiple shards. See the
223/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
224/// more info.
225///
226/// # Examples
227///
228/// Create and start a shard and print new and deleted messages:
229///
230/// ```no_run
231/// use std::env;
232/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
233///
234/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
235/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
236/// // token. Of course, this value may be passed into the program however is
237/// // preferred.
238/// let token = env::var("DISCORD_TOKEN")?;
239/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
240///
241/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
242///
243/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
244///     let Ok(event) = item else {
245///         tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
246///
247///         continue;
248///     };
249///
250///     match event {
251///         Event::MessageCreate(message) => {
252///             println!("message received with content: {}", message.content);
253///         }
254///         Event::MessageDelete(message) => {
255///             println!("message with ID {} deleted", message.id);
256///         }
257///         _ => {}
258///     }
259/// }
260/// # Ok(()) }
261/// ```
262///
263/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
264/// [gateway commands]: Shard::command
265/// [`poll_next`]: Shard::poll_next
266/// [`queue`]: crate::queue
267#[derive(Debug)]
268pub struct Shard<Q = InMemoryQueue> {
269    /// User provided configuration.
270    ///
271    /// Configurations are provided or created in shard initializing via
272    /// [`Shard::new`] or [`Shard::with_config`].
273    config: Config<Q>,
274    /// Future to establish a WebSocket connection with the Gateway.
275    connection_future: Option<ConnectionFuture>,
276    /// Websocket connection, which may be connected to Discord's gateway.
277    ///
278    /// The connection should only be dropped after it has returned `Ok(None)`
279    /// to comply with the WebSocket protocol.
280    connection: Option<Connection>,
281    /// Zstd decompressor.
282    #[cfg(feature = "zstd")]
283    decompressor: Decompressor,
284    /// Interval of how often the gateway would like the shard to send
285    /// heartbeats.
286    ///
287    /// The interval is received in the [`GatewayEvent::Hello`] event when
288    /// first opening a new [connection].
289    ///
290    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
291    /// [connection]: Self::connection
292    heartbeat_interval: Option<Interval>,
293    /// Whether an event has been received in the current heartbeat interval.
294    heartbeat_interval_event: bool,
295    /// ID of the shard.
296    id: ShardId,
297    /// Identify queue receiver.
298    identify_rx: Option<oneshot::Receiver<()>>,
299    /// Zlib decompressor.
300    #[allow(deprecated)]
301    #[cfg(all(
302        any(feature = "zlib-stock", feature = "zlib-simd"),
303        not(feature = "zstd")
304    ))]
305    inflater: Inflater,
306    /// Potentially pending outgoing message.
307    pending: Option<Pending>,
308    /// Recent heartbeat latency statistics.
309    ///
310    /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
311    /// may have changed, invalidating previous latency statistic.
312    ///
313    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
314    latency: Latency,
315    /// Command ratelimiter, if it was enabled via
316    /// [`Config::ratelimit_messages`].
317    ratelimiter: Option<CommandRatelimiter>,
318    /// Used for resuming connections.
319    resume_url: Option<Box<str>>,
320    /// Active session of the shard.
321    ///
322    /// The shard may not have an active session if it hasn't yet identified and
323    /// received a `READY` dispatch event response.
324    session: Option<Session>,
325    /// Current state of the shard.
326    state: ShardState,
327    /// Messages from the user to be relayed and sent over the Websocket
328    /// connection.
329    user_channel: MessageChannel,
330}
331
332impl Shard {
333    /// Create a new shard with the default configuration.
334    pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
335        Self::with_config(id, Config::new(token, intents))
336    }
337}
338
339impl<Q> Shard<Q> {
340    /// Create a new shard with the provided configuration.
341    pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
342        let session = config.take_session();
343        let mut resume_url = config.take_resume_url();
344        //ensure resume_url is only used if we have a session to resume
345        if session.is_none() {
346            resume_url = None;
347        }
348
349        Self {
350            config,
351            connection_future: None,
352            connection: None,
353            #[cfg(feature = "zstd")]
354            decompressor: Decompressor::new(),
355            heartbeat_interval: None,
356            heartbeat_interval_event: false,
357            id: shard_id,
358            identify_rx: None,
359            #[allow(deprecated)]
360            #[cfg(all(
361                any(feature = "zlib-stock", feature = "zlib-simd"),
362                not(feature = "zstd")
363            ))]
364            inflater: Inflater::new(),
365            pending: None,
366            latency: Latency::new(),
367            ratelimiter: None,
368            resume_url,
369            session,
370            state: ShardState::Disconnected {
371                reconnect_attempts: 0,
372            },
373            user_channel: MessageChannel::new(),
374        }
375    }
376
377    /// Immutable reference to the configuration used to instantiate this shard.
378    pub const fn config(&self) -> &Config<Q> {
379        &self.config
380    }
381
382    /// ID of the shard.
383    pub const fn id(&self) -> ShardId {
384        self.id
385    }
386
387    /// Zlib decompressor statistics.
388    ///
389    /// Reset when reconnecting to the gateway.
390    #[allow(deprecated)]
391    #[cfg(all(
392        any(feature = "zlib-stock", feature = "zlib-simd"),
393        not(feature = "zstd")
394    ))]
395    #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
396    pub const fn inflater(&self) -> &Inflater {
397        &self.inflater
398    }
399
400    /// State of the shard.
401    pub const fn state(&self) -> ShardState {
402        self.state
403    }
404
405    /// Shard latency statistics, including average latency and recent heartbeat
406    /// latency times.
407    ///
408    /// Reset when reconnecting to the gateway.
409    pub const fn latency(&self) -> &Latency {
410        &self.latency
411    }
412
413    /// Statistics about the number of available commands and when the command
414    /// ratelimiter will refresh.
415    ///
416    /// This won't be present if ratelimiting was disabled via
417    /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
418    ///
419    /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
420    pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
421        self.ratelimiter.as_ref()
422    }
423
424    /// Immutable reference to the gateways current resume URL.
425    ///
426    /// A resume URL might not be present if the shard had its session
427    /// invalidated and has not yet reconnected.
428    pub fn resume_url(&self) -> Option<&str> {
429        self.resume_url.as_deref()
430    }
431
432    /// Immutable reference to the active gateway session.
433    ///
434    /// An active session may not be present if the shard had its session
435    /// invalidated and has not yet reconnected.
436    pub const fn session(&self) -> Option<&Session> {
437        self.session.as_ref()
438    }
439
440    /// Queue a command to be sent to the gateway.
441    ///
442    /// Serializes the command and then calls [`send`].
443    ///
444    /// [`send`]: Self::send
445    #[allow(clippy::missing_panics_doc)]
446    pub fn command(&self, command: &impl Command) {
447        self.send(json::to_string(command).expect("serialization cannot fail"));
448    }
449
450    /// Queue a JSON encoded gateway event to be sent to the gateway.
451    #[allow(clippy::missing_panics_doc)]
452    pub fn send(&self, json: String) {
453        self.user_channel
454            .command_tx
455            .send(json)
456            .expect("channel open");
457    }
458
459    /// Queue a websocket close frame.
460    ///
461    /// Invalidates the session and shows the application's bot as offline if
462    /// the close frame code is `1000` or `1001`. Otherwise Discord will
463    /// continue showing the bot as online until its presence times out.
464    ///
465    /// To read all remaining messages, continue calling [`poll_next`] until it
466    /// returns [`Message::Close`].
467    ///
468    /// # Example
469    ///
470    /// Close the shard and process remaining messages:
471    ///
472    /// ```no_run
473    /// # use twilight_gateway::{Intents, Shard, ShardId};
474    /// # #[tokio::main] async fn main() {
475    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
476    /// use tokio_stream::StreamExt;
477    /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
478    ///
479    /// shard.close(CloseFrame::NORMAL);
480    ///
481    /// while let Some(item) = shard.next().await {
482    ///     match item {
483    ///         Ok(Message::Close(_)) => break,
484    ///         Ok(Message::Text(_)) => unimplemented!(),
485    ///         Err(source) => unimplemented!(),
486    ///     }
487    /// }
488    /// # }
489    /// ```
490    ///
491    /// [`poll_next`]: Shard::poll_next
492    pub fn close(&self, close_frame: CloseFrame<'static>) {
493        _ = self.user_channel.close_tx.try_send(close_frame);
494    }
495
496    /// Retrieve a channel to send messages over the shard to the gateway.
497    ///
498    /// This is primarily useful for sending to other tasks and threads where
499    /// the shard won't be available.
500    ///
501    /// # Example
502    ///
503    /// Queue a command in another process:
504    ///
505    /// ```no_run
506    /// # use twilight_gateway::{Intents, Shard, ShardId};
507    /// # #[tokio::main] async fn main() {
508    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
509    /// use tokio_stream::StreamExt;
510    ///
511    /// while let Some(item) = shard.next().await {
512    ///     match item {
513    ///         Ok(message) => {
514    ///             let sender = shard.sender();
515    ///             tokio::spawn(async move {
516    ///                 let command = unimplemented!();
517    ///                 sender.send(command);
518    ///             });
519    ///         }
520    ///         Err(source) => unimplemented!(),
521    ///     }
522    /// }
523    /// # }
524    /// ```
525    pub fn sender(&self) -> MessageSender {
526        self.user_channel.sender()
527    }
528
529    /// Update internal state from gateway disconnect.
530    fn disconnect(&mut self, initiator: CloseInitiator) {
531        // May not send any additional WebSocket messages.
532        self.heartbeat_interval = None;
533        self.ratelimiter = None;
534        // Abort identify.
535        self.identify_rx = None;
536        self.state = match initiator {
537            CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
538            _ => ShardState::Disconnected {
539                reconnect_attempts: 0,
540            },
541        };
542        if let CloseInitiator::Shard(frame) = initiator {
543            // Not resuming, drop session and resume URL.
544            // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
545            if matches!(frame.code, 1000 | 1001) {
546                self.resume_url = None;
547                self.session = None;
548            }
549            self.pending = Some(Pending {
550                gateway_event: Some(Message::Close(Some(frame))),
551                is_heartbeat: false,
552            });
553        }
554    }
555
556    /// Parse a JSON message into an event with minimal data for [processing].
557    ///
558    /// # Errors
559    ///
560    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
561    /// event isn't a recognized structure, which may be the case for new or
562    /// undocumented events.
563    ///
564    /// [processing]: Self::process
565    fn parse_event<T: DeserializeOwned>(
566        json: &str,
567    ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
568        json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
569            kind: ReceiveMessageErrorType::Deserializing {
570                event: json.to_owned(),
571            },
572            source: Some(Box::new(source)),
573        })
574    }
575}
576
577impl<Q: Queue> Shard<Q> {
578    /// Attempts to send due commands to the gateway.
579    ///
580    /// # Returns
581    ///
582    /// * `Poll::Pending` if sending is in progress
583    /// * `Poll::Ready(Ok)` if no more scheduled commands remain
584    /// * `Poll::Ready(Err)` if sending a command failed.
585    fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
586        loop {
587            if let Some(pending) = self.pending.as_mut() {
588                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
589
590                if let Some(message) = &pending.gateway_event {
591                    if let Some(ratelimiter) = self.ratelimiter.as_mut() {
592                        if message.is_text() && !pending.is_heartbeat {
593                            ready!(ratelimiter.poll_acquire(cx));
594                        }
595                    }
596
597                    let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
598                    Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
599                }
600
601                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
602
603                if pending.is_heartbeat {
604                    self.latency.record_sent();
605                }
606                self.pending = None;
607            }
608
609            if !self.state.is_disconnected() {
610                if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
611                    let frame = frame.expect("shard owns channel");
612
613                    tracing::debug!("sending close frame from user channel");
614                    self.disconnect(CloseInitiator::Shard(frame));
615
616                    continue;
617                }
618            }
619
620            if self
621                .heartbeat_interval
622                .as_mut()
623                .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
624            {
625                // Discord never responded after the last heartbeat, connection
626                // is failed or "zombied", see
627                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
628                // Note that unlike documented *any* event is okay; it does not
629                // have to be a heartbeat ACK.
630                if self.latency.sent().is_some() && !self.heartbeat_interval_event {
631                    tracing::info!("connection is failed or \"zombied\"");
632                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
633                } else {
634                    tracing::debug!("sending heartbeat");
635                    self.pending = Pending::text(
636                        json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
637                            .expect("serialization cannot fail"),
638                        true,
639                    );
640                    self.heartbeat_interval_event = false;
641                }
642
643                continue;
644            }
645
646            let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
647                ratelimiter.poll_available(cx).is_ready()
648            });
649
650            if not_ratelimited {
651                if let Some(Poll::Ready(canceled)) = self
652                    .identify_rx
653                    .as_mut()
654                    .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
655                {
656                    if canceled {
657                        self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
658                        continue;
659                    }
660
661                    tracing::debug!("sending identify");
662
663                    self.pending = Pending::text(
664                        json::to_string(&Identify::new(IdentifyInfo {
665                            compress: false,
666                            intents: self.config.intents(),
667                            large_threshold: self.config.large_threshold(),
668                            presence: self.config.presence().cloned(),
669                            properties: self
670                                .config
671                                .identify_properties()
672                                .cloned()
673                                .unwrap_or_else(default_identify_properties),
674                            shard: Some(self.id),
675                            token: self.config.token().to_owned(),
676                        }))
677                        .expect("serialization cannot fail"),
678                        false,
679                    );
680                    self.identify_rx = None;
681
682                    continue;
683                }
684            }
685
686            if not_ratelimited && self.state.is_identified() {
687                if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
688                    let command = command.expect("shard owns channel");
689
690                    tracing::debug!("sending command from user channel");
691                    self.pending = Some(Pending {
692                        gateway_event: Some(Message::Text(command)),
693                        is_heartbeat: false,
694                    });
695
696                    continue;
697                }
698            }
699
700            return Poll::Ready(Ok(()));
701        }
702    }
703
704    /// Updates the shard's internal state from a gateway event by recording
705    /// and/or responding to certain Discord events.
706    ///
707    /// # Errors
708    ///
709    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
710    /// gateway event isn't a recognized structure.
711    #[allow(clippy::too_many_lines)]
712    fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
713        let (raw_opcode, maybe_sequence, maybe_event_type) =
714            GatewayEventDeserializer::from_json(event)
715                .ok_or_else(|| ReceiveMessageError {
716                    kind: ReceiveMessageErrorType::Deserializing {
717                        event: event.to_owned(),
718                    },
719                    source: Some("missing opcode".into()),
720                })?
721                .into_parts();
722
723        if self.latency.sent().is_some() {
724            self.heartbeat_interval_event = true;
725        }
726
727        match OpCode::from(raw_opcode) {
728            Some(OpCode::Dispatch) => {
729                let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
730                    kind: ReceiveMessageErrorType::Deserializing {
731                        event: event.to_owned(),
732                    },
733                    source: Some("missing dispatch event type".into()),
734                })?;
735                let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
736                    kind: ReceiveMessageErrorType::Deserializing {
737                        event: event.to_owned(),
738                    },
739                    source: Some("missing sequence".into()),
740                })?;
741                tracing::debug!(%event_type, %sequence, "received dispatch");
742
743                match event_type.as_ref() {
744                    "READY" => {
745                        let event = Self::parse_event::<MinimalReady>(event)?;
746
747                        self.resume_url = Some(event.data.resume_gateway_url);
748                        self.session = Some(Session::new(sequence, event.data.session_id));
749                        self.state = ShardState::Active;
750                    }
751                    "RESUMED" => self.state = ShardState::Active,
752                    _ => {}
753                }
754
755                if let Some(session) = self.session.as_mut() {
756                    session.set_sequence(sequence);
757                }
758            }
759            Some(OpCode::Heartbeat) => {
760                tracing::debug!("received heartbeat");
761                self.pending = Pending::text(
762                    json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
763                        .expect("serialization cannot fail"),
764                    true,
765                );
766            }
767            Some(OpCode::HeartbeatAck) => {
768                let requested = self.latency.received().is_none() && self.latency.sent().is_some();
769                if requested {
770                    tracing::debug!("received heartbeat ack");
771                    self.latency.record_received();
772                } else {
773                    tracing::info!("received unrequested heartbeat ack");
774                }
775            }
776            Some(OpCode::Hello) => {
777                let event = Self::parse_event::<Hello>(event)?;
778                let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
779                // First heartbeat should have some jitter, see
780                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
781                let jitter = heartbeat_interval.mul_f64(fastrand::f64());
782                tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
783
784                if self.config().ratelimit_messages() {
785                    self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
786                }
787
788                let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
789                interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
790                self.heartbeat_interval = Some(interval);
791
792                // Reset `Latency` since the shard might have connected to a new
793                // remote which invalidates the recorded latencies.
794                self.latency = Latency::new();
795
796                if let Some(session) = &self.session {
797                    self.pending = Pending::text(
798                        json::to_string(&Resume::new(
799                            session.sequence(),
800                            session.id(),
801                            self.config.token(),
802                        ))
803                        .expect("serialization cannot fail"),
804                        false,
805                    );
806                    self.state = ShardState::Resuming;
807                } else {
808                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
809                }
810            }
811            Some(OpCode::InvalidSession) => {
812                let resumable = Self::parse_event(event)?.data;
813                tracing::debug!(resumable, "received invalid session");
814                if resumable {
815                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
816                } else {
817                    self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
818                }
819            }
820            Some(OpCode::Reconnect) => {
821                tracing::debug!("received reconnect");
822                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
823            }
824            _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
825        }
826
827        Ok(())
828    }
829}
830
831impl<Q: Queue + Unpin> Stream for Shard<Q> {
832    type Item = Result<Message, ReceiveMessageError>;
833
834    #[allow(clippy::too_many_lines)]
835    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
836        let message = loop {
837            match self.state {
838                ShardState::FatallyClosed => {
839                    _ = ready!(Pin::new(
840                        self.connection
841                            .as_mut()
842                            .expect("poll_next called after Poll::Ready(None)")
843                    )
844                    .poll_close(cx));
845                    self.connection = None;
846                    return Poll::Ready(None);
847                }
848                ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
849                    if self.connection_future.is_none() {
850                        let base_url = self
851                            .resume_url
852                            .as_deref()
853                            .or_else(|| self.config.proxy_url())
854                            .unwrap_or(GATEWAY_URL);
855                        let uri = format!(
856                            "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
857                        );
858
859                        tracing::debug!(url = base_url, "connecting to gateway");
860
861                        let tls = self.config.tls.clone();
862                        self.connection_future = Some(ConnectionFuture(Box::pin(async move {
863                            let secs = 2u8.saturating_pow(reconnect_attempts.into());
864                            time::sleep(Duration::from_secs(secs.into())).await;
865
866                            Ok(ClientBuilder::new()
867                                .uri(&uri)
868                                .expect("URL should be valid")
869                                .limits(Limits::unlimited())
870                                .connector(&tls)
871                                .connect()
872                                .await?
873                                .0)
874                        })));
875                    }
876
877                    let res =
878                        ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
879                    self.connection_future = None;
880                    match res {
881                        Ok(connection) => {
882                            self.connection = Some(connection);
883                            self.state = ShardState::Identifying;
884                            #[cfg(feature = "zstd")]
885                            self.decompressor.reset();
886                            #[allow(deprecated)]
887                            #[cfg(all(
888                                not(feature = "zstd"),
889                                any(feature = "zlib-stock", feature = "zlib-simd")
890                            ))]
891                            self.inflater.reset();
892                        }
893                        Err(source) => {
894                            self.resume_url = None;
895                            self.state = ShardState::Disconnected {
896                                reconnect_attempts: reconnect_attempts + 1,
897                            };
898
899                            return Poll::Ready(Some(Err(ReceiveMessageError {
900                                kind: ReceiveMessageErrorType::Reconnect,
901                                source: Some(Box::new(source)),
902                            })));
903                        }
904                    }
905                }
906                _ => {}
907            }
908
909            if ready!(self.poll_send(cx)).is_err() {
910                self.disconnect(CloseInitiator::Transport);
911                self.connection = None;
912
913                return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
914            }
915
916            match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
917                Some(Ok(message)) => {
918                    #[cfg(feature = "zstd")]
919                    if message.is_binary() {
920                        match self.decompressor.decompress(message.as_payload()) {
921                            Ok(message) => break Message::Text(message),
922                            Err(source) => {
923                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
924                                return Poll::Ready(Some(Err(
925                                    ReceiveMessageError::from_compression(source),
926                                )));
927                            }
928                        }
929                    }
930                    #[cfg(all(
931                        not(feature = "zstd"),
932                        any(feature = "zlib-stock", feature = "zlib-simd")
933                    ))]
934                    if message.is_binary() {
935                        match self.inflater.inflate(message.as_payload()) {
936                            Ok(Some(message)) => break Message::Text(message),
937                            Ok(None) => continue,
938                            Err(source) => {
939                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
940                                return Poll::Ready(Some(Err(
941                                    ReceiveMessageError::from_compression(source),
942                                )));
943                            }
944                        }
945                    }
946                    if let Some(message) = Message::from_websocket_msg(&message) {
947                        break message;
948                    }
949                }
950                // Discord, against recommendations from the WebSocket spec,
951                // does not send a close_notify prior to shutting down the TCP
952                // stream. This arm tries to gracefully handle this. The
953                // connection is considered unusable after encountering an io
954                // error, returning `None`.
955                #[cfg(any(
956                    feature = "native-tls",
957                    feature = "rustls-native-roots",
958                    feature = "rustls-platform-verifier",
959                    feature = "rustls-webpki-roots"
960                ))]
961                Some(Err(WebsocketError::Io(e)))
962                    if e.kind() == IoErrorKind::UnexpectedEof
963                        && self.config.proxy_url().is_none()
964                        && self.state.is_disconnected() =>
965                {
966                    continue
967                }
968                Some(Err(_)) => {
969                    self.disconnect(CloseInitiator::Transport);
970                    return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
971                }
972                None => {
973                    _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
974                    tracing::debug!("gateway WebSocket connection closed");
975                    // Unclean closure.
976                    if !self.state.is_disconnected() {
977                        self.disconnect(CloseInitiator::Transport);
978                    }
979                    self.connection = None;
980                }
981            }
982        };
983
984        match &message {
985            Message::Close(frame) => {
986                // tokio-websockets automatically replies to the close message.
987                tracing::debug!(?frame, "received WebSocket close message");
988                // Don't run `disconnect` if we initiated the close.
989                if !self.state.is_disconnected() {
990                    self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
991                }
992            }
993            Message::Text(event) => {
994                self.process(event)?;
995            }
996        }
997
998        Poll::Ready(Some(Ok(message)))
999    }
1000}
1001
1002/// Default identify properties to use when the user hasn't customized it in
1003/// [`Config::identify_properties`].
1004///
1005/// [`Config::identify_properties`]: Config::identify_properties
1006fn default_identify_properties() -> IdentifyProperties {
1007    IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012    use super::Shard;
1013    use static_assertions::{assert_impl_all, assert_not_impl_any};
1014    use std::fmt::Debug;
1015
1016    assert_impl_all!(Shard: Debug, Send);
1017    assert_not_impl_any!(Shard: Sync);
1018}