twilight_gateway/shard.rs
1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(any(feature = "zlib", feature = "zstd"))]
10use crate::compression::Decompressor;
11use crate::{
12 API_VERSION, Command, Config, Message, ShardId,
13 channel::{MessageChannel, MessageSender},
14 error::{ReceiveMessageError, ReceiveMessageErrorType},
15 json,
16 latency::Latency,
17 queue::{InMemoryQueue, Queue},
18 ratelimiter::CommandRatelimiter,
19 session::Session,
20};
21use futures_core::Stream;
22use futures_sink::Sink;
23use serde::{Deserialize, Serialize, de::DeserializeOwned};
24use std::{
25 env::consts::OS,
26 error::Error,
27 fmt,
28 future::Future,
29 io,
30 pin::Pin,
31 str,
32 sync::Arc,
33 task::{Context, Poll, ready},
34};
35use tokio::{
36 net::TcpStream,
37 sync::oneshot,
38 time::{self, Duration, Instant, Interval, MissedTickBehavior},
39};
40use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
41use twilight_model::gateway::{
42 CloseCode, CloseFrame, Intents, OpCode,
43 event::GatewayEventDeserializer,
44 payload::{
45 incoming::Hello,
46 outgoing::{
47 Heartbeat, Identify, Resume,
48 identify::{IdentifyInfo, IdentifyProperties},
49 },
50 },
51};
52
53/// URL of the Discord gateway.
54const GATEWAY_URL: &str = "wss://gateway.discord.gg";
55
56/// Query argument depending on enabled compression features.
57const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
58 "&compress=zstd-stream"
59} else if cfg!(feature = "zlib") {
60 "&compress=zlib-stream"
61} else {
62 ""
63};
64
65/// Timeout for connecting to the gateway.
66const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
67
68/// [`tokio_websockets`] library Websocket connection.
69type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
70
71/// Dynamically dispatched [`Error`].
72type GenericError = Box<dyn Error + Send + Sync>;
73
74/// Wrapper struct around an `async fn` with a `Debug` implementation.
75struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, GenericError>> + Send>>);
76
77impl fmt::Debug for ConnectionFuture {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_tuple("ConnectionFuture")
80 .field(&"<async fn>")
81 .finish()
82 }
83}
84
85/// Close initiator of a websocket connection.
86#[derive(Clone, Debug)]
87enum CloseInitiator {
88 /// Gateway initiated the close.
89 ///
90 /// Contains an optional close code.
91 Gateway(Option<u16>),
92 /// Shard initiated the close.
93 ///
94 /// Contains a close code.
95 Shard(CloseFrame<'static>),
96 /// Transport error initiated the close.
97 Transport,
98}
99
100/// Current state of a [Shard].
101#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102pub enum ShardState {
103 /// Shard is connected to the gateway with an active session.
104 Active,
105 /// Shard is disconnected from the gateway but may reconnect in the future.
106 ///
107 /// The websocket connection may still be open.
108 Disconnected {
109 /// Number of reconnection attempts that have been made.
110 reconnect_attempts: u8,
111 },
112 /// Shard has fatally closed.
113 ///
114 /// Possible reasons may be due to [failed authentication],
115 /// [invalid intents], or other reasons. Refer to the documentation for
116 /// [`CloseCode`] for possible reasons.
117 ///
118 /// [failed authentication]: CloseCode::AuthenticationFailed
119 /// [invalid intents]: CloseCode::InvalidIntents
120 FatallyClosed,
121 /// Shard is waiting to establish or resume a session.
122 Identifying,
123 /// Shard is replaying missed dispatch events.
124 ///
125 /// The shard is considered identified whilst resuming.
126 Resuming,
127}
128
129impl ShardState {
130 /// Determine the connection status from the close code.
131 ///
132 /// Defers to [`CloseCode::can_reconnect`] to determine whether the
133 /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
134 /// the close code is unknown.
135 fn from_close_code(close_code: Option<u16>) -> Self {
136 match close_code.map(CloseCode::try_from) {
137 Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
138 _ => Self::Disconnected {
139 reconnect_attempts: 0,
140 },
141 }
142 }
143
144 /// Whether the shard has disconnected but may reconnect in the future.
145 const fn is_disconnected(self) -> bool {
146 matches!(self, Self::Disconnected { .. })
147 }
148
149 /// Whether the shard is identified with an active session.
150 ///
151 /// `true` if the status is [`Active`] or [`Resuming`].
152 ///
153 /// [`Active`]: Self::Active
154 /// [`Resuming`]: Self::Resuming
155 pub const fn is_identified(self) -> bool {
156 matches!(self, Self::Active | Self::Resuming)
157 }
158}
159
160/// Gateway event with only minimal required data.
161#[derive(Deserialize)]
162struct MinimalEvent<T> {
163 /// Attached data of the gateway event.
164 #[serde(rename = "d")]
165 data: T,
166}
167
168/// Minimal [`Ready`] for light deserialization.
169///
170/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
171#[derive(Deserialize)]
172struct MinimalReady {
173 /// Used for resuming connections.
174 resume_gateway_url: Box<str>,
175 /// ID of the new identified session.
176 session_id: String,
177}
178
179/// Pending outgoing message indicator.
180#[derive(Debug)]
181struct Pending {
182 /// The pending message, if not already sent.
183 gateway_event: Option<Message>,
184 /// Whether the pending gateway event is a heartbeat.
185 is_heartbeat: bool,
186}
187
188impl Pending {
189 /// Constructor for a pending gateway event.
190 fn event<T: Serialize>(event: T, is_heartbeat: bool) -> Option<Self> {
191 Some(Self {
192 gateway_event: Some(Message::Text(
193 json::to_string(&event).expect("json serialization is infallible"),
194 )),
195 is_heartbeat,
196 })
197 }
198}
199
200/// Gateway API client responsible for up to 2500 guilds.
201///
202/// Shards are responsible for maintaining the gateway connection by processing
203/// events relevant to the operation of shards---such as requests from the
204/// gateway to re-connect or invalidate a session---and then to pass them on to
205/// the user.
206///
207/// Shards start out disconnected, but will on the first successful call to
208/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
209/// be repeatedly called in order for the shard to maintain its connection and
210/// update its internal state.
211///
212/// Shards go through an [identify queue][`queue`] that rate limits concurrent
213/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
214/// invalidates the shard's session and it is therefore **very important** to
215/// reuse the same queue for all shards.
216///
217/// # Sharding
218///
219/// A shard may not be connected to more than 2500 guilds, so large bots must
220/// split themselves across multiple shards. See the
221/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
222/// more info.
223///
224/// # Examples
225///
226/// Create and start a shard and print new and deleted messages:
227///
228/// ```no_run
229/// use std::env;
230/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
231///
232/// # #[tokio::main] async fn main() -> Result<(), Box<dyn std::error::Error>> {
233/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
234/// // token. Of course, this value may be passed into the program however is
235/// // preferred.
236/// let token = env::var("DISCORD_TOKEN")?;
237/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
238///
239/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
240///
241/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
242/// let Ok(event) = item else {
243/// tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
244///
245/// continue;
246/// };
247///
248/// match event {
249/// Event::MessageCreate(message) => {
250/// println!("message received with content: {}", message.content);
251/// }
252/// Event::MessageDelete(message) => {
253/// println!("message with ID {} deleted", message.id);
254/// }
255/// _ => {}
256/// }
257/// }
258/// # Ok(()) }
259/// ```
260///
261/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
262/// [gateway commands]: Shard::command
263/// [`poll_next`]: Shard::poll_next
264/// [`queue`]: crate::queue
265#[derive(Debug)]
266pub struct Shard<Q = InMemoryQueue> {
267 /// User provided configuration.
268 ///
269 /// Configurations are provided or created in shard initializing via
270 /// [`Shard::new`] or [`Shard::with_config`].
271 config: Config<Q>,
272 /// Future to establish a WebSocket connection with the Gateway.
273 connection_future: Option<ConnectionFuture>,
274 /// Websocket connection, which may be connected to Discord's gateway.
275 ///
276 /// The connection should only be dropped after it has returned `Ok(None)`
277 /// to comply with the WebSocket protocol.
278 connection: Option<Connection>,
279 /// Event decompressor.
280 #[cfg(any(feature = "zlib", feature = "zstd"))]
281 decompressor: Decompressor,
282 /// Interval of how often the gateway would like the shard to send
283 /// heartbeats.
284 ///
285 /// The interval is received in the [`GatewayEvent::Hello`] event when
286 /// first opening a new [connection].
287 ///
288 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
289 /// [connection]: Self::connection
290 heartbeat_interval: Option<Interval>,
291 /// Whether an event has been received in the current heartbeat interval.
292 heartbeat_interval_event: bool,
293 /// ID of the shard.
294 id: ShardId,
295 /// Identify queue receiver.
296 identify_rx: Option<oneshot::Receiver<()>>,
297 /// Potentially pending outgoing message.
298 pending: Option<Pending>,
299 /// Recent heartbeat latency statistics.
300 ///
301 /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
302 /// may have changed, invalidating previous latency statistic.
303 ///
304 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
305 latency: Latency,
306 /// Command ratelimiter, if it was enabled via
307 /// [`Config::ratelimit_messages`].
308 ratelimiter: Option<CommandRatelimiter>,
309 /// Used for resuming connections.
310 resume_url: Option<Box<str>>,
311 /// Active session of the shard.
312 ///
313 /// The shard may not have an active session if it hasn't yet identified and
314 /// received a `READY` dispatch event response.
315 session: Option<Session>,
316 /// Current state of the shard.
317 state: ShardState,
318 /// Messages from the user to be relayed and sent over the Websocket
319 /// connection.
320 user_channel: MessageChannel,
321}
322
323impl Shard {
324 /// Create a new shard with the default configuration.
325 pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
326 Self::with_config(id, Config::new(token, intents))
327 }
328}
329
330impl<Q> Shard<Q> {
331 /// Create a new shard with the provided configuration.
332 pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
333 let session = config.take_session();
334 let mut resume_url = config.take_resume_url();
335 //ensure resume_url is only used if we have a session to resume
336 if session.is_none() {
337 resume_url = None;
338 }
339
340 Self {
341 config,
342 connection_future: None,
343 connection: None,
344 #[cfg(any(feature = "zlib", feature = "zstd"))]
345 decompressor: Decompressor::new(),
346 heartbeat_interval: None,
347 heartbeat_interval_event: false,
348 id: shard_id,
349 identify_rx: None,
350 pending: None,
351 latency: Latency::new(),
352 ratelimiter: None,
353 resume_url,
354 session,
355 state: ShardState::Disconnected {
356 reconnect_attempts: 0,
357 },
358 user_channel: MessageChannel::new(),
359 }
360 }
361
362 /// Immutable reference to the configuration used to instantiate this shard.
363 pub const fn config(&self) -> &Config<Q> {
364 &self.config
365 }
366
367 /// ID of the shard.
368 pub const fn id(&self) -> ShardId {
369 self.id
370 }
371
372 /// State of the shard.
373 pub const fn state(&self) -> ShardState {
374 self.state
375 }
376
377 /// Shard latency statistics, including average latency and recent heartbeat
378 /// latency times.
379 ///
380 /// Reset when reconnecting to the gateway.
381 pub const fn latency(&self) -> &Latency {
382 &self.latency
383 }
384
385 /// Statistics about the number of available commands and when the command
386 /// ratelimiter will refresh.
387 ///
388 /// This won't be present if ratelimiting was disabled via
389 /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
390 ///
391 /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
392 pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
393 self.ratelimiter.as_ref()
394 }
395
396 /// Immutable reference to the gateways current resume URL.
397 ///
398 /// A resume URL might not be present if the shard had its session
399 /// invalidated and has not yet reconnected.
400 pub fn resume_url(&self) -> Option<&str> {
401 self.resume_url.as_deref()
402 }
403
404 /// Immutable reference to the active gateway session.
405 ///
406 /// An active session may not be present if the shard had its session
407 /// invalidated and has not yet reconnected.
408 pub const fn session(&self) -> Option<&Session> {
409 self.session.as_ref()
410 }
411
412 /// Queue a command to be sent to the gateway.
413 ///
414 /// Serializes the command and then calls [`send`].
415 ///
416 /// [`send`]: Self::send
417 #[allow(clippy::missing_panics_doc)]
418 pub fn command(&self, command: &impl Command) {
419 self.send(json::to_string(command).expect("serialization cannot fail"));
420 }
421
422 /// Queue a JSON encoded gateway event to be sent to the gateway.
423 #[allow(clippy::missing_panics_doc)]
424 pub fn send(&self, json: String) {
425 self.user_channel
426 .command_tx
427 .send(json)
428 .expect("channel open");
429 }
430
431 /// Queue a websocket close frame.
432 ///
433 /// Invalidates the session and shows the application's bot as offline if
434 /// the close frame code is `1000` or `1001`. Otherwise Discord will
435 /// continue showing the bot as online until its presence times out.
436 ///
437 /// To read all remaining messages, continue calling [`poll_next`] until it
438 /// returns [`Message::Close`].
439 ///
440 /// # Example
441 ///
442 /// Close the shard and process remaining messages:
443 ///
444 /// ```no_run
445 /// # use twilight_gateway::{Intents, Shard, ShardId};
446 /// # #[tokio::main] async fn main() {
447 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
448 /// use tokio_stream::StreamExt;
449 /// use twilight_gateway::{CloseFrame, Message, error::ReceiveMessageErrorType};
450 ///
451 /// shard.close(CloseFrame::NORMAL);
452 ///
453 /// while let Some(item) = shard.next().await {
454 /// match item {
455 /// Ok(Message::Close(_)) => break,
456 /// Ok(Message::Text(_)) => unimplemented!(),
457 /// Err(source) => unimplemented!(),
458 /// }
459 /// }
460 /// # }
461 /// ```
462 ///
463 /// [`poll_next`]: Shard::poll_next
464 pub fn close(&self, close_frame: CloseFrame<'static>) {
465 _ = self.user_channel.close_tx.try_send(close_frame);
466 }
467
468 /// Retrieve a channel to send messages over the shard to the gateway.
469 ///
470 /// This is primarily useful for sending to other tasks and threads where
471 /// the shard won't be available.
472 ///
473 /// # Example
474 ///
475 /// Queue a command in another process:
476 ///
477 /// ```no_run
478 /// # use twilight_gateway::{Intents, Shard, ShardId};
479 /// # #[tokio::main] async fn main() {
480 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
481 /// use tokio_stream::StreamExt;
482 ///
483 /// while let Some(item) = shard.next().await {
484 /// match item {
485 /// Ok(message) => {
486 /// let sender = shard.sender();
487 /// tokio::spawn(async move {
488 /// let command = unimplemented!();
489 /// sender.send(command);
490 /// });
491 /// }
492 /// Err(source) => unimplemented!(),
493 /// }
494 /// }
495 /// # }
496 /// ```
497 pub fn sender(&self) -> MessageSender {
498 self.user_channel.sender()
499 }
500
501 /// Update internal state from gateway disconnect.
502 fn disconnect(&mut self, initiator: CloseInitiator) {
503 // May not send any additional WebSocket messages.
504 self.heartbeat_interval = None;
505 self.ratelimiter = None;
506 // Abort identify.
507 self.identify_rx = None;
508 self.state = match initiator {
509 CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
510 _ => ShardState::Disconnected {
511 reconnect_attempts: 0,
512 },
513 };
514 if let CloseInitiator::Shard(frame) = initiator {
515 // Not resuming, drop session and resume URL.
516 // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
517 if matches!(frame.code, 1000 | 1001) {
518 self.resume_url = None;
519 self.session = None;
520 }
521 self.pending = Some(Pending {
522 gateway_event: Some(Message::Close(Some(frame))),
523 is_heartbeat: false,
524 });
525 }
526 }
527
528 /// Parse a JSON message into an event with minimal data for [processing].
529 ///
530 /// # Errors
531 ///
532 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
533 /// event isn't a recognized structure, which may be the case for new or
534 /// undocumented events.
535 ///
536 /// [processing]: Self::process
537 fn parse_event<T: DeserializeOwned>(
538 json: &str,
539 ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
540 json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
541 kind: ReceiveMessageErrorType::Deserializing {
542 event: json.to_owned(),
543 },
544 source: Some(Box::new(source)),
545 })
546 }
547
548 /// Attempts to connect to the gateway.
549 ///
550 /// # Returns
551 ///
552 /// * `Poll::Pending` if connection is in progress
553 /// * `Poll::Ready(Ok)` if connected
554 /// * `Poll::Ready(Err)` if connecting to the gateway failed.
555 fn poll_connect(
556 &mut self,
557 cx: &mut Context<'_>,
558 attempt: u8,
559 ) -> Poll<Result<(), ReceiveMessageError>> {
560 let fut = self.connection_future.get_or_insert_with(|| {
561 let base_url = self
562 .resume_url
563 .as_deref()
564 .or_else(|| self.config.proxy_url())
565 .unwrap_or(GATEWAY_URL);
566 let base_url_len = base_url.len();
567 let uri = format!("{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}");
568
569 let tls = Arc::clone(&self.config.tls);
570 ConnectionFuture(Box::pin(async move {
571 let delay = Duration::from_secs(2u8.saturating_pow(attempt.into()).into());
572 time::sleep(delay).await;
573 tracing::debug!(url = &uri[..base_url_len], "connecting");
574
575 let builder = ClientBuilder::new()
576 .uri(&uri)
577 .expect("valid URL")
578 .limits(Limits::unlimited())
579 .connector(&tls);
580 Ok(time::timeout(CONNECT_TIMEOUT, builder.connect()).await??.0)
581 }))
582 });
583
584 let res = ready!(Pin::new(&mut fut.0).poll(cx));
585 self.connection_future = None;
586 match res {
587 Ok(connection) => {
588 self.connection = Some(connection);
589 self.state = ShardState::Identifying;
590 #[cfg(any(feature = "zlib", feature = "zstd"))]
591 self.decompressor.reset();
592 }
593 Err(source) => {
594 self.resume_url = None;
595 self.state = ShardState::Disconnected {
596 reconnect_attempts: attempt + 1,
597 };
598
599 return Poll::Ready(Err(ReceiveMessageError {
600 kind: ReceiveMessageErrorType::Reconnect,
601 source: Some(source),
602 }));
603 }
604 }
605
606 Poll::Ready(Ok(()))
607 }
608}
609
610impl<Q: Queue> Shard<Q> {
611 /// Attempts to send due commands to the gateway.
612 ///
613 /// # Returns
614 ///
615 /// * `Poll::Pending` if sending is in progress
616 /// * `Poll::Ready(Ok)` if no more scheduled commands remain
617 /// * `Poll::Ready(Err)` if sending a command failed.
618 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
619 loop {
620 if let Some(pending) = self.pending.as_mut() {
621 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
622
623 let is_ratelimited = pending.gateway_event.as_ref().is_some_and(Message::is_text)
624 && !pending.is_heartbeat;
625 if is_ratelimited && let Some(ratelimiter) = &mut self.ratelimiter {
626 ready!(ratelimiter.poll_acquire(cx));
627 }
628
629 if let Some(msg) = pending.gateway_event.take().map(Message::into_websocket) {
630 Pin::new(self.connection.as_mut().unwrap()).start_send(msg)?;
631 }
632
633 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
634
635 if pending.is_heartbeat {
636 self.latency.record_sent();
637 }
638 self.pending = None;
639 }
640
641 if !self.state.is_disconnected()
642 && let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx)
643 {
644 let frame = frame.expect("shard owns channel");
645
646 tracing::debug!("sending close frame from user channel");
647 self.disconnect(CloseInitiator::Shard(frame));
648
649 continue;
650 }
651
652 if let Some(heartbeater) = &mut self.heartbeat_interval
653 && heartbeater.poll_tick(cx).is_ready()
654 {
655 // Discord never responded after the last heartbeat, connection
656 // is failed or "zombied", see
657 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
658 // Note that unlike documented *any* event is okay; it does not
659 // have to be a heartbeat ACK.
660 if self.latency.sent().is_some() && !self.heartbeat_interval_event {
661 tracing::info!("connection is failed or \"zombied\"");
662
663 return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
664 }
665
666 tracing::debug!("sending heartbeat");
667 self.pending =
668 Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
669 self.heartbeat_interval_event = false;
670
671 continue;
672 }
673
674 let not_ratelimited = self
675 .ratelimiter
676 .as_mut()
677 .is_none_or(|ratelimiter| ratelimiter.poll_available(cx).is_ready());
678
679 if not_ratelimited
680 && let Some(rx) = &mut self.identify_rx
681 && let Poll::Ready(canceled) = Pin::new(rx).poll(cx).map(|r| r.is_err())
682 {
683 if canceled {
684 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
685 continue;
686 }
687
688 tracing::debug!("sending identify");
689
690 self.pending = Pending::event(
691 Identify::new(IdentifyInfo {
692 compress: false,
693 intents: self.config.intents(),
694 large_threshold: self.config.large_threshold(),
695 presence: self.config.presence().cloned(),
696 properties: self
697 .config
698 .identify_properties()
699 .cloned()
700 .unwrap_or_else(default_identify_properties),
701 shard: Some(self.id),
702 token: self.config.token().to_owned(),
703 }),
704 false,
705 );
706 self.identify_rx = None;
707
708 continue;
709 }
710
711 if not_ratelimited
712 && self.state.is_identified()
713 && let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx)
714 {
715 let command = command.expect("shard owns channel");
716
717 tracing::debug!("sending command from user channel");
718 self.pending = Some(Pending {
719 gateway_event: Some(Message::Text(command)),
720 is_heartbeat: false,
721 });
722
723 continue;
724 }
725
726 return Poll::Ready(Ok(()));
727 }
728 }
729
730 /// Updates the shard's internal state from a gateway event by recording
731 /// and/or responding to certain Discord events.
732 ///
733 /// # Errors
734 ///
735 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
736 /// gateway event isn't a recognized structure.
737 #[allow(clippy::too_many_lines)]
738 fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
739 let (raw_opcode, maybe_sequence, maybe_event_type) =
740 GatewayEventDeserializer::from_json(event)
741 .ok_or_else(|| ReceiveMessageError {
742 kind: ReceiveMessageErrorType::Deserializing {
743 event: event.to_owned(),
744 },
745 source: Some("missing opcode".into()),
746 })?
747 .into_parts();
748
749 if self.latency.sent().is_some() {
750 self.heartbeat_interval_event = true;
751 }
752
753 match OpCode::from(raw_opcode) {
754 Some(OpCode::Dispatch) => {
755 let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
756 kind: ReceiveMessageErrorType::Deserializing {
757 event: event.to_owned(),
758 },
759 source: Some("missing dispatch event type".into()),
760 })?;
761 let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
762 kind: ReceiveMessageErrorType::Deserializing {
763 event: event.to_owned(),
764 },
765 source: Some("missing sequence".into()),
766 })?;
767 tracing::debug!(%event_type, %sequence, "received dispatch");
768
769 match event_type.as_ref() {
770 "READY" => {
771 let event = Self::parse_event::<MinimalReady>(event)?;
772
773 self.resume_url = Some(event.data.resume_gateway_url);
774 self.session = Some(Session::new(sequence, event.data.session_id));
775 self.state = ShardState::Active;
776 }
777 "RESUMED" => self.state = ShardState::Active,
778 _ => {}
779 }
780
781 if let Some(session) = self.session.as_mut() {
782 session.set_sequence(sequence);
783 }
784 }
785 Some(OpCode::Heartbeat) => {
786 tracing::debug!("received heartbeat");
787 self.pending =
788 Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
789 }
790 Some(OpCode::HeartbeatAck) => {
791 let requested = self.latency.received().is_none() && self.latency.sent().is_some();
792 if requested {
793 tracing::debug!("received heartbeat ack");
794 self.latency.record_received();
795 } else {
796 tracing::info!("received unrequested heartbeat ack");
797 }
798 }
799 Some(OpCode::Hello) => {
800 let event = Self::parse_event::<Hello>(event)?;
801 let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
802 // First heartbeat should have some jitter, see
803 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
804 let jitter = heartbeat_interval.mul_f64(fastrand::f64());
805 tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
806
807 if self.config().ratelimit_messages() {
808 self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
809 }
810
811 let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
812 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
813 self.heartbeat_interval = Some(interval);
814
815 // Reset `Latency` since the shard might have connected to a new
816 // remote which invalidates the recorded latencies.
817 self.latency = Latency::new();
818
819 if let Some(session) = &self.session {
820 self.pending = Pending::event(
821 Resume::new(session.sequence(), session.id(), self.config.token()),
822 false,
823 );
824 self.state = ShardState::Resuming;
825 } else {
826 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
827 }
828 }
829 Some(OpCode::InvalidSession) => {
830 let resumable = Self::parse_event(event)?.data;
831 tracing::debug!(resumable, "received invalid session");
832 if resumable {
833 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
834 } else {
835 self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
836 }
837 }
838 Some(OpCode::Reconnect) => {
839 tracing::debug!("received reconnect");
840 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
841 }
842 _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
843 }
844
845 Ok(())
846 }
847}
848
849impl<Q: Queue + Unpin> Stream for Shard<Q> {
850 type Item = Result<Message, ReceiveMessageError>;
851
852 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
853 let message = loop {
854 match self.state {
855 ShardState::FatallyClosed => {
856 _ = ready!(
857 Pin::new(
858 self.connection
859 .as_mut()
860 .expect("poll_next called after Poll::Ready(None)")
861 )
862 .poll_close(cx)
863 );
864 self.connection = None;
865 return Poll::Ready(None);
866 }
867 ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
868 ready!(self.poll_connect(cx, reconnect_attempts))?;
869 }
870 _ => {}
871 }
872
873 if ready!(self.poll_send(cx)).is_err() {
874 self.disconnect(CloseInitiator::Transport);
875 self.connection = None;
876
877 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
878 }
879
880 match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
881 Some(Ok(message)) => {
882 #[cfg(any(feature = "zlib", feature = "zstd"))]
883 if message.is_binary() {
884 match self.decompressor.decompress(message.as_payload()) {
885 #[cfg(feature = "zstd")]
886 Ok(message) => break Message::Text(message),
887 #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
888 Ok(Some(message)) => break Message::Text(message),
889 #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
890 Ok(None) => continue,
891 Err(source) => {
892 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
893 return Poll::Ready(Some(Err(
894 ReceiveMessageError::from_compression(source),
895 )));
896 }
897 }
898 }
899 if let Some(message) = Message::from_websocket_msg(&message) {
900 break message;
901 }
902 }
903 Some(Err(_)) if self.state.is_disconnected() => {}
904 Some(Err(_)) => {
905 self.disconnect(CloseInitiator::Transport);
906 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
907 }
908 None => {
909 _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
910 tracing::debug!("gateway WebSocket connection closed");
911 // Unclean closure.
912 if !self.state.is_disconnected() {
913 self.disconnect(CloseInitiator::Transport);
914 }
915 self.connection = None;
916 }
917 }
918 };
919
920 match &message {
921 Message::Close(frame) => {
922 // tokio-websockets automatically replies to the close message.
923 tracing::debug!(?frame, "received WebSocket close message");
924 // Don't run `disconnect` if we initiated the close.
925 if !self.state.is_disconnected() {
926 self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
927 }
928 }
929 Message::Text(event) => {
930 self.process(event)?;
931 }
932 }
933
934 Poll::Ready(Some(Ok(message)))
935 }
936}
937
938/// Default identify properties to use when the user hasn't customized it in
939/// [`Config::identify_properties`].
940///
941/// [`Config::identify_properties`]: Config::identify_properties
942fn default_identify_properties() -> IdentifyProperties {
943 IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
944}
945
946#[cfg(test)]
947mod tests {
948 use super::Shard;
949 use static_assertions::{assert_impl_all, assert_not_impl_any};
950 use std::fmt::Debug;
951
952 assert_impl_all!(Shard: Debug, Send);
953 assert_not_impl_any!(Shard: Sync);
954}