twilight_gateway/
shard.rs

1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13    any(feature = "zlib-stock", feature = "zlib-simd"),
14    not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18    channel::{MessageChannel, MessageSender},
19    error::{ReceiveMessageError, ReceiveMessageErrorType},
20    json,
21    latency::Latency,
22    queue::{InMemoryQueue, Queue},
23    ratelimiter::CommandRatelimiter,
24    session::Session,
25    Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30use std::{
31    env::consts::OS,
32    fmt,
33    future::Future,
34    io,
35    pin::Pin,
36    str,
37    task::{ready, Context, Poll},
38};
39use tokio::{
40    net::TcpStream,
41    sync::oneshot,
42    time::{self, Duration, Instant, Interval, MissedTickBehavior},
43};
44use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
45use twilight_model::gateway::{
46    event::GatewayEventDeserializer,
47    payload::{
48        incoming::Hello,
49        outgoing::{
50            identify::{IdentifyInfo, IdentifyProperties},
51            Heartbeat, Identify, Resume,
52        },
53    },
54    CloseCode, CloseFrame, Intents, OpCode,
55};
56
57/// URL of the Discord gateway.
58const GATEWAY_URL: &str = "wss://gateway.discord.gg";
59
60/// Query argument depending on enabled compression features.
61const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
62    "&compress=zstd-stream"
63} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
64    "&compress=zlib-stream"
65} else {
66    ""
67};
68
69/// [`tokio_websockets`] library Websocket connection.
70type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
71
72/// Wrapper struct around an `async fn` with a `Debug` implementation.
73struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
74
75impl fmt::Debug for ConnectionFuture {
76    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77        f.debug_tuple("ConnectionFuture")
78            .field(&"<async fn>")
79            .finish()
80    }
81}
82
83/// Close initiator of a websocket connection.
84#[derive(Clone, Debug)]
85enum CloseInitiator {
86    /// Gateway initiated the close.
87    ///
88    /// Contains an optional close code.
89    Gateway(Option<u16>),
90    /// Shard initiated the close.
91    ///
92    /// Contains a close code.
93    Shard(CloseFrame<'static>),
94    /// Transport error initiated the close.
95    Transport,
96}
97
98/// Current state of a [Shard].
99#[derive(Clone, Copy, Debug, Eq, PartialEq)]
100pub enum ShardState {
101    /// Shard is connected to the gateway with an active session.
102    Active,
103    /// Shard is disconnected from the gateway but may reconnect in the future.
104    ///
105    /// The websocket connection may still be open.
106    Disconnected {
107        /// Number of reconnection attempts that have been made.
108        reconnect_attempts: u8,
109    },
110    /// Shard has fatally closed.
111    ///
112    /// Possible reasons may be due to [failed authentication],
113    /// [invalid intents], or other reasons. Refer to the documentation for
114    /// [`CloseCode`] for possible reasons.
115    ///
116    /// [failed authentication]: CloseCode::AuthenticationFailed
117    /// [invalid intents]: CloseCode::InvalidIntents
118    FatallyClosed,
119    /// Shard is waiting to establish or resume a session.
120    Identifying,
121    /// Shard is replaying missed dispatch events.
122    ///
123    /// The shard is considered identified whilst resuming.
124    Resuming,
125}
126
127impl ShardState {
128    /// Determine the connection status from the close code.
129    ///
130    /// Defers to [`CloseCode::can_reconnect`] to determine whether the
131    /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
132    /// the close code is unknown.
133    fn from_close_code(close_code: Option<u16>) -> Self {
134        match close_code.map(CloseCode::try_from) {
135            Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
136            _ => Self::Disconnected {
137                reconnect_attempts: 0,
138            },
139        }
140    }
141
142    /// Whether the shard has disconnected but may reconnect in the future.
143    const fn is_disconnected(self) -> bool {
144        matches!(self, Self::Disconnected { .. })
145    }
146
147    /// Whether the shard is identified with an active session.
148    ///
149    /// `true` if the status is [`Active`] or [`Resuming`].
150    ///
151    /// [`Active`]: Self::Active
152    /// [`Resuming`]: Self::Resuming
153    pub const fn is_identified(self) -> bool {
154        matches!(self, Self::Active | Self::Resuming)
155    }
156}
157
158/// Gateway event with only minimal required data.
159#[derive(Deserialize)]
160struct MinimalEvent<T> {
161    /// Attached data of the gateway event.
162    #[serde(rename = "d")]
163    data: T,
164}
165
166/// Minimal [`Ready`] for light deserialization.
167///
168/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
169#[derive(Deserialize)]
170struct MinimalReady {
171    /// Used for resuming connections.
172    resume_gateway_url: Box<str>,
173    /// ID of the new identified session.
174    session_id: String,
175}
176
177/// Pending outgoing message indicator.
178#[derive(Debug)]
179struct Pending {
180    /// The pending message, if not already sent.
181    gateway_event: Option<Message>,
182    /// Whether the pending gateway event is a heartbeat.
183    is_heartbeat: bool,
184}
185
186impl Pending {
187    /// Constructor for a pending gateway event.
188    const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
189        Some(Self {
190            gateway_event: Some(Message::Text(json)),
191            is_heartbeat,
192        })
193    }
194}
195
196/// Gateway API client responsible for up to 2500 guilds.
197///
198/// Shards are responsible for maintaining the gateway connection by processing
199/// events relevant to the operation of shards---such as requests from the
200/// gateway to re-connect or invalidate a session---and then to pass them on to
201/// the user.
202///
203/// Shards start out disconnected, but will on the first successful call to
204/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
205/// be repeatedly called in order for the shard to maintain its connection and
206/// update its internal state.
207///
208/// Shards go through an [identify queue][`queue`] that rate limits concurrent
209/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
210/// invalidates the shard's session and it is therefore **very important** to
211/// reuse the same queue for all shards.
212///
213/// # Sharding
214///
215/// A shard may not be connected to more than 2500 guilds, so large bots must
216/// split themselves across multiple shards. See the
217/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
218/// more info.
219///
220/// # Examples
221///
222/// Create and start a shard and print new and deleted messages:
223///
224/// ```no_run
225/// use std::env;
226/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
227///
228/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
229/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
230/// // token. Of course, this value may be passed into the program however is
231/// // preferred.
232/// let token = env::var("DISCORD_TOKEN")?;
233/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
234///
235/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
236///
237/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
238///     let Ok(event) = item else {
239///         tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
240///
241///         continue;
242///     };
243///
244///     match event {
245///         Event::MessageCreate(message) => {
246///             println!("message received with content: {}", message.content);
247///         }
248///         Event::MessageDelete(message) => {
249///             println!("message with ID {} deleted", message.id);
250///         }
251///         _ => {}
252///     }
253/// }
254/// # Ok(()) }
255/// ```
256///
257/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
258/// [gateway commands]: Shard::command
259/// [`poll_next`]: Shard::poll_next
260/// [`queue`]: crate::queue
261#[derive(Debug)]
262pub struct Shard<Q = InMemoryQueue> {
263    /// User provided configuration.
264    ///
265    /// Configurations are provided or created in shard initializing via
266    /// [`Shard::new`] or [`Shard::with_config`].
267    config: Config<Q>,
268    /// Future to establish a WebSocket connection with the Gateway.
269    connection_future: Option<ConnectionFuture>,
270    /// Websocket connection, which may be connected to Discord's gateway.
271    ///
272    /// The connection should only be dropped after it has returned `Ok(None)`
273    /// to comply with the WebSocket protocol.
274    connection: Option<Connection>,
275    /// Zstd decompressor.
276    #[cfg(feature = "zstd")]
277    decompressor: Decompressor,
278    /// Interval of how often the gateway would like the shard to send
279    /// heartbeats.
280    ///
281    /// The interval is received in the [`GatewayEvent::Hello`] event when
282    /// first opening a new [connection].
283    ///
284    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
285    /// [connection]: Self::connection
286    heartbeat_interval: Option<Interval>,
287    /// Whether an event has been received in the current heartbeat interval.
288    heartbeat_interval_event: bool,
289    /// ID of the shard.
290    id: ShardId,
291    /// Identify queue receiver.
292    identify_rx: Option<oneshot::Receiver<()>>,
293    /// Zlib decompressor.
294    #[allow(deprecated)]
295    #[cfg(all(
296        any(feature = "zlib-stock", feature = "zlib-simd"),
297        not(feature = "zstd")
298    ))]
299    inflater: Inflater,
300    /// Potentially pending outgoing message.
301    pending: Option<Pending>,
302    /// Recent heartbeat latency statistics.
303    ///
304    /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
305    /// may have changed, invalidating previous latency statistic.
306    ///
307    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
308    latency: Latency,
309    /// Command ratelimiter, if it was enabled via
310    /// [`Config::ratelimit_messages`].
311    ratelimiter: Option<CommandRatelimiter>,
312    /// Used for resuming connections.
313    resume_url: Option<Box<str>>,
314    /// Active session of the shard.
315    ///
316    /// The shard may not have an active session if it hasn't yet identified and
317    /// received a `READY` dispatch event response.
318    session: Option<Session>,
319    /// Current state of the shard.
320    state: ShardState,
321    /// Messages from the user to be relayed and sent over the Websocket
322    /// connection.
323    user_channel: MessageChannel,
324}
325
326impl Shard {
327    /// Create a new shard with the default configuration.
328    pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
329        Self::with_config(id, Config::new(token, intents))
330    }
331}
332
333impl<Q> Shard<Q> {
334    /// Create a new shard with the provided configuration.
335    pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
336        let session = config.take_session();
337        let mut resume_url = config.take_resume_url();
338        //ensure resume_url is only used if we have a session to resume
339        if session.is_none() {
340            resume_url = None;
341        }
342
343        Self {
344            config,
345            connection_future: None,
346            connection: None,
347            #[cfg(feature = "zstd")]
348            decompressor: Decompressor::new(),
349            heartbeat_interval: None,
350            heartbeat_interval_event: false,
351            id: shard_id,
352            identify_rx: None,
353            #[allow(deprecated)]
354            #[cfg(all(
355                any(feature = "zlib-stock", feature = "zlib-simd"),
356                not(feature = "zstd")
357            ))]
358            inflater: Inflater::new(),
359            pending: None,
360            latency: Latency::new(),
361            ratelimiter: None,
362            resume_url,
363            session,
364            state: ShardState::Disconnected {
365                reconnect_attempts: 0,
366            },
367            user_channel: MessageChannel::new(),
368        }
369    }
370
371    /// Immutable reference to the configuration used to instantiate this shard.
372    pub const fn config(&self) -> &Config<Q> {
373        &self.config
374    }
375
376    /// ID of the shard.
377    pub const fn id(&self) -> ShardId {
378        self.id
379    }
380
381    /// Zlib decompressor statistics.
382    ///
383    /// Reset when reconnecting to the gateway.
384    #[allow(deprecated)]
385    #[cfg(all(
386        any(feature = "zlib-stock", feature = "zlib-simd"),
387        not(feature = "zstd")
388    ))]
389    #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
390    pub const fn inflater(&self) -> &Inflater {
391        &self.inflater
392    }
393
394    /// State of the shard.
395    pub const fn state(&self) -> ShardState {
396        self.state
397    }
398
399    /// Shard latency statistics, including average latency and recent heartbeat
400    /// latency times.
401    ///
402    /// Reset when reconnecting to the gateway.
403    pub const fn latency(&self) -> &Latency {
404        &self.latency
405    }
406
407    /// Statistics about the number of available commands and when the command
408    /// ratelimiter will refresh.
409    ///
410    /// This won't be present if ratelimiting was disabled via
411    /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
412    ///
413    /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
414    pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
415        self.ratelimiter.as_ref()
416    }
417
418    /// Immutable reference to the gateways current resume URL.
419    ///
420    /// A resume URL might not be present if the shard had its session
421    /// invalidated and has not yet reconnected.
422    pub fn resume_url(&self) -> Option<&str> {
423        self.resume_url.as_deref()
424    }
425
426    /// Immutable reference to the active gateway session.
427    ///
428    /// An active session may not be present if the shard had its session
429    /// invalidated and has not yet reconnected.
430    pub const fn session(&self) -> Option<&Session> {
431        self.session.as_ref()
432    }
433
434    /// Queue a command to be sent to the gateway.
435    ///
436    /// Serializes the command and then calls [`send`].
437    ///
438    /// [`send`]: Self::send
439    #[allow(clippy::missing_panics_doc)]
440    pub fn command(&self, command: &impl Command) {
441        self.send(json::to_string(command).expect("serialization cannot fail"));
442    }
443
444    /// Queue a JSON encoded gateway event to be sent to the gateway.
445    #[allow(clippy::missing_panics_doc)]
446    pub fn send(&self, json: String) {
447        self.user_channel
448            .command_tx
449            .send(json)
450            .expect("channel open");
451    }
452
453    /// Queue a websocket close frame.
454    ///
455    /// Invalidates the session and shows the application's bot as offline if
456    /// the close frame code is `1000` or `1001`. Otherwise Discord will
457    /// continue showing the bot as online until its presence times out.
458    ///
459    /// To read all remaining messages, continue calling [`poll_next`] until it
460    /// returns [`Message::Close`].
461    ///
462    /// # Example
463    ///
464    /// Close the shard and process remaining messages:
465    ///
466    /// ```no_run
467    /// # use twilight_gateway::{Intents, Shard, ShardId};
468    /// # #[tokio::main] async fn main() {
469    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
470    /// use tokio_stream::StreamExt;
471    /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
472    ///
473    /// shard.close(CloseFrame::NORMAL);
474    ///
475    /// while let Some(item) = shard.next().await {
476    ///     match item {
477    ///         Ok(Message::Close(_)) => break,
478    ///         Ok(Message::Text(_)) => unimplemented!(),
479    ///         Err(source) => unimplemented!(),
480    ///     }
481    /// }
482    /// # }
483    /// ```
484    ///
485    /// [`poll_next`]: Shard::poll_next
486    pub fn close(&self, close_frame: CloseFrame<'static>) {
487        _ = self.user_channel.close_tx.try_send(close_frame);
488    }
489
490    /// Retrieve a channel to send messages over the shard to the gateway.
491    ///
492    /// This is primarily useful for sending to other tasks and threads where
493    /// the shard won't be available.
494    ///
495    /// # Example
496    ///
497    /// Queue a command in another process:
498    ///
499    /// ```no_run
500    /// # use twilight_gateway::{Intents, Shard, ShardId};
501    /// # #[tokio::main] async fn main() {
502    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
503    /// use tokio_stream::StreamExt;
504    ///
505    /// while let Some(item) = shard.next().await {
506    ///     match item {
507    ///         Ok(message) => {
508    ///             let sender = shard.sender();
509    ///             tokio::spawn(async move {
510    ///                 let command = unimplemented!();
511    ///                 sender.send(command);
512    ///             });
513    ///         }
514    ///         Err(source) => unimplemented!(),
515    ///     }
516    /// }
517    /// # }
518    /// ```
519    pub fn sender(&self) -> MessageSender {
520        self.user_channel.sender()
521    }
522
523    /// Update internal state from gateway disconnect.
524    fn disconnect(&mut self, initiator: CloseInitiator) {
525        // May not send any additional WebSocket messages.
526        self.heartbeat_interval = None;
527        self.ratelimiter = None;
528        // Abort identify.
529        self.identify_rx = None;
530        self.state = match initiator {
531            CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
532            _ => ShardState::Disconnected {
533                reconnect_attempts: 0,
534            },
535        };
536        if let CloseInitiator::Shard(frame) = initiator {
537            // Not resuming, drop session and resume URL.
538            // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
539            if matches!(frame.code, 1000 | 1001) {
540                self.resume_url = None;
541                self.session = None;
542            }
543            self.pending = Some(Pending {
544                gateway_event: Some(Message::Close(Some(frame))),
545                is_heartbeat: false,
546            });
547        }
548    }
549
550    /// Parse a JSON message into an event with minimal data for [processing].
551    ///
552    /// # Errors
553    ///
554    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
555    /// event isn't a recognized structure, which may be the case for new or
556    /// undocumented events.
557    ///
558    /// [processing]: Self::process
559    fn parse_event<T: DeserializeOwned>(
560        json: &str,
561    ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
562        json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
563            kind: ReceiveMessageErrorType::Deserializing {
564                event: json.to_owned(),
565            },
566            source: Some(Box::new(source)),
567        })
568    }
569}
570
571impl<Q: Queue> Shard<Q> {
572    /// Attempts to send due commands to the gateway.
573    ///
574    /// # Returns
575    ///
576    /// * `Poll::Pending` if sending is in progress
577    /// * `Poll::Ready(Ok)` if no more scheduled commands remain
578    /// * `Poll::Ready(Err)` if sending a command failed.
579    fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
580        loop {
581            if let Some(pending) = self.pending.as_mut() {
582                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
583
584                if let Some(message) = &pending.gateway_event {
585                    if let Some(ratelimiter) = self.ratelimiter.as_mut() {
586                        if message.is_text() && !pending.is_heartbeat {
587                            ready!(ratelimiter.poll_acquire(cx));
588                        }
589                    }
590
591                    let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
592                    Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
593                }
594
595                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
596
597                if pending.is_heartbeat {
598                    self.latency.record_sent();
599                }
600                self.pending = None;
601            }
602
603            if !self.state.is_disconnected() {
604                if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
605                    let frame = frame.expect("shard owns channel");
606
607                    tracing::debug!("sending close frame from user channel");
608                    self.disconnect(CloseInitiator::Shard(frame));
609
610                    continue;
611                }
612            }
613
614            if self
615                .heartbeat_interval
616                .as_mut()
617                .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
618            {
619                // Discord never responded after the last heartbeat, connection
620                // is failed or "zombied", see
621                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
622                // Note that unlike documented *any* event is okay; it does not
623                // have to be a heartbeat ACK.
624                if self.latency.sent().is_some() && !self.heartbeat_interval_event {
625                    tracing::info!("connection is failed or \"zombied\"");
626
627                    return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
628                }
629
630                tracing::debug!("sending heartbeat");
631                self.pending = Pending::text(
632                    json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
633                        .expect("serialization cannot fail"),
634                    true,
635                );
636                self.heartbeat_interval_event = false;
637
638                continue;
639            }
640
641            let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
642                ratelimiter.poll_available(cx).is_ready()
643            });
644
645            if not_ratelimited {
646                if let Some(Poll::Ready(canceled)) = self
647                    .identify_rx
648                    .as_mut()
649                    .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
650                {
651                    if canceled {
652                        self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
653                        continue;
654                    }
655
656                    tracing::debug!("sending identify");
657
658                    self.pending = Pending::text(
659                        json::to_string(&Identify::new(IdentifyInfo {
660                            compress: false,
661                            intents: self.config.intents(),
662                            large_threshold: self.config.large_threshold(),
663                            presence: self.config.presence().cloned(),
664                            properties: self
665                                .config
666                                .identify_properties()
667                                .cloned()
668                                .unwrap_or_else(default_identify_properties),
669                            shard: Some(self.id),
670                            token: self.config.token().to_owned(),
671                        }))
672                        .expect("serialization cannot fail"),
673                        false,
674                    );
675                    self.identify_rx = None;
676
677                    continue;
678                }
679            }
680
681            if not_ratelimited && self.state.is_identified() {
682                if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
683                    let command = command.expect("shard owns channel");
684
685                    tracing::debug!("sending command from user channel");
686                    self.pending = Some(Pending {
687                        gateway_event: Some(Message::Text(command)),
688                        is_heartbeat: false,
689                    });
690
691                    continue;
692                }
693            }
694
695            return Poll::Ready(Ok(()));
696        }
697    }
698
699    /// Updates the shard's internal state from a gateway event by recording
700    /// and/or responding to certain Discord events.
701    ///
702    /// # Errors
703    ///
704    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
705    /// gateway event isn't a recognized structure.
706    #[allow(clippy::too_many_lines)]
707    fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
708        let (raw_opcode, maybe_sequence, maybe_event_type) =
709            GatewayEventDeserializer::from_json(event)
710                .ok_or_else(|| ReceiveMessageError {
711                    kind: ReceiveMessageErrorType::Deserializing {
712                        event: event.to_owned(),
713                    },
714                    source: Some("missing opcode".into()),
715                })?
716                .into_parts();
717
718        if self.latency.sent().is_some() {
719            self.heartbeat_interval_event = true;
720        }
721
722        match OpCode::from(raw_opcode) {
723            Some(OpCode::Dispatch) => {
724                let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
725                    kind: ReceiveMessageErrorType::Deserializing {
726                        event: event.to_owned(),
727                    },
728                    source: Some("missing dispatch event type".into()),
729                })?;
730                let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
731                    kind: ReceiveMessageErrorType::Deserializing {
732                        event: event.to_owned(),
733                    },
734                    source: Some("missing sequence".into()),
735                })?;
736                tracing::debug!(%event_type, %sequence, "received dispatch");
737
738                match event_type.as_ref() {
739                    "READY" => {
740                        let event = Self::parse_event::<MinimalReady>(event)?;
741
742                        self.resume_url = Some(event.data.resume_gateway_url);
743                        self.session = Some(Session::new(sequence, event.data.session_id));
744                        self.state = ShardState::Active;
745                    }
746                    "RESUMED" => self.state = ShardState::Active,
747                    _ => {}
748                }
749
750                if let Some(session) = self.session.as_mut() {
751                    session.set_sequence(sequence);
752                }
753            }
754            Some(OpCode::Heartbeat) => {
755                tracing::debug!("received heartbeat");
756                self.pending = Pending::text(
757                    json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
758                        .expect("serialization cannot fail"),
759                    true,
760                );
761            }
762            Some(OpCode::HeartbeatAck) => {
763                let requested = self.latency.received().is_none() && self.latency.sent().is_some();
764                if requested {
765                    tracing::debug!("received heartbeat ack");
766                    self.latency.record_received();
767                } else {
768                    tracing::info!("received unrequested heartbeat ack");
769                }
770            }
771            Some(OpCode::Hello) => {
772                let event = Self::parse_event::<Hello>(event)?;
773                let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
774                // First heartbeat should have some jitter, see
775                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
776                let jitter = heartbeat_interval.mul_f64(fastrand::f64());
777                tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
778
779                if self.config().ratelimit_messages() {
780                    self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
781                }
782
783                let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
784                interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
785                self.heartbeat_interval = Some(interval);
786
787                // Reset `Latency` since the shard might have connected to a new
788                // remote which invalidates the recorded latencies.
789                self.latency = Latency::new();
790
791                if let Some(session) = &self.session {
792                    self.pending = Pending::text(
793                        json::to_string(&Resume::new(
794                            session.sequence(),
795                            session.id(),
796                            self.config.token(),
797                        ))
798                        .expect("serialization cannot fail"),
799                        false,
800                    );
801                    self.state = ShardState::Resuming;
802                } else {
803                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
804                }
805            }
806            Some(OpCode::InvalidSession) => {
807                let resumable = Self::parse_event(event)?.data;
808                tracing::debug!(resumable, "received invalid session");
809                if resumable {
810                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
811                } else {
812                    self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
813                }
814            }
815            Some(OpCode::Reconnect) => {
816                tracing::debug!("received reconnect");
817                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
818            }
819            _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
820        }
821
822        Ok(())
823    }
824}
825
826impl<Q: Queue + Unpin> Stream for Shard<Q> {
827    type Item = Result<Message, ReceiveMessageError>;
828
829    #[allow(clippy::too_many_lines)]
830    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
831        let message = loop {
832            match self.state {
833                ShardState::FatallyClosed => {
834                    _ = ready!(Pin::new(
835                        self.connection
836                            .as_mut()
837                            .expect("poll_next called after Poll::Ready(None)")
838                    )
839                    .poll_close(cx));
840                    self.connection = None;
841                    return Poll::Ready(None);
842                }
843                ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
844                    if self.connection_future.is_none() {
845                        let base_url = self
846                            .resume_url
847                            .as_deref()
848                            .or_else(|| self.config.proxy_url())
849                            .unwrap_or(GATEWAY_URL);
850                        let uri = format!(
851                            "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
852                        );
853
854                        tracing::debug!(url = base_url, "connecting to gateway");
855
856                        let tls = self.config.tls.clone();
857                        self.connection_future = Some(ConnectionFuture(Box::pin(async move {
858                            let secs = 2u8.saturating_pow(reconnect_attempts.into());
859                            time::sleep(Duration::from_secs(secs.into())).await;
860
861                            Ok(ClientBuilder::new()
862                                .uri(&uri)
863                                .expect("URL should be valid")
864                                .limits(Limits::unlimited())
865                                .connector(&tls)
866                                .connect()
867                                .await?
868                                .0)
869                        })));
870                    }
871
872                    let res =
873                        ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
874                    self.connection_future = None;
875                    match res {
876                        Ok(connection) => {
877                            self.connection = Some(connection);
878                            self.state = ShardState::Identifying;
879                            #[cfg(feature = "zstd")]
880                            self.decompressor.reset();
881                            #[allow(deprecated)]
882                            #[cfg(all(
883                                not(feature = "zstd"),
884                                any(feature = "zlib-stock", feature = "zlib-simd")
885                            ))]
886                            self.inflater.reset();
887                        }
888                        Err(source) => {
889                            self.resume_url = None;
890                            self.state = ShardState::Disconnected {
891                                reconnect_attempts: reconnect_attempts + 1,
892                            };
893
894                            return Poll::Ready(Some(Err(ReceiveMessageError {
895                                kind: ReceiveMessageErrorType::Reconnect,
896                                source: Some(Box::new(source)),
897                            })));
898                        }
899                    }
900                }
901                _ => {}
902            }
903
904            if ready!(self.poll_send(cx)).is_err() {
905                self.disconnect(CloseInitiator::Transport);
906                self.connection = None;
907
908                return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
909            }
910
911            match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
912                Some(Ok(message)) => {
913                    #[cfg(feature = "zstd")]
914                    if message.is_binary() {
915                        match self.decompressor.decompress(message.as_payload()) {
916                            Ok(message) => break Message::Text(message),
917                            Err(source) => {
918                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
919                                return Poll::Ready(Some(Err(
920                                    ReceiveMessageError::from_compression(source),
921                                )));
922                            }
923                        }
924                    }
925                    #[cfg(all(
926                        not(feature = "zstd"),
927                        any(feature = "zlib-stock", feature = "zlib-simd")
928                    ))]
929                    if message.is_binary() {
930                        match self.inflater.inflate(message.as_payload()) {
931                            Ok(Some(message)) => break Message::Text(message),
932                            Ok(None) => continue,
933                            Err(source) => {
934                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
935                                return Poll::Ready(Some(Err(
936                                    ReceiveMessageError::from_compression(source),
937                                )));
938                            }
939                        }
940                    }
941                    if let Some(message) = Message::from_websocket_msg(&message) {
942                        break message;
943                    }
944                }
945                Some(Err(_)) if self.state.is_disconnected() => {}
946                Some(Err(_)) => {
947                    self.disconnect(CloseInitiator::Transport);
948                    return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
949                }
950                None => {
951                    _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
952                    tracing::debug!("gateway WebSocket connection closed");
953                    // Unclean closure.
954                    if !self.state.is_disconnected() {
955                        self.disconnect(CloseInitiator::Transport);
956                    }
957                    self.connection = None;
958                }
959            }
960        };
961
962        match &message {
963            Message::Close(frame) => {
964                // tokio-websockets automatically replies to the close message.
965                tracing::debug!(?frame, "received WebSocket close message");
966                // Don't run `disconnect` if we initiated the close.
967                if !self.state.is_disconnected() {
968                    self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
969                }
970            }
971            Message::Text(event) => {
972                self.process(event)?;
973            }
974        }
975
976        Poll::Ready(Some(Ok(message)))
977    }
978}
979
980/// Default identify properties to use when the user hasn't customized it in
981/// [`Config::identify_properties`].
982///
983/// [`Config::identify_properties`]: Config::identify_properties
984fn default_identify_properties() -> IdentifyProperties {
985    IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
986}
987
988#[cfg(test)]
989mod tests {
990    use super::Shard;
991    use static_assertions::{assert_impl_all, assert_not_impl_any};
992    use std::fmt::Debug;
993
994    assert_impl_all!(Shard: Debug, Send);
995    assert_not_impl_any!(Shard: Sync);
996}