twilight_gateway/
shard.rs

1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13    any(feature = "zlib-stock", feature = "zlib-simd"),
14    not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18    channel::{MessageChannel, MessageSender},
19    error::{ReceiveMessageError, ReceiveMessageErrorType},
20    json,
21    latency::Latency,
22    queue::{InMemoryQueue, Queue},
23    ratelimiter::CommandRatelimiter,
24    session::Session,
25    Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30use std::{
31    env::consts::OS,
32    fmt,
33    future::Future,
34    pin::Pin,
35    str,
36    task::{ready, Context, Poll},
37};
38use tokio::{
39    net::TcpStream,
40    sync::oneshot,
41    time::{self, Duration, Instant, Interval, MissedTickBehavior},
42};
43use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
44use twilight_model::gateway::{
45    event::GatewayEventDeserializer,
46    payload::{
47        incoming::Hello,
48        outgoing::{
49            identify::{IdentifyInfo, IdentifyProperties},
50            Heartbeat, Identify, Resume,
51        },
52    },
53    CloseCode, CloseFrame, Intents, OpCode,
54};
55
56/// URL of the Discord gateway.
57const GATEWAY_URL: &str = "wss://gateway.discord.gg";
58
59/// Query argument depending on enabled compression features.
60const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
61    "&compress=zstd-stream"
62} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
63    "&compress=zlib-stream"
64} else {
65    ""
66};
67
68/// [`tokio_websockets`] library Websocket connection.
69type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
70
71/// Wrapper struct around an `async fn` with a `Debug` implementation.
72struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
73
74impl fmt::Debug for ConnectionFuture {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        f.debug_tuple("ConnectionFuture")
77            .field(&"<async fn>")
78            .finish()
79    }
80}
81
82/// Close initiator of a websocket connection.
83#[derive(Clone, Debug)]
84enum CloseInitiator {
85    /// Gateway initiated the close.
86    ///
87    /// Contains an optional close code.
88    Gateway(Option<u16>),
89    /// Shard initiated the close.
90    ///
91    /// Contains a close code.
92    Shard(CloseFrame<'static>),
93    /// Transport error initiated the close.
94    Transport,
95}
96
97/// Current state of a [Shard].
98#[derive(Clone, Copy, Debug, Eq, PartialEq)]
99pub enum ShardState {
100    /// Shard is connected to the gateway with an active session.
101    Active,
102    /// Shard is disconnected from the gateway but may reconnect in the future.
103    ///
104    /// The websocket connection may still be open.
105    Disconnected {
106        /// Number of reconnection attempts that have been made.
107        reconnect_attempts: u8,
108    },
109    /// Shard has fatally closed.
110    ///
111    /// Possible reasons may be due to [failed authentication],
112    /// [invalid intents], or other reasons. Refer to the documentation for
113    /// [`CloseCode`] for possible reasons.
114    ///
115    /// [failed authentication]: CloseCode::AuthenticationFailed
116    /// [invalid intents]: CloseCode::InvalidIntents
117    FatallyClosed,
118    /// Shard is waiting to establish or resume a session.
119    Identifying,
120    /// Shard is replaying missed dispatch events.
121    ///
122    /// The shard is considered identified whilst resuming.
123    Resuming,
124}
125
126impl ShardState {
127    /// Determine the connection status from the close code.
128    ///
129    /// Defers to [`CloseCode::can_reconnect`] to determine whether the
130    /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
131    /// the close code is unknown.
132    fn from_close_code(close_code: Option<u16>) -> Self {
133        match close_code.map(CloseCode::try_from) {
134            Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
135            _ => Self::Disconnected {
136                reconnect_attempts: 0,
137            },
138        }
139    }
140
141    /// Whether the shard has disconnected but may reconnect in the future.
142    const fn is_disconnected(self) -> bool {
143        matches!(self, Self::Disconnected { .. })
144    }
145
146    /// Whether the shard is identified with an active session.
147    ///
148    /// `true` if the status is [`Active`] or [`Resuming`].
149    ///
150    /// [`Active`]: Self::Active
151    /// [`Resuming`]: Self::Resuming
152    pub const fn is_identified(self) -> bool {
153        matches!(self, Self::Active | Self::Resuming)
154    }
155}
156
157/// Gateway event with only minimal required data.
158#[derive(Deserialize)]
159struct MinimalEvent<T> {
160    /// Attached data of the gateway event.
161    #[serde(rename = "d")]
162    data: T,
163}
164
165/// Minimal [`Ready`] for light deserialization.
166///
167/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
168#[derive(Deserialize)]
169struct MinimalReady {
170    /// Used for resuming connections.
171    resume_gateway_url: Box<str>,
172    /// ID of the new identified session.
173    session_id: String,
174}
175
176/// Pending outgoing message indicator.
177#[derive(Debug)]
178struct Pending {
179    /// The pending message, if not already sent.
180    gateway_event: Option<Message>,
181    /// Whether the pending gateway event is a heartbeat.
182    is_heartbeat: bool,
183}
184
185impl Pending {
186    /// Constructor for a pending gateway event.
187    const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
188        Some(Self {
189            gateway_event: Some(Message::Text(json)),
190            is_heartbeat,
191        })
192    }
193}
194
195/// Gateway API client responsible for up to 2500 guilds.
196///
197/// Shards are responsible for maintaining the gateway connection by processing
198/// events relevant to the operation of shards---such as requests from the
199/// gateway to re-connect or invalidate a session---and then to pass them on to
200/// the user.
201///
202/// Shards start out disconnected, but will on the first successful call to
203/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
204/// be repeatedly called in order for the shard to maintain its connection and
205/// update its internal state.
206///
207/// Shards go through an [identify queue][`queue`] that rate limits concurrent
208/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
209/// invalidates the shard's session and it is therefore **very important** to
210/// reuse the same queue for all shards.
211///
212/// # Sharding
213///
214/// A shard may not be connected to more than 2500 guilds, so large bots must
215/// split themselves across multiple shards. See the
216/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
217/// more info.
218///
219/// # Examples
220///
221/// Create and start a shard and print new and deleted messages:
222///
223/// ```no_run
224/// use std::env;
225/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
226///
227/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
228/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
229/// // token. Of course, this value may be passed into the program however is
230/// // preferred.
231/// let token = env::var("DISCORD_TOKEN")?;
232/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
233///
234/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
235///
236/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
237///     let Ok(event) = item else {
238///         tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
239///
240///         continue;
241///     };
242///
243///     match event {
244///         Event::MessageCreate(message) => {
245///             println!("message received with content: {}", message.content);
246///         }
247///         Event::MessageDelete(message) => {
248///             println!("message with ID {} deleted", message.id);
249///         }
250///         _ => {}
251///     }
252/// }
253/// # Ok(()) }
254/// ```
255///
256/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
257/// [gateway commands]: Shard::command
258/// [`poll_next`]: Shard::poll_next
259/// [`queue`]: crate::queue
260#[derive(Debug)]
261pub struct Shard<Q = InMemoryQueue> {
262    /// User provided configuration.
263    ///
264    /// Configurations are provided or created in shard initializing via
265    /// [`Shard::new`] or [`Shard::with_config`].
266    config: Config<Q>,
267    /// Future to establish a WebSocket connection with the Gateway.
268    connection_future: Option<ConnectionFuture>,
269    /// Websocket connection, which may be connected to Discord's gateway.
270    ///
271    /// The connection should only be dropped after it has returned `Ok(None)`
272    /// to comply with the WebSocket protocol.
273    connection: Option<Connection>,
274    /// Zstd decompressor.
275    #[cfg(feature = "zstd")]
276    decompressor: Decompressor,
277    /// Interval of how often the gateway would like the shard to send
278    /// heartbeats.
279    ///
280    /// The interval is received in the [`GatewayEvent::Hello`] event when
281    /// first opening a new [connection].
282    ///
283    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
284    /// [connection]: Self::connection
285    heartbeat_interval: Option<Interval>,
286    /// Whether an event has been received in the current heartbeat interval.
287    heartbeat_interval_event: bool,
288    /// ID of the shard.
289    id: ShardId,
290    /// Identify queue receiver.
291    identify_rx: Option<oneshot::Receiver<()>>,
292    /// Zlib decompressor.
293    #[allow(deprecated)]
294    #[cfg(all(
295        any(feature = "zlib-stock", feature = "zlib-simd"),
296        not(feature = "zstd")
297    ))]
298    inflater: Inflater,
299    /// Potentially pending outgoing message.
300    pending: Option<Pending>,
301    /// Recent heartbeat latency statistics.
302    ///
303    /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
304    /// may have changed, invalidating previous latency statistic.
305    ///
306    /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
307    latency: Latency,
308    /// Command ratelimiter, if it was enabled via
309    /// [`Config::ratelimit_messages`].
310    ratelimiter: Option<CommandRatelimiter>,
311    /// Used for resuming connections.
312    resume_url: Option<Box<str>>,
313    /// Active session of the shard.
314    ///
315    /// The shard may not have an active session if it hasn't yet identified and
316    /// received a `READY` dispatch event response.
317    session: Option<Session>,
318    /// Current state of the shard.
319    state: ShardState,
320    /// Messages from the user to be relayed and sent over the Websocket
321    /// connection.
322    user_channel: MessageChannel,
323}
324
325impl Shard {
326    /// Create a new shard with the default configuration.
327    pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
328        Self::with_config(id, Config::new(token, intents))
329    }
330}
331
332impl<Q> Shard<Q> {
333    /// Create a new shard with the provided configuration.
334    pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
335        let session = config.take_session();
336        let mut resume_url = config.take_resume_url();
337        //ensure resume_url is only used if we have a session to resume
338        if session.is_none() {
339            resume_url = None;
340        }
341
342        Self {
343            config,
344            connection_future: None,
345            connection: None,
346            #[cfg(feature = "zstd")]
347            decompressor: Decompressor::new(),
348            heartbeat_interval: None,
349            heartbeat_interval_event: false,
350            id: shard_id,
351            identify_rx: None,
352            #[allow(deprecated)]
353            #[cfg(all(
354                any(feature = "zlib-stock", feature = "zlib-simd"),
355                not(feature = "zstd")
356            ))]
357            inflater: Inflater::new(),
358            pending: None,
359            latency: Latency::new(),
360            ratelimiter: None,
361            resume_url,
362            session,
363            state: ShardState::Disconnected {
364                reconnect_attempts: 0,
365            },
366            user_channel: MessageChannel::new(),
367        }
368    }
369
370    /// Immutable reference to the configuration used to instantiate this shard.
371    pub const fn config(&self) -> &Config<Q> {
372        &self.config
373    }
374
375    /// ID of the shard.
376    pub const fn id(&self) -> ShardId {
377        self.id
378    }
379
380    /// Zlib decompressor statistics.
381    ///
382    /// Reset when reconnecting to the gateway.
383    #[allow(deprecated)]
384    #[cfg(all(
385        any(feature = "zlib-stock", feature = "zlib-simd"),
386        not(feature = "zstd")
387    ))]
388    #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
389    pub const fn inflater(&self) -> &Inflater {
390        &self.inflater
391    }
392
393    /// State of the shard.
394    pub const fn state(&self) -> ShardState {
395        self.state
396    }
397
398    /// Shard latency statistics, including average latency and recent heartbeat
399    /// latency times.
400    ///
401    /// Reset when reconnecting to the gateway.
402    pub const fn latency(&self) -> &Latency {
403        &self.latency
404    }
405
406    /// Statistics about the number of available commands and when the command
407    /// ratelimiter will refresh.
408    ///
409    /// This won't be present if ratelimiting was disabled via
410    /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
411    ///
412    /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
413    pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
414        self.ratelimiter.as_ref()
415    }
416
417    /// Immutable reference to the gateways current resume URL.
418    ///
419    /// A resume URL might not be present if the shard had its session
420    /// invalidated and has not yet reconnected.
421    pub fn resume_url(&self) -> Option<&str> {
422        self.resume_url.as_deref()
423    }
424
425    /// Immutable reference to the active gateway session.
426    ///
427    /// An active session may not be present if the shard had its session
428    /// invalidated and has not yet reconnected.
429    pub const fn session(&self) -> Option<&Session> {
430        self.session.as_ref()
431    }
432
433    /// Queue a command to be sent to the gateway.
434    ///
435    /// Serializes the command and then calls [`send`].
436    ///
437    /// [`send`]: Self::send
438    #[allow(clippy::missing_panics_doc)]
439    pub fn command(&self, command: &impl Command) {
440        self.send(json::to_string(command).expect("serialization cannot fail"));
441    }
442
443    /// Queue a JSON encoded gateway event to be sent to the gateway.
444    #[allow(clippy::missing_panics_doc)]
445    pub fn send(&self, json: String) {
446        self.user_channel
447            .command_tx
448            .send(json)
449            .expect("channel open");
450    }
451
452    /// Queue a websocket close frame.
453    ///
454    /// Invalidates the session and shows the application's bot as offline if
455    /// the close frame code is `1000` or `1001`. Otherwise Discord will
456    /// continue showing the bot as online until its presence times out.
457    ///
458    /// To read all remaining messages, continue calling [`poll_next`] until it
459    /// returns [`Message::Close`].
460    ///
461    /// # Example
462    ///
463    /// Close the shard and process remaining messages:
464    ///
465    /// ```no_run
466    /// # use twilight_gateway::{Intents, Shard, ShardId};
467    /// # #[tokio::main] async fn main() {
468    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
469    /// use tokio_stream::StreamExt;
470    /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
471    ///
472    /// shard.close(CloseFrame::NORMAL);
473    ///
474    /// while let Some(item) = shard.next().await {
475    ///     match item {
476    ///         Ok(Message::Close(_)) => break,
477    ///         Ok(Message::Text(_)) => unimplemented!(),
478    ///         Err(source) => unimplemented!(),
479    ///     }
480    /// }
481    /// # }
482    /// ```
483    ///
484    /// [`poll_next`]: Shard::poll_next
485    pub fn close(&self, close_frame: CloseFrame<'static>) {
486        _ = self.user_channel.close_tx.try_send(close_frame);
487    }
488
489    /// Retrieve a channel to send messages over the shard to the gateway.
490    ///
491    /// This is primarily useful for sending to other tasks and threads where
492    /// the shard won't be available.
493    ///
494    /// # Example
495    ///
496    /// Queue a command in another process:
497    ///
498    /// ```no_run
499    /// # use twilight_gateway::{Intents, Shard, ShardId};
500    /// # #[tokio::main] async fn main() {
501    /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
502    /// use tokio_stream::StreamExt;
503    ///
504    /// while let Some(item) = shard.next().await {
505    ///     match item {
506    ///         Ok(message) => {
507    ///             let sender = shard.sender();
508    ///             tokio::spawn(async move {
509    ///                 let command = unimplemented!();
510    ///                 sender.send(command);
511    ///             });
512    ///         }
513    ///         Err(source) => unimplemented!(),
514    ///     }
515    /// }
516    /// # }
517    /// ```
518    pub fn sender(&self) -> MessageSender {
519        self.user_channel.sender()
520    }
521
522    /// Update internal state from gateway disconnect.
523    fn disconnect(&mut self, initiator: CloseInitiator) {
524        // May not send any additional WebSocket messages.
525        self.heartbeat_interval = None;
526        self.ratelimiter = None;
527        // Abort identify.
528        self.identify_rx = None;
529        self.state = match initiator {
530            CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
531            _ => ShardState::Disconnected {
532                reconnect_attempts: 0,
533            },
534        };
535        if let CloseInitiator::Shard(frame) = initiator {
536            // Not resuming, drop session and resume URL.
537            // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
538            if matches!(frame.code, 1000 | 1001) {
539                self.resume_url = None;
540                self.session = None;
541            }
542            self.pending = Some(Pending {
543                gateway_event: Some(Message::Close(Some(frame))),
544                is_heartbeat: false,
545            });
546        }
547    }
548
549    /// Parse a JSON message into an event with minimal data for [processing].
550    ///
551    /// # Errors
552    ///
553    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
554    /// event isn't a recognized structure, which may be the case for new or
555    /// undocumented events.
556    ///
557    /// [processing]: Self::process
558    fn parse_event<T: DeserializeOwned>(
559        json: &str,
560    ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
561        json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
562            kind: ReceiveMessageErrorType::Deserializing {
563                event: json.to_owned(),
564            },
565            source: Some(Box::new(source)),
566        })
567    }
568}
569
570impl<Q: Queue> Shard<Q> {
571    /// Attempts to send due commands to the gateway.
572    ///
573    /// # Returns
574    ///
575    /// * `Poll::Pending` if sending is in progress
576    /// * `Poll::Ready(Ok)` if no more scheduled commands remain
577    /// * `Poll::Ready(Err)` if sending a command failed.
578    fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
579        loop {
580            if let Some(pending) = self.pending.as_mut() {
581                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
582
583                if let Some(message) = &pending.gateway_event {
584                    if let Some(ratelimiter) = self.ratelimiter.as_mut() {
585                        if message.is_text() && !pending.is_heartbeat {
586                            ready!(ratelimiter.poll_acquire(cx));
587                        }
588                    }
589
590                    let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
591                    Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
592                }
593
594                ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
595
596                if pending.is_heartbeat {
597                    self.latency.record_sent();
598                }
599                self.pending = None;
600            }
601
602            if !self.state.is_disconnected() {
603                if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
604                    let frame = frame.expect("shard owns channel");
605
606                    tracing::debug!("sending close frame from user channel");
607                    self.disconnect(CloseInitiator::Shard(frame));
608
609                    continue;
610                }
611            }
612
613            if self
614                .heartbeat_interval
615                .as_mut()
616                .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
617            {
618                // Discord never responded after the last heartbeat, connection
619                // is failed or "zombied", see
620                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
621                // Note that unlike documented *any* event is okay; it does not
622                // have to be a heartbeat ACK.
623                if self.latency.sent().is_some() && !self.heartbeat_interval_event {
624                    tracing::info!("connection is failed or \"zombied\"");
625                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
626                } else {
627                    tracing::debug!("sending heartbeat");
628                    self.pending = Pending::text(
629                        json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
630                            .expect("serialization cannot fail"),
631                        true,
632                    );
633                    self.heartbeat_interval_event = false;
634                }
635
636                continue;
637            }
638
639            let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
640                ratelimiter.poll_available(cx).is_ready()
641            });
642
643            if not_ratelimited {
644                if let Some(Poll::Ready(canceled)) = self
645                    .identify_rx
646                    .as_mut()
647                    .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
648                {
649                    if canceled {
650                        self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
651                        continue;
652                    }
653
654                    tracing::debug!("sending identify");
655
656                    self.pending = Pending::text(
657                        json::to_string(&Identify::new(IdentifyInfo {
658                            compress: false,
659                            intents: self.config.intents(),
660                            large_threshold: self.config.large_threshold(),
661                            presence: self.config.presence().cloned(),
662                            properties: self
663                                .config
664                                .identify_properties()
665                                .cloned()
666                                .unwrap_or_else(default_identify_properties),
667                            shard: Some(self.id),
668                            token: self.config.token().to_owned(),
669                        }))
670                        .expect("serialization cannot fail"),
671                        false,
672                    );
673                    self.identify_rx = None;
674
675                    continue;
676                }
677            }
678
679            if not_ratelimited && self.state.is_identified() {
680                if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
681                    let command = command.expect("shard owns channel");
682
683                    tracing::debug!("sending command from user channel");
684                    self.pending = Some(Pending {
685                        gateway_event: Some(Message::Text(command)),
686                        is_heartbeat: false,
687                    });
688
689                    continue;
690                }
691            }
692
693            return Poll::Ready(Ok(()));
694        }
695    }
696
697    /// Updates the shard's internal state from a gateway event by recording
698    /// and/or responding to certain Discord events.
699    ///
700    /// # Errors
701    ///
702    /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
703    /// gateway event isn't a recognized structure.
704    #[allow(clippy::too_many_lines)]
705    fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
706        let (raw_opcode, maybe_sequence, maybe_event_type) =
707            GatewayEventDeserializer::from_json(event)
708                .ok_or_else(|| ReceiveMessageError {
709                    kind: ReceiveMessageErrorType::Deserializing {
710                        event: event.to_owned(),
711                    },
712                    source: Some("missing opcode".into()),
713                })?
714                .into_parts();
715
716        if self.latency.sent().is_some() {
717            self.heartbeat_interval_event = true;
718        }
719
720        match OpCode::from(raw_opcode) {
721            Some(OpCode::Dispatch) => {
722                let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
723                    kind: ReceiveMessageErrorType::Deserializing {
724                        event: event.to_owned(),
725                    },
726                    source: Some("missing dispatch event type".into()),
727                })?;
728                let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
729                    kind: ReceiveMessageErrorType::Deserializing {
730                        event: event.to_owned(),
731                    },
732                    source: Some("missing sequence".into()),
733                })?;
734                tracing::debug!(%event_type, %sequence, "received dispatch");
735
736                match event_type.as_ref() {
737                    "READY" => {
738                        let event = Self::parse_event::<MinimalReady>(event)?;
739
740                        self.resume_url = Some(event.data.resume_gateway_url);
741                        self.session = Some(Session::new(sequence, event.data.session_id));
742                        self.state = ShardState::Active;
743                    }
744                    "RESUMED" => self.state = ShardState::Active,
745                    _ => {}
746                }
747
748                if let Some(session) = self.session.as_mut() {
749                    session.set_sequence(sequence);
750                }
751            }
752            Some(OpCode::Heartbeat) => {
753                tracing::debug!("received heartbeat");
754                self.pending = Pending::text(
755                    json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
756                        .expect("serialization cannot fail"),
757                    true,
758                );
759            }
760            Some(OpCode::HeartbeatAck) => {
761                let requested = self.latency.received().is_none() && self.latency.sent().is_some();
762                if requested {
763                    tracing::debug!("received heartbeat ack");
764                    self.latency.record_received();
765                } else {
766                    tracing::info!("received unrequested heartbeat ack");
767                }
768            }
769            Some(OpCode::Hello) => {
770                let event = Self::parse_event::<Hello>(event)?;
771                let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
772                // First heartbeat should have some jitter, see
773                // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
774                let jitter = heartbeat_interval.mul_f64(fastrand::f64());
775                tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
776
777                if self.config().ratelimit_messages() {
778                    self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
779                }
780
781                let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
782                interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
783                self.heartbeat_interval = Some(interval);
784
785                // Reset `Latency` since the shard might have connected to a new
786                // remote which invalidates the recorded latencies.
787                self.latency = Latency::new();
788
789                if let Some(session) = &self.session {
790                    self.pending = Pending::text(
791                        json::to_string(&Resume::new(
792                            session.sequence(),
793                            session.id(),
794                            self.config.token(),
795                        ))
796                        .expect("serialization cannot fail"),
797                        false,
798                    );
799                    self.state = ShardState::Resuming;
800                } else {
801                    self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
802                }
803            }
804            Some(OpCode::InvalidSession) => {
805                let resumable = Self::parse_event(event)?.data;
806                tracing::debug!(resumable, "received invalid session");
807                if resumable {
808                    self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
809                } else {
810                    self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
811                }
812            }
813            Some(OpCode::Reconnect) => {
814                tracing::debug!("received reconnect");
815                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
816            }
817            _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
818        }
819
820        Ok(())
821    }
822}
823
824impl<Q: Queue + Unpin> Stream for Shard<Q> {
825    type Item = Result<Message, ReceiveMessageError>;
826
827    #[allow(clippy::too_many_lines)]
828    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
829        let message = loop {
830            match self.state {
831                ShardState::FatallyClosed => {
832                    _ = ready!(Pin::new(
833                        self.connection
834                            .as_mut()
835                            .expect("poll_next called after Poll::Ready(None)")
836                    )
837                    .poll_close(cx));
838                    self.connection = None;
839                    return Poll::Ready(None);
840                }
841                ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
842                    if self.connection_future.is_none() {
843                        let base_url = self
844                            .resume_url
845                            .as_deref()
846                            .or_else(|| self.config.proxy_url())
847                            .unwrap_or(GATEWAY_URL);
848                        let uri = format!(
849                            "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
850                        );
851
852                        tracing::debug!(url = base_url, "connecting to gateway");
853
854                        let tls = self.config.tls.clone();
855                        self.connection_future = Some(ConnectionFuture(Box::pin(async move {
856                            let secs = 2u8.saturating_pow(reconnect_attempts.into());
857                            time::sleep(Duration::from_secs(secs.into())).await;
858
859                            Ok(ClientBuilder::new()
860                                .uri(&uri)
861                                .expect("URL should be valid")
862                                .limits(Limits::unlimited())
863                                .connector(&tls)
864                                .connect()
865                                .await?
866                                .0)
867                        })));
868                    }
869
870                    let res =
871                        ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
872                    self.connection_future = None;
873                    match res {
874                        Ok(connection) => {
875                            self.connection = Some(connection);
876                            self.state = ShardState::Identifying;
877                            #[cfg(feature = "zstd")]
878                            self.decompressor.reset();
879                            #[allow(deprecated)]
880                            #[cfg(all(
881                                not(feature = "zstd"),
882                                any(feature = "zlib-stock", feature = "zlib-simd")
883                            ))]
884                            self.inflater.reset();
885                        }
886                        Err(source) => {
887                            self.resume_url = None;
888                            self.state = ShardState::Disconnected {
889                                reconnect_attempts: reconnect_attempts + 1,
890                            };
891
892                            return Poll::Ready(Some(Err(ReceiveMessageError {
893                                kind: ReceiveMessageErrorType::Reconnect,
894                                source: Some(Box::new(source)),
895                            })));
896                        }
897                    }
898                }
899                _ => {}
900            }
901
902            if ready!(self.poll_send(cx)).is_err() {
903                self.disconnect(CloseInitiator::Transport);
904                self.connection = None;
905
906                return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
907            }
908
909            match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
910                Some(Ok(message)) => {
911                    #[cfg(feature = "zstd")]
912                    if message.is_binary() {
913                        match self.decompressor.decompress(message.as_payload()) {
914                            Ok(message) => break Message::Text(message),
915                            Err(source) => {
916                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
917                                return Poll::Ready(Some(Err(
918                                    ReceiveMessageError::from_compression(source),
919                                )));
920                            }
921                        }
922                    }
923                    #[cfg(all(
924                        not(feature = "zstd"),
925                        any(feature = "zlib-stock", feature = "zlib-simd")
926                    ))]
927                    if message.is_binary() {
928                        match self.inflater.inflate(message.as_payload()) {
929                            Ok(Some(message)) => break Message::Text(message),
930                            Ok(None) => continue,
931                            Err(source) => {
932                                self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
933                                return Poll::Ready(Some(Err(
934                                    ReceiveMessageError::from_compression(source),
935                                )));
936                            }
937                        }
938                    }
939                    if let Some(message) = Message::from_websocket_msg(&message) {
940                        break message;
941                    }
942                }
943                Some(Err(_)) if self.state.is_disconnected() => {}
944                Some(Err(_)) => {
945                    self.disconnect(CloseInitiator::Transport);
946                    return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
947                }
948                None => {
949                    _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
950                    tracing::debug!("gateway WebSocket connection closed");
951                    // Unclean closure.
952                    if !self.state.is_disconnected() {
953                        self.disconnect(CloseInitiator::Transport);
954                    }
955                    self.connection = None;
956                }
957            }
958        };
959
960        match &message {
961            Message::Close(frame) => {
962                // tokio-websockets automatically replies to the close message.
963                tracing::debug!(?frame, "received WebSocket close message");
964                // Don't run `disconnect` if we initiated the close.
965                if !self.state.is_disconnected() {
966                    self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
967                }
968            }
969            Message::Text(event) => {
970                self.process(event)?;
971            }
972        }
973
974        Poll::Ready(Some(Ok(message)))
975    }
976}
977
978/// Default identify properties to use when the user hasn't customized it in
979/// [`Config::identify_properties`].
980///
981/// [`Config::identify_properties`]: Config::identify_properties
982fn default_identify_properties() -> IdentifyProperties {
983    IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
984}
985
986#[cfg(test)]
987mod tests {
988    use super::Shard;
989    use static_assertions::{assert_impl_all, assert_not_impl_any};
990    use std::fmt::Debug;
991
992    assert_impl_all!(Shard: Debug, Send);
993    assert_not_impl_any!(Shard: Sync);
994}