twilight_gateway/shard.rs
1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13 any(feature = "zlib-stock", feature = "zlib-simd"),
14 not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18 channel::{MessageChannel, MessageSender},
19 error::{ReceiveMessageError, ReceiveMessageErrorType},
20 json,
21 latency::Latency,
22 queue::{InMemoryQueue, Queue},
23 ratelimiter::CommandRatelimiter,
24 session::Session,
25 Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30use std::{
31 env::consts::OS,
32 fmt,
33 future::Future,
34 io,
35 pin::Pin,
36 str,
37 task::{ready, Context, Poll},
38};
39use tokio::{
40 net::TcpStream,
41 sync::oneshot,
42 time::{self, Duration, Instant, Interval, MissedTickBehavior},
43};
44use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
45use twilight_model::gateway::{
46 event::GatewayEventDeserializer,
47 payload::{
48 incoming::Hello,
49 outgoing::{
50 identify::{IdentifyInfo, IdentifyProperties},
51 Heartbeat, Identify, Resume,
52 },
53 },
54 CloseCode, CloseFrame, Intents, OpCode,
55};
56
57/// URL of the Discord gateway.
58const GATEWAY_URL: &str = "wss://gateway.discord.gg";
59
60/// Query argument depending on enabled compression features.
61const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
62 "&compress=zstd-stream"
63} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
64 "&compress=zlib-stream"
65} else {
66 ""
67};
68
69/// [`tokio_websockets`] library Websocket connection.
70type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
71
72/// Wrapper struct around an `async fn` with a `Debug` implementation.
73struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, WebsocketError>> + Send>>);
74
75impl fmt::Debug for ConnectionFuture {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_tuple("ConnectionFuture")
78 .field(&"<async fn>")
79 .finish()
80 }
81}
82
83/// Close initiator of a websocket connection.
84#[derive(Clone, Debug)]
85enum CloseInitiator {
86 /// Gateway initiated the close.
87 ///
88 /// Contains an optional close code.
89 Gateway(Option<u16>),
90 /// Shard initiated the close.
91 ///
92 /// Contains a close code.
93 Shard(CloseFrame<'static>),
94 /// Transport error initiated the close.
95 Transport,
96}
97
98/// Current state of a [Shard].
99#[derive(Clone, Copy, Debug, Eq, PartialEq)]
100pub enum ShardState {
101 /// Shard is connected to the gateway with an active session.
102 Active,
103 /// Shard is disconnected from the gateway but may reconnect in the future.
104 ///
105 /// The websocket connection may still be open.
106 Disconnected {
107 /// Number of reconnection attempts that have been made.
108 reconnect_attempts: u8,
109 },
110 /// Shard has fatally closed.
111 ///
112 /// Possible reasons may be due to [failed authentication],
113 /// [invalid intents], or other reasons. Refer to the documentation for
114 /// [`CloseCode`] for possible reasons.
115 ///
116 /// [failed authentication]: CloseCode::AuthenticationFailed
117 /// [invalid intents]: CloseCode::InvalidIntents
118 FatallyClosed,
119 /// Shard is waiting to establish or resume a session.
120 Identifying,
121 /// Shard is replaying missed dispatch events.
122 ///
123 /// The shard is considered identified whilst resuming.
124 Resuming,
125}
126
127impl ShardState {
128 /// Determine the connection status from the close code.
129 ///
130 /// Defers to [`CloseCode::can_reconnect`] to determine whether the
131 /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
132 /// the close code is unknown.
133 fn from_close_code(close_code: Option<u16>) -> Self {
134 match close_code.map(CloseCode::try_from) {
135 Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
136 _ => Self::Disconnected {
137 reconnect_attempts: 0,
138 },
139 }
140 }
141
142 /// Whether the shard has disconnected but may reconnect in the future.
143 const fn is_disconnected(self) -> bool {
144 matches!(self, Self::Disconnected { .. })
145 }
146
147 /// Whether the shard is identified with an active session.
148 ///
149 /// `true` if the status is [`Active`] or [`Resuming`].
150 ///
151 /// [`Active`]: Self::Active
152 /// [`Resuming`]: Self::Resuming
153 pub const fn is_identified(self) -> bool {
154 matches!(self, Self::Active | Self::Resuming)
155 }
156}
157
158/// Gateway event with only minimal required data.
159#[derive(Deserialize)]
160struct MinimalEvent<T> {
161 /// Attached data of the gateway event.
162 #[serde(rename = "d")]
163 data: T,
164}
165
166/// Minimal [`Ready`] for light deserialization.
167///
168/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
169#[derive(Deserialize)]
170struct MinimalReady {
171 /// Used for resuming connections.
172 resume_gateway_url: Box<str>,
173 /// ID of the new identified session.
174 session_id: String,
175}
176
177/// Pending outgoing message indicator.
178#[derive(Debug)]
179struct Pending {
180 /// The pending message, if not already sent.
181 gateway_event: Option<Message>,
182 /// Whether the pending gateway event is a heartbeat.
183 is_heartbeat: bool,
184}
185
186impl Pending {
187 /// Constructor for a pending gateway event.
188 const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
189 Some(Self {
190 gateway_event: Some(Message::Text(json)),
191 is_heartbeat,
192 })
193 }
194}
195
196/// Gateway API client responsible for up to 2500 guilds.
197///
198/// Shards are responsible for maintaining the gateway connection by processing
199/// events relevant to the operation of shards---such as requests from the
200/// gateway to re-connect or invalidate a session---and then to pass them on to
201/// the user.
202///
203/// Shards start out disconnected, but will on the first successful call to
204/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
205/// be repeatedly called in order for the shard to maintain its connection and
206/// update its internal state.
207///
208/// Shards go through an [identify queue][`queue`] that rate limits concurrent
209/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
210/// invalidates the shard's session and it is therefore **very important** to
211/// reuse the same queue for all shards.
212///
213/// # Sharding
214///
215/// A shard may not be connected to more than 2500 guilds, so large bots must
216/// split themselves across multiple shards. See the
217/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
218/// more info.
219///
220/// # Examples
221///
222/// Create and start a shard and print new and deleted messages:
223///
224/// ```no_run
225/// use std::env;
226/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
227///
228/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
229/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
230/// // token. Of course, this value may be passed into the program however is
231/// // preferred.
232/// let token = env::var("DISCORD_TOKEN")?;
233/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
234///
235/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
236///
237/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
238/// let Ok(event) = item else {
239/// tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
240///
241/// continue;
242/// };
243///
244/// match event {
245/// Event::MessageCreate(message) => {
246/// println!("message received with content: {}", message.content);
247/// }
248/// Event::MessageDelete(message) => {
249/// println!("message with ID {} deleted", message.id);
250/// }
251/// _ => {}
252/// }
253/// }
254/// # Ok(()) }
255/// ```
256///
257/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
258/// [gateway commands]: Shard::command
259/// [`poll_next`]: Shard::poll_next
260/// [`queue`]: crate::queue
261#[derive(Debug)]
262pub struct Shard<Q = InMemoryQueue> {
263 /// User provided configuration.
264 ///
265 /// Configurations are provided or created in shard initializing via
266 /// [`Shard::new`] or [`Shard::with_config`].
267 config: Config<Q>,
268 /// Future to establish a WebSocket connection with the Gateway.
269 connection_future: Option<ConnectionFuture>,
270 /// Websocket connection, which may be connected to Discord's gateway.
271 ///
272 /// The connection should only be dropped after it has returned `Ok(None)`
273 /// to comply with the WebSocket protocol.
274 connection: Option<Connection>,
275 /// Zstd decompressor.
276 #[cfg(feature = "zstd")]
277 decompressor: Decompressor,
278 /// Interval of how often the gateway would like the shard to send
279 /// heartbeats.
280 ///
281 /// The interval is received in the [`GatewayEvent::Hello`] event when
282 /// first opening a new [connection].
283 ///
284 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
285 /// [connection]: Self::connection
286 heartbeat_interval: Option<Interval>,
287 /// Whether an event has been received in the current heartbeat interval.
288 heartbeat_interval_event: bool,
289 /// ID of the shard.
290 id: ShardId,
291 /// Identify queue receiver.
292 identify_rx: Option<oneshot::Receiver<()>>,
293 /// Zlib decompressor.
294 #[allow(deprecated)]
295 #[cfg(all(
296 any(feature = "zlib-stock", feature = "zlib-simd"),
297 not(feature = "zstd")
298 ))]
299 inflater: Inflater,
300 /// Potentially pending outgoing message.
301 pending: Option<Pending>,
302 /// Recent heartbeat latency statistics.
303 ///
304 /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
305 /// may have changed, invalidating previous latency statistic.
306 ///
307 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
308 latency: Latency,
309 /// Command ratelimiter, if it was enabled via
310 /// [`Config::ratelimit_messages`].
311 ratelimiter: Option<CommandRatelimiter>,
312 /// Used for resuming connections.
313 resume_url: Option<Box<str>>,
314 /// Active session of the shard.
315 ///
316 /// The shard may not have an active session if it hasn't yet identified and
317 /// received a `READY` dispatch event response.
318 session: Option<Session>,
319 /// Current state of the shard.
320 state: ShardState,
321 /// Messages from the user to be relayed and sent over the Websocket
322 /// connection.
323 user_channel: MessageChannel,
324}
325
326impl Shard {
327 /// Create a new shard with the default configuration.
328 pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
329 Self::with_config(id, Config::new(token, intents))
330 }
331}
332
333impl<Q> Shard<Q> {
334 /// Create a new shard with the provided configuration.
335 pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
336 let session = config.take_session();
337 let mut resume_url = config.take_resume_url();
338 //ensure resume_url is only used if we have a session to resume
339 if session.is_none() {
340 resume_url = None;
341 }
342
343 Self {
344 config,
345 connection_future: None,
346 connection: None,
347 #[cfg(feature = "zstd")]
348 decompressor: Decompressor::new(),
349 heartbeat_interval: None,
350 heartbeat_interval_event: false,
351 id: shard_id,
352 identify_rx: None,
353 #[allow(deprecated)]
354 #[cfg(all(
355 any(feature = "zlib-stock", feature = "zlib-simd"),
356 not(feature = "zstd")
357 ))]
358 inflater: Inflater::new(),
359 pending: None,
360 latency: Latency::new(),
361 ratelimiter: None,
362 resume_url,
363 session,
364 state: ShardState::Disconnected {
365 reconnect_attempts: 0,
366 },
367 user_channel: MessageChannel::new(),
368 }
369 }
370
371 /// Immutable reference to the configuration used to instantiate this shard.
372 pub const fn config(&self) -> &Config<Q> {
373 &self.config
374 }
375
376 /// ID of the shard.
377 pub const fn id(&self) -> ShardId {
378 self.id
379 }
380
381 /// Zlib decompressor statistics.
382 ///
383 /// Reset when reconnecting to the gateway.
384 #[allow(deprecated)]
385 #[cfg(all(
386 any(feature = "zlib-stock", feature = "zlib-simd"),
387 not(feature = "zstd")
388 ))]
389 #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
390 pub const fn inflater(&self) -> &Inflater {
391 &self.inflater
392 }
393
394 /// State of the shard.
395 pub const fn state(&self) -> ShardState {
396 self.state
397 }
398
399 /// Shard latency statistics, including average latency and recent heartbeat
400 /// latency times.
401 ///
402 /// Reset when reconnecting to the gateway.
403 pub const fn latency(&self) -> &Latency {
404 &self.latency
405 }
406
407 /// Statistics about the number of available commands and when the command
408 /// ratelimiter will refresh.
409 ///
410 /// This won't be present if ratelimiting was disabled via
411 /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
412 ///
413 /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
414 pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
415 self.ratelimiter.as_ref()
416 }
417
418 /// Immutable reference to the gateways current resume URL.
419 ///
420 /// A resume URL might not be present if the shard had its session
421 /// invalidated and has not yet reconnected.
422 pub fn resume_url(&self) -> Option<&str> {
423 self.resume_url.as_deref()
424 }
425
426 /// Immutable reference to the active gateway session.
427 ///
428 /// An active session may not be present if the shard had its session
429 /// invalidated and has not yet reconnected.
430 pub const fn session(&self) -> Option<&Session> {
431 self.session.as_ref()
432 }
433
434 /// Queue a command to be sent to the gateway.
435 ///
436 /// Serializes the command and then calls [`send`].
437 ///
438 /// [`send`]: Self::send
439 #[allow(clippy::missing_panics_doc)]
440 pub fn command(&self, command: &impl Command) {
441 self.send(json::to_string(command).expect("serialization cannot fail"));
442 }
443
444 /// Queue a JSON encoded gateway event to be sent to the gateway.
445 #[allow(clippy::missing_panics_doc)]
446 pub fn send(&self, json: String) {
447 self.user_channel
448 .command_tx
449 .send(json)
450 .expect("channel open");
451 }
452
453 /// Queue a websocket close frame.
454 ///
455 /// Invalidates the session and shows the application's bot as offline if
456 /// the close frame code is `1000` or `1001`. Otherwise Discord will
457 /// continue showing the bot as online until its presence times out.
458 ///
459 /// To read all remaining messages, continue calling [`poll_next`] until it
460 /// returns [`Message::Close`].
461 ///
462 /// # Example
463 ///
464 /// Close the shard and process remaining messages:
465 ///
466 /// ```no_run
467 /// # use twilight_gateway::{Intents, Shard, ShardId};
468 /// # #[tokio::main] async fn main() {
469 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
470 /// use tokio_stream::StreamExt;
471 /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
472 ///
473 /// shard.close(CloseFrame::NORMAL);
474 ///
475 /// while let Some(item) = shard.next().await {
476 /// match item {
477 /// Ok(Message::Close(_)) => break,
478 /// Ok(Message::Text(_)) => unimplemented!(),
479 /// Err(source) => unimplemented!(),
480 /// }
481 /// }
482 /// # }
483 /// ```
484 ///
485 /// [`poll_next`]: Shard::poll_next
486 pub fn close(&self, close_frame: CloseFrame<'static>) {
487 _ = self.user_channel.close_tx.try_send(close_frame);
488 }
489
490 /// Retrieve a channel to send messages over the shard to the gateway.
491 ///
492 /// This is primarily useful for sending to other tasks and threads where
493 /// the shard won't be available.
494 ///
495 /// # Example
496 ///
497 /// Queue a command in another process:
498 ///
499 /// ```no_run
500 /// # use twilight_gateway::{Intents, Shard, ShardId};
501 /// # #[tokio::main] async fn main() {
502 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
503 /// use tokio_stream::StreamExt;
504 ///
505 /// while let Some(item) = shard.next().await {
506 /// match item {
507 /// Ok(message) => {
508 /// let sender = shard.sender();
509 /// tokio::spawn(async move {
510 /// let command = unimplemented!();
511 /// sender.send(command);
512 /// });
513 /// }
514 /// Err(source) => unimplemented!(),
515 /// }
516 /// }
517 /// # }
518 /// ```
519 pub fn sender(&self) -> MessageSender {
520 self.user_channel.sender()
521 }
522
523 /// Update internal state from gateway disconnect.
524 fn disconnect(&mut self, initiator: CloseInitiator) {
525 // May not send any additional WebSocket messages.
526 self.heartbeat_interval = None;
527 self.ratelimiter = None;
528 // Abort identify.
529 self.identify_rx = None;
530 self.state = match initiator {
531 CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
532 _ => ShardState::Disconnected {
533 reconnect_attempts: 0,
534 },
535 };
536 if let CloseInitiator::Shard(frame) = initiator {
537 // Not resuming, drop session and resume URL.
538 // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
539 if matches!(frame.code, 1000 | 1001) {
540 self.resume_url = None;
541 self.session = None;
542 }
543 self.pending = Some(Pending {
544 gateway_event: Some(Message::Close(Some(frame))),
545 is_heartbeat: false,
546 });
547 }
548 }
549
550 /// Parse a JSON message into an event with minimal data for [processing].
551 ///
552 /// # Errors
553 ///
554 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
555 /// event isn't a recognized structure, which may be the case for new or
556 /// undocumented events.
557 ///
558 /// [processing]: Self::process
559 fn parse_event<T: DeserializeOwned>(
560 json: &str,
561 ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
562 json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
563 kind: ReceiveMessageErrorType::Deserializing {
564 event: json.to_owned(),
565 },
566 source: Some(Box::new(source)),
567 })
568 }
569}
570
571impl<Q: Queue> Shard<Q> {
572 /// Attempts to send due commands to the gateway.
573 ///
574 /// # Returns
575 ///
576 /// * `Poll::Pending` if sending is in progress
577 /// * `Poll::Ready(Ok)` if no more scheduled commands remain
578 /// * `Poll::Ready(Err)` if sending a command failed.
579 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
580 loop {
581 if let Some(pending) = self.pending.as_mut() {
582 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
583
584 if let Some(message) = &pending.gateway_event {
585 if let Some(ratelimiter) = self.ratelimiter.as_mut() {
586 if message.is_text() && !pending.is_heartbeat {
587 ready!(ratelimiter.poll_acquire(cx));
588 }
589 }
590
591 let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
592 Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
593 }
594
595 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
596
597 if pending.is_heartbeat {
598 self.latency.record_sent();
599 }
600 self.pending = None;
601 }
602
603 if !self.state.is_disconnected() {
604 if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
605 let frame = frame.expect("shard owns channel");
606
607 tracing::debug!("sending close frame from user channel");
608 self.disconnect(CloseInitiator::Shard(frame));
609
610 continue;
611 }
612 }
613
614 if self
615 .heartbeat_interval
616 .as_mut()
617 .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
618 {
619 // Discord never responded after the last heartbeat, connection
620 // is failed or "zombied", see
621 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
622 // Note that unlike documented *any* event is okay; it does not
623 // have to be a heartbeat ACK.
624 if self.latency.sent().is_some() && !self.heartbeat_interval_event {
625 tracing::info!("connection is failed or \"zombied\"");
626
627 return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
628 }
629
630 tracing::debug!("sending heartbeat");
631 self.pending = Pending::text(
632 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
633 .expect("serialization cannot fail"),
634 true,
635 );
636 self.heartbeat_interval_event = false;
637
638 continue;
639 }
640
641 let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
642 ratelimiter.poll_available(cx).is_ready()
643 });
644
645 if not_ratelimited {
646 if let Some(Poll::Ready(canceled)) = self
647 .identify_rx
648 .as_mut()
649 .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
650 {
651 if canceled {
652 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
653 continue;
654 }
655
656 tracing::debug!("sending identify");
657
658 self.pending = Pending::text(
659 json::to_string(&Identify::new(IdentifyInfo {
660 compress: false,
661 intents: self.config.intents(),
662 large_threshold: self.config.large_threshold(),
663 presence: self.config.presence().cloned(),
664 properties: self
665 .config
666 .identify_properties()
667 .cloned()
668 .unwrap_or_else(default_identify_properties),
669 shard: Some(self.id),
670 token: self.config.token().to_owned(),
671 }))
672 .expect("serialization cannot fail"),
673 false,
674 );
675 self.identify_rx = None;
676
677 continue;
678 }
679 }
680
681 if not_ratelimited && self.state.is_identified() {
682 if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
683 let command = command.expect("shard owns channel");
684
685 tracing::debug!("sending command from user channel");
686 self.pending = Some(Pending {
687 gateway_event: Some(Message::Text(command)),
688 is_heartbeat: false,
689 });
690
691 continue;
692 }
693 }
694
695 return Poll::Ready(Ok(()));
696 }
697 }
698
699 /// Updates the shard's internal state from a gateway event by recording
700 /// and/or responding to certain Discord events.
701 ///
702 /// # Errors
703 ///
704 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
705 /// gateway event isn't a recognized structure.
706 #[allow(clippy::too_many_lines)]
707 fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
708 let (raw_opcode, maybe_sequence, maybe_event_type) =
709 GatewayEventDeserializer::from_json(event)
710 .ok_or_else(|| ReceiveMessageError {
711 kind: ReceiveMessageErrorType::Deserializing {
712 event: event.to_owned(),
713 },
714 source: Some("missing opcode".into()),
715 })?
716 .into_parts();
717
718 if self.latency.sent().is_some() {
719 self.heartbeat_interval_event = true;
720 }
721
722 match OpCode::from(raw_opcode) {
723 Some(OpCode::Dispatch) => {
724 let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
725 kind: ReceiveMessageErrorType::Deserializing {
726 event: event.to_owned(),
727 },
728 source: Some("missing dispatch event type".into()),
729 })?;
730 let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
731 kind: ReceiveMessageErrorType::Deserializing {
732 event: event.to_owned(),
733 },
734 source: Some("missing sequence".into()),
735 })?;
736 tracing::debug!(%event_type, %sequence, "received dispatch");
737
738 match event_type.as_ref() {
739 "READY" => {
740 let event = Self::parse_event::<MinimalReady>(event)?;
741
742 self.resume_url = Some(event.data.resume_gateway_url);
743 self.session = Some(Session::new(sequence, event.data.session_id));
744 self.state = ShardState::Active;
745 }
746 "RESUMED" => self.state = ShardState::Active,
747 _ => {}
748 }
749
750 if let Some(session) = self.session.as_mut() {
751 session.set_sequence(sequence);
752 }
753 }
754 Some(OpCode::Heartbeat) => {
755 tracing::debug!("received heartbeat");
756 self.pending = Pending::text(
757 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
758 .expect("serialization cannot fail"),
759 true,
760 );
761 }
762 Some(OpCode::HeartbeatAck) => {
763 let requested = self.latency.received().is_none() && self.latency.sent().is_some();
764 if requested {
765 tracing::debug!("received heartbeat ack");
766 self.latency.record_received();
767 } else {
768 tracing::info!("received unrequested heartbeat ack");
769 }
770 }
771 Some(OpCode::Hello) => {
772 let event = Self::parse_event::<Hello>(event)?;
773 let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
774 // First heartbeat should have some jitter, see
775 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
776 let jitter = heartbeat_interval.mul_f64(fastrand::f64());
777 tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
778
779 if self.config().ratelimit_messages() {
780 self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
781 }
782
783 let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
784 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
785 self.heartbeat_interval = Some(interval);
786
787 // Reset `Latency` since the shard might have connected to a new
788 // remote which invalidates the recorded latencies.
789 self.latency = Latency::new();
790
791 if let Some(session) = &self.session {
792 self.pending = Pending::text(
793 json::to_string(&Resume::new(
794 session.sequence(),
795 session.id(),
796 self.config.token(),
797 ))
798 .expect("serialization cannot fail"),
799 false,
800 );
801 self.state = ShardState::Resuming;
802 } else {
803 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
804 }
805 }
806 Some(OpCode::InvalidSession) => {
807 let resumable = Self::parse_event(event)?.data;
808 tracing::debug!(resumable, "received invalid session");
809 if resumable {
810 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
811 } else {
812 self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
813 }
814 }
815 Some(OpCode::Reconnect) => {
816 tracing::debug!("received reconnect");
817 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
818 }
819 _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
820 }
821
822 Ok(())
823 }
824}
825
826impl<Q: Queue + Unpin> Stream for Shard<Q> {
827 type Item = Result<Message, ReceiveMessageError>;
828
829 #[allow(clippy::too_many_lines)]
830 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
831 let message = loop {
832 match self.state {
833 ShardState::FatallyClosed => {
834 _ = ready!(Pin::new(
835 self.connection
836 .as_mut()
837 .expect("poll_next called after Poll::Ready(None)")
838 )
839 .poll_close(cx));
840 self.connection = None;
841 return Poll::Ready(None);
842 }
843 ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
844 if self.connection_future.is_none() {
845 let base_url = self
846 .resume_url
847 .as_deref()
848 .or_else(|| self.config.proxy_url())
849 .unwrap_or(GATEWAY_URL);
850 let uri = format!(
851 "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
852 );
853
854 tracing::debug!(url = base_url, "connecting to gateway");
855
856 let tls = self.config.tls.clone();
857 self.connection_future = Some(ConnectionFuture(Box::pin(async move {
858 let secs = 2u8.saturating_pow(reconnect_attempts.into());
859 time::sleep(Duration::from_secs(secs.into())).await;
860
861 Ok(ClientBuilder::new()
862 .uri(&uri)
863 .expect("URL should be valid")
864 .limits(Limits::unlimited())
865 .connector(&tls)
866 .connect()
867 .await?
868 .0)
869 })));
870 }
871
872 let res =
873 ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
874 self.connection_future = None;
875 match res {
876 Ok(connection) => {
877 self.connection = Some(connection);
878 self.state = ShardState::Identifying;
879 #[cfg(feature = "zstd")]
880 self.decompressor.reset();
881 #[allow(deprecated)]
882 #[cfg(all(
883 not(feature = "zstd"),
884 any(feature = "zlib-stock", feature = "zlib-simd")
885 ))]
886 self.inflater.reset();
887 }
888 Err(source) => {
889 self.resume_url = None;
890 self.state = ShardState::Disconnected {
891 reconnect_attempts: reconnect_attempts + 1,
892 };
893
894 return Poll::Ready(Some(Err(ReceiveMessageError {
895 kind: ReceiveMessageErrorType::Reconnect,
896 source: Some(Box::new(source)),
897 })));
898 }
899 }
900 }
901 _ => {}
902 }
903
904 if ready!(self.poll_send(cx)).is_err() {
905 self.disconnect(CloseInitiator::Transport);
906 self.connection = None;
907
908 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
909 }
910
911 match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
912 Some(Ok(message)) => {
913 #[cfg(feature = "zstd")]
914 if message.is_binary() {
915 match self.decompressor.decompress(message.as_payload()) {
916 Ok(message) => break Message::Text(message),
917 Err(source) => {
918 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
919 return Poll::Ready(Some(Err(
920 ReceiveMessageError::from_compression(source),
921 )));
922 }
923 }
924 }
925 #[cfg(all(
926 not(feature = "zstd"),
927 any(feature = "zlib-stock", feature = "zlib-simd")
928 ))]
929 if message.is_binary() {
930 match self.inflater.inflate(message.as_payload()) {
931 Ok(Some(message)) => break Message::Text(message),
932 Ok(None) => continue,
933 Err(source) => {
934 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
935 return Poll::Ready(Some(Err(
936 ReceiveMessageError::from_compression(source),
937 )));
938 }
939 }
940 }
941 if let Some(message) = Message::from_websocket_msg(&message) {
942 break message;
943 }
944 }
945 Some(Err(_)) if self.state.is_disconnected() => {}
946 Some(Err(_)) => {
947 self.disconnect(CloseInitiator::Transport);
948 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
949 }
950 None => {
951 _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
952 tracing::debug!("gateway WebSocket connection closed");
953 // Unclean closure.
954 if !self.state.is_disconnected() {
955 self.disconnect(CloseInitiator::Transport);
956 }
957 self.connection = None;
958 }
959 }
960 };
961
962 match &message {
963 Message::Close(frame) => {
964 // tokio-websockets automatically replies to the close message.
965 tracing::debug!(?frame, "received WebSocket close message");
966 // Don't run `disconnect` if we initiated the close.
967 if !self.state.is_disconnected() {
968 self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
969 }
970 }
971 Message::Text(event) => {
972 self.process(event)?;
973 }
974 }
975
976 Poll::Ready(Some(Ok(message)))
977 }
978}
979
980/// Default identify properties to use when the user hasn't customized it in
981/// [`Config::identify_properties`].
982///
983/// [`Config::identify_properties`]: Config::identify_properties
984fn default_identify_properties() -> IdentifyProperties {
985 IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
986}
987
988#[cfg(test)]
989mod tests {
990 use super::Shard;
991 use static_assertions::{assert_impl_all, assert_not_impl_any};
992 use std::fmt::Debug;
993
994 assert_impl_all!(Shard: Debug, Send);
995 assert_not_impl_any!(Shard: Sync);
996}