twilight_gateway/shard.rs
1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13 any(feature = "zlib-stock", feature = "zlib-simd"),
14 not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18 channel::{MessageChannel, MessageSender},
19 error::{ReceiveMessageError, ReceiveMessageErrorType},
20 json,
21 latency::Latency,
22 queue::{InMemoryQueue, Queue},
23 ratelimiter::CommandRatelimiter,
24 session::Session,
25 Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30use std::{
31 env::consts::OS,
32 fmt,
33 future::Future,
34 pin::Pin,
35 str,
36 task::{ready, Context, Poll},
37};
38use tokio::{
39 net::TcpStream,
40 sync::oneshot,
41 time::{self, Duration, Instant, Interval, MissedTickBehavior},
42};
43use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
44use twilight_model::gateway::{
45 event::GatewayEventDeserializer,
46 payload::{
47 incoming::Hello,
48 outgoing::{
49 identify::{IdentifyInfo, IdentifyProperties},
50 Heartbeat, Identify, Resume,
51 },
52 },
53 CloseCode, CloseFrame, Intents, OpCode,
54};
55
56/// URL of the Discord gateway.
57const GATEWAY_URL: &str = "wss://gateway.discord.gg";
58
59/// Query argument depending on enabled compression features.
60const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
61 "&compress=zstd-stream"
62} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
63 "&compress=zlib-stream"
64} else {
65 ""
66};
67
68/// [`tokio_websockets`] library Websocket connection.
69type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
70
71/// Wrapper struct around an `async fn` with a `Debug` implementation.
72struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
73
74impl fmt::Debug for ConnectionFuture {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 f.debug_tuple("ConnectionFuture")
77 .field(&"<async fn>")
78 .finish()
79 }
80}
81
82/// Close initiator of a websocket connection.
83#[derive(Clone, Debug)]
84enum CloseInitiator {
85 /// Gateway initiated the close.
86 ///
87 /// Contains an optional close code.
88 Gateway(Option<u16>),
89 /// Shard initiated the close.
90 ///
91 /// Contains a close code.
92 Shard(CloseFrame<'static>),
93 /// Transport error initiated the close.
94 Transport,
95}
96
97/// Current state of a [Shard].
98#[derive(Clone, Copy, Debug, Eq, PartialEq)]
99pub enum ShardState {
100 /// Shard is connected to the gateway with an active session.
101 Active,
102 /// Shard is disconnected from the gateway but may reconnect in the future.
103 ///
104 /// The websocket connection may still be open.
105 Disconnected {
106 /// Number of reconnection attempts that have been made.
107 reconnect_attempts: u8,
108 },
109 /// Shard has fatally closed.
110 ///
111 /// Possible reasons may be due to [failed authentication],
112 /// [invalid intents], or other reasons. Refer to the documentation for
113 /// [`CloseCode`] for possible reasons.
114 ///
115 /// [failed authentication]: CloseCode::AuthenticationFailed
116 /// [invalid intents]: CloseCode::InvalidIntents
117 FatallyClosed,
118 /// Shard is waiting to establish or resume a session.
119 Identifying,
120 /// Shard is replaying missed dispatch events.
121 ///
122 /// The shard is considered identified whilst resuming.
123 Resuming,
124}
125
126impl ShardState {
127 /// Determine the connection status from the close code.
128 ///
129 /// Defers to [`CloseCode::can_reconnect`] to determine whether the
130 /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
131 /// the close code is unknown.
132 fn from_close_code(close_code: Option<u16>) -> Self {
133 match close_code.map(CloseCode::try_from) {
134 Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
135 _ => Self::Disconnected {
136 reconnect_attempts: 0,
137 },
138 }
139 }
140
141 /// Whether the shard has disconnected but may reconnect in the future.
142 const fn is_disconnected(self) -> bool {
143 matches!(self, Self::Disconnected { .. })
144 }
145
146 /// Whether the shard is identified with an active session.
147 ///
148 /// `true` if the status is [`Active`] or [`Resuming`].
149 ///
150 /// [`Active`]: Self::Active
151 /// [`Resuming`]: Self::Resuming
152 pub const fn is_identified(self) -> bool {
153 matches!(self, Self::Active | Self::Resuming)
154 }
155}
156
157/// Gateway event with only minimal required data.
158#[derive(Deserialize)]
159struct MinimalEvent<T> {
160 /// Attached data of the gateway event.
161 #[serde(rename = "d")]
162 data: T,
163}
164
165/// Minimal [`Ready`] for light deserialization.
166///
167/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
168#[derive(Deserialize)]
169struct MinimalReady {
170 /// Used for resuming connections.
171 resume_gateway_url: Box<str>,
172 /// ID of the new identified session.
173 session_id: String,
174}
175
176/// Pending outgoing message indicator.
177#[derive(Debug)]
178struct Pending {
179 /// The pending message, if not already sent.
180 gateway_event: Option<Message>,
181 /// Whether the pending gateway event is a heartbeat.
182 is_heartbeat: bool,
183}
184
185impl Pending {
186 /// Constructor for a pending gateway event.
187 const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
188 Some(Self {
189 gateway_event: Some(Message::Text(json)),
190 is_heartbeat,
191 })
192 }
193}
194
195/// Gateway API client responsible for up to 2500 guilds.
196///
197/// Shards are responsible for maintaining the gateway connection by processing
198/// events relevant to the operation of shards---such as requests from the
199/// gateway to re-connect or invalidate a session---and then to pass them on to
200/// the user.
201///
202/// Shards start out disconnected, but will on the first successful call to
203/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
204/// be repeatedly called in order for the shard to maintain its connection and
205/// update its internal state.
206///
207/// Shards go through an [identify queue][`queue`] that rate limits concurrent
208/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
209/// invalidates the shard's session and it is therefore **very important** to
210/// reuse the same queue for all shards.
211///
212/// # Sharding
213///
214/// A shard may not be connected to more than 2500 guilds, so large bots must
215/// split themselves across multiple shards. See the
216/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
217/// more info.
218///
219/// # Examples
220///
221/// Create and start a shard and print new and deleted messages:
222///
223/// ```no_run
224/// use std::env;
225/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
226///
227/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
228/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
229/// // token. Of course, this value may be passed into the program however is
230/// // preferred.
231/// let token = env::var("DISCORD_TOKEN")?;
232/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
233///
234/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
235///
236/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
237/// let Ok(event) = item else {
238/// tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
239///
240/// continue;
241/// };
242///
243/// match event {
244/// Event::MessageCreate(message) => {
245/// println!("message received with content: {}", message.content);
246/// }
247/// Event::MessageDelete(message) => {
248/// println!("message with ID {} deleted", message.id);
249/// }
250/// _ => {}
251/// }
252/// }
253/// # Ok(()) }
254/// ```
255///
256/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
257/// [gateway commands]: Shard::command
258/// [`poll_next`]: Shard::poll_next
259/// [`queue`]: crate::queue
260#[derive(Debug)]
261pub struct Shard<Q = InMemoryQueue> {
262 /// User provided configuration.
263 ///
264 /// Configurations are provided or created in shard initializing via
265 /// [`Shard::new`] or [`Shard::with_config`].
266 config: Config<Q>,
267 /// Future to establish a WebSocket connection with the Gateway.
268 connection_future: Option<ConnectionFuture>,
269 /// Websocket connection, which may be connected to Discord's gateway.
270 ///
271 /// The connection should only be dropped after it has returned `Ok(None)`
272 /// to comply with the WebSocket protocol.
273 connection: Option<Connection>,
274 /// Zstd decompressor.
275 #[cfg(feature = "zstd")]
276 decompressor: Decompressor,
277 /// Interval of how often the gateway would like the shard to send
278 /// heartbeats.
279 ///
280 /// The interval is received in the [`GatewayEvent::Hello`] event when
281 /// first opening a new [connection].
282 ///
283 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
284 /// [connection]: Self::connection
285 heartbeat_interval: Option<Interval>,
286 /// Whether an event has been received in the current heartbeat interval.
287 heartbeat_interval_event: bool,
288 /// ID of the shard.
289 id: ShardId,
290 /// Identify queue receiver.
291 identify_rx: Option<oneshot::Receiver<()>>,
292 /// Zlib decompressor.
293 #[allow(deprecated)]
294 #[cfg(all(
295 any(feature = "zlib-stock", feature = "zlib-simd"),
296 not(feature = "zstd")
297 ))]
298 inflater: Inflater,
299 /// Potentially pending outgoing message.
300 pending: Option<Pending>,
301 /// Recent heartbeat latency statistics.
302 ///
303 /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
304 /// may have changed, invalidating previous latency statistic.
305 ///
306 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
307 latency: Latency,
308 /// Command ratelimiter, if it was enabled via
309 /// [`Config::ratelimit_messages`].
310 ratelimiter: Option<CommandRatelimiter>,
311 /// Used for resuming connections.
312 resume_url: Option<Box<str>>,
313 /// Active session of the shard.
314 ///
315 /// The shard may not have an active session if it hasn't yet identified and
316 /// received a `READY` dispatch event response.
317 session: Option<Session>,
318 /// Current state of the shard.
319 state: ShardState,
320 /// Messages from the user to be relayed and sent over the Websocket
321 /// connection.
322 user_channel: MessageChannel,
323}
324
325impl Shard {
326 /// Create a new shard with the default configuration.
327 pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
328 Self::with_config(id, Config::new(token, intents))
329 }
330}
331
332impl<Q> Shard<Q> {
333 /// Create a new shard with the provided configuration.
334 pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
335 let session = config.take_session();
336 let mut resume_url = config.take_resume_url();
337 //ensure resume_url is only used if we have a session to resume
338 if session.is_none() {
339 resume_url = None;
340 }
341
342 Self {
343 config,
344 connection_future: None,
345 connection: None,
346 #[cfg(feature = "zstd")]
347 decompressor: Decompressor::new(),
348 heartbeat_interval: None,
349 heartbeat_interval_event: false,
350 id: shard_id,
351 identify_rx: None,
352 #[allow(deprecated)]
353 #[cfg(all(
354 any(feature = "zlib-stock", feature = "zlib-simd"),
355 not(feature = "zstd")
356 ))]
357 inflater: Inflater::new(),
358 pending: None,
359 latency: Latency::new(),
360 ratelimiter: None,
361 resume_url,
362 session,
363 state: ShardState::Disconnected {
364 reconnect_attempts: 0,
365 },
366 user_channel: MessageChannel::new(),
367 }
368 }
369
370 /// Immutable reference to the configuration used to instantiate this shard.
371 pub const fn config(&self) -> &Config<Q> {
372 &self.config
373 }
374
375 /// ID of the shard.
376 pub const fn id(&self) -> ShardId {
377 self.id
378 }
379
380 /// Zlib decompressor statistics.
381 ///
382 /// Reset when reconnecting to the gateway.
383 #[allow(deprecated)]
384 #[cfg(all(
385 any(feature = "zlib-stock", feature = "zlib-simd"),
386 not(feature = "zstd")
387 ))]
388 #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
389 pub const fn inflater(&self) -> &Inflater {
390 &self.inflater
391 }
392
393 /// State of the shard.
394 pub const fn state(&self) -> ShardState {
395 self.state
396 }
397
398 /// Shard latency statistics, including average latency and recent heartbeat
399 /// latency times.
400 ///
401 /// Reset when reconnecting to the gateway.
402 pub const fn latency(&self) -> &Latency {
403 &self.latency
404 }
405
406 /// Statistics about the number of available commands and when the command
407 /// ratelimiter will refresh.
408 ///
409 /// This won't be present if ratelimiting was disabled via
410 /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
411 ///
412 /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
413 pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
414 self.ratelimiter.as_ref()
415 }
416
417 /// Immutable reference to the gateways current resume URL.
418 ///
419 /// A resume URL might not be present if the shard had its session
420 /// invalidated and has not yet reconnected.
421 pub fn resume_url(&self) -> Option<&str> {
422 self.resume_url.as_deref()
423 }
424
425 /// Immutable reference to the active gateway session.
426 ///
427 /// An active session may not be present if the shard had its session
428 /// invalidated and has not yet reconnected.
429 pub const fn session(&self) -> Option<&Session> {
430 self.session.as_ref()
431 }
432
433 /// Queue a command to be sent to the gateway.
434 ///
435 /// Serializes the command and then calls [`send`].
436 ///
437 /// [`send`]: Self::send
438 #[allow(clippy::missing_panics_doc)]
439 pub fn command(&self, command: &impl Command) {
440 self.send(json::to_string(command).expect("serialization cannot fail"));
441 }
442
443 /// Queue a JSON encoded gateway event to be sent to the gateway.
444 #[allow(clippy::missing_panics_doc)]
445 pub fn send(&self, json: String) {
446 self.user_channel
447 .command_tx
448 .send(json)
449 .expect("channel open");
450 }
451
452 /// Queue a websocket close frame.
453 ///
454 /// Invalidates the session and shows the application's bot as offline if
455 /// the close frame code is `1000` or `1001`. Otherwise Discord will
456 /// continue showing the bot as online until its presence times out.
457 ///
458 /// To read all remaining messages, continue calling [`poll_next`] until it
459 /// returns [`Message::Close`].
460 ///
461 /// # Example
462 ///
463 /// Close the shard and process remaining messages:
464 ///
465 /// ```no_run
466 /// # use twilight_gateway::{Intents, Shard, ShardId};
467 /// # #[tokio::main] async fn main() {
468 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
469 /// use tokio_stream::StreamExt;
470 /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
471 ///
472 /// shard.close(CloseFrame::NORMAL);
473 ///
474 /// while let Some(item) = shard.next().await {
475 /// match item {
476 /// Ok(Message::Close(_)) => break,
477 /// Ok(Message::Text(_)) => unimplemented!(),
478 /// Err(source) => unimplemented!(),
479 /// }
480 /// }
481 /// # }
482 /// ```
483 ///
484 /// [`poll_next`]: Shard::poll_next
485 pub fn close(&self, close_frame: CloseFrame<'static>) {
486 _ = self.user_channel.close_tx.try_send(close_frame);
487 }
488
489 /// Retrieve a channel to send messages over the shard to the gateway.
490 ///
491 /// This is primarily useful for sending to other tasks and threads where
492 /// the shard won't be available.
493 ///
494 /// # Example
495 ///
496 /// Queue a command in another process:
497 ///
498 /// ```no_run
499 /// # use twilight_gateway::{Intents, Shard, ShardId};
500 /// # #[tokio::main] async fn main() {
501 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
502 /// use tokio_stream::StreamExt;
503 ///
504 /// while let Some(item) = shard.next().await {
505 /// match item {
506 /// Ok(message) => {
507 /// let sender = shard.sender();
508 /// tokio::spawn(async move {
509 /// let command = unimplemented!();
510 /// sender.send(command);
511 /// });
512 /// }
513 /// Err(source) => unimplemented!(),
514 /// }
515 /// }
516 /// # }
517 /// ```
518 pub fn sender(&self) -> MessageSender {
519 self.user_channel.sender()
520 }
521
522 /// Update internal state from gateway disconnect.
523 fn disconnect(&mut self, initiator: CloseInitiator) {
524 // May not send any additional WebSocket messages.
525 self.heartbeat_interval = None;
526 self.ratelimiter = None;
527 // Abort identify.
528 self.identify_rx = None;
529 self.state = match initiator {
530 CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
531 _ => ShardState::Disconnected {
532 reconnect_attempts: 0,
533 },
534 };
535 if let CloseInitiator::Shard(frame) = initiator {
536 // Not resuming, drop session and resume URL.
537 // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
538 if matches!(frame.code, 1000 | 1001) {
539 self.resume_url = None;
540 self.session = None;
541 }
542 self.pending = Some(Pending {
543 gateway_event: Some(Message::Close(Some(frame))),
544 is_heartbeat: false,
545 });
546 }
547 }
548
549 /// Parse a JSON message into an event with minimal data for [processing].
550 ///
551 /// # Errors
552 ///
553 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
554 /// event isn't a recognized structure, which may be the case for new or
555 /// undocumented events.
556 ///
557 /// [processing]: Self::process
558 fn parse_event<T: DeserializeOwned>(
559 json: &str,
560 ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
561 json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
562 kind: ReceiveMessageErrorType::Deserializing {
563 event: json.to_owned(),
564 },
565 source: Some(Box::new(source)),
566 })
567 }
568}
569
570impl<Q: Queue> Shard<Q> {
571 /// Attempts to send due commands to the gateway.
572 ///
573 /// # Returns
574 ///
575 /// * `Poll::Pending` if sending is in progress
576 /// * `Poll::Ready(Ok)` if no more scheduled commands remain
577 /// * `Poll::Ready(Err)` if sending a command failed.
578 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
579 loop {
580 if let Some(pending) = self.pending.as_mut() {
581 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
582
583 if let Some(message) = &pending.gateway_event {
584 if let Some(ratelimiter) = self.ratelimiter.as_mut() {
585 if message.is_text() && !pending.is_heartbeat {
586 ready!(ratelimiter.poll_acquire(cx));
587 }
588 }
589
590 let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
591 Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
592 }
593
594 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
595
596 if pending.is_heartbeat {
597 self.latency.record_sent();
598 }
599 self.pending = None;
600 }
601
602 if !self.state.is_disconnected() {
603 if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
604 let frame = frame.expect("shard owns channel");
605
606 tracing::debug!("sending close frame from user channel");
607 self.disconnect(CloseInitiator::Shard(frame));
608
609 continue;
610 }
611 }
612
613 if self
614 .heartbeat_interval
615 .as_mut()
616 .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
617 {
618 // Discord never responded after the last heartbeat, connection
619 // is failed or "zombied", see
620 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
621 // Note that unlike documented *any* event is okay; it does not
622 // have to be a heartbeat ACK.
623 if self.latency.sent().is_some() && !self.heartbeat_interval_event {
624 tracing::info!("connection is failed or \"zombied\"");
625 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
626 } else {
627 tracing::debug!("sending heartbeat");
628 self.pending = Pending::text(
629 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
630 .expect("serialization cannot fail"),
631 true,
632 );
633 self.heartbeat_interval_event = false;
634 }
635
636 continue;
637 }
638
639 let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
640 ratelimiter.poll_available(cx).is_ready()
641 });
642
643 if not_ratelimited {
644 if let Some(Poll::Ready(canceled)) = self
645 .identify_rx
646 .as_mut()
647 .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
648 {
649 if canceled {
650 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
651 continue;
652 }
653
654 tracing::debug!("sending identify");
655
656 self.pending = Pending::text(
657 json::to_string(&Identify::new(IdentifyInfo {
658 compress: false,
659 intents: self.config.intents(),
660 large_threshold: self.config.large_threshold(),
661 presence: self.config.presence().cloned(),
662 properties: self
663 .config
664 .identify_properties()
665 .cloned()
666 .unwrap_or_else(default_identify_properties),
667 shard: Some(self.id),
668 token: self.config.token().to_owned(),
669 }))
670 .expect("serialization cannot fail"),
671 false,
672 );
673 self.identify_rx = None;
674
675 continue;
676 }
677 }
678
679 if not_ratelimited && self.state.is_identified() {
680 if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
681 let command = command.expect("shard owns channel");
682
683 tracing::debug!("sending command from user channel");
684 self.pending = Some(Pending {
685 gateway_event: Some(Message::Text(command)),
686 is_heartbeat: false,
687 });
688
689 continue;
690 }
691 }
692
693 return Poll::Ready(Ok(()));
694 }
695 }
696
697 /// Updates the shard's internal state from a gateway event by recording
698 /// and/or responding to certain Discord events.
699 ///
700 /// # Errors
701 ///
702 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
703 /// gateway event isn't a recognized structure.
704 #[allow(clippy::too_many_lines)]
705 fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
706 let (raw_opcode, maybe_sequence, maybe_event_type) =
707 GatewayEventDeserializer::from_json(event)
708 .ok_or_else(|| ReceiveMessageError {
709 kind: ReceiveMessageErrorType::Deserializing {
710 event: event.to_owned(),
711 },
712 source: Some("missing opcode".into()),
713 })?
714 .into_parts();
715
716 if self.latency.sent().is_some() {
717 self.heartbeat_interval_event = true;
718 }
719
720 match OpCode::from(raw_opcode) {
721 Some(OpCode::Dispatch) => {
722 let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
723 kind: ReceiveMessageErrorType::Deserializing {
724 event: event.to_owned(),
725 },
726 source: Some("missing dispatch event type".into()),
727 })?;
728 let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
729 kind: ReceiveMessageErrorType::Deserializing {
730 event: event.to_owned(),
731 },
732 source: Some("missing sequence".into()),
733 })?;
734 tracing::debug!(%event_type, %sequence, "received dispatch");
735
736 match event_type.as_ref() {
737 "READY" => {
738 let event = Self::parse_event::<MinimalReady>(event)?;
739
740 self.resume_url = Some(event.data.resume_gateway_url);
741 self.session = Some(Session::new(sequence, event.data.session_id));
742 self.state = ShardState::Active;
743 }
744 "RESUMED" => self.state = ShardState::Active,
745 _ => {}
746 }
747
748 if let Some(session) = self.session.as_mut() {
749 session.set_sequence(sequence);
750 }
751 }
752 Some(OpCode::Heartbeat) => {
753 tracing::debug!("received heartbeat");
754 self.pending = Pending::text(
755 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
756 .expect("serialization cannot fail"),
757 true,
758 );
759 }
760 Some(OpCode::HeartbeatAck) => {
761 let requested = self.latency.received().is_none() && self.latency.sent().is_some();
762 if requested {
763 tracing::debug!("received heartbeat ack");
764 self.latency.record_received();
765 } else {
766 tracing::info!("received unrequested heartbeat ack");
767 }
768 }
769 Some(OpCode::Hello) => {
770 let event = Self::parse_event::<Hello>(event)?;
771 let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
772 // First heartbeat should have some jitter, see
773 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
774 let jitter = heartbeat_interval.mul_f64(fastrand::f64());
775 tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
776
777 if self.config().ratelimit_messages() {
778 self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
779 }
780
781 let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
782 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
783 self.heartbeat_interval = Some(interval);
784
785 // Reset `Latency` since the shard might have connected to a new
786 // remote which invalidates the recorded latencies.
787 self.latency = Latency::new();
788
789 if let Some(session) = &self.session {
790 self.pending = Pending::text(
791 json::to_string(&Resume::new(
792 session.sequence(),
793 session.id(),
794 self.config.token(),
795 ))
796 .expect("serialization cannot fail"),
797 false,
798 );
799 self.state = ShardState::Resuming;
800 } else {
801 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
802 }
803 }
804 Some(OpCode::InvalidSession) => {
805 let resumable = Self::parse_event(event)?.data;
806 tracing::debug!(resumable, "received invalid session");
807 if resumable {
808 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
809 } else {
810 self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
811 }
812 }
813 Some(OpCode::Reconnect) => {
814 tracing::debug!("received reconnect");
815 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
816 }
817 _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
818 }
819
820 Ok(())
821 }
822}
823
824impl<Q: Queue + Unpin> Stream for Shard<Q> {
825 type Item = Result<Message, ReceiveMessageError>;
826
827 #[allow(clippy::too_many_lines)]
828 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
829 let message = loop {
830 match self.state {
831 ShardState::FatallyClosed => {
832 _ = ready!(Pin::new(
833 self.connection
834 .as_mut()
835 .expect("poll_next called after Poll::Ready(None)")
836 )
837 .poll_close(cx));
838 self.connection = None;
839 return Poll::Ready(None);
840 }
841 ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
842 if self.connection_future.is_none() {
843 let base_url = self
844 .resume_url
845 .as_deref()
846 .or_else(|| self.config.proxy_url())
847 .unwrap_or(GATEWAY_URL);
848 let uri = format!(
849 "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
850 );
851
852 tracing::debug!(url = base_url, "connecting to gateway");
853
854 let tls = self.config.tls.clone();
855 self.connection_future = Some(ConnectionFuture(Box::pin(async move {
856 let secs = 2u8.saturating_pow(reconnect_attempts.into());
857 time::sleep(Duration::from_secs(secs.into())).await;
858
859 Ok(ClientBuilder::new()
860 .uri(&uri)
861 .expect("URL should be valid")
862 .limits(Limits::unlimited())
863 .connector(&tls)
864 .connect()
865 .await?
866 .0)
867 })));
868 }
869
870 let res =
871 ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
872 self.connection_future = None;
873 match res {
874 Ok(connection) => {
875 self.connection = Some(connection);
876 self.state = ShardState::Identifying;
877 #[cfg(feature = "zstd")]
878 self.decompressor.reset();
879 #[allow(deprecated)]
880 #[cfg(all(
881 not(feature = "zstd"),
882 any(feature = "zlib-stock", feature = "zlib-simd")
883 ))]
884 self.inflater.reset();
885 }
886 Err(source) => {
887 self.resume_url = None;
888 self.state = ShardState::Disconnected {
889 reconnect_attempts: reconnect_attempts + 1,
890 };
891
892 return Poll::Ready(Some(Err(ReceiveMessageError {
893 kind: ReceiveMessageErrorType::Reconnect,
894 source: Some(Box::new(source)),
895 })));
896 }
897 }
898 }
899 _ => {}
900 }
901
902 if ready!(self.poll_send(cx)).is_err() {
903 self.disconnect(CloseInitiator::Transport);
904 self.connection = None;
905
906 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
907 }
908
909 match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
910 Some(Ok(message)) => {
911 #[cfg(feature = "zstd")]
912 if message.is_binary() {
913 match self.decompressor.decompress(message.as_payload()) {
914 Ok(message) => break Message::Text(message),
915 Err(source) => {
916 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
917 return Poll::Ready(Some(Err(
918 ReceiveMessageError::from_compression(source),
919 )));
920 }
921 }
922 }
923 #[cfg(all(
924 not(feature = "zstd"),
925 any(feature = "zlib-stock", feature = "zlib-simd")
926 ))]
927 if message.is_binary() {
928 match self.inflater.inflate(message.as_payload()) {
929 Ok(Some(message)) => break Message::Text(message),
930 Ok(None) => continue,
931 Err(source) => {
932 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
933 return Poll::Ready(Some(Err(
934 ReceiveMessageError::from_compression(source),
935 )));
936 }
937 }
938 }
939 if let Some(message) = Message::from_websocket_msg(&message) {
940 break message;
941 }
942 }
943 Some(Err(_)) if self.state.is_disconnected() => {}
944 Some(Err(_)) => {
945 self.disconnect(CloseInitiator::Transport);
946 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
947 }
948 None => {
949 _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
950 tracing::debug!("gateway WebSocket connection closed");
951 // Unclean closure.
952 if !self.state.is_disconnected() {
953 self.disconnect(CloseInitiator::Transport);
954 }
955 self.connection = None;
956 }
957 }
958 };
959
960 match &message {
961 Message::Close(frame) => {
962 // tokio-websockets automatically replies to the close message.
963 tracing::debug!(?frame, "received WebSocket close message");
964 // Don't run `disconnect` if we initiated the close.
965 if !self.state.is_disconnected() {
966 self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
967 }
968 }
969 Message::Text(event) => {
970 self.process(event)?;
971 }
972 }
973
974 Poll::Ready(Some(Ok(message)))
975 }
976}
977
978/// Default identify properties to use when the user hasn't customized it in
979/// [`Config::identify_properties`].
980///
981/// [`Config::identify_properties`]: Config::identify_properties
982fn default_identify_properties() -> IdentifyProperties {
983 IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
984}
985
986#[cfg(test)]
987mod tests {
988 use super::Shard;
989 use static_assertions::{assert_impl_all, assert_not_impl_any};
990 use std::fmt::Debug;
991
992 assert_impl_all!(Shard: Debug, Send);
993 assert_not_impl_any!(Shard: Sync);
994}