Skip to main content

twilight_lavalink/
node.rs

1//! Nodes for communicating with a Lavalink server.
2//!
3//! Using nodes, you can send events to a server and receive events.
4//!
5//! This is a bit more low level than using the [`Lavalink`] client because you
6//! will need to provide your own `VoiceUpdate` events when your bot joins
7//! channels, meaning you will have to accumulate and combine voice state update
8//! and voice server update events from the Discord gateway to send them to
9//! a node.
10//!
11//! Additionally, you will have to create and manage your own [`PlayerManager`]
12//! and make your own players for guilds when your bot joins voice channels.
13//!
14//! This can be a lot of work, and there's not really much reason to do it
15//! yourself. For that reason, you should almost always use the `Lavalink`
16//! client which does all of this for you.
17//!
18//! [`Lavalink`]: crate::client::Lavalink
19
20use crate::{
21    model::{IncomingEvent, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22    player::PlayerManager,
23};
24use futures_util::{
25    lock::BiLock,
26    stream::{Stream, StreamExt},
27};
28use http::header::{AUTHORIZATION, HeaderName, HeaderValue};
29use http_body_util::Full;
30use hyper::{Method, Request, Uri, body::Bytes, header};
31use hyper_util::{
32    client::legacy::{Client as HyperClient, connect::HttpConnector},
33    rt::TokioExecutor,
34};
35use std::{
36    borrow::Borrow,
37    error::Error,
38    fmt::{Debug, Display, Formatter, Result as FmtResult, Write as _},
39    net::SocketAddr,
40    pin::Pin,
41    task::{Context, Poll},
42    time::Duration,
43};
44use tokio::{
45    net::TcpStream,
46    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
47    time as tokio_time,
48};
49use tokio_websockets::{
50    ClientBuilder, Error as WebsocketError, MaybeTlsStream, Message, WebSocketStream, upgrade,
51};
52use twilight_model::id::{Id, marker::UserMarker};
53
54/// An error occurred while either initializing a connection or while running
55/// its event loop.
56#[derive(Debug)]
57pub struct NodeError {
58    kind: NodeErrorType,
59    source: Option<Box<dyn Error + Send + Sync>>,
60}
61
62impl NodeError {
63    /// Immutable reference to the type of error that occurred.
64    #[must_use = "retrieving the type has no effect if left unused"]
65    pub const fn kind(&self) -> &NodeErrorType {
66        &self.kind
67    }
68
69    /// Consume the error, returning the source error if there is any.
70    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
71    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
72        self.source
73    }
74
75    /// Consume the error, returning the owned error type and the source error.
76    #[must_use = "consuming the error into its parts has no effect if left unused"]
77    pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
78        (self.kind, self.source)
79    }
80}
81
82impl Display for NodeError {
83    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
84        match &self.kind {
85            NodeErrorType::BuildingConnectionRequest => {
86                f.write_str("failed to build connection request")
87            }
88            NodeErrorType::HttpRequestFailed => {
89                f.write_str("failed to send http request to lavalink server")
90            }
91            NodeErrorType::Connecting => f.write_str("Failed to connect to the node"),
92            NodeErrorType::OutgoingEventHasNoSession => {
93                f.write_str("no session id found for connection to lavalink api")
94            }
95            NodeErrorType::SerializingMessage { message: _ } => {
96                f.write_str("failed to serialize outgoing message as json")
97            }
98            NodeErrorType::Unauthorized { address, .. } => {
99                f.write_str("the authorization used to connect to node ")?;
100                Display::fmt(address, f)?;
101
102                f.write_str(" is invalid")
103            }
104        }
105    }
106}
107
108impl Error for NodeError {
109    fn source(&self) -> Option<&(dyn Error + 'static)> {
110        self.source
111            .as_ref()
112            .map(|source| &**source as &(dyn Error + 'static))
113    }
114}
115
116/// Type of [`NodeError`] that occurred.
117#[derive(Debug)]
118#[non_exhaustive]
119pub enum NodeErrorType {
120    /// Building the HTTP request to initialize a connection failed.
121    BuildingConnectionRequest,
122    /// Sending the HTTP request to Lavalink failed.
123    HttpRequestFailed,
124    /// Connecting to the Lavalink server failed after several backoff attempts.
125    Connecting,
126    /// There are potentially no valid session before trying to send outgoing
127    /// events. The session id is obtained in the startup sequence of the node.
128    /// If there is an attempt to send events before connecting, it will error out.
129    OutgoingEventHasNoSession,
130    /// Serializing a JSON message to be sent to a Lavalink node failed.
131    SerializingMessage {
132        /// The message that couldn't be serialized.
133        message: OutgoingEvent,
134    },
135    /// The given authorization for the node is incorrect.
136    Unauthorized {
137        /// The address of the node that failed to authorize.
138        address: SocketAddr,
139        /// The authorization used to connect to the node.
140        authorization: String,
141    },
142}
143
144/// An error that can occur while sending an event over a node.
145#[derive(Debug)]
146pub struct NodeSenderError {
147    kind: NodeSenderErrorType,
148    source: Option<Box<dyn Error + Send + Sync>>,
149}
150
151impl NodeSenderError {
152    /// Immutable reference to the type of error that occurred.
153    pub const fn kind(&self) -> &NodeSenderErrorType {
154        &self.kind
155    }
156
157    /// Consume the error, returning the source error if there is any.
158    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
159        self.source
160    }
161
162    /// Consume the error, returning the owned error type and the source error.
163    #[must_use = "consuming the error into its parts has no effect if left unused"]
164    pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
165        (self.kind, self.source)
166    }
167}
168
169impl Display for NodeSenderError {
170    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
171        match &self.kind {
172            NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
173        }
174    }
175}
176
177impl Error for NodeSenderError {
178    fn source(&self) -> Option<&(dyn Error + 'static)> {
179        self.source
180            .as_ref()
181            .map(|source| &**source as &(dyn Error + 'static))
182    }
183}
184
185/// Type of [`NodeSenderError`] that occurred.
186#[derive(Debug)]
187#[non_exhaustive]
188pub enum NodeSenderErrorType {
189    /// Error occurred while sending over the channel.
190    Sending,
191}
192
193/// Stream of incoming events from a node.
194pub struct IncomingEvents {
195    inner: UnboundedReceiver<IncomingEvent>,
196}
197
198impl IncomingEvents {
199    /// Closes the receiving half of a channel without dropping it.
200    pub fn close(&mut self) {
201        self.inner.close();
202    }
203}
204
205impl Stream for IncomingEvents {
206    type Item = IncomingEvent;
207
208    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
209        self.inner.poll_recv(cx)
210    }
211}
212
213/// Send outgoing events to the associated node.
214pub struct NodeSender {
215    inner: UnboundedSender<OutgoingEvent>,
216}
217
218impl NodeSender {
219    /// Returns whether this channel is closed without needing a context.
220    pub fn is_closed(&self) -> bool {
221        self.inner.is_closed()
222    }
223
224    /// Sends a message along this channel.
225    ///
226    /// This is an unbounded sender, so this function differs from `Sink::send`
227    /// by ensuring the return type reflects that the channel is always ready to
228    /// receive messages.
229    ///
230    /// # Errors
231    ///
232    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
233    /// longer connected.
234    pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
235        self.inner.send(msg).map_err(|source| NodeSenderError {
236            kind: NodeSenderErrorType::Sending,
237            source: Some(Box::new(source)),
238        })
239    }
240}
241
242/// The configuration that a [`Node`] uses to connect to a Lavalink server.
243#[derive(Clone, Eq, PartialEq)]
244#[non_exhaustive]
245// Keep fields in sync with its Debug implementation.
246pub struct NodeConfig {
247    /// The address of the node.
248    pub address: SocketAddr,
249    /// The password to use when authenticating.
250    pub authorization: String,
251    /// The user ID of the bot.
252    pub user_id: Id<UserMarker>,
253    /// Whether or not to enable TLS.
254    pub enable_tls: bool,
255    /// Optional session ID to resume an existing Lavalink session.
256    ///
257    /// If provided, the client will attempt to resume this session instead of creating a new one.
258    /// This is only applicable for Lavalink v4+.
259    pub session_id: Option<String>,
260}
261
262impl Debug for NodeConfig {
263    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
264        /// Debug as `<redacted>`. Necessary because debugging a struct field
265        /// with a value of of `"<redacted>"` will insert quotations in the
266        /// string, which doesn't align with other token debugs.
267        struct Redacted;
268
269        impl Debug for Redacted {
270            fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
271                f.write_str("<redacted>")
272            }
273        }
274
275        f.debug_struct("NodeConfig")
276            .field("address", &self.address)
277            .field("authorization", &Redacted)
278            .field("user_id", &self.user_id)
279            .field("enable_tls", &self.enable_tls)
280            .field("session_id", &self.session_id)
281            .finish()
282    }
283}
284
285impl NodeConfig {
286    /// Create a new configuration for connecting to a node via
287    /// [`Node::connect`].
288    ///
289    /// If adding a node through the [`Lavalink`] client then you don't need to
290    /// do this yourself.
291    ///
292    /// [`Lavalink`]: crate::client::Lavalink
293    pub fn new(
294        user_id: Id<UserMarker>,
295        address: impl Into<SocketAddr>,
296        authorization: impl Into<String>,
297        enable_tls: bool,
298    ) -> Self {
299        Self::_new(user_id, address.into(), authorization.into(), enable_tls)
300    }
301
302    const fn _new(
303        user_id: Id<UserMarker>,
304        address: SocketAddr,
305        authorization: String,
306        enable_tls: bool,
307    ) -> Self {
308        Self {
309            address,
310            authorization,
311            user_id,
312            enable_tls,
313            session_id: None,
314        }
315    }
316}
317
318/// A connection to a single Lavalink server. It receives events and forwards
319/// events from players to the server.
320///
321/// Please refer to the [module] documentation.
322///
323/// [module]: crate
324#[derive(Debug)]
325pub struct Node {
326    config: NodeConfig,
327    lavalink_tx: UnboundedSender<OutgoingEvent>,
328    players: PlayerManager,
329    stats: BiLock<Stats>,
330}
331
332impl Node {
333    /// Connect to a node, providing a player manager so that the node can
334    /// update player details.
335    ///
336    /// Please refer to the [module] documentation for some additional
337    /// information about directly creating and using nodes. You are encouraged
338    /// to use the [`Lavalink`] client instead.
339    ///
340    /// [`Lavalink`]: crate::client::Lavalink
341    /// [module]: crate
342    ///
343    /// # Errors
344    ///
345    /// Returns an error of type [`Connecting`] if the connection fails after
346    /// several backoff attempts.
347    ///
348    /// Returns an error of type [`BuildingConnectionRequest`] if the request
349    /// failed to build.
350    ///
351    /// Returns an error of type [`Unauthorized`] if the supplied authorization
352    /// is rejected by the node.
353    ///
354    /// [`Connecting`]: crate::node::NodeErrorType::Connecting
355    /// [`BuildingConnectionRequest`]: crate::node::NodeErrorType::BuildingConnectionRequest
356    /// [`Unauthorized`]: crate::node::NodeErrorType::Unauthorized
357    pub async fn connect(
358        config: NodeConfig,
359        players: PlayerManager,
360    ) -> Result<(Self, IncomingEvents), NodeError> {
361        let (bilock_left, bilock_right) = BiLock::new(Stats {
362            cpu: StatsCpu {
363                cores: 0,
364                lavalink_load: 0f64,
365                system_load: 0f64,
366            },
367            frame_stats: None,
368            memory: StatsMemory {
369                allocated: 0,
370                free: 0,
371                used: 0,
372                reservable: 0,
373            },
374            players: 0,
375            playing_players: 0,
376            uptime: 0,
377        });
378
379        tracing::debug!("starting connection to {}", config.address);
380
381        let (conn_loop, lavalink_tx, lavalink_rx) =
382            Connection::connect(config.clone(), players.clone(), bilock_right).await?;
383
384        tracing::debug!("started connection to {}", config.address);
385
386        tokio::spawn(conn_loop.run());
387
388        Ok((
389            Self {
390                config,
391                lavalink_tx,
392                players,
393                stats: bilock_left,
394            },
395            IncomingEvents { inner: lavalink_rx },
396        ))
397    }
398
399    /// Retrieve an immutable reference to the node's configuration.
400    pub const fn config(&self) -> &NodeConfig {
401        &self.config
402    }
403
404    /// Retrieve an immutable reference to the player manager used by the node.
405    pub const fn players(&self) -> &PlayerManager {
406        &self.players
407    }
408
409    /// Retrieve an immutable reference to the node's configuration.
410    ///
411    /// Note that sending player events through the node's sender won't update
412    /// player states, such as whether it's paused.
413    ///
414    /// # Errors
415    ///
416    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
417    /// longer connected.
418    pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
419        self.sender().send(event)
420    }
421
422    /// Retrieve a unique sender to send events to the Lavalink server.
423    ///
424    /// Note that sending player events through the node's sender won't update
425    /// player states, such as whether it's paused.
426    pub fn sender(&self) -> NodeSender {
427        NodeSender {
428            inner: self.lavalink_tx.clone(),
429        }
430    }
431
432    /// Retrieve a copy of the node's stats.
433    pub async fn stats(&self) -> Stats {
434        (*self.stats.lock().await).clone()
435    }
436
437    /// Retrieve the calculated penalty score of the node.
438    ///
439    /// This score can be used to calculate how loaded the server is. A higher
440    /// number means it is more heavily loaded.
441    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
442    pub async fn penalty(&self) -> i32 {
443        let stats = self.stats.lock().await;
444        let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
445
446        let (deficit_frame, null_frame) = (
447            1.03f64.powf(
448                500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64),
449            ) * 300f64
450                - 300f64,
451            (1.03f64.powf(
452                500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64),
453            ) * 300f64
454                - 300f64)
455                * 2f64,
456        );
457
458        stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
459    }
460}
461
462struct Connection {
463    config: NodeConfig,
464    stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
465    lavalink_http: HyperClient<HttpConnector, Full<Bytes>>,
466    node_from: UnboundedReceiver<OutgoingEvent>,
467    node_to: UnboundedSender<IncomingEvent>,
468    players: PlayerManager,
469    stats: BiLock<Stats>,
470    lavalink_session_id: Option<Box<str>>,
471}
472
473impl Connection {
474    async fn connect(
475        config: NodeConfig,
476        players: PlayerManager,
477        stats: BiLock<Stats>,
478    ) -> Result<
479        (
480            Self,
481            UnboundedSender<OutgoingEvent>,
482            UnboundedReceiver<IncomingEvent>,
483        ),
484        NodeError,
485    > {
486        let stream = reconnect(&config).await?;
487
488        let (to_node, from_lavalink) = mpsc::unbounded_channel();
489        let (to_lavalink, from_node) = mpsc::unbounded_channel();
490
491        let mut client_builder = HyperClient::builder(TokioExecutor::new());
492
493        if config.enable_tls {
494            client_builder.http2_only(config.enable_tls);
495        }
496
497        let lavalink_http = client_builder.build_http();
498
499        Ok((
500            Self {
501                config,
502                stream,
503                lavalink_http,
504                node_from: from_node,
505                node_to: to_node,
506                players,
507                stats,
508                lavalink_session_id: None,
509            },
510            to_lavalink,
511            from_lavalink,
512        ))
513    }
514
515    async fn run(mut self) -> Result<(), NodeError> {
516        loop {
517            tokio::select! {
518                incoming = self.stream.next() => {
519                    if let Some(Ok(incoming)) = incoming {
520                        self.incoming(incoming).await?;
521                    } else {
522                        tracing::debug!("connection to {} closed, reconnecting", self.config.address);
523                        self.stream = reconnect(&self.config).await?;
524                    }
525                }
526                outgoing = self.node_from.recv() => {
527                    if let Some(outgoing) = outgoing {
528                        self.outgoing(outgoing).await?;
529                    } else {
530                        tracing::debug!("node {} closed, ending connection", self.config.address);
531                        break;
532                    }
533                }
534            }
535        }
536
537        Ok(())
538    }
539
540    fn get_outgoing_endpoint_based_on_event(
541        &mut self,
542        outgoing: &OutgoingEvent,
543    ) -> Result<(Method, hyper::Uri), NodeError> {
544        let address = self.config.address;
545        tracing::debug!("forwarding event to {address}: {outgoing:?}");
546
547        let guild_id = outgoing.guild_id();
548        let no_replace = outgoing.no_replace();
549
550        if let Some(session) = &self.lavalink_session_id {
551            let mut path = format!("/v4/sessions/{session}/players/{guild_id}");
552            if !matches!(outgoing, OutgoingEvent::Destroy(_)) {
553                let _ = write!(path, "?noReplace={no_replace}");
554            }
555            let uri = Uri::builder()
556                .scheme("http")
557                .authority(address.to_string())
558                .path_and_query(path)
559                .build()
560                .expect("uri is valid");
561            return if matches!(outgoing, OutgoingEvent::Destroy(_)) {
562                Ok((Method::DELETE, uri))
563            } else {
564                Ok((Method::PATCH, uri))
565            };
566        }
567
568        tracing::error!("no session id is found");
569
570        Err(NodeError {
571            kind: NodeErrorType::OutgoingEventHasNoSession,
572            source: None,
573        })
574    }
575
576    async fn outgoing(&mut self, outgoing: OutgoingEvent) -> Result<(), NodeError> {
577        let (method, url) = self.get_outgoing_endpoint_based_on_event(&outgoing)?;
578        let payload = serde_json::to_string(&outgoing).expect("serialization cannot fail");
579
580        let authority = url.authority().expect("authority comes from endpoint");
581
582        let req = Request::builder()
583            .uri(url.borrow())
584            .method(method)
585            .header(header::HOST, authority.as_str())
586            .header(header::AUTHORIZATION, self.config.authorization.as_str())
587            .header(header::CONTENT_TYPE, "application/json")
588            .body(Full::from(payload))
589            .map_err(|source| NodeError {
590                kind: NodeErrorType::BuildingConnectionRequest,
591                source: Some(Box::new(source)),
592            })?;
593
594        self.lavalink_http
595            .request(req)
596            .await
597            .map_err(|source| NodeError {
598                kind: NodeErrorType::HttpRequestFailed,
599                source: Some(Box::new(source)),
600            })?;
601
602        Ok(())
603    }
604
605    async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
606        tracing::debug!(
607            "received message from {}: {incoming:?}",
608            self.config.address,
609        );
610
611        let text = if incoming.is_text() {
612            incoming.as_text().expect("message is text")
613        } else if incoming.is_close() {
614            tracing::debug!("got close, closing connection");
615
616            return Ok(false);
617        } else {
618            tracing::debug!("got ping, pong or binary payload: {incoming:?}");
619
620            return Ok(true);
621        };
622
623        let Ok(event) = serde_json::from_str(text) else {
624            tracing::warn!("unknown message from lavalink node: {text}");
625
626            return Ok(true);
627        };
628
629        match &event {
630            IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
631            IncomingEvent::Ready(ready) => {
632                self.lavalink_session_id = Some(ready.session_id.clone().into_boxed_str());
633            }
634            IncomingEvent::Stats(stats) => self.stats(stats).await?,
635            IncomingEvent::Event(_) => {}
636        }
637
638        // It's fine if the rx end dropped, often users don't need to care about
639        // these events.
640        if !self.node_to.is_closed() {
641            let _result = self.node_to.send(event);
642        }
643
644        Ok(true)
645    }
646
647    fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
648        let Some(player) = self.players.get(&update.guild_id) else {
649            tracing::warn!(
650                "invalid player update for guild {}: {update:?}",
651                update.guild_id,
652            );
653
654            return Ok(());
655        };
656
657        player.set_position(update.state.position);
658        player.set_time(update.state.time);
659
660        Ok(())
661    }
662
663    async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
664        *self.stats.lock().await = stats.clone();
665
666        Ok(())
667    }
668}
669
670impl Drop for Connection {
671    fn drop(&mut self) {
672        // Cleanup local players associated with the node
673        self.players
674            .players
675            .retain(|_, v| v.node().config().address != self.config.address);
676    }
677}
678
679const TWILIGHT_CLIENT_NAME: &str = concat!("twilight-lavalink/", env!("CARGO_PKG_VERSION"));
680
681fn connect_request(state: &NodeConfig) -> Result<ClientBuilder<'_>, NodeError> {
682    let websocket_protocol = if state.enable_tls { "wss" } else { "ws" };
683
684    let mut builder = ClientBuilder::new()
685        .uri(&format!(
686            "{}://{}/v4/websocket",
687            websocket_protocol, state.address
688        ))
689        .map_err(|source| NodeError {
690            kind: NodeErrorType::BuildingConnectionRequest,
691            source: Some(Box::new(source)),
692        })?
693        .add_header(
694            AUTHORIZATION,
695            state.authorization.parse().map_err(|source| NodeError {
696                kind: NodeErrorType::BuildingConnectionRequest,
697                source: Some(Box::new(source)),
698            })?,
699        )
700        .expect("Unable to create authorization header")
701        .add_header(
702            HeaderName::from_static("user-id"),
703            state.user_id.get().into(),
704        )
705        .expect("Unable to add user-id")
706        .add_header(
707            HeaderName::from_static("client-name"),
708            HeaderValue::from_static(TWILIGHT_CLIENT_NAME),
709        )
710        .expect("Unable to create builder");
711
712    // Add Session-Id header if we have a previous session to resume (Lavalink v4)
713    if let Some(session_id) = &state.session_id {
714        builder = builder
715            .add_header(
716                HeaderName::from_static("session-id"),
717                session_id.parse().map_err(|source| NodeError {
718                    kind: NodeErrorType::BuildingConnectionRequest,
719                    source: Some(Box::new(source)),
720                })?,
721            )
722            .expect("Unable to add Session-Id header");
723    }
724
725    Ok(builder)
726}
727
728/// Reconnect to a Lavalink node via exponential backoff.
729///
730/// If `config.session_id` is set, the reconnection will attempt to resume
731/// the existing session. If the server indicates the session was not resumed
732/// (for example, via a `Session-Resumed: false` header or by omitting the
733/// `Session-Resumed` header entirely), this function returns a
734/// [`Connecting`] error. The caller can retry with `session_id` set to
735/// `None` to create a fresh session.
736///
737/// Note: Session resume configuration (timeout, etc.) is done via the
738/// Lavalink REST API `PATCH /v4/sessions/{sessionId}`. This should be
739/// called after receiving the `Ready` event with the new session ID.
740///
741/// [`Connecting`]: NodeErrorType::Connecting
742async fn reconnect(
743    config: &NodeConfig,
744) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
745    let (stream, res) = backoff(config).await?;
746
747    let headers = res.headers();
748
749    // Check if session was resumed via response header (Lavalink v4)
750    let header = HeaderName::from_static("session-resumed");
751    if let Some(value) = headers.get(header) {
752        if value.as_bytes() == b"true" {
753            tracing::info!(
754                "Successfully resumed Lavalink session for node {}",
755                config.address
756            );
757        } else if config.session_id.is_some() {
758            tracing::warn!(
759                "Failed to resume Lavalink session for node {} (session not resumed)",
760                config.address,
761            );
762            return Err(NodeError {
763                kind: NodeErrorType::Connecting,
764                source: None,
765            });
766        } else {
767            tracing::debug!("New Lavalink session created for node {}", config.address);
768        }
769    } else if config.session_id.is_some() {
770        tracing::warn!(
771            "Session-Resumed header not present for node {}; resume may have failed",
772            config.address,
773        );
774        return Err(NodeError {
775            kind: NodeErrorType::Connecting,
776            source: None,
777        });
778    }
779
780    Ok(stream)
781}
782
783async fn backoff(
784    config: &NodeConfig,
785) -> Result<
786    (
787        WebSocketStream<MaybeTlsStream<TcpStream>>,
788        upgrade::Response,
789    ),
790    NodeError,
791> {
792    let mut seconds = 1;
793
794    loop {
795        let request = connect_request(config)?;
796
797        match request.connect().await {
798            Ok((stream, response)) => return Ok((stream, response)),
799            Err(source) => {
800                tracing::warn!("failed to connect to node {source}: {:?}", config.address);
801
802                if matches!(
803                    &source,
804                    WebsocketError::Upgrade(upgrade::Error::DidNotSwitchProtocols(401))
805                ) {
806                    return Err(NodeError {
807                        kind: NodeErrorType::Unauthorized {
808                            address: config.address,
809                            authorization: config.authorization.clone(),
810                        },
811                        source: None,
812                    });
813                }
814
815                if seconds > 64 {
816                    tracing::debug!("no longer trying to connect to node {}", config.address);
817
818                    return Err(NodeError {
819                        kind: NodeErrorType::Connecting,
820                        source: Some(Box::new(source)),
821                    });
822                }
823
824                tracing::debug!(
825                    "waiting {seconds} seconds before attempting to connect to node {} again",
826                    config.address,
827                );
828                tokio_time::sleep(Duration::from_secs(seconds)).await;
829
830                seconds *= 2;
831            }
832        }
833    }
834}
835
836#[cfg(test)]
837mod tests {
838    use super::{Node, NodeConfig, NodeError, NodeErrorType};
839    use static_assertions::{assert_fields, assert_impl_all};
840    use std::{
841        error::Error,
842        fmt::Debug,
843        net::{Ipv4Addr, SocketAddr, SocketAddrV4},
844    };
845    use twilight_model::id::Id;
846
847    assert_fields!(NodeConfig: address, authorization, user_id, enable_tls, session_id);
848    assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
849    assert_fields!(NodeErrorType::SerializingMessage: message);
850    assert_fields!(NodeErrorType::Unauthorized: address, authorization);
851    assert_impl_all!(NodeErrorType: Debug, Send, Sync);
852    assert_impl_all!(NodeError: Error, Send, Sync);
853    assert_impl_all!(Node: Debug, Send, Sync);
854
855    #[test]
856    fn node_config_debug() {
857        let config = NodeConfig {
858            address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1312)),
859            authorization: "some auth".to_owned(),
860            user_id: Id::new(123),
861            enable_tls: false,
862            session_id: None,
863        };
864
865        assert!(format!("{config:?}").contains("authorization: <redacted>"));
866    }
867}