twilight_gateway/shard.rs
1//! Primary logic and implementation details of Discord gateway websocket
2//! connections.
3//!
4//! Shards are, at their heart, a websocket connection with some state for
5//! maintaining an identified session with the Discord gateway. For more
6//! information about what a shard is in the context of Discord's gateway API,
7//! refer to the documentation for [`Shard`].
8
9#[cfg(any(feature = "zlib", feature = "zstd"))]
10use crate::compression::Decompressor;
11use crate::{
12 API_VERSION, Command, Config, Message, ShardId,
13 channel::{MessageChannel, MessageSender},
14 error::{ReceiveMessageError, ReceiveMessageErrorType},
15 json,
16 latency::Latency,
17 queue::{InMemoryQueue, Queue},
18 ratelimiter::CommandRatelimiter,
19 session::Session,
20};
21use futures_core::Stream;
22use futures_sink::Sink;
23use serde::{Deserialize, Serialize, de::DeserializeOwned};
24use std::{
25 env::consts::OS,
26 error::Error,
27 fmt,
28 future::Future,
29 io,
30 pin::Pin,
31 str,
32 sync::Arc,
33 task::{Context, Poll, ready},
34};
35use tokio::{
36 net::TcpStream,
37 sync::oneshot,
38 time::{self, Duration, Instant, Interval, MissedTickBehavior},
39};
40use tokio_websockets::{ClientBuilder, Error as WebsocketError, Limits, MaybeTlsStream};
41use twilight_model::gateway::{
42 CloseCode, CloseFrame, Intents, OpCode,
43 event::GatewayEventDeserializer,
44 payload::{
45 incoming::Hello,
46 outgoing::{
47 Heartbeat, Identify, Resume,
48 identify::{IdentifyInfo, IdentifyProperties},
49 },
50 },
51};
52
53/// URL of the Discord gateway.
54const GATEWAY_URL: &str = "wss://gateway.discord.gg";
55
56/// Query argument depending on enabled compression features.
57const COMPRESSION_FEATURES: &str = if cfg!(feature = "zstd") {
58 "&compress=zstd-stream"
59} else if cfg!(feature = "zlib") {
60 "&compress=zlib-stream"
61} else {
62 ""
63};
64
65/// Timeout for connecting to the gateway.
66const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
67
68/// [`tokio_websockets`] library Websocket connection.
69type Connection = tokio_websockets::WebSocketStream<MaybeTlsStream<TcpStream>>;
70
71/// Dynamically dispatched [`Error`].
72type GenericError = Box<dyn Error + Send + Sync>;
73
74/// Wrapper struct around an `async fn` with a `Debug` implementation.
75struct ConnectionFuture(Pin<Box<dyn Future<Output = Result<Connection, GenericError>> + Send>>);
76
77impl fmt::Debug for ConnectionFuture {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_tuple("ConnectionFuture")
80 .field(&"<async fn>")
81 .finish()
82 }
83}
84
85/// Close initiator of a websocket connection.
86#[derive(Clone, Debug)]
87enum CloseInitiator {
88 /// Gateway initiated the close.
89 ///
90 /// Contains an optional close code.
91 Gateway(Option<u16>),
92 /// Shard initiated the close.
93 ///
94 /// Contains a close code.
95 Shard(CloseFrame<'static>),
96 /// Transport error initiated the close.
97 Transport,
98}
99
100/// Current state of a [Shard].
101#[derive(Clone, Copy, Debug, Eq, PartialEq)]
102pub enum ShardState {
103 /// Shard is connected to the gateway with an active session.
104 Active,
105 /// Shard is disconnected from the gateway but may reconnect in the future.
106 ///
107 /// The websocket connection may still be open.
108 Disconnected {
109 /// Number of reconnection attempts that have been made.
110 reconnect_attempts: u8,
111 },
112 /// Shard has fatally closed.
113 ///
114 /// Possible reasons may be due to [failed authentication],
115 /// [invalid intents], or other reasons. Refer to the documentation for
116 /// [`CloseCode`] for possible reasons.
117 ///
118 /// [failed authentication]: CloseCode::AuthenticationFailed
119 /// [invalid intents]: CloseCode::InvalidIntents
120 FatallyClosed,
121 /// Shard is waiting to establish or resume a session.
122 Identifying,
123 /// Shard is replaying missed dispatch events.
124 ///
125 /// The shard is considered identified whilst resuming.
126 Resuming,
127}
128
129impl ShardState {
130 /// Determine the connection status from the close code.
131 ///
132 /// Defers to [`CloseCode::can_reconnect`] to determine whether the
133 /// connection can be reconnected, defaulting to [`Self::Disconnected`] if
134 /// the close code is unknown.
135 fn from_close_code(close_code: Option<u16>) -> Self {
136 match close_code.map(CloseCode::try_from) {
137 Some(Ok(close_code)) if !close_code.can_reconnect() => Self::FatallyClosed,
138 _ => Self::Disconnected {
139 reconnect_attempts: 0,
140 },
141 }
142 }
143
144 /// Whether the shard has disconnected but may reconnect in the future.
145 const fn is_disconnected(self) -> bool {
146 matches!(self, Self::Disconnected { .. })
147 }
148
149 /// Whether the shard is identified with an active session.
150 ///
151 /// `true` if the status is [`Active`] or [`Resuming`].
152 ///
153 /// [`Active`]: Self::Active
154 /// [`Resuming`]: Self::Resuming
155 pub const fn is_identified(self) -> bool {
156 matches!(self, Self::Active | Self::Resuming)
157 }
158}
159
160/// Gateway event with only minimal required data.
161#[derive(Deserialize)]
162struct MinimalEvent<T> {
163 /// Attached data of the gateway event.
164 #[serde(rename = "d")]
165 data: T,
166}
167
168/// Minimal [`Ready`] for light deserialization.
169///
170/// [`Ready`]: twilight_model::gateway::payload::incoming::Ready
171#[derive(Deserialize)]
172struct MinimalReady {
173 /// Used for resuming connections.
174 resume_gateway_url: Box<str>,
175 /// ID of the new identified session.
176 session_id: String,
177}
178
179/// Pending outgoing message indicator.
180#[derive(Debug)]
181struct Pending {
182 /// The pending message, if not already sent.
183 gateway_event: Option<Message>,
184 /// Whether the pending gateway event is a heartbeat.
185 is_heartbeat: bool,
186}
187
188impl Pending {
189 /// Constructor for a pending gateway event.
190 fn event<T: Serialize>(event: T, is_heartbeat: bool) -> Option<Self> {
191 Some(Self {
192 gateway_event: Some(Message::Text(
193 json::to_string(&event).expect("json serialization is infallible"),
194 )),
195 is_heartbeat,
196 })
197 }
198}
199
200/// Gateway API client responsible for up to 2500 guilds.
201///
202/// Shards are responsible for maintaining the gateway connection by processing
203/// events relevant to the operation of shards---such as requests from the
204/// gateway to re-connect or invalidate a session---and then to pass them on to
205/// the user.
206///
207/// Shards start out disconnected, but will on the first successful call to
208/// [`poll_next`] try to reconnect to the gateway. [`poll_next`] must then
209/// be repeatedly called in order for the shard to maintain its connection and
210/// update its internal state.
211///
212/// Shards go through an [identify queue][`queue`] that rate limits concurrent
213/// `Identify` events (across all shards) per 5 seconds. Exceeding this limit
214/// invalidates the shard's session and it is therefore **very important** to
215/// reuse the same queue for all shards.
216///
217/// # Sharding
218///
219/// A shard may not be connected to more than 2500 guilds, so large bots must
220/// split themselves across multiple shards. See the
221/// [Discord Docs/Sharding][docs:sharding] and [`ShardId`] documentation for
222/// more info.
223///
224/// # Examples
225///
226/// Create and start a shard and print new and deleted messages:
227///
228/// ```no_run
229/// use std::env;
230/// use twilight_gateway::{Event, EventTypeFlags, Intents, Shard, ShardId, StreamExt as _};
231///
232/// # #[tokio::main(flavor = "current_thread")] async fn main() -> Result<(), Box<dyn std::error::Error>> {
233/// // Use the value of the "DISCORD_TOKEN" environment variable as the bot's
234/// // token. Of course, this value may be passed into the program however is
235/// // preferred.
236/// let token = env::var("DISCORD_TOKEN")?;
237/// let wanted_event_types = EventTypeFlags::MESSAGE_CREATE | EventTypeFlags::MESSAGE_DELETE;
238///
239/// let mut shard = Shard::new(ShardId::ONE, token, Intents::GUILD_MESSAGES);
240///
241/// while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
242/// let Ok(event) = item else {
243/// tracing::warn!(source = ?item.unwrap_err(), "error receiving event");
244///
245/// continue;
246/// };
247///
248/// match event {
249/// Event::MessageCreate(message) => {
250/// println!("message received with content: {}", message.content);
251/// }
252/// Event::MessageDelete(message) => {
253/// println!("message with ID {} deleted", message.id);
254/// }
255/// _ => {}
256/// }
257/// }
258/// # Ok(()) }
259/// ```
260///
261/// [docs:sharding]: https://discord.com/developers/docs/topics/gateway#sharding
262/// [gateway commands]: Shard::command
263/// [`poll_next`]: Shard::poll_next
264/// [`queue`]: crate::queue
265#[derive(Debug)]
266pub struct Shard<Q = InMemoryQueue> {
267 /// User provided configuration.
268 ///
269 /// Configurations are provided or created in shard initializing via
270 /// [`Shard::new`] or [`Shard::with_config`].
271 config: Config<Q>,
272 /// Future to establish a WebSocket connection with the Gateway.
273 connection_future: Option<ConnectionFuture>,
274 /// Websocket connection, which may be connected to Discord's gateway.
275 ///
276 /// The connection should only be dropped after it has returned `Ok(None)`
277 /// to comply with the WebSocket protocol.
278 connection: Option<Connection>,
279 /// Event decompressor.
280 #[cfg(any(feature = "zlib", feature = "zstd"))]
281 decompressor: Decompressor,
282 /// Interval of how often the gateway would like the shard to send
283 /// heartbeats.
284 ///
285 /// The interval is received in the [`GatewayEvent::Hello`] event when
286 /// first opening a new [connection].
287 ///
288 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
289 /// [connection]: Self::connection
290 heartbeat_interval: Option<Interval>,
291 /// Whether an event has been received in the current heartbeat interval.
292 heartbeat_interval_event: bool,
293 /// ID of the shard.
294 id: ShardId,
295 /// Identify queue receiver.
296 identify_rx: Option<oneshot::Receiver<()>>,
297 /// Potentially pending outgoing message.
298 pending: Option<Pending>,
299 /// Recent heartbeat latency statistics.
300 ///
301 /// The latency is reset on receiving [`GatewayEvent::Hello`] as the host
302 /// may have changed, invalidating previous latency statistic.
303 ///
304 /// [`GatewayEvent::Hello`]: twilight_model::gateway::event::GatewayEvent::Hello
305 latency: Latency,
306 /// Command ratelimiter, if it was enabled via
307 /// [`Config::ratelimit_messages`].
308 ratelimiter: Option<CommandRatelimiter>,
309 /// Used for resuming connections.
310 resume_url: Option<Box<str>>,
311 /// Active session of the shard.
312 ///
313 /// The shard may not have an active session if it hasn't yet identified and
314 /// received a `READY` dispatch event response.
315 session: Option<Session>,
316 /// Current state of the shard.
317 state: ShardState,
318 /// Messages from the user to be relayed and sent over the Websocket
319 /// connection.
320 user_channel: MessageChannel,
321}
322
323impl Shard {
324 /// Create a new shard with the default configuration.
325 pub fn new(id: ShardId, token: String, intents: Intents) -> Self {
326 Self::with_config(id, Config::new(token, intents))
327 }
328}
329
330impl<Q> Shard<Q> {
331 /// Create a new shard with the provided configuration.
332 pub fn with_config(shard_id: ShardId, mut config: Config<Q>) -> Self {
333 let session = config.take_session();
334 let mut resume_url = config.take_resume_url();
335 //ensure resume_url is only used if we have a session to resume
336 if session.is_none() {
337 resume_url = None;
338 }
339
340 Self {
341 config,
342 connection_future: None,
343 connection: None,
344 #[cfg(any(feature = "zlib", feature = "zstd"))]
345 decompressor: Decompressor::new(),
346 heartbeat_interval: None,
347 heartbeat_interval_event: false,
348 id: shard_id,
349 identify_rx: None,
350 pending: None,
351 latency: Latency::new(),
352 ratelimiter: None,
353 resume_url,
354 session,
355 state: ShardState::Disconnected {
356 reconnect_attempts: 0,
357 },
358 user_channel: MessageChannel::new(),
359 }
360 }
361
362 /// Immutable reference to the configuration used to instantiate this shard.
363 pub const fn config(&self) -> &Config<Q> {
364 &self.config
365 }
366
367 /// ID of the shard.
368 pub const fn id(&self) -> ShardId {
369 self.id
370 }
371
372 /// State of the shard.
373 pub const fn state(&self) -> ShardState {
374 self.state
375 }
376
377 /// Shard latency statistics, including average latency and recent heartbeat
378 /// latency times.
379 ///
380 /// Reset when reconnecting to the gateway.
381 pub const fn latency(&self) -> &Latency {
382 &self.latency
383 }
384
385 /// Statistics about the number of available commands and when the command
386 /// ratelimiter will refresh.
387 ///
388 /// This won't be present if ratelimiting was disabled via
389 /// [`ConfigBuilder::ratelimit_messages`] or if the shard is disconnected.
390 ///
391 /// [`ConfigBuilder::ratelimit_messages`]: crate::ConfigBuilder::ratelimit_messages
392 pub const fn ratelimiter(&self) -> Option<&CommandRatelimiter> {
393 self.ratelimiter.as_ref()
394 }
395
396 /// Immutable reference to the gateways current resume URL.
397 ///
398 /// A resume URL might not be present if the shard had its session
399 /// invalidated and has not yet reconnected.
400 pub fn resume_url(&self) -> Option<&str> {
401 self.resume_url.as_deref()
402 }
403
404 /// Immutable reference to the active gateway session.
405 ///
406 /// An active session may not be present if the shard had its session
407 /// invalidated and has not yet reconnected.
408 pub const fn session(&self) -> Option<&Session> {
409 self.session.as_ref()
410 }
411
412 /// Queue a command to be sent to the gateway.
413 ///
414 /// Serializes the command and then calls [`send`].
415 ///
416 /// [`send`]: Self::send
417 #[allow(clippy::missing_panics_doc)]
418 pub fn command(&self, command: &impl Command) {
419 self.send(json::to_string(command).expect("serialization cannot fail"));
420 }
421
422 /// Queue a JSON encoded gateway event to be sent to the gateway.
423 #[allow(clippy::missing_panics_doc)]
424 pub fn send(&self, json: String) {
425 self.user_channel
426 .command_tx
427 .send(json)
428 .expect("channel open");
429 }
430
431 /// Queue a websocket close frame.
432 ///
433 /// Invalidates the session and shows the application's bot as offline if
434 /// the close frame code is `1000` or `1001`. Otherwise Discord will
435 /// continue showing the bot as online until its presence times out.
436 ///
437 /// To read all remaining messages, continue calling [`poll_next`] until it
438 /// returns [`Message::Close`].
439 ///
440 /// # Example
441 ///
442 /// Close the shard and process remaining messages:
443 ///
444 /// ```no_run
445 /// # use twilight_gateway::{Intents, Shard, ShardId};
446 /// # #[tokio::main(flavor = "current_thread")] async fn main() {
447 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
448 /// use tokio_stream::StreamExt;
449 /// use twilight_gateway::{CloseFrame, Message, error::ReceiveMessageErrorType};
450 ///
451 /// shard.close(CloseFrame::NORMAL);
452 ///
453 /// while let Some(item) = shard.next().await {
454 /// match item {
455 /// Ok(Message::Close(_)) => break,
456 /// Ok(Message::Text(_)) => unimplemented!(),
457 /// Err(source) => unimplemented!(),
458 /// }
459 /// }
460 /// # }
461 /// ```
462 ///
463 /// [`poll_next`]: Shard::poll_next
464 pub fn close(&self, close_frame: CloseFrame<'static>) {
465 _ = self.user_channel.close_tx.try_send(close_frame);
466 }
467
468 /// Retrieve a channel to send messages over the shard to the gateway.
469 ///
470 /// This is primarily useful for sending to other tasks and threads where
471 /// the shard won't be available.
472 ///
473 /// # Example
474 ///
475 /// Queue a command in another process:
476 ///
477 /// ```no_run
478 /// # use twilight_gateway::{Intents, Shard, ShardId};
479 /// # #[tokio::main(flavor = "current_thread")] async fn main() {
480 /// # let mut shard = Shard::new(ShardId::ONE, String::new(), Intents::empty());
481 /// use tokio_stream::StreamExt;
482 ///
483 /// while let Some(item) = shard.next().await {
484 /// match item {
485 /// Ok(message) => {
486 /// let sender = shard.sender();
487 /// tokio::spawn(async move {
488 /// let command = unimplemented!();
489 /// sender.send(command);
490 /// });
491 /// }
492 /// Err(source) => unimplemented!(),
493 /// }
494 /// }
495 /// # }
496 /// ```
497 pub fn sender(&self) -> MessageSender {
498 self.user_channel.sender()
499 }
500
501 /// Update internal state from gateway disconnect.
502 fn disconnect(&mut self, initiator: CloseInitiator) {
503 // May not send any additional WebSocket messages.
504 self.heartbeat_interval = None;
505 self.ratelimiter = None;
506 // Abort identify.
507 self.identify_rx = None;
508 self.state = match initiator {
509 CloseInitiator::Gateway(close_code) => ShardState::from_close_code(close_code),
510 _ => ShardState::Disconnected {
511 reconnect_attempts: 0,
512 },
513 };
514 if let CloseInitiator::Shard(frame) = initiator {
515 // Not resuming, drop session and resume URL.
516 // https://discord.com/developers/docs/topics/gateway#initiating-a-disconnect
517 if matches!(frame.code, 1000 | 1001) {
518 self.resume_url = None;
519 self.session = None;
520 }
521 self.pending = Some(Pending {
522 gateway_event: Some(Message::Close(Some(frame))),
523 is_heartbeat: false,
524 });
525 }
526 }
527
528 /// Parse a JSON message into an event with minimal data for [processing].
529 ///
530 /// # Errors
531 ///
532 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the gateway
533 /// event isn't a recognized structure, which may be the case for new or
534 /// undocumented events.
535 ///
536 /// [processing]: Self::process
537 fn parse_event<T: DeserializeOwned>(
538 json: &str,
539 ) -> Result<MinimalEvent<T>, ReceiveMessageError> {
540 json::from_str::<MinimalEvent<T>>(json).map_err(|source| ReceiveMessageError {
541 kind: ReceiveMessageErrorType::Deserializing {
542 event: json.to_owned(),
543 },
544 source: Some(Box::new(source)),
545 })
546 }
547
548 /// Attempts to connect to the gateway.
549 ///
550 /// # Returns
551 ///
552 /// * `Poll::Pending` if connection is in progress
553 /// * `Poll::Ready(Ok)` if connected
554 /// * `Poll::Ready(Err)` if connecting to the gateway failed.
555 fn poll_connect(
556 &mut self,
557 cx: &mut Context<'_>,
558 attempt: u8,
559 ) -> Poll<Result<(), ReceiveMessageError>> {
560 let fut = self.connection_future.get_or_insert_with(|| {
561 let base_url = self
562 .resume_url
563 .as_deref()
564 .or_else(|| self.config.proxy_url())
565 .unwrap_or(GATEWAY_URL);
566 let base_url_len = base_url.len();
567 let uri = format!("{base_url}/?v={API_VERSION}&encoding=json{COMPRESSION_FEATURES}");
568
569 let tls = Arc::clone(&self.config.tls);
570 ConnectionFuture(Box::pin(async move {
571 if attempt != 0 {
572 let secs = 2u8.saturating_pow(u32::from(attempt) - 1);
573 time::sleep(Duration::from_secs(secs.into())).await;
574 }
575 tracing::debug!(url = &uri[..base_url_len], "connecting");
576
577 let builder = ClientBuilder::new()
578 .uri(&uri)
579 .expect("valid URL")
580 .limits(Limits::unlimited())
581 .connector(&tls);
582 Ok(time::timeout(CONNECT_TIMEOUT, builder.connect()).await??.0)
583 }))
584 });
585
586 let res = ready!(Pin::new(&mut fut.0).poll(cx));
587 self.connection_future = None;
588 match res {
589 Ok(connection) => {
590 self.connection = Some(connection);
591 self.state = ShardState::Identifying;
592 #[cfg(any(feature = "zlib", feature = "zstd"))]
593 self.decompressor.reset();
594 }
595 Err(source) => {
596 self.resume_url = None;
597 self.state = ShardState::Disconnected {
598 reconnect_attempts: attempt.saturating_add(1),
599 };
600
601 return Poll::Ready(Err(ReceiveMessageError {
602 kind: ReceiveMessageErrorType::Reconnect,
603 source: Some(source),
604 }));
605 }
606 }
607
608 Poll::Ready(Ok(()))
609 }
610}
611
612impl<Q: Queue> Shard<Q> {
613 /// Attempts to send due commands to the gateway.
614 ///
615 /// # Returns
616 ///
617 /// * `Poll::Pending` if sending is in progress
618 /// * `Poll::Ready(Ok)` if no more scheduled commands remain
619 /// * `Poll::Ready(Err)` if sending a command failed.
620 fn poll_send(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), WebsocketError>> {
621 loop {
622 if let Some(pending) = self.pending.as_mut() {
623 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_ready(cx))?;
624
625 let is_ratelimited = pending.gateway_event.as_ref().is_some_and(Message::is_text)
626 && !pending.is_heartbeat;
627 if is_ratelimited && let Some(ratelimiter) = &mut self.ratelimiter {
628 ready!(ratelimiter.poll_acquire(cx));
629 }
630
631 if let Some(msg) = pending.gateway_event.take().map(Message::into_websocket) {
632 Pin::new(self.connection.as_mut().unwrap()).start_send(msg)?;
633 }
634
635 ready!(Pin::new(self.connection.as_mut().unwrap()).poll_flush(cx))?;
636
637 if pending.is_heartbeat {
638 self.latency.record_sent();
639 }
640 self.pending = None;
641 }
642
643 if !self.state.is_disconnected()
644 && let Poll::Ready(frame) = self.user_channel.close_rx.poll_recv(cx)
645 {
646 let frame = frame.expect("shard owns channel");
647
648 tracing::debug!("sending close frame from user channel");
649 self.disconnect(CloseInitiator::Shard(frame));
650
651 continue;
652 }
653
654 if let Some(heartbeater) = &mut self.heartbeat_interval
655 && heartbeater.poll_tick(cx).is_ready()
656 {
657 // Discord never responded after the last heartbeat, connection
658 // is failed or "zombied", see
659 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval-example-heartbeat-ack
660 // Note that unlike documented *any* event is okay; it does not
661 // have to be a heartbeat ACK.
662 if self.latency.sent().is_some() && !self.heartbeat_interval_event {
663 tracing::info!("connection is failed or \"zombied\"");
664
665 return Poll::Ready(Err(WebsocketError::Io(io::ErrorKind::TimedOut.into())));
666 }
667
668 tracing::debug!("sending heartbeat");
669 self.pending =
670 Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
671 self.heartbeat_interval_event = false;
672
673 continue;
674 }
675
676 let not_ratelimited = self
677 .ratelimiter
678 .as_mut()
679 .is_none_or(|ratelimiter| ratelimiter.poll_available(cx).is_ready());
680
681 if not_ratelimited
682 && let Some(rx) = &mut self.identify_rx
683 && let Poll::Ready(canceled) = Pin::new(rx).poll(cx).map(|r| r.is_err())
684 {
685 if canceled {
686 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
687 continue;
688 }
689
690 tracing::debug!("sending identify");
691
692 self.pending = Pending::event(
693 Identify::new(IdentifyInfo {
694 compress: false,
695 intents: self.config.intents(),
696 large_threshold: self.config.large_threshold(),
697 presence: self.config.presence().cloned(),
698 properties: self
699 .config
700 .identify_properties()
701 .cloned()
702 .unwrap_or_else(default_identify_properties),
703 shard: Some(self.id),
704 token: self.config.token().to_owned(),
705 }),
706 false,
707 );
708 self.identify_rx = None;
709
710 continue;
711 }
712
713 if not_ratelimited
714 && self.state.is_identified()
715 && let Poll::Ready(command) = self.user_channel.command_rx.poll_recv(cx)
716 {
717 let command = command.expect("shard owns channel");
718
719 tracing::debug!("sending command from user channel");
720 self.pending = Some(Pending {
721 gateway_event: Some(Message::Text(command)),
722 is_heartbeat: false,
723 });
724
725 continue;
726 }
727
728 return Poll::Ready(Ok(()));
729 }
730 }
731
732 /// Updates the shard's internal state from a gateway event by recording
733 /// and/or responding to certain Discord events.
734 ///
735 /// # Errors
736 ///
737 /// Returns a [`ReceiveMessageErrorType::Deserializing`] error type if the
738 /// gateway event isn't a recognized structure.
739 #[allow(clippy::too_many_lines)]
740 fn process(&mut self, event: &str) -> Result<(), ReceiveMessageError> {
741 let (raw_opcode, maybe_sequence, maybe_event_type) =
742 GatewayEventDeserializer::from_json(event)
743 .ok_or_else(|| ReceiveMessageError {
744 kind: ReceiveMessageErrorType::Deserializing {
745 event: event.to_owned(),
746 },
747 source: Some("missing opcode".into()),
748 })?
749 .into_parts();
750
751 if self.latency.sent().is_some() {
752 self.heartbeat_interval_event = true;
753 }
754
755 match OpCode::from(raw_opcode) {
756 Some(OpCode::Dispatch) => {
757 let event_type = maybe_event_type.ok_or_else(|| ReceiveMessageError {
758 kind: ReceiveMessageErrorType::Deserializing {
759 event: event.to_owned(),
760 },
761 source: Some("missing dispatch event type".into()),
762 })?;
763 let sequence = maybe_sequence.ok_or_else(|| ReceiveMessageError {
764 kind: ReceiveMessageErrorType::Deserializing {
765 event: event.to_owned(),
766 },
767 source: Some("missing sequence".into()),
768 })?;
769 tracing::debug!(%event_type, %sequence, "received dispatch");
770
771 match event_type.as_ref() {
772 "READY" => {
773 let event = Self::parse_event::<MinimalReady>(event)?;
774
775 self.resume_url = Some(event.data.resume_gateway_url);
776 self.session = Some(Session::new(sequence, event.data.session_id));
777 self.state = ShardState::Active;
778 }
779 "RESUMED" => self.state = ShardState::Active,
780 _ => {}
781 }
782
783 if let Some(session) = self.session.as_mut() {
784 session.set_sequence(sequence);
785 }
786 }
787 Some(OpCode::Heartbeat) => {
788 tracing::debug!("received heartbeat");
789 self.pending =
790 Pending::event(Heartbeat::new(self.session().map(Session::sequence)), true);
791 }
792 Some(OpCode::HeartbeatAck) => {
793 let requested = self.latency.received().is_none() && self.latency.sent().is_some();
794 if requested {
795 tracing::debug!("received heartbeat ack");
796 self.latency.record_received();
797 } else {
798 tracing::info!("received unrequested heartbeat ack");
799 }
800 }
801 Some(OpCode::Hello) => {
802 let event = Self::parse_event::<Hello>(event)?;
803 let heartbeat_interval = Duration::from_millis(event.data.heartbeat_interval);
804 // First heartbeat should have some jitter, see
805 // https://discord.com/developers/docs/topics/gateway#heartbeat-interval
806 let jitter = heartbeat_interval.mul_f64(fastrand::f64());
807 tracing::debug!(?heartbeat_interval, ?jitter, "received hello");
808
809 if self.config().ratelimit_messages() {
810 self.ratelimiter = Some(CommandRatelimiter::new(heartbeat_interval));
811 }
812
813 let mut interval = time::interval_at(Instant::now() + jitter, heartbeat_interval);
814 interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
815 self.heartbeat_interval = Some(interval);
816
817 // Reset `Latency` since the shard might have connected to a new
818 // remote which invalidates the recorded latencies.
819 self.latency = Latency::new();
820
821 if let Some(session) = &self.session {
822 self.pending = Pending::event(
823 Resume::new(session.sequence(), session.id(), self.config.token()),
824 false,
825 );
826 self.state = ShardState::Resuming;
827 } else {
828 self.identify_rx = Some(self.config.queue().enqueue(self.id.number()));
829 }
830 }
831 Some(OpCode::InvalidSession) => {
832 let resumable = Self::parse_event(event)?.data;
833 tracing::debug!(resumable, "received invalid session");
834 if resumable {
835 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
836 } else {
837 self.disconnect(CloseInitiator::Shard(CloseFrame::NORMAL));
838 }
839 }
840 Some(OpCode::Reconnect) => {
841 tracing::debug!("received reconnect");
842 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
843 }
844 _ => tracing::info!("received an unknown opcode: {raw_opcode}"),
845 }
846
847 Ok(())
848 }
849}
850
851impl<Q: Queue + Unpin> Stream for Shard<Q> {
852 type Item = Result<Message, ReceiveMessageError>;
853
854 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
855 let message = loop {
856 match self.state {
857 ShardState::FatallyClosed => {
858 _ = ready!(
859 Pin::new(
860 self.connection
861 .as_mut()
862 .expect("poll_next called after Poll::Ready(None)")
863 )
864 .poll_close(cx)
865 );
866 self.connection = None;
867 return Poll::Ready(None);
868 }
869 ShardState::Disconnected { reconnect_attempts } if self.connection.is_none() => {
870 ready!(self.poll_connect(cx, reconnect_attempts))?;
871 }
872 _ => {}
873 }
874
875 if ready!(self.poll_send(cx)).is_err() {
876 self.disconnect(CloseInitiator::Transport);
877 self.connection = None;
878
879 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
880 }
881
882 match ready!(Pin::new(self.connection.as_mut().unwrap()).poll_next(cx)) {
883 Some(Ok(message)) => {
884 #[cfg(any(feature = "zlib", feature = "zstd"))]
885 if message.is_binary() {
886 match self.decompressor.decompress(message.as_payload()) {
887 #[cfg(feature = "zstd")]
888 Ok(message) => break Message::Text(message),
889 #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
890 Ok(Some(message)) => break Message::Text(message),
891 #[cfg(all(not(feature = "zstd"), feature = "zlib"))]
892 Ok(None) => continue,
893 Err(source) => {
894 self.disconnect(CloseInitiator::Shard(CloseFrame::RESUME));
895 return Poll::Ready(Some(Err(
896 ReceiveMessageError::from_compression(source),
897 )));
898 }
899 }
900 }
901 if let Some(message) = Message::from_websocket_msg(&message) {
902 break message;
903 }
904 }
905 Some(Err(_)) if self.state.is_disconnected() => {}
906 Some(Err(_)) => {
907 self.disconnect(CloseInitiator::Transport);
908 return Poll::Ready(Some(Ok(Message::ABNORMAL_CLOSE)));
909 }
910 None => {
911 _ = ready!(Pin::new(self.connection.as_mut().unwrap()).poll_close(cx));
912 tracing::debug!("gateway WebSocket connection closed");
913 // Unclean closure.
914 if !self.state.is_disconnected() {
915 self.disconnect(CloseInitiator::Transport);
916 }
917 self.connection = None;
918 }
919 }
920 };
921
922 match &message {
923 Message::Close(frame) => {
924 // tokio-websockets automatically replies to the close message.
925 tracing::debug!(?frame, "received WebSocket close message");
926 // Don't run `disconnect` if we initiated the close.
927 if !self.state.is_disconnected() {
928 self.disconnect(CloseInitiator::Gateway(frame.as_ref().map(|f| f.code)));
929 }
930 }
931 Message::Text(event) => {
932 self.process(event)?;
933 }
934 }
935
936 Poll::Ready(Some(Ok(message)))
937 }
938}
939
940/// Default identify properties to use when the user hasn't customized it in
941/// [`Config::identify_properties`].
942///
943/// [`Config::identify_properties`]: Config::identify_properties
944fn default_identify_properties() -> IdentifyProperties {
945 IdentifyProperties::new("twilight.rs", "twilight.rs", OS)
946}
947
948#[cfg(test)]
949mod tests {
950 use super::Shard;
951 use static_assertions::{assert_impl_all, assert_not_impl_any};
952 use std::fmt::Debug;
953
954 assert_impl_all!(Shard: Debug, Send);
955 assert_not_impl_any!(Shard: Sync);
956}