twilight_gateway/shard.rs
1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(feature = "zstd")]
10use crate::compression::Decompressor;
11#[allow(deprecated)]
12#[cfg(all(
13 any(feature = "zlib-stock", feature = "zlib-simd"),
14 not(feature = "zstd")
15))]
16use crate::inflater::Inflater;
17use crate::{
18 channel::{MessageChannel, MessageSender},
19 error::{ReceiveMessageError, ReceiveMessageErrorType},
20 json,
21 latency::Latency,
22 queue::{InMemoryQueue, Queue},
23 ratelimiter::CommandRatelimiter,
24 session::Session,
25 Command, Config, Message, ShardId, API_VERSION,
26};
27use futures_core::Stream;
28use futures_sink::Sink;
29use serde::{de::DeserializeOwned, Deserialize};
30use std::{
31 env::consts::OS,
32 error::Error,
33 fmt,
34 future::Future,
35 io,
36 pin::Pin,
37 str,
38 task::{ready, Context, Poll},
39};
40use tokio::{
41 net::TcpStream,
42 sync::oneshot,
43 time::{self, error::Elapsed, timeout, Duration, Instant, Interval, MissedTickBehavior},
44};
45use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
46use twilight_model::gateway::{
47 event::GatewayEventDeserializer,
48 payload::{
49 incoming::Hello,
50 outgoing::{
51 identify::{IdentifyInfo, IdentifyProperties},
52 Heartbeat, Identify, Resume,
53 },
54 },
55 CloseCode, CloseFrame, Intents, OpCode,
56};
57
58/// URL of the Discord gateway.
59const GATEWAY_URL: &str = "wss://gateway.discord.gg";
60
61/// Query argument depending on enabled compression features.
62const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
63 "&compress=zstd-stream"
64} else if cfg!(feature = "zlib-stock") || cfg!(feature = "zlib-simd") {
65 "&compress=zlib-stream"
66} else {
67 ""
68};
69
70/// Timeout for connecting to the gateway.
71const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
72
73/// [`tokio_websockets`] library Websocket connection.
74type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
75
76/// Wrapper enum around [`WebsocketError`] with a timeout case.
77enum ConnectionError {
78 /// Connection attempt timed out.
79 Timeout(Elapsed),
80 /// Error from the websocket library, [`tokio_websockets`].
81 Websocket(WebsocketError),
82}
83
84impl ConnectionError {
85 /// Returns the boxed wrapped error.
86 fn into_boxed_error(self) -> Box<dyn Error + Send + Sync> {
87 match self {
88 Self::Websocket(e) => Box::new(e),
89 Self::Timeout(e) => Box::new(e),
90 }
91 }
92}
93
94impl From<WebsocketError> for ConnectionError {
95 fn from(value: WebsocketError) -> Self {
96 Self::Websocket(value)
97 }
98}
99
100impl From<Elapsed> for ConnectionError {
101 fn from(value: Elapsed) -> Self {
102 Self::Timeout(value)
103 }
104}
105
106/// Wrapper struct around an `async fn` with a `Debug` implementation.
107struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, ConnectionError>> + Send>>);
108
109impl fmt::Debug for ConnectionFuture {
110 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
111 f.debug_tuple("ConnectionFuture")
112 .field(&"<async fn>")
113 .finish()
114 }
115}
116
117/// Close initiator of a websocket connection.
118#[derive(Clone, Debug)]
119enum CloseInitiator {
120 /// Gateway initiated the close.
121 ///
122 /// Contains an optional close code.
123 Gateway(Option<u16>),
124 /// Shard initiated the close.
125 ///
126 /// Contains a close code.
127 Shard(CloseFrame<'static>),
128 /// Transport error initiated the close.
129 Transport,
130}
131
132/// Current state of a [Shard].
133#[derive(Clone, Copy, Debug, Eq, PartialEq)]
134pub enum ShardState {
135 /// Shard is connected to the gateway with an active session.
136 Active,
137 /// Shard is disconnected from the gateway but may reconnect in the future.
138 ///
139 /// The websocket connection may still be open.
140 Disconnected {
141 /// Number of reconnection attempts that have been made.
142 reconnect_attempts: u8,
143 },
144 /// Shard has fatally closed.
145 ///
146 /// Possible reasons may be due to [failed authentication],
147 /// [invalid intents], or other reasons. Refer to the documentation for
148 /// [`CloseCode`] for possible reasons.
149 ///
150 /// [failed authentication]: CloseCode::AuthenticationFailed
151 /// [invalid intents]: CloseCode::InvalidIntents
152 FatallyClosed,
153 /// Shard is waiting to establish or resume a session.
154 Identifying,
155 /// Shard is replaying missed dispatch events.
156 ///
157 /// The shard is considered identified whilst resuming.
158 Resuming,
159}
160
161impl ShardState {
162 /// Determine the connection status from the close code.
163 ///
164 /// Defers to [`CloseCode::can_reconnect`] to determine whether the
165 /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
166 /// the close code is unknown.
167 fn from_close_code(close_code: Option<u16>) -> Self {
168 match close_code.map(CloseCode::try_from) {
169 Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
170 _ => Self::Disconnected {
171 reconnect_attempts: 0,
172 },
173 }
174 }
175
176 /// Whether the shard has disconnected but may reconnect in the future.
177 const fn is_disconnected(self) -> bool {
178 matches!(self, Self::Disconnected { .. })
179 }
180
181 /// Whether the shard is identified with an active session.
182 ///
183 /// `true` if the status is [`Active`] or [`Resuming`].
184 ///
185 /// [`Active`]: Self::Active
186 /// [`Resuming`]: Self::Resuming
187 pub const fn is_identified(self) -> bool {
188 matches!(self, Self::Active | Self::Resuming)
189 }
190}
191
192/// Gateway event with only minimal required data.
193#[derive(Deserialize)]
194struct MinimalEvent<T> {
195 /// Attached data of the gateway event.
196 #[serde(rename = "d")]
197 data: T,
198}
199
200/// Minimal [`Ready`] for light deserialization.
201///
202/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
203#[derive(Deserialize)]
204struct MinimalReady {
205 /// Used for resuming connections.
206 resume_gateway_url: Box<str>,
207 /// ID of the new identified session.
208 session_id: String,
209}
210
211/// Pending outgoing message indicator.
212#[derive(Debug)]
213struct Pending {
214 /// The pending message, if not already sent.
215 gateway_event: Option<Message>,
216 /// Whether the pending gateway event is a heartbeat.
217 is_heartbeat: bool,
218}
219
220impl Pending {
221 /// Constructor for a pending gateway event.
222 const fn text(json: String, is_heartbeat: bool) -> Option<Self> {
223 Some(Self {
224 gateway_event: Some(Message::Text(json)),
225 is_heartbeat,
226 })
227 }
228}
229
230/// Gateway API client responsible for up to 2500 guilds.
231///
232/// Shards are responsible for maintaining the gateway connection by processing
233/// events relevant to the operation of shards---such as requests from the
234/// gateway to re-connect or invalidate a session---and then to pass them on to
235/// the user.
236///
237/// Shards start out disconnected, but will on the first successful call to
238/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
239/// be repeatedly called in order for the shard to maintain its connection and
240/// update its internal state.
241///
242/// Shards go through an [identify queue][`queue`] that rate limits concurrent
243/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
244/// invalidates the shard's session and it is therefore **very important** to
245/// reuse the same queue for all shards.
246///
247/// # Sharding
248///
249/// A shard may not be connected to more than 2500 guilds, so large bots must
250/// split themselves across multiple shards. See the
251/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
252/// more info.
253///
254/// # Examples
255///
256/// Create and start a shard and print new and deleted messages:
257///
258/// ```no_run
259/// use std::env;
260/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
261///
262/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
263/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
264/// // token. Of course, this value may be passed into the program however is
265/// // preferred.
266/// let token = env::var("DISCORD_TOKEN")?;
267/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
268///
269/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
270///
271/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
272/// let Ok(event) = item else {
273/// tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
274///
275/// continue;
276/// };
277///
278/// match event {
279/// Event::MessageCreate(message) => {
280/// println!("message received with content: {}", message.content);
281/// }
282/// Event::MessageDelete(message) => {
283/// println!("message with ID {} deleted", message.id);
284/// }
285/// _ => {}
286/// }
287/// }
288/// # Ok(()) }
289/// ```
290///
291/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
292/// [gateway commands]: Shard::command
293/// [`poll_next`]: Shard::poll_next
294/// [`queue`]: crate::queue
295#[derive(Debug)]
296pub struct Shard<Q = InMemoryQueue> {
297 /// User provided configuration.
298 ///
299 /// Configurations are provided or created in shard initializing via
300 /// [`Shard::new`] or [`Shard::with_config`].
301 config: Config<Q>,
302 /// Future to establish a WebSocket connection with the Gateway.
303 connection_future: Option<ConnectionFuture>,
304 /// Websocket connection, which may be connected to Discord's gateway.
305 ///
306 /// The connection should only be dropped after it has returned `Ok(None)`
307 /// to comply with the WebSocket protocol.
308 connection: Option<Connection>,
309 /// Zstd decompressor.
310 #[cfg(feature = "zstd")]
311 decompressor: Decompressor,
312 /// Interval of how often the gateway would like the shard to send
313 /// heartbeats.
314 ///
315 /// The interval is received in the [`GatewayEvent::Hello`] event when
316 /// first opening a new [connection].
317 ///
318 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
319 /// [connection]: Self::connection
320 heartbeat_interval: Option<Interval>,
321 /// Whether an event has been received in the current heartbeat interval.
322 heartbeat_interval_event: bool,
323 /// ID of the shard.
324 id: ShardId,
325 /// Identify queue receiver.
326 identify_rx: Option<oneshot::Receiver<()>>,
327 /// Zlib decompressor.
328 #[allow(deprecated)]
329 #[cfg(all(
330 any(feature = "zlib-stock", feature = "zlib-simd"),
331 not(feature = "zstd")
332 ))]
333 inflater: Inflater,
334 /// Potentially pending outgoing message.
335 pending: Option<Pending>,
336 /// Recent heartbeat latency statistics.
337 ///
338 /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
339 /// may have changed, invalidating previous latency statistic.
340 ///
341 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
342 latency: Latency,
343 /// Command ratelimiter, if it was enabled via
344 /// [`Config::ratelimit_messages`].
345 ratelimiter: Option<CommandRatelimiter>,
346 /// Used for resuming connections.
347 resume_url: Option<Box<str>>,
348 /// Active session of the shard.
349 ///
350 /// The shard may not have an active session if it hasn't yet identified and
351 /// received a `READY` dispatch event response.
352 session: Option<Session>,
353 /// Current state of the shard.
354 state: ShardState,
355 /// Messages from the user to be relayed and sent over the Websocket
356 /// connection.
357 user_channel: MessageChannel,
358}
359
360impl Shard {
361 /// Create a new shard with the default configuration.
362 pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
363 Self::with_config(id, Config::new(token, intents))
364 }
365}
366
367impl<Q> Shard<Q> {
368 /// Create a new shard with the provided configuration.
369 pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
370 let session = config.take_session();
371 let mut resume_url = config.take_resume_url();
372 //ensure resume_url is only used if we have a session to resume
373 if session.is_none() {
374 resume_url = None;
375 }
376
377 Self {
378 config,
379 connection_future: None,
380 connection: None,
381 #[cfg(feature = "zstd")]
382 decompressor: Decompressor::new(),
383 heartbeat_interval: None,
384 heartbeat_interval_event: false,
385 id: shard_id,
386 identify_rx: None,
387 #[allow(deprecated)]
388 #[cfg(all(
389 any(feature = "zlib-stock", feature = "zlib-simd"),
390 not(feature = "zstd")
391 ))]
392 inflater: Inflater::new(),
393 pending: None,
394 latency: Latency::new(),
395 ratelimiter: None,
396 resume_url,
397 session,
398 state: ShardState::Disconnected {
399 reconnect_attempts: 0,
400 },
401 user_channel: MessageChannel::new(),
402 }
403 }
404
405 /// Immutable reference to the configuration used to instantiate this shard.
406 pub const fn config(&self) -> &Config<Q> {
407 &self.config
408 }
409
410 /// ID of the shard.
411 pub const fn id(&self) -> ShardId {
412 self.id
413 }
414
415 /// Zlib decompressor statistics.
416 ///
417 /// Reset when reconnecting to the gateway.
418 #[allow(deprecated)]
419 #[cfg(all(
420 any(feature = "zlib-stock", feature = "zlib-simd"),
421 not(feature = "zstd")
422 ))]
423 #[deprecated(since = "0.16.1", note = "replaced by zstd compression")]
424 pub const fn inflater(&self) -> &Inflater {
425 &self.inflater
426 }
427
428 /// State of the shard.
429 pub const fn state(&self) -> ShardState {
430 self.state
431 }
432
433 /// Shard latency statistics, including average latency and recent heartbeat
434 /// latency times.
435 ///
436 /// Reset when reconnecting to the gateway.
437 pub const fn latency(&self) -> &Latency {
438 &self.latency
439 }
440
441 /// Statistics about the number of available commands and when the command
442 /// ratelimiter will refresh.
443 ///
444 /// This won't be present if ratelimiting was disabled via
445 /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
446 ///
447 /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
448 pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
449 self.ratelimiter.as_ref()
450 }
451
452 /// Immutable reference to the gateways current resume URL.
453 ///
454 /// A resume URL might not be present if the shard had its session
455 /// invalidated and has not yet reconnected.
456 pub fn resume_url(&self) -> Option<&str> {
457 self.resume_url.as_deref()
458 }
459
460 /// Immutable reference to the active gateway session.
461 ///
462 /// An active session may not be present if the shard had its session
463 /// invalidated and has not yet reconnected.
464 pub const fn session(&self) -> Option<&Session> {
465 self.session.as_ref()
466 }
467
468 /// Queue a command to be sent to the gateway.
469 ///
470 /// Serializes the command and then calls [`send`].
471 ///
472 /// [`send`]: Self::send
473 #[allow(clippy::missing_panics_doc)]
474 pub fn command(&self, command: &impl Command) {
475 self.send(json::to_string(command).expect("serialization cannot fail"));
476 }
477
478 /// Queue a JSON encoded gateway event to be sent to the gateway.
479 #[allow(clippy::missing_panics_doc)]
480 pub fn send(&self, json: String) {
481 self.user_channel
482 .command_tx
483 .send(json)
484 .expect("channel open");
485 }
486
487 /// Queue a websocket close frame.
488 ///
489 /// Invalidates the session and shows the application's bot as offline if
490 /// the close frame code is `1000` or `1001`. Otherwise Discord will
491 /// continue showing the bot as online until its presence times out.
492 ///
493 /// To read all remaining messages, continue calling [`poll_next`] until it
494 /// returns [`Message::Close`].
495 ///
496 /// # Example
497 ///
498 /// Close the shard and process remaining messages:
499 ///
500 /// ```no_run
501 /// # use twilight_gateway::{Intents, Shard, ShardId};
502 /// # #[tokio::main] async fn main() {
503 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
504 /// use tokio_stream::StreamExt;
505 /// use twilight_gateway::{error::ReceiveMessageErrorType, CloseFrame, Message};
506 ///
507 /// shard.close(CloseFrame::NORMAL);
508 ///
509 /// while let Some(item) = shard.next().await {
510 /// match item {
511 /// Ok(Message::Close(_)) => break,
512 /// Ok(Message::Text(_)) => unimplemented!(),
513 /// Err(source) => unimplemented!(),
514 /// }
515 /// }
516 /// # }
517 /// ```
518 ///
519 /// [`poll_next`]: Shard::poll_next
520 pub fn close(&self, close_frame: CloseFrame<'static>) {
521 _ = self.user_channel.close_tx.try_send(close_frame);
522 }
523
524 /// Retrieve a channel to send messages over the shard to the gateway.
525 ///
526 /// This is primarily useful for sending to other tasks and threads where
527 /// the shard won't be available.
528 ///
529 /// # Example
530 ///
531 /// Queue a command in another process:
532 ///
533 /// ```no_run
534 /// # use twilight_gateway::{Intents, Shard, ShardId};
535 /// # #[tokio::main] async fn main() {
536 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
537 /// use tokio_stream::StreamExt;
538 ///
539 /// while let Some(item) = shard.next().await {
540 /// match item {
541 /// Ok(message) => {
542 /// let sender = shard.sender();
543 /// tokio::spawn(async move {
544 /// let command = unimplemented!();
545 /// sender.send(command);
546 /// });
547 /// }
548 /// Err(source) => unimplemented!(),
549 /// }
550 /// }
551 /// # }
552 /// ```
553 pub fn sender(&self) -> MessageSender {
554 self.user_channel.sender()
555 }
556
557 /// Update internal state from gateway disconnect.
558 fn disconnect(&mut self, initiator: CloseInitiator) {
559 // May not send any additional WebSocket messages.
560 self.heartbeat_interval = None;
561 self.ratelimiter = None;
562 // Abort identify.
563 self.identify_rx = None;
564 self.state = match initiator {
565 CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
566 _ => ShardState::Disconnected {
567 reconnect_attempts: 0,
568 },
569 };
570 if let CloseInitiator::Shard(frame) = initiator {
571 // Not resuming, drop session and resume URL.
572 // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
573 if matches!(frame.code, 1000 | 1001) {
574 self.resume_url = None;
575 self.session = None;
576 }
577 self.pending = Some(Pending {
578 gateway_event: Some(Message::Close(Some(frame))),
579 is_heartbeat: false,
580 });
581 }
582 }
583
584 /// Parse a JSON message into an event with minimal data for [processing].
585 ///
586 /// # Errors
587 ///
588 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
589 /// event isn't a recognized structure, which may be the case for new or
590 /// undocumented events.
591 ///
592 /// [processing]: Self::process
593 fn parse_event<T: DeserializeOwned>(
594 json: &str,
595 ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
596 json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
597 kind: ReceiveMessageErrorType::Deserializing {
598 event: json.to_owned(),
599 },
600 source: Some(Box::new(source)),
601 })
602 }
603}
604
605impl<Q: Queue> Shard<Q> {
606 /// Attempts to send due commands to the gateway.
607 ///
608 /// # Returns
609 ///
610 /// * `Poll::Pending` if sending is in progress
611 /// * `Poll::Ready(Ok)` if no more scheduled commands remain
612 /// * `Poll::Ready(Err)` if sending a command failed.
613 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
614 loop {
615 if let Some(pending) = self.pending.as_mut() {
616 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
617
618 if let Some(message) = &pending.gateway_event {
619 if let Some(ratelimiter) = self.ratelimiter.as_mut() {
620 if message.is_text() && !pending.is_heartbeat {
621 ready!(ratelimiter.poll_acquire(cx));
622 }
623 }
624
625 let ws_message = pending.gateway_event.take().unwrap().into_websocket_msg();
626 Pin::new(self.connection.as_mut().unwrap()).start_send(ws_message)?;
627 }
628
629 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
630
631 if pending.is_heartbeat {
632 self.latency.record_sent();
633 }
634 self.pending = None;
635 }
636
637 if !self.state.is_disconnected() {
638 if let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx) {
639 let frame = frame.expect("shard owns channel");
640
641 tracing::debug!("sending close frame from user channel");
642 self.disconnect(CloseInitiator::Shard(frame));
643
644 continue;
645 }
646 }
647
648 if self
649 .heartbeat_interval
650 .as_mut()
651 .is_some_and(|heartbeater| heartbeater.poll_tick(cx).is_ready())
652 {
653 // Discord never responded after the last heartbeat, connection
654 // is failed or "zombied", see
655 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
656 // Note that unlike documented *any* event is okay; it does not
657 // have to be a heartbeat ACK.
658 if self.latency.sent().is_some() && !self.heartbeat_interval_event {
659 tracing::info!("connection is failed or \"zombied\"");
660
661 return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
662 }
663
664 tracing::debug!("sending heartbeat");
665 self.pending = Pending::text(
666 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
667 .expect("serialization cannot fail"),
668 true,
669 );
670 self.heartbeat_interval_event = false;
671
672 continue;
673 }
674
675 let not_ratelimited = self.ratelimiter.as_mut().map_or(true, |ratelimiter| {
676 ratelimiter.poll_available(cx).is_ready()
677 });
678
679 if not_ratelimited {
680 if let Some(Poll::Ready(canceled)) = self
681 .identify_rx
682 .as_mut()
683 .map(|rx| Pin::new(rx).poll(cx).map(|r| r.is_err()))
684 {
685 if canceled {
686 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
687 continue;
688 }
689
690 tracing::debug!("sending identify");
691
692 self.pending = Pending::text(
693 json::to_string(&Identify::new(IdentifyInfo {
694 compress: false,
695 intents: self.config.intents(),
696 large_threshold: self.config.large_threshold(),
697 presence: self.config.presence().cloned(),
698 properties: self
699 .config
700 .identify_properties()
701 .cloned()
702 .unwrap_or_else(default_identify_properties),
703 shard: Some(self.id),
704 token: self.config.token().to_owned(),
705 }))
706 .expect("serialization cannot fail"),
707 false,
708 );
709 self.identify_rx = None;
710
711 continue;
712 }
713 }
714
715 if not_ratelimited && self.state.is_identified() {
716 if let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx) {
717 let command = command.expect("shard owns channel");
718
719 tracing::debug!("sending command from user channel");
720 self.pending = Some(Pending {
721 gateway_event: Some(Message::Text(command)),
722 is_heartbeat: false,
723 });
724
725 continue;
726 }
727 }
728
729 return Poll::Ready(Ok(()));
730 }
731 }
732
733 /// Updates the shard's internal state from a gateway event by recording
734 /// and/or responding to certain Discord events.
735 ///
736 /// # Errors
737 ///
738 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
739 /// gateway event isn't a recognized structure.
740 #[allow(clippy::too_many_lines)]
741 fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
742 let (raw_opcode, maybe_sequence, maybe_event_type) =
743 GatewayEventDeserializer::from_json(event)
744 .ok_or_else(|| ReceiveMessageError {
745 kind: ReceiveMessageErrorType::Deserializing {
746 event: event.to_owned(),
747 },
748 source: Some("missing opcode".into()),
749 })?
750 .into_parts();
751
752 if self.latency.sent().is_some() {
753 self.heartbeat_interval_event = true;
754 }
755
756 match OpCode::from(raw_opcode) {
757 Some(OpCode::Dispatch) => {
758 let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
759 kind: ReceiveMessageErrorType::Deserializing {
760 event: event.to_owned(),
761 },
762 source: Some("missing dispatch event type".into()),
763 })?;
764 let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
765 kind: ReceiveMessageErrorType::Deserializing {
766 event: event.to_owned(),
767 },
768 source: Some("missing sequence".into()),
769 })?;
770 tracing::debug!(%event_type, %sequence, "received dispatch");
771
772 match event_type.as_ref() {
773 "READY" => {
774 let event = Self::parse_event::<MinimalReady>(event)?;
775
776 self.resume_url = Some(event.data.resume_gateway_url);
777 self.session = Some(Session::new(sequence, event.data.session_id));
778 self.state = ShardState::Active;
779 }
780 "RESUMED" => self.state = ShardState::Active,
781 _ => {}
782 }
783
784 if let Some(session) = self.session.as_mut() {
785 session.set_sequence(sequence);
786 }
787 }
788 Some(OpCode::Heartbeat) => {
789 tracing::debug!("received heartbeat");
790 self.pending = Pending::text(
791 json::to_string(&Heartbeat::new(self.session().map(Session::sequence)))
792 .expect("serialization cannot fail"),
793 true,
794 );
795 }
796 Some(OpCode::HeartbeatAck) => {
797 let requested = self.latency.received().is_none() && self.latency.sent().is_some();
798 if requested {
799 tracing::debug!("received heartbeat ack");
800 self.latency.record_received();
801 } else {
802 tracing::info!("received unrequested heartbeat ack");
803 }
804 }
805 Some(OpCode::Hello) => {
806 let event = Self::parse_event::<Hello>(event)?;
807 let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
808 // First heartbeat should have some jitter, see
809 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
810 let jitter = heartbeat_interval.mul_f64(fastrand::f64());
811 tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
812
813 if self.config().ratelimit_messages() {
814 self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
815 }
816
817 let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
818 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
819 self.heartbeat_interval = Some(interval);
820
821 // Reset `Latency` since the shard might have connected to a new
822 // remote which invalidates the recorded latencies.
823 self.latency = Latency::new();
824
825 if let Some(session) = &self.session {
826 self.pending = Pending::text(
827 json::to_string(&Resume::new(
828 session.sequence(),
829 session.id(),
830 self.config.token(),
831 ))
832 .expect("serialization cannot fail"),
833 false,
834 );
835 self.state = ShardState::Resuming;
836 } else {
837 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
838 }
839 }
840 Some(OpCode::InvalidSession) => {
841 let resumable = Self::parse_event(event)?.data;
842 tracing::debug!(resumable, "received invalid session");
843 if resumable {
844 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
845 } else {
846 self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
847 }
848 }
849 Some(OpCode::Reconnect) => {
850 tracing::debug!("received reconnect");
851 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
852 }
853 _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
854 }
855
856 Ok(())
857 }
858}
859
860impl<Q: Queue + Unpin> Stream for Shard<Q> {
861 type Item = Result<Message, ReceiveMessageError>;
862
863 #[allow(clippy::too_many_lines)]
864 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
865 let message = loop {
866 match self.state {
867 ShardState::FatallyClosed => {
868 _ = ready!(Pin::new(
869 self.connection
870 .as_mut()
871 .expect("poll_next called after Poll::Ready(None)")
872 )
873 .poll_close(cx));
874 self.connection = None;
875 return Poll::Ready(None);
876 }
877 ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
878 if self.connection_future.is_none() {
879 let base_url = self
880 .resume_url
881 .as_deref()
882 .or_else(|| self.config.proxy_url())
883 .unwrap_or(GATEWAY_URL);
884 let uri = format!(
885 "{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}"
886 );
887
888 tracing::debug!(url = base_url, "connecting to gateway");
889
890 let tls = self.config.tls.clone();
891 self.connection_future = Some(ConnectionFuture(Box::pin(async move {
892 let secs = 2u8.saturating_pow(reconnect_attempts.into());
893 time::sleep(Duration::from_secs(secs.into())).await;
894
895 Ok(timeout(
896 CONNECT_TIMEOUT,
897 ClientBuilder::new()
898 .uri(&uri)
899 .expect("URL should be valid")
900 .limits(Limits::unlimited())
901 .connector(&tls)
902 .connect(),
903 )
904 .await??
905 .0)
906 })));
907 }
908
909 let res =
910 ready!(Pin::new(&mut self.connection_future.as_mut().unwrap().0).poll(cx));
911 self.connection_future = None;
912 match res {
913 Ok(connection) => {
914 self.connection = Some(connection);
915 self.state = ShardState::Identifying;
916 #[cfg(feature = "zstd")]
917 self.decompressor.reset();
918 #[allow(deprecated)]
919 #[cfg(all(
920 not(feature = "zstd"),
921 any(feature = "zlib-stock", feature = "zlib-simd")
922 ))]
923 self.inflater.reset();
924 }
925 Err(source) => {
926 self.resume_url = None;
927 self.state = ShardState::Disconnected {
928 reconnect_attempts: reconnect_attempts + 1,
929 };
930
931 return Poll::Ready(Some(Err(ReceiveMessageError {
932 kind: ReceiveMessageErrorType::Reconnect,
933 source: Some(source.into_boxed_error()),
934 })));
935 }
936 }
937 }
938 _ => {}
939 }
940
941 if ready!(self.poll_send(cx)).is_err() {
942 self.disconnect(CloseInitiator::Transport);
943 self.connection = None;
944
945 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
946 }
947
948 match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
949 Some(Ok(message)) => {
950 #[cfg(feature = "zstd")]
951 if message.is_binary() {
952 match self.decompressor.decompress(message.as_payload()) {
953 Ok(message) => break Message::Text(message),
954 Err(source) => {
955 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
956 return Poll::Ready(Some(Err(
957 ReceiveMessageError::from_compression(source),
958 )));
959 }
960 }
961 }
962 #[cfg(all(
963 not(feature = "zstd"),
964 any(feature = "zlib-stock", feature = "zlib-simd")
965 ))]
966 if message.is_binary() {
967 match self.inflater.inflate(message.as_payload()) {
968 Ok(Some(message)) => break Message::Text(message),
969 Ok(None) => continue,
970 Err(source) => {
971 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
972 return Poll::Ready(Some(Err(
973 ReceiveMessageError::from_compression(source),
974 )));
975 }
976 }
977 }
978 if let Some(message) = Message::from_websocket_msg(&message) {
979 break message;
980 }
981 }
982 Some(Err(_)) if self.state.is_disconnected() => {}
983 Some(Err(_)) => {
984 self.disconnect(CloseInitiator::Transport);
985 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
986 }
987 None => {
988 _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
989 tracing::debug!("gateway WebSocket connection closed");
990 // Unclean closure.
991 if !self.state.is_disconnected() {
992 self.disconnect(CloseInitiator::Transport);
993 }
994 self.connection = None;
995 }
996 }
997 };
998
999 match &message {
1000 Message::Close(frame) => {
1001 // tokio-websockets automatically replies to the close message.
1002 tracing::debug!(?frame, "received WebSocket close message");
1003 // Don't run `disconnect` if we initiated the close.
1004 if !self.state.is_disconnected() {
1005 self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
1006 }
1007 }
1008 Message::Text(event) => {
1009 self.process(event)?;
1010 }
1011 }
1012
1013 Poll::Ready(Some(Ok(message)))
1014 }
1015}
1016
1017/// Default identify properties to use when the user hasn't customized it in
1018/// [`Config::identify_properties`].
1019///
1020/// [`Config::identify_properties`]: Config::identify_properties
1021fn default_identify_properties() -> IdentifyProperties {
1022 IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027 use super::Shard;
1028 use static_assertions::{assert_impl_all, assert_not_impl_any};
1029 use std::fmt::Debug;
1030
1031 assert_impl_all!(Shard: Debug, Send);
1032 assert_not_impl_any!(Shard: Sync);
1033}