twilight_gateway/shard.rs
1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13 any(feature = "zlib-stock", feature = "zlib-simd"),
14 not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18 channel::{MessageChannel, MessageSender},
19 error::{ReceiveMessageError, ReceiveMessageErrorType},
20 json,
21 latency::Latency,
22 queue::{InMemoryQueue, Queue},
23 ratelimiter::CommandRatelimiter,
24 session::Session,
25 Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30#[cfg(any(
31 feature = "native-tls",
32 feature = "rustls-native-roots",
33 feature = "rustls-platform-verifier",
34 feature = "rustls-webpki-roots"
35))]
36use std::io::ErrorKind as IoErrorKind;
37use std::{
38 env::consts::OS,
39 fmt,
40 future::Future,
41 pin::Pin,
42 str,
43 task::{ready, Context, Poll},
44};
45use tokio::{
46 net::TcpStream,
47 sync::oneshot,
48 time::{self, Duration, Instant, Interval, MissedTickBehavior},
49};
50use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
51use twilight_model::gateway::{
52 event::GatewayEventDeserializer,
53 payload::{
54 incoming::Hello,
55 outgoing::{
56 identify::{IdentifyInfo, IdentifyProperties},
57 Heartbeat, Identify, Resume,
58 },
59 },
60 CloseCode, CloseFrame, Intents, OpCode,
61};
62
63/// URL of the Discord gateway.
64const GATEWAY_URL: &str = "wss://gateway.discord.gg";
65
66/// Query argument depending on enabled compression features.
67const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
68 "&compress=zstd-stream"
69} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
70 "&compress=zlib-stream"
71} else {
72 ""
73};
74
75/// [`tokio_websockets`] library Websocket connection.
76type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
77
78/// Wrapper struct around an `async fn` with a `Debug` implementation.
79struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
80
81impl fmt::Debug for ConnectionFuture {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 f.debug_tuple("ConnectionFuture")
84 .field(&"<async fn>")
85 .finish()
86 }
87}
88
89/// Close initiator of a websocket connection.
90#[derive(Clone, Debug)]
91enum CloseInitiator {
92 /// Gateway initiated the close.
93 ///
94 /// Contains an optional close code.
95 Gateway(Option<u16>),
96 /// Shard initiated the close.
97 ///
98 /// Contains a close code.
99 Shard(CloseFrame<'static>),
100 /// Transport error initiated the close.
101 Transport,
102}
103
104/// Current state of a [Shard].
105#[derive(Clone, Copy, Debug, Eq, PartialEq)]
106pub enum ShardState {
107 /// Shard is connected to the gateway with an active session.
108 Active,
109 /// Shard is disconnected from the gateway but may reconnect in the future.
110 ///
111 /// The websocket connection may still be open.
112 Disconnected {
113 /// Number of reconnection attempts that have been made.
114 reconnect_attempts: u8,
115 },
116 /// Shard has fatally closed.
117 ///
118 /// Possible reasons may be due to [failed authentication],
119 /// [invalid intents], or other reasons. Refer to the documentation for
120 /// [`CloseCode`] for possible reasons.
121 ///
122 /// [failed authentication]: CloseCode::AuthenticationFailed
123 /// [invalid intents]: CloseCode::InvalidIntents
124 FatallyClosed,
125 /// Shard is waiting to establish or resume a session.
126 Identifying,
127 /// Shard is replaying missed dispatch events.
128 ///
129 /// The shard is considered identified whilst resuming.
130 Resuming,
131}
132
133impl ShardState {
134 /// Determine the connection status from the close code.
135 ///
136 /// Defers to [`CloseCode::can_reconnect`] to determine whether the
137 /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
138 /// the close code is unknown.
139 fn from_close_code(close_code: Option<u16>) -> Self {
140 match close_code.map(CloseCode::try_from) {
141 Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
142 _ => Self::Disconnected {
143 reconnect_attempts: 0,
144 },
145 }
146 }
147
148 /// Whether the shard has disconnected but may reconnect in the future.
149 const fn is_disconnected(self) -> bool {
150 matches!(self, Self::Disconnected { .. })
151 }
152
153 /// Whether the shard is identified with an active session.
154 ///
155 /// `true` if the status is [`Active`] or [`Resuming`].
156 ///
157 /// [`Active`]: Self::Active
158 /// [`Resuming`]: Self::Resuming
159 pub const fn is_identified(self) -> bool {
160 matches!(self, Self::Active | Self::Resuming)
161 }
162}
163
164/// Gateway event with only minimal required data.
165#[derive(Deserialize)]
166struct MinimalEvent<T> {
167 /// Attached data of the gateway event.
168 #[serde(rename = "d")]
169 data: T,
170}
171
172/// Minimal [`Ready`] for light deserialization.
173///
174/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
175#[derive(Deserialize)]
176struct MinimalReady {
177 /// Used for resuming connections.
178 resume_gateway_url: Box<str>,
179 /// ID of the new identified session.
180 session_id: String,
181}
182
183/// Pending outgoing message indicator.
184#[derive(Debug)]
185struct Pending {
186 /// The pending message, if not already sent.
187 gateway_event: Option<Message>,
188 /// Whether the pending gateway event is a heartbeat.
189 is_heartbeat: bool,
190}
191
192impl Pending {
193 /// Constructor for a pending gateway event.
194 const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
195 Some(Self {
196 gateway_event: Some(Message::Text(json)),
197 is_heartbeat,
198 })
199 }
200}
201
202/// Gateway API client responsible for up to 2500 guilds.
203///
204/// Shards are responsible for maintaining the gateway connection by processing
205/// events relevant to the operation of shards---such as requests from the
206/// gateway to re-connect or invalidate a session---and then to pass them on to
207/// the user.
208///
209/// Shards start out disconnected, but will on the first successful call to
210/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
211/// be repeatedly called in order for the shard to maintain its connection and
212/// update its internal state.
213///
214/// Shards go through an [identify queue][`queue`] that rate limits concurrent
215/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
216/// invalidates the shard's session and it is therefore **very important** to
217/// reuse the same queue for all shards.
218///
219/// # Sharding
220///
221/// A shard may not be connected to more than 2500 guilds, so large bots must
222/// split themselves across multiple shards. See the
223/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
224/// more info.
225///
226/// # Examples
227///
228/// Create and start a shard and print new and deleted messages:
229///
230/// ```no_run
231/// use std::env;
232/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
233///
234/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
235/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
236/// // token. Of course, this value may be passed into the program however is
237/// // preferred.
238/// let token = env::var("DISCORD_TOKEN")?;
239/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
240///
241/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
242///
243/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
244/// let Ok(event) = item else {
245/// tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
246///
247/// continue;
248/// };
249///
250/// match event {
251/// Event::MessageCreate(message) => {
252/// println!("message received with content: {}", message.content);
253/// }
254/// Event::MessageDelete(message) => {
255/// println!("message with ID {} deleted", message.id);
256/// }
257/// _ => {}
258/// }
259/// }
260/// # Ok(()) }
261/// ```
262///
263/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
264/// [gateway commands]: Shard::command
265/// [`poll_next`]: Shard::poll_next
266/// [`queue`]: crate::queue
267#[derive(Debug)]
268pub struct Shard<Q = InMemoryQueue> {
269 /// User provided configuration.
270 ///
271 /// Configurations are provided or created in shard initializing via
272 /// [`Shard::new`] or [`Shard::with_config`].
273 config: Config<Q>,
274 /// Future to establish a WebSocket connection with the Gateway.
275 connection_future: Option<ConnectionFuture>,
276 /// Websocket connection, which may be connected to Discord's gateway.
277 ///
278 /// The connection should only be dropped after it has returned `Ok(None)`
279 /// to comply with the WebSocket protocol.
280 connection: Option<Connection>,
281 /// Zstd decompressor.
282 #[cfg(feature = "zstd")]
283 decompressor: Decompressor,
284 /// Interval of how often the gateway would like the shard to send
285 /// heartbeats.
286 ///
287 /// The interval is received in the [`GatewayEvent::Hello`] event when
288 /// first opening a new [connection].
289 ///
290 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
291 /// [connection]: Self::connection
292 heartbeat_interval: Option<Interval>,
293 /// Whether an event has been received in the current heartbeat interval.
294 heartbeat_interval_event: bool,
295 /// ID of the shard.
296 id: ShardId,
297 /// Identify queue receiver.
298 identify_rx: Option<oneshot::Receiver<()>>,
299 /// Zlib decompressor.
300 #[allow(deprecated)]
301 #[cfg(all(
302 any(feature = "zlib-stock", feature = "zlib-simd"),
303 not(feature = "zstd")
304 ))]
305 inflater: Inflater,
306 /// Potentially pending outgoing message.
307 pending: Option<Pending>,
308 /// Recent heartbeat latency statistics.
309 ///
310 /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
311 /// may have changed, invalidating previous latency statistic.
312 ///
313 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
314 latency: Latency,
315 /// Command ratelimiter, if it was enabled via
316 /// [`Config::ratelimit_messages`].
317 ratelimiter: Option<CommandRatelimiter>,
318 /// Used for resuming connections.
319 resume_url: Option<Box<str>>,
320 /// Active session of the shard.
321 ///
322 /// The shard may not have an active session if it hasn't yet identified and
323 /// received a `READY` dispatch event response.
324 session: Option<Session>,
325 /// Current state of the shard.
326 state: ShardState,
327 /// Messages from the user to be relayed and sent over the Websocket
328 /// connection.
329 user_channel: MessageChannel,
330}
331
332impl Shard {
333 /// Create a new shard with the default configuration.
334 pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
335 Self::with_config(id, Config::new(token, intents))
336 }
337}
338
339impl<Q> Shard<Q> {
340 /// Create a new shard with the provided configuration.
341 pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
342 let session = config.take_session();
343 let mut resume_url = config.take_resume_url();
344 //ensure resume_url is only used if we have a session to resume
345 if session.is_none() {
346 resume_url = None;
347 }
348
349 Self {
350 config,
351 connection_future: None,
352 connection: None,
353 #[cfg(feature = "zstd")]
354 decompressor: Decompressor::new(),
355 heartbeat_interval: None,
356 heartbeat_interval_event: false,
357 id: shard_id,
358 identify_rx: None,
359 #[allow(deprecated)]
360 #[cfg(all(
361 any(feature = "zlib-stock", feature = "zlib-simd"),
362 not(feature = "zstd")
363 ))]
364 inflater: Inflater::new(),
365 pending: None,
366 latency: Latency::new(),
367 ratelimiter: None,
368 resume_url,
369 session,
370 state: ShardState::Disconnected {
371 reconnect_attempts: 0,
372 },
373 user_channel: MessageChannel::new(),
374 }
375 }
376
377 /// Immutable reference to the configuration used to instantiate this shard.
378 pub const fn config(&self) -> &Config<Q> {
379 &self.config
380 }
381
382 /// ID of the shard.
383 pub const fn id(&self) -> ShardId {
384 self.id
385 }
386
387 /// Zlib decompressor statistics.
388 ///
389 /// Reset when reconnecting to the gateway.
390 #[allow(deprecated)]
391 #[cfg(all(
392 any(feature = "zlib-stock", feature = "zlib-simd"),
393 not(feature = "zstd")
394 ))]
395 #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
396 pub const fn inflater(&self) -> &Inflater {
397 &self.inflater
398 }
399
400 /// State of the shard.
401 pub const fn state(&self) -> ShardState {
402 self.state
403 }
404
405 /// Shard latency statistics, including average latency and recent heartbeat
406 /// latency times.
407 ///
408 /// Reset when reconnecting to the gateway.
409 pub const fn latency(&self) -> &Latency {
410 &self.latency
411 }
412
413 /// Statistics about the number of available commands and when the command
414 /// ratelimiter will refresh.
415 ///
416 /// This won't be present if ratelimiting was disabled via
417 /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
418 ///
419 /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
420 pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
421 self.ratelimiter.as_ref()
422 }
423
424 /// Immutable reference to the gateways current resume URL.
425 ///
426 /// A resume URL might not be present if the shard had its session
427 /// invalidated and has not yet reconnected.
428 pub fn resume_url(&self) -> Option<&str> {
429 self.resume_url.as_deref()
430 }
431
432 /// Immutable reference to the active gateway session.
433 ///
434 /// An active session may not be present if the shard had its session
435 /// invalidated and has not yet reconnected.
436 pub const fn session(&self) -> Option<&Session> {
437 self.session.as_ref()
438 }
439
440 /// Queue a command to be sent to the gateway.
441 ///
442 /// Serializes the command and then calls [`send`].
443 ///
444 /// [`send`]: Self::send
445 #[allow(clippy::missing_panics_doc)]
446 pub fn command(&self, command: &impl Command) {
447 self.send(json::to_string(command).expect("serialization cannot fail"));
448 }
449
450 /// Queue a JSON encoded gateway event to be sent to the gateway.
451 #[allow(clippy::missing_panics_doc)]
452 pub fn send(&self, json: String) {
453 self.user_channel
454 .command_tx
455 .send(json)
456 .expect("channel open");
457 }
458
459 /// Queue a websocket close frame.
460 ///
461 /// Invalidates the session and shows the application's bot as offline if
462 /// the close frame code is `1000` or `1001`. Otherwise Discord will
463 /// continue showing the bot as online until its presence times out.
464 ///
465 /// To read all remaining messages, continue calling [`poll_next`] until it
466 /// returns [`Message::Close`].
467 ///
468 /// # Example
469 ///
470 /// Close the shard and process remaining messages:
471 ///
472 /// ```no_run
473 /// # use twilight_gateway::{Intents, Shard, ShardId};
474 /// # #[tokio::main] async fn main() {
475 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
476 /// use tokio_stream::StreamExt;
477 /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
478 ///
479 /// shard.close(CloseFrame::NORMAL);
480 ///
481 /// while let Some(item) = shard.next().await {
482 /// match item {
483 /// Ok(Message::Close(_)) => break,
484 /// Ok(Message::Text(_)) => unimplemented!(),
485 /// Err(source) => unimplemented!(),
486 /// }
487 /// }
488 /// # }
489 /// ```
490 ///
491 /// [`poll_next`]: Shard::poll_next
492 pub fn close(&self, close_frame: CloseFrame<'static>) {
493 _ = self.user_channel.close_tx.try_send(close_frame);
494 }
495
496 /// Retrieve a channel to send messages over the shard to the gateway.
497 ///
498 /// This is primarily useful for sending to other tasks and threads where
499 /// the shard won't be available.
500 ///
501 /// # Example
502 ///
503 /// Queue a command in another process:
504 ///
505 /// ```no_run
506 /// # use twilight_gateway::{Intents, Shard, ShardId};
507 /// # #[tokio::main] async fn main() {
508 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
509 /// use tokio_stream::StreamExt;
510 ///
511 /// while let Some(item) = shard.next().await {
512 /// match item {
513 /// Ok(message) => {
514 /// let sender = shard.sender();
515 /// tokio::spawn(async move {
516 /// let command = unimplemented!();
517 /// sender.send(command);
518 /// });
519 /// }
520 /// Err(source) => unimplemented!(),
521 /// }
522 /// }
523 /// # }
524 /// ```
525 pub fn sender(&self) -> MessageSender {
526 self.user_channel.sender()
527 }
528
529 /// Update internal state from gateway disconnect.
530 fn disconnect(&mut self, initiator: CloseInitiator) {
531 // May not send any additional WebSocket messages.
532 self.heartbeat_interval = None;
533 self.ratelimiter = None;
534 // Abort identify.
535 self.identify_rx = None;
536 self.state = match initiator {
537 CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
538 _ => ShardState::Disconnected {
539 reconnect_attempts: 0,
540 },
541 };
542 if let CloseInitiator::Shard(frame) = initiator {
543 // Not resuming, drop session and resume URL.
544 // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
545 if matches!(frame.code, 1000 | 1001) {
546 self.resume_url = None;
547 self.session = None;
548 }
549 self.pending = Some(Pending {
550 gateway_event: Some(Message::Close(Some(frame))),
551 is_heartbeat: false,
552 });
553 }
554 }
555
556 /// Parse a JSON message into an event with minimal data for [processing].
557 ///
558 /// # Errors
559 ///
560 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
561 /// event isn't a recognized structure, which may be the case for new or
562 /// undocumented events.
563 ///
564 /// [processing]: Self::process
565 fn parse_event<T: DeserializeOwned>(
566 json: &str,
567 ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
568 json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
569 kind: ReceiveMessageErrorType::Deserializing {
570 event: json.to_owned(),
571 },
572 source: Some(Box::new(source)),
573 })
574 }
575}
576
577impl<Q: Queue> Shard<Q> {
578 /// Attempts to send due commands to the gateway.
579 ///
580 /// # Returns
581 ///
582 /// * `Poll::Pending` if sending is in progress
583 /// * `Poll::Ready(Ok)` if no more scheduled commands remain
584 /// * `Poll::Ready(Err)` if sending a command failed.
585 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
586 loop {
587 if let Some(pending) = self.pending.as_mut() {
588 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
589
590 if let Some(message) = &pending.gateway_event {
591 if let Some(ratelimiter) = self.ratelimiter.as_mut() {
592 if message.is_text() && !pending.is_heartbeat {
593 ready!(ratelimiter.poll_acquire(cx));
594 }
595 }
596
597 let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
598 Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
599 }
600
601 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
602
603 if pending.is_heartbeat {
604 self.latency.record_sent();
605 }
606 self.pending = None;
607 }
608
609 if !self.state.is_disconnected() {
610 if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
611 let frame = frame.expect("shard owns channel");
612
613 tracing::debug!("sending close frame from user channel");
614 self.disconnect(CloseInitiator::Shard(frame));
615
616 continue;
617 }
618 }
619
620 if self
621 .heartbeat_interval
622 .as_mut()
623 .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
624 {
625 // Discord never responded after the last heartbeat, connection
626 // is failed or "zombied", see
627 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
628 // Note that unlike documented *any* event is okay; it does not
629 // have to be a heartbeat ACK.
630 if self.latency.sent().is_some() && !self.heartbeat_interval_event {
631 tracing::info!("connection is failed or \"zombied\"");
632 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
633 } else {
634 tracing::debug!("sending heartbeat");
635 self.pending = Pending::text(
636 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
637 .expect("serialization cannot fail"),
638 true,
639 );
640 self.heartbeat_interval_event = false;
641 }
642
643 continue;
644 }
645
646 let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
647 ratelimiter.poll_available(cx).is_ready()
648 });
649
650 if not_ratelimited {
651 if let Some(Poll::Ready(canceled)) = self
652 .identify_rx
653 .as_mut()
654 .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
655 {
656 if canceled {
657 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
658 continue;
659 }
660
661 tracing::debug!("sending identify");
662
663 self.pending = Pending::text(
664 json::to_string(&Identify::new(IdentifyInfo {
665 compress: false,
666 intents: self.config.intents(),
667 large_threshold: self.config.large_threshold(),
668 presence: self.config.presence().cloned(),
669 properties: self
670 .config
671 .identify_properties()
672 .cloned()
673 .unwrap_or_else(default_identify_properties),
674 shard: Some(self.id),
675 token: self.config.token().to_owned(),
676 }))
677 .expect("serialization cannot fail"),
678 false,
679 );
680 self.identify_rx = None;
681
682 continue;
683 }
684 }
685
686 if not_ratelimited && self.state.is_identified() {
687 if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
688 let command = command.expect("shard owns channel");
689
690 tracing::debug!("sending command from user channel");
691 self.pending = Some(Pending {
692 gateway_event: Some(Message::Text(command)),
693 is_heartbeat: false,
694 });
695
696 continue;
697 }
698 }
699
700 return Poll::Ready(Ok(()));
701 }
702 }
703
704 /// Updates the shard's internal state from a gateway event by recording
705 /// and/or responding to certain Discord events.
706 ///
707 /// # Errors
708 ///
709 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
710 /// gateway event isn't a recognized structure.
711 #[allow(clippy::too_many_lines)]
712 fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
713 let (raw_opcode, maybe_sequence, maybe_event_type) =
714 GatewayEventDeserializer::from_json(event)
715 .ok_or_else(|| ReceiveMessageError {
716 kind: ReceiveMessageErrorType::Deserializing {
717 event: event.to_owned(),
718 },
719 source: Some("missing opcode".into()),
720 })?
721 .into_parts();
722
723 if self.latency.sent().is_some() {
724 self.heartbeat_interval_event = true;
725 }
726
727 match OpCode::from(raw_opcode) {
728 Some(OpCode::Dispatch) => {
729 let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
730 kind: ReceiveMessageErrorType::Deserializing {
731 event: event.to_owned(),
732 },
733 source: Some("missing dispatch event type".into()),
734 })?;
735 let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
736 kind: ReceiveMessageErrorType::Deserializing {
737 event: event.to_owned(),
738 },
739 source: Some("missing sequence".into()),
740 })?;
741 tracing::debug!(%event_type, %sequence, "received dispatch");
742
743 match event_type.as_ref() {
744 "READY" => {
745 let event = Self::parse_event::<MinimalReady>(event)?;
746
747 self.resume_url = Some(event.data.resume_gateway_url);
748 self.session = Some(Session::new(sequence, event.data.session_id));
749 self.state = ShardState::Active;
750 }
751 "RESUMED" => self.state = ShardState::Active,
752 _ => {}
753 }
754
755 if let Some(session) = self.session.as_mut() {
756 session.set_sequence(sequence);
757 }
758 }
759 Some(OpCode::Heartbeat) => {
760 tracing::debug!("received heartbeat");
761 self.pending = Pending::text(
762 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
763 .expect("serialization cannot fail"),
764 true,
765 );
766 }
767 Some(OpCode::HeartbeatAck) => {
768 let requested = self.latency.received().is_none() && self.latency.sent().is_some();
769 if requested {
770 tracing::debug!("received heartbeat ack");
771 self.latency.record_received();
772 } else {
773 tracing::info!("received unrequested heartbeat ack");
774 }
775 }
776 Some(OpCode::Hello) => {
777 let event = Self::parse_event::<Hello>(event)?;
778 let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
779 // First heartbeat should have some jitter, see
780 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
781 let jitter = heartbeat_interval.mul_f64(fastrand::f64());
782 tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
783
784 if self.config().ratelimit_messages() {
785 self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
786 }
787
788 let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
789 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
790 self.heartbeat_interval = Some(interval);
791
792 // Reset `Latency` since the shard might have connected to a new
793 // remote which invalidates the recorded latencies.
794 self.latency = Latency::new();
795
796 if let Some(session) = &self.session {
797 self.pending = Pending::text(
798 json::to_string(&Resume::new(
799 session.sequence(),
800 session.id(),
801 self.config.token(),
802 ))
803 .expect("serialization cannot fail"),
804 false,
805 );
806 self.state = ShardState::Resuming;
807 } else {
808 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
809 }
810 }
811 Some(OpCode::InvalidSession) => {
812 let resumable = Self::parse_event(event)?.data;
813 tracing::debug!(resumable, "received invalid session");
814 if resumable {
815 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
816 } else {
817 self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
818 }
819 }
820 Some(OpCode::Reconnect) => {
821 tracing::debug!("received reconnect");
822 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
823 }
824 _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
825 }
826
827 Ok(())
828 }
829}
830
831impl<Q: Queue + Unpin> Stream for Shard<Q> {
832 type Item = Result<Message, ReceiveMessageError>;
833
834 #[allow(clippy::too_many_lines)]
835 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
836 let message = loop {
837 match self.state {
838 ShardState::FatallyClosed => {
839 _ = ready!(Pin::new(
840 self.connection
841 .as_mut()
842 .expect("poll_next called after Poll::Ready(None)")
843 )
844 .poll_close(cx));
845 self.connection = None;
846 return Poll::Ready(None);
847 }
848 ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
849 if self.connection_future.is_none() {
850 let base_url = self
851 .resume_url
852 .as_deref()
853 .or_else(|| self.config.proxy_url())
854 .unwrap_or(GATEWAY_URL);
855 let uri = format!(
856 "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
857 );
858
859 tracing::debug!(url = base_url, "connecting to gateway");
860
861 let tls = self.config.tls.clone();
862 self.connection_future = Some(ConnectionFuture(Box::pin(async move {
863 let secs = 2u8.saturating_pow(reconnect_attempts.into());
864 time::sleep(Duration::from_secs(secs.into())).await;
865
866 Ok(ClientBuilder::new()
867 .uri(&uri)
868 .expect("URL should be valid")
869 .limits(Limits::unlimited())
870 .connector(&tls)
871 .connect()
872 .await?
873 .0)
874 })));
875 }
876
877 let res =
878 ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
879 self.connection_future = None;
880 match res {
881 Ok(connection) => {
882 self.connection = Some(connection);
883 self.state = ShardState::Identifying;
884 #[cfg(feature = "zstd")]
885 self.decompressor.reset();
886 #[allow(deprecated)]
887 #[cfg(all(
888 not(feature = "zstd"),
889 any(feature = "zlib-stock", feature = "zlib-simd")
890 ))]
891 self.inflater.reset();
892 }
893 Err(source) => {
894 self.resume_url = None;
895 self.state = ShardState::Disconnected {
896 reconnect_attempts: reconnect_attempts + 1,
897 };
898
899 return Poll::Ready(Some(Err(ReceiveMessageError {
900 kind: ReceiveMessageErrorType::Reconnect,
901 source: Some(Box::new(source)),
902 })));
903 }
904 }
905 }
906 _ => {}
907 }
908
909 if ready!(self.poll_send(cx)).is_err() {
910 self.disconnect(CloseInitiator::Transport);
911 self.connection = None;
912
913 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
914 }
915
916 match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
917 Some(Ok(message)) => {
918 #[cfg(feature = "zstd")]
919 if message.is_binary() {
920 match self.decompressor.decompress(message.as_payload()) {
921 Ok(message) => break Message::Text(message),
922 Err(source) => {
923 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
924 return Poll::Ready(Some(Err(
925 ReceiveMessageError::from_compression(source),
926 )));
927 }
928 }
929 }
930 #[cfg(all(
931 not(feature = "zstd"),
932 any(feature = "zlib-stock", feature = "zlib-simd")
933 ))]
934 if message.is_binary() {
935 match self.inflater.inflate(message.as_payload()) {
936 Ok(Some(message)) => break Message::Text(message),
937 Ok(None) => continue,
938 Err(source) => {
939 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
940 return Poll::Ready(Some(Err(
941 ReceiveMessageError::from_compression(source),
942 )));
943 }
944 }
945 }
946 if let Some(message) = Message::from_websocket_msg(&message) {
947 break message;
948 }
949 }
950 // Discord, against recommendations from the WebSocket spec,
951 // does not send a close_notify prior to shutting down the TCP
952 // stream. This arm tries to gracefully handle this. The
953 // connection is considered unusable after encountering an io
954 // error, returning `None`.
955 #[cfg(any(
956 feature = "native-tls",
957 feature = "rustls-native-roots",
958 feature = "rustls-platform-verifier",
959 feature = "rustls-webpki-roots"
960 ))]
961 Some(Err(WebsocketError::Io(e)))
962 if e.kind() == IoErrorKind::UnexpectedEof
963 && self.config.proxy_url().is_none()
964 && self.state.is_disconnected() =>
965 {
966 continue
967 }
968 Some(Err(_)) => {
969 self.disconnect(CloseInitiator::Transport);
970 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
971 }
972 None => {
973 _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
974 tracing::debug!("gateway WebSocket connection closed");
975 // Unclean closure.
976 if !self.state.is_disconnected() {
977 self.disconnect(CloseInitiator::Transport);
978 }
979 self.connection = None;
980 }
981 }
982 };
983
984 match &message {
985 Message::Close(frame) => {
986 // tokio-websockets automatically replies to the close message.
987 tracing::debug!(?frame, "received WebSocket close message");
988 // Don't run `disconnect` if we initiated the close.
989 if !self.state.is_disconnected() {
990 self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
991 }
992 }
993 Message::Text(event) => {
994 self.process(event)?;
995 }
996 }
997
998 Poll::Ready(Some(Ok(message)))
999 }
1000}
1001
1002/// Default identify properties to use when the user hasn't customized it in
1003/// [`Config::identify_properties`].
1004///
1005/// [`Config::identify_properties`]: Config::identify_properties
1006fn default_identify_properties() -> IdentifyProperties {
1007 IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012 use super::Shard;
1013 use static_assertions::{assert_impl_all, assert_not_impl_any};
1014 use std::fmt::Debug;
1015
1016 assert_impl_all!(Shard: Debug, Send);
1017 assert_not_impl_any!(Shard: Sync);
1018}