twilight_lavalink/
node.rs

1//! Nodes for communicating with a Lavalink server.
2//!
3//! Using nodes, you can send events to a server and receive events.
4//!
5//! This is a bit more low level than using the [`Lavalink`] client because you
6//! will need to provide your own `VoiceUpdate` events when your bot joins
7//! channels, meaning you will have to accumulate and combine voice state update
8//! and voice server update events from the Discord gateway to send them to
9//! a node.
10//!
11//! Additionally, you will have to create and manage your own [`PlayerManager`]
12//! and make your own players for guilds when your bot joins voice channels.
13//!
14//! This can be a lot of work, and there's not really much reason to do it
15//! yourself. For that reason, you should almost always use the `Lavalink`
16//! client which does all of this for you.
17//!
18//! [`Lavalink`]: crate::client::Lavalink
19
20use crate::{
21    model::{IncomingEvent, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22    player::PlayerManager,
23};
24use futures_util::{
25    lock::BiLock,
26    sink::SinkExt,
27    stream::{Stream, StreamExt},
28};
29use http::header::{AUTHORIZATION, HeaderName, HeaderValue};
30use http_body_util::Full;
31use hyper::{Method, Request, Uri, body::Bytes, header};
32use hyper_util::{
33    client::legacy::{Client as HyperClient, connect::HttpConnector},
34    rt::TokioExecutor,
35};
36use std::{
37    borrow::Borrow,
38    error::Error,
39    fmt::{Debug, Display, Formatter, Result as FmtResult, Write as _},
40    net::SocketAddr,
41    pin::Pin,
42    task::{Context, Poll},
43    time::Duration,
44};
45use tokio::{
46    net::TcpStream,
47    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
48    time as tokio_time,
49};
50use tokio_websockets::{
51    ClientBuilder, Error as WebsocketError, MaybeTlsStream, Message, WebSocketStream, upgrade,
52};
53use twilight_model::id::{Id, marker::UserMarker};
54
55/// An error occurred while either initializing a connection or while running
56/// its event loop.
57#[derive(Debug)]
58pub struct NodeError {
59    kind: NodeErrorType,
60    source: Option<Box<dyn Error + Send + Sync>>,
61}
62
63impl NodeError {
64    /// Immutable reference to the type of error that occurred.
65    #[must_use = "retrieving the type has no effect if left unused"]
66    pub const fn kind(&self) -> &NodeErrorType {
67        &self.kind
68    }
69
70    /// Consume the error, returning the source error if there is any.
71    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
72    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
73        self.source
74    }
75
76    /// Consume the error, returning the owned error type and the source error.
77    #[must_use = "consuming the error into its parts has no effect if left unused"]
78    pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
79        (self.kind, self.source)
80    }
81}
82
83impl Display for NodeError {
84    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
85        match &self.kind {
86            NodeErrorType::BuildingConnectionRequest => {
87                f.write_str("failed to build connection request")
88            }
89            NodeErrorType::HttpRequestFailed => {
90                f.write_str("failed to send http request to lavalink server")
91            }
92            NodeErrorType::Connecting => f.write_str("Failed to connect to the node"),
93            NodeErrorType::OutgoingEventHasNoSession => {
94                f.write_str("no session id found for connection to lavalink api")
95            }
96            NodeErrorType::SerializingMessage { message: _ } => {
97                f.write_str("failed to serialize outgoing message as json")
98            }
99            NodeErrorType::Unauthorized { address, .. } => {
100                f.write_str("the authorization used to connect to node ")?;
101                Display::fmt(address, f)?;
102
103                f.write_str(" is invalid")
104            }
105        }
106    }
107}
108
109impl Error for NodeError {
110    fn source(&self) -> Option<&(dyn Error + 'static)> {
111        self.source
112            .as_ref()
113            .map(|source| &**source as &(dyn Error + 'static))
114    }
115}
116
117/// Type of [`NodeError`] that occurred.
118#[derive(Debug)]
119#[non_exhaustive]
120pub enum NodeErrorType {
121    /// Building the HTTP request to initialize a connection failed.
122    BuildingConnectionRequest,
123    /// Sending the HTTP request to Lavalink failed.
124    HttpRequestFailed,
125    /// Connecting to the Lavalink server failed after several backoff attempts.
126    Connecting,
127    /// There are potentially no valid session before trying to send outgoing
128    /// events. The session id is obtained in the startup sequence of the node.
129    /// If there is an attempt to send events before connecting, it will error out.
130    OutgoingEventHasNoSession,
131    /// Serializing a JSON message to be sent to a Lavalink node failed.
132    SerializingMessage {
133        /// The message that couldn't be serialized.
134        message: OutgoingEvent,
135    },
136    /// The given authorization for the node is incorrect.
137    Unauthorized {
138        /// The address of the node that failed to authorize.
139        address: SocketAddr,
140        /// The authorization used to connect to the node.
141        authorization: String,
142    },
143}
144
145/// An error that can occur while sending an event over a node.
146#[derive(Debug)]
147pub struct NodeSenderError {
148    kind: NodeSenderErrorType,
149    source: Option<Box<dyn Error + Send + Sync>>,
150}
151
152impl NodeSenderError {
153    /// Immutable reference to the type of error that occurred.
154    pub const fn kind(&self) -> &NodeSenderErrorType {
155        &self.kind
156    }
157
158    /// Consume the error, returning the source error if there is any.
159    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
160        self.source
161    }
162
163    /// Consume the error, returning the owned error type and the source error.
164    #[must_use = "consuming the error into its parts has no effect if left unused"]
165    pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
166        (self.kind, self.source)
167    }
168}
169
170impl Display for NodeSenderError {
171    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
172        match &self.kind {
173            NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
174        }
175    }
176}
177
178impl Error for NodeSenderError {
179    fn source(&self) -> Option<&(dyn Error + 'static)> {
180        self.source
181            .as_ref()
182            .map(|source| &**source as &(dyn Error + 'static))
183    }
184}
185
186/// Type of [`NodeSenderError`] that occurred.
187#[derive(Debug)]
188#[non_exhaustive]
189pub enum NodeSenderErrorType {
190    /// Error occurred while sending over the channel.
191    Sending,
192}
193
194/// Stream of incoming events from a node.
195pub struct IncomingEvents {
196    inner: UnboundedReceiver<IncomingEvent>,
197}
198
199impl IncomingEvents {
200    /// Closes the receiving half of a channel without dropping it.
201    pub fn close(&mut self) {
202        self.inner.close();
203    }
204}
205
206impl Stream for IncomingEvents {
207    type Item = IncomingEvent;
208
209    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
210        self.inner.poll_recv(cx)
211    }
212}
213
214/// Send outgoing events to the associated node.
215pub struct NodeSender {
216    inner: UnboundedSender<OutgoingEvent>,
217}
218
219impl NodeSender {
220    /// Returns whether this channel is closed without needing a context.
221    pub fn is_closed(&self) -> bool {
222        self.inner.is_closed()
223    }
224
225    /// Sends a message along this channel.
226    ///
227    /// This is an unbounded sender, so this function differs from `Sink::send`
228    /// by ensuring the return type reflects that the channel is always ready to
229    /// receive messages.
230    ///
231    /// # Errors
232    ///
233    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
234    /// longer connected.
235    pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
236        self.inner.send(msg).map_err(|source| NodeSenderError {
237            kind: NodeSenderErrorType::Sending,
238            source: Some(Box::new(source)),
239        })
240    }
241}
242
243/// The configuration that a [`Node`] uses to connect to a Lavalink server.
244#[derive(Clone, Eq, PartialEq)]
245#[non_exhaustive]
246// Keep fields in sync with its Debug implementation.
247pub struct NodeConfig {
248    /// The address of the node.
249    pub address: SocketAddr,
250    /// The password to use when authenticating.
251    pub authorization: String,
252    /// The details for resuming a Lavalink session, if any.
253    ///
254    /// Set this to `None` to disable resume capability.
255    pub resume: Option<Resume>,
256    /// The user ID of the bot.
257    pub user_id: Id<UserMarker>,
258    /// Weather or not to enable TLS
259    pub enable_tls: bool,
260}
261
262impl Debug for NodeConfig {
263    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
264        /// Debug as `<redacted>`. Necessary because debugging a struct field
265        /// with a value of of `"<redacted>"` will insert quotations in the
266        /// string, which doesn't align with other token debugs.
267        struct Redacted;
268
269        impl Debug for Redacted {
270            fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
271                f.write_str("<redacted>")
272            }
273        }
274
275        f.debug_struct("NodeConfig")
276            .field("address", &self.address)
277            .field("authorization", &Redacted)
278            .field("resume", &self.resume)
279            .field("user_id", &self.user_id)
280            .field("enable_tls", &self.enable_tls)
281            .finish()
282    }
283}
284
285/// Configuration for a session which can be resumed.
286#[derive(Clone, Debug, Eq, PartialEq)]
287#[non_exhaustive]
288pub struct Resume {
289    /// The number of seconds that the Lavalink server will allow the session to
290    /// be resumed for after a disconnect.
291    ///
292    /// The default is 60.
293    pub timeout: u64,
294}
295
296impl Resume {
297    /// Configure resume capability, providing the number of seconds that the
298    /// Lavalink server should queue events for when the connection is resumed.
299    pub const fn new(seconds: u64) -> Self {
300        Self { timeout: seconds }
301    }
302}
303
304impl Default for Resume {
305    fn default() -> Self {
306        Self { timeout: 60 }
307    }
308}
309
310impl NodeConfig {
311    /// Create a new configuration for connecting to a node via
312    /// [`Node::connect`].
313    ///
314    /// If adding a node through the [`Lavalink`] client then you don't need to
315    /// do this yourself.
316    ///
317    /// [`Lavalink`]: crate::client::Lavalink
318    pub fn new(
319        user_id: Id<UserMarker>,
320        address: impl Into<SocketAddr>,
321        authorization: impl Into<String>,
322        resume: impl Into<Option<Resume>>,
323        enable_tls: bool,
324    ) -> Self {
325        Self::_new(
326            user_id,
327            address.into(),
328            authorization.into(),
329            resume.into(),
330            enable_tls,
331        )
332    }
333
334    const fn _new(
335        user_id: Id<UserMarker>,
336        address: SocketAddr,
337        authorization: String,
338        resume: Option<Resume>,
339        enable_tls: bool,
340    ) -> Self {
341        Self {
342            address,
343            authorization,
344            resume,
345            user_id,
346            enable_tls,
347        }
348    }
349}
350
351/// A connection to a single Lavalink server. It receives events and forwards
352/// events from players to the server.
353///
354/// Please refer to the [module] documentation.
355///
356/// [module]: crate
357#[derive(Debug)]
358pub struct Node {
359    config: NodeConfig,
360    lavalink_tx: UnboundedSender<OutgoingEvent>,
361    players: PlayerManager,
362    stats: BiLock<Stats>,
363}
364
365impl Node {
366    /// Connect to a node, providing a player manager so that the node can
367    /// update player details.
368    ///
369    /// Please refer to the [module] documentation for some additional
370    /// information about directly creating and using nodes. You are encouraged
371    /// to use the [`Lavalink`] client instead.
372    ///
373    /// [`Lavalink`]: crate::client::Lavalink
374    /// [module]: crate
375    ///
376    /// # Errors
377    ///
378    /// Returns an error of type [`Connecting`] if the connection fails after
379    /// several backoff attempts.
380    ///
381    /// Returns an error of type [`BuildingConnectionRequest`] if the request
382    /// failed to build.
383    ///
384    /// Returns an error of type [`Unauthorized`] if the supplied authorization
385    /// is rejected by the node.
386    ///
387    /// [`Connecting`]: crate::node::NodeErrorType::Connecting
388    /// [`BuildingConnectionRequest`]: crate::node::NodeErrorType::BuildingConnectionRequest
389    /// [`Unauthorized`]: crate::node::NodeErrorType::Unauthorized
390    pub async fn connect(
391        config: NodeConfig,
392        players: PlayerManager,
393    ) -> Result<(Self, IncomingEvents), NodeError> {
394        let (bilock_left, bilock_right) = BiLock::new(Stats {
395            cpu: StatsCpu {
396                cores: 0,
397                lavalink_load: 0f64,
398                system_load: 0f64,
399            },
400            frame_stats: None,
401            memory: StatsMemory {
402                allocated: 0,
403                free: 0,
404                used: 0,
405                reservable: 0,
406            },
407            players: 0,
408            playing_players: 0,
409            uptime: 0,
410        });
411
412        tracing::debug!("starting connection to {}", config.address);
413
414        let (conn_loop, lavalink_tx, lavalink_rx) =
415            Connection::connect(config.clone(), players.clone(), bilock_right).await?;
416
417        tracing::debug!("started connection to {}", config.address);
418
419        tokio::spawn(conn_loop.run());
420
421        Ok((
422            Self {
423                config,
424                lavalink_tx,
425                players,
426                stats: bilock_left,
427            },
428            IncomingEvents { inner: lavalink_rx },
429        ))
430    }
431
432    /// Retrieve an immutable reference to the node's configuration.
433    pub const fn config(&self) -> &NodeConfig {
434        &self.config
435    }
436
437    /// Retrieve an immutable reference to the player manager used by the node.
438    pub const fn players(&self) -> &PlayerManager {
439        &self.players
440    }
441
442    /// Retrieve an immutable reference to the node's configuration.
443    ///
444    /// Note that sending player events through the node's sender won't update
445    /// player states, such as whether it's paused.
446    ///
447    /// # Errors
448    ///
449    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
450    /// longer connected.
451    pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
452        self.sender().send(event)
453    }
454
455    /// Retrieve a unique sender to send events to the Lavalink server.
456    ///
457    /// Note that sending player events through the node's sender won't update
458    /// player states, such as whether it's paused.
459    pub fn sender(&self) -> NodeSender {
460        NodeSender {
461            inner: self.lavalink_tx.clone(),
462        }
463    }
464
465    /// Retrieve a copy of the node's stats.
466    pub async fn stats(&self) -> Stats {
467        (*self.stats.lock().await).clone()
468    }
469
470    /// Retrieve the calculated penalty score of the node.
471    ///
472    /// This score can be used to calculate how loaded the server is. A higher
473    /// number means it is more heavily loaded.
474    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
475    pub async fn penalty(&self) -> i32 {
476        let stats = self.stats.lock().await;
477        let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
478
479        let (deficit_frame, null_frame) = (
480            1.03f64.powf(
481                500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64),
482            ) * 300f64
483                - 300f64,
484            (1.03f64.powf(
485                500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64),
486            ) * 300f64
487                - 300f64)
488                * 2f64,
489        );
490
491        stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
492    }
493}
494
495struct Connection {
496    config: NodeConfig,
497    stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
498    lavalink_http: HyperClient<HttpConnector, Full<Bytes>>,
499    node_from: UnboundedReceiver<OutgoingEvent>,
500    node_to: UnboundedSender<IncomingEvent>,
501    players: PlayerManager,
502    stats: BiLock<Stats>,
503    lavalink_session_id: Option<Box<str>>,
504}
505
506impl Connection {
507    async fn connect(
508        config: NodeConfig,
509        players: PlayerManager,
510        stats: BiLock<Stats>,
511    ) -> Result<
512        (
513            Self,
514            UnboundedSender<OutgoingEvent>,
515            UnboundedReceiver<IncomingEvent>,
516        ),
517        NodeError,
518    > {
519        let stream = reconnect(&config).await?;
520
521        let (to_node, from_lavalink) = mpsc::unbounded_channel();
522        let (to_lavalink, from_node) = mpsc::unbounded_channel();
523
524        let mut client_builder = HyperClient::builder(TokioExecutor::new());
525
526        if config.enable_tls {
527            client_builder.http2_only(config.enable_tls);
528        }
529
530        let lavalink_http = client_builder.build_http();
531
532        Ok((
533            Self {
534                config,
535                stream,
536                lavalink_http,
537                node_from: from_node,
538                node_to: to_node,
539                players,
540                stats,
541                lavalink_session_id: None,
542            },
543            to_lavalink,
544            from_lavalink,
545        ))
546    }
547
548    async fn run(mut self) -> Result<(), NodeError> {
549        loop {
550            tokio::select! {
551                incoming = self.stream.next() => {
552                    if let Some(Ok(incoming)) = incoming {
553                        self.incoming(incoming).await?;
554                    } else {
555                        tracing::debug!("connection to {} closed, reconnecting", self.config.address);
556                        self.stream = reconnect(&self.config).await?;
557                    }
558                }
559                outgoing = self.node_from.recv() => {
560                    if let Some(outgoing) = outgoing {
561                        self.outgoing(outgoing).await?;
562                    } else {
563                        tracing::debug!("node {} closed, ending connection", self.config.address);
564                        break;
565                    }
566                }
567            }
568        }
569
570        Ok(())
571    }
572
573    fn get_outgoing_endpoint_based_on_event(
574        &mut self,
575        outgoing: &OutgoingEvent,
576    ) -> Result<(Method, hyper::Uri), NodeError> {
577        let address = self.config.address;
578        tracing::debug!("forwarding event to {address}: {outgoing:?}");
579
580        let guild_id = outgoing.guild_id();
581        let no_replace = outgoing.no_replace();
582
583        if let Some(session) = &self.lavalink_session_id {
584            let mut path = format!("/v4/sessions/{session}/players/{guild_id}");
585            if !matches!(outgoing, OutgoingEvent::Destroy(_)) {
586                let _ = write!(path, "?noReplace={no_replace}");
587            }
588            let uri = Uri::builder()
589                .scheme("http")
590                .authority(address.to_string())
591                .path_and_query(path)
592                .build()
593                .expect("uri is valid");
594            return if matches!(outgoing, OutgoingEvent::Destroy(_)) {
595                Ok((Method::DELETE, uri))
596            } else {
597                Ok((Method::PATCH, uri))
598            };
599        }
600
601        tracing::error!("no session id is found");
602
603        Err(NodeError {
604            kind: NodeErrorType::OutgoingEventHasNoSession,
605            source: None,
606        })
607    }
608
609    async fn outgoing(&mut self, outgoing: OutgoingEvent) -> Result<(), NodeError> {
610        let (method, url) = self.get_outgoing_endpoint_based_on_event(&outgoing)?;
611        let payload = serde_json::to_string(&outgoing).expect("serialization cannot fail");
612
613        let authority = url.authority().expect("authority comes from endpoint");
614
615        let req = Request::builder()
616            .uri(url.borrow())
617            .method(method)
618            .header(header::HOST, authority.as_str())
619            .header(header::AUTHORIZATION, self.config.authorization.as_str())
620            .header(header::CONTENT_TYPE, "application/json")
621            .body(Full::from(payload))
622            .map_err(|source| NodeError {
623                kind: NodeErrorType::BuildingConnectionRequest,
624                source: Some(Box::new(source)),
625            })?;
626
627        self.lavalink_http
628            .request(req)
629            .await
630            .map_err(|source| NodeError {
631                kind: NodeErrorType::HttpRequestFailed,
632                source: Some(Box::new(source)),
633            })?;
634
635        Ok(())
636    }
637
638    async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
639        tracing::debug!(
640            "received message from {}: {incoming:?}",
641            self.config.address,
642        );
643
644        let text = if incoming.is_text() {
645            incoming.as_text().expect("message is text")
646        } else if incoming.is_close() {
647            tracing::debug!("got close, closing connection");
648
649            return Ok(false);
650        } else {
651            tracing::debug!("got ping, pong or binary payload: {incoming:?}");
652
653            return Ok(true);
654        };
655
656        let Ok(event) = serde_json::from_str(text) else {
657            tracing::warn!("unknown message from lavalink node: {text}");
658
659            return Ok(true);
660        };
661
662        match &event {
663            IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
664            IncomingEvent::Ready(ready) => {
665                self.lavalink_session_id = Some(ready.session_id.clone().into_boxed_str());
666            }
667            IncomingEvent::Stats(stats) => self.stats(stats).await?,
668            IncomingEvent::Event(_) => {}
669        }
670
671        // It's fine if the rx end dropped, often users don't need to care about
672        // these events.
673        if !self.node_to.is_closed() {
674            let _result = self.node_to.send(event);
675        }
676
677        Ok(true)
678    }
679
680    fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
681        let Some(player) = self.players.get(&update.guild_id) else {
682            tracing::warn!(
683                "invalid player update for guild {}: {update:?}",
684                update.guild_id,
685            );
686
687            return Ok(());
688        };
689
690        player.set_position(update.state.position);
691        player.set_time(update.state.time);
692
693        Ok(())
694    }
695
696    async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
697        *self.stats.lock().await = stats.clone();
698
699        Ok(())
700    }
701}
702
703impl Drop for Connection {
704    fn drop(&mut self) {
705        // Cleanup local players associated with the node
706        self.players
707            .players
708            .retain(|_, v| v.node().config().address != self.config.address);
709    }
710}
711
712const TWILIGHT_CLIENT_NAME: &str = concat!("twilight-lavalink/", env!("CARGO_PKG_VERSION"));
713
714fn connect_request(state: &NodeConfig) -> Result<ClientBuilder<'_>, NodeError> {
715    let websocket_protocol = if state.enable_tls { "wss" } else { "ws" };
716
717    let mut builder = ClientBuilder::new()
718        .uri(&format!(
719            "{}://{}/v4/websocket",
720            websocket_protocol, state.address
721        ))
722        .map_err(|source| NodeError {
723            kind: NodeErrorType::BuildingConnectionRequest,
724            source: Some(Box::new(source)),
725        })?
726        .add_header(
727            AUTHORIZATION,
728            state.authorization.parse().map_err(|source| NodeError {
729                kind: NodeErrorType::BuildingConnectionRequest,
730                source: Some(Box::new(source)),
731            })?,
732        )
733        .expect("Unable to crate authorization header")
734        .add_header(
735            HeaderName::from_static("user-id"),
736            state.user_id.get().into(),
737        )
738        .expect("Unable to add user-id")
739        .add_header(
740            HeaderName::from_static("client-name"),
741            HeaderValue::from_static(TWILIGHT_CLIENT_NAME),
742        )
743        .expect("Unable to crate builder");
744
745    if state.resume.is_some() {
746        builder = builder
747            .add_header(
748                HeaderName::from_static("resume-key"),
749                state
750                    .address
751                    .to_string()
752                    .parse()
753                    .map_err(|source| NodeError {
754                        kind: NodeErrorType::BuildingConnectionRequest,
755                        source: Some(Box::new(source)),
756                    })?,
757            )
758            .expect("Unable to build state header");
759    }
760
761    Ok(builder)
762}
763
764async fn reconnect(
765    config: &NodeConfig,
766) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
767    let (mut stream, res) = backoff(config).await?;
768
769    let headers = res.headers();
770
771    if let Some(resume) = config.resume.as_ref() {
772        let header = HeaderName::from_static("session-resumed");
773
774        if let Some(value) = headers.get(header) {
775            if value.as_bytes() == b"false" {
776                tracing::debug!("session to node {} didn't resume", config.address);
777
778                let payload = serde_json::json!({
779                    "op": "configureResuming",
780                    "key": config.address,
781                    "timeout": resume.timeout,
782                });
783                let msg = Message::text(
784                    serde_json::to_string(&payload).expect("serialize can't panic here"),
785                );
786
787                stream.send(msg).await.map_err(|source| NodeError {
788                    kind: NodeErrorType::Connecting,
789                    source: Some(Box::new(source)),
790                })?;
791            } else {
792                tracing::debug!("session to {} resumed", config.address);
793            }
794        }
795    }
796
797    Ok(stream)
798}
799
800async fn backoff(
801    config: &NodeConfig,
802) -> Result<
803    (
804        WebSocketStream<MaybeTlsStream<TcpStream>>,
805        upgrade::Response,
806    ),
807    NodeError,
808> {
809    let mut seconds = 1;
810
811    loop {
812        let request = connect_request(config)?;
813
814        match request.connect().await {
815            Ok((stream, response)) => return Ok((stream, response)),
816            Err(source) => {
817                tracing::warn!("failed to connect to node {source}: {:?}", config.address);
818
819                if matches!(
820                    &source,
821                    WebsocketError::Upgrade(upgrade::Error::DidNotSwitchProtocols(401))
822                ) {
823                    return Err(NodeError {
824                        kind: NodeErrorType::Unauthorized {
825                            address: config.address,
826                            authorization: config.authorization.clone(),
827                        },
828                        source: None,
829                    });
830                }
831
832                if seconds > 64 {
833                    tracing::debug!("no longer trying to connect to node {}", config.address);
834
835                    return Err(NodeError {
836                        kind: NodeErrorType::Connecting,
837                        source: Some(Box::new(source)),
838                    });
839                }
840
841                tracing::debug!(
842                    "waiting {seconds} seconds before attempting to connect to node {} again",
843                    config.address,
844                );
845                tokio_time::sleep(Duration::from_secs(seconds)).await;
846
847                seconds *= 2;
848            }
849        }
850    }
851}
852
853#[cfg(test)]
854mod tests {
855    use super::{Node, NodeConfig, NodeError, NodeErrorType, Resume};
856    use static_assertions::{assert_fields, assert_impl_all};
857    use std::{
858        error::Error,
859        fmt::Debug,
860        net::{Ipv4Addr, SocketAddr, SocketAddrV4},
861    };
862    use twilight_model::id::Id;
863
864    assert_fields!(NodeConfig: address, authorization, resume, user_id, enable_tls);
865    assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
866    assert_fields!(NodeErrorType::SerializingMessage: message);
867    assert_fields!(NodeErrorType::Unauthorized: address, authorization);
868    assert_impl_all!(NodeErrorType: Debug, Send, Sync);
869    assert_impl_all!(NodeError: Error, Send, Sync);
870    assert_impl_all!(Node: Debug, Send, Sync);
871    assert_fields!(Resume: timeout);
872    assert_impl_all!(Resume: Clone, Debug, Default, Eq, PartialEq, Send, Sync);
873
874    #[test]
875    fn node_config_debug() {
876        let config = NodeConfig {
877            address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1312)),
878            authorization: "some auth".to_owned(),
879            resume: None,
880            user_id: Id::new(123),
881            enable_tls: false,
882        };
883
884        assert!(format!("{config:?}").contains("authorization: <redacted>"));
885    }
886}