twilight_lavalink/
node.rs

1//! Nodes for communicating with a Lavalink server.
2//!
3//! Using nodes, you can send events to a server and receive events.
4//!
5//! This is a bit more low level than using the [`Lavalink`] client because you
6//! will need to provide your own `VoiceUpdate` events when your bot joins
7//! channels, meaning you will have to accumulate and combine voice state update
8//! and voice server update events from the Discord gateway to send them to
9//! a node.
10//!
11//! Additionally, you will have to create and manage your own [`PlayerManager`]
12//! and make your own players for guilds when your bot joins voice channels.
13//!
14//! This can be a lot of work, and there's not really much reason to do it
15//! yourself. For that reason, you should almost always use the `Lavalink`
16//! client which does all of this for you.
17//!
18//! [`Lavalink`]: crate::client::Lavalink
19
20use crate::{
21    model::{IncomingEvent, Opcode, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22    player::PlayerManager,
23};
24use futures_util::{
25    lock::BiLock,
26    sink::SinkExt,
27    stream::{Stream, StreamExt},
28};
29use http::header::{HeaderName, AUTHORIZATION};
30use std::{
31    error::Error,
32    fmt::{Debug, Display, Formatter, Result as FmtResult},
33    net::SocketAddr,
34    pin::Pin,
35    task::{Context, Poll},
36    time::Duration,
37};
38use tokio::{
39    net::TcpStream,
40    sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
41    time as tokio_time,
42};
43use tokio_websockets::{
44    upgrade, ClientBuilder, Error as WebsocketError, MaybeTlsStream, Message, WebSocketStream,
45};
46use twilight_model::id::{marker::UserMarker, Id};
47
48/// An error occurred while either initializing a connection or while running
49/// its event loop.
50#[derive(Debug)]
51pub struct NodeError {
52    kind: NodeErrorType,
53    source: Option<Box<dyn Error + Send + Sync>>,
54}
55
56impl NodeError {
57    /// Immutable reference to the type of error that occurred.
58    #[must_use = "retrieving the type has no effect if left unused"]
59    pub const fn kind(&self) -> &NodeErrorType {
60        &self.kind
61    }
62
63    /// Consume the error, returning the source error if there is any.
64    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
65    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
66        self.source
67    }
68
69    /// Consume the error, returning the owned error type and the source error.
70    #[must_use = "consuming the error into its parts has no effect if left unused"]
71    pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
72        (self.kind, self.source)
73    }
74}
75
76impl Display for NodeError {
77    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
78        match &self.kind {
79            NodeErrorType::BuildingConnectionRequest { .. } => {
80                f.write_str("failed to build connection request")
81            }
82            NodeErrorType::Connecting { .. } => f.write_str("Failed to connect to the node"),
83            NodeErrorType::SerializingMessage { .. } => {
84                f.write_str("failed to serialize outgoing message as json")
85            }
86            NodeErrorType::Unauthorized { address, .. } => {
87                f.write_str("the authorization used to connect to node ")?;
88                Display::fmt(address, f)?;
89
90                f.write_str(" is invalid")
91            }
92        }
93    }
94}
95
96impl Error for NodeError {
97    fn source(&self) -> Option<&(dyn Error + 'static)> {
98        self.source
99            .as_ref()
100            .map(|source| &**source as &(dyn Error + 'static))
101    }
102}
103
104/// Type of [`NodeError`] that occurred.
105#[derive(Debug)]
106#[non_exhaustive]
107pub enum NodeErrorType {
108    /// Building the HTTP request to initialize a connection failed.
109    BuildingConnectionRequest,
110    /// Connecting to the Lavalink server failed after several backoff attempts.
111    Connecting,
112    /// Serializing a JSON message to be sent to a Lavalink node failed.
113    SerializingMessage {
114        /// The message that couldn't be serialized.
115        message: OutgoingEvent,
116    },
117    /// The given authorization for the node is incorrect.
118    Unauthorized {
119        /// The address of the node that failed to authorize.
120        address: SocketAddr,
121        /// The authorization used to connect to the node.
122        authorization: String,
123    },
124}
125
126/// An error that can occur while sending an event over a node.
127#[derive(Debug)]
128pub struct NodeSenderError {
129    kind: NodeSenderErrorType,
130    source: Option<Box<dyn Error + Send + Sync>>,
131}
132
133impl NodeSenderError {
134    /// Immutable reference to the type of error that occurred.
135    pub const fn kind(&self) -> &NodeSenderErrorType {
136        &self.kind
137    }
138
139    /// Consume the error, returning the source error if there is any.
140    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
141        self.source
142    }
143
144    /// Consume the error, returning the owned error type and the source error.
145    #[must_use = "consuming the error into its parts has no effect if left unused"]
146    pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
147        (self.kind, self.source)
148    }
149}
150
151impl Display for NodeSenderError {
152    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
153        match &self.kind {
154            NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
155        }
156    }
157}
158
159impl Error for NodeSenderError {
160    fn source(&self) -> Option<&(dyn Error + 'static)> {
161        self.source
162            .as_ref()
163            .map(|source| &**source as &(dyn Error + 'static))
164    }
165}
166
167/// Type of [`NodeSenderError`] that occurred.
168#[derive(Debug)]
169#[non_exhaustive]
170pub enum NodeSenderErrorType {
171    /// Error occurred while sending over the channel.
172    Sending,
173}
174
175/// Stream of incoming events from a node.
176pub struct IncomingEvents {
177    inner: UnboundedReceiver<IncomingEvent>,
178}
179
180impl IncomingEvents {
181    /// Closes the receiving half of a channel without dropping it.
182    pub fn close(&mut self) {
183        self.inner.close();
184    }
185}
186
187impl Stream for IncomingEvents {
188    type Item = IncomingEvent;
189
190    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
191        self.inner.poll_recv(cx)
192    }
193}
194
195/// Send outgoing events to the associated node.
196pub struct NodeSender {
197    inner: UnboundedSender<OutgoingEvent>,
198}
199
200impl NodeSender {
201    /// Returns whether this channel is closed without needing a context.
202    pub fn is_closed(&self) -> bool {
203        self.inner.is_closed()
204    }
205
206    /// Sends a message along this channel.
207    ///
208    /// This is an unbounded sender, so this function differs from `Sink::send`
209    /// by ensuring the return type reflects that the channel is always ready to
210    /// receive messages.
211    ///
212    /// # Errors
213    ///
214    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
215    /// longer connected.
216    pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
217        self.inner.send(msg).map_err(|source| NodeSenderError {
218            kind: NodeSenderErrorType::Sending,
219            source: Some(Box::new(source)),
220        })
221    }
222}
223
224/// The configuration that a [`Node`] uses to connect to a Lavalink server.
225#[derive(Clone, Eq, PartialEq)]
226#[non_exhaustive]
227// Keep fields in sync with its Debug implementation.
228pub struct NodeConfig {
229    /// The address of the node.
230    pub address: SocketAddr,
231    /// The password to use when authenticating.
232    pub authorization: String,
233    /// The details for resuming a Lavalink session, if any.
234    ///
235    /// Set this to `None` to disable resume capability.
236    pub resume: Option<Resume>,
237    /// The user ID of the bot.
238    pub user_id: Id<UserMarker>,
239}
240
241impl Debug for NodeConfig {
242    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
243        /// Debug as `<redacted>`. Necessary because debugging a struct field
244        /// with a value of of `"<redacted>"` will insert quotations in the
245        /// string, which doesn't align with other token debugs.
246        struct Redacted;
247
248        impl Debug for Redacted {
249            fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
250                f.write_str("<redacted>")
251            }
252        }
253
254        f.debug_struct("NodeConfig")
255            .field("address", &self.address)
256            .field("authorization", &Redacted)
257            .field("resume", &self.resume)
258            .field("user_id", &self.user_id)
259            .finish()
260    }
261}
262
263/// Configuration for a session which can be resumed.
264#[derive(Clone, Debug, Eq, PartialEq)]
265#[non_exhaustive]
266pub struct Resume {
267    /// The number of seconds that the Lavalink server will allow the session to
268    /// be resumed for after a disconnect.
269    ///
270    /// The default is 60.
271    pub timeout: u64,
272}
273
274impl Resume {
275    /// Configure resume capability, providing the number of seconds that the
276    /// Lavalink server should queue events for when the connection is resumed.
277    pub const fn new(seconds: u64) -> Self {
278        Self { timeout: seconds }
279    }
280}
281
282impl Default for Resume {
283    fn default() -> Self {
284        Self { timeout: 60 }
285    }
286}
287
288impl NodeConfig {
289    /// Create a new configuration for connecting to a node via
290    /// [`Node::connect`].
291    ///
292    /// If adding a node through the [`Lavalink`] client then you don't need to
293    /// do this yourself.
294    ///
295    /// [`Lavalink`]: crate::client::Lavalink
296    pub fn new(
297        user_id: Id<UserMarker>,
298        address: impl Into<SocketAddr>,
299        authorization: impl Into<String>,
300        resume: impl Into<Option<Resume>>,
301    ) -> Self {
302        Self::_new(user_id, address.into(), authorization.into(), resume.into())
303    }
304
305    const fn _new(
306        user_id: Id<UserMarker>,
307        address: SocketAddr,
308        authorization: String,
309        resume: Option<Resume>,
310    ) -> Self {
311        Self {
312            address,
313            authorization,
314            resume,
315            user_id,
316        }
317    }
318}
319
320/// A connection to a single Lavalink server. It receives events and forwards
321/// events from players to the server.
322///
323/// Please refer to the [module] documentation.
324///
325/// [module]: crate
326#[derive(Debug)]
327pub struct Node {
328    config: NodeConfig,
329    lavalink_tx: UnboundedSender<OutgoingEvent>,
330    players: PlayerManager,
331    stats: BiLock<Stats>,
332}
333
334impl Node {
335    /// Connect to a node, providing a player manager so that the node can
336    /// update player details.
337    ///
338    /// Please refer to the [module] documentation for some additional
339    /// information about directly creating and using nodes. You are encouraged
340    /// to use the [`Lavalink`] client instead.
341    ///
342    /// [`Lavalink`]: crate::client::Lavalink
343    /// [module]: crate
344    ///
345    /// # Errors
346    ///
347    /// Returns an error of type [`Connecting`] if the connection fails after
348    /// several backoff attempts.
349    ///
350    /// Returns an error of type [`BuildingConnectionRequest`] if the request
351    /// failed to build.
352    ///
353    /// Returns an error of type [`Unauthorized`] if the supplied authorization
354    /// is rejected by the node.
355    ///
356    /// [`Connecting`]: crate::node::NodeErrorType::Connecting
357    /// [`BuildingConnectionRequest`]: crate::node::NodeErrorType::BuildingConnectionRequest
358    /// [`Unauthorized`]: crate::node::NodeErrorType::Unauthorized
359    pub async fn connect(
360        config: NodeConfig,
361        players: PlayerManager,
362    ) -> Result<(Self, IncomingEvents), NodeError> {
363        let (bilock_left, bilock_right) = BiLock::new(Stats {
364            cpu: StatsCpu {
365                cores: 0,
366                lavalink_load: 0f64,
367                system_load: 0f64,
368            },
369            frames: None,
370            memory: StatsMemory {
371                allocated: 0,
372                free: 0,
373                used: 0,
374                reservable: 0,
375            },
376            players: 0,
377            playing_players: 0,
378            op: Opcode::Stats,
379            uptime: 0,
380        });
381
382        tracing::debug!("starting connection to {}", config.address);
383
384        let (conn_loop, lavalink_tx, lavalink_rx) =
385            Connection::connect(config.clone(), players.clone(), bilock_right).await?;
386
387        tracing::debug!("started connection to {}", config.address);
388
389        tokio::spawn(conn_loop.run());
390
391        Ok((
392            Self {
393                config,
394                lavalink_tx,
395                players,
396                stats: bilock_left,
397            },
398            IncomingEvents { inner: lavalink_rx },
399        ))
400    }
401
402    /// Retrieve an immutable reference to the node's configuration.
403    pub const fn config(&self) -> &NodeConfig {
404        &self.config
405    }
406
407    /// Retrieve an immutable reference to the player manager used by the node.
408    pub const fn players(&self) -> &PlayerManager {
409        &self.players
410    }
411
412    /// Retrieve an immutable reference to the node's configuration.
413    ///
414    /// Note that sending player events through the node's sender won't update
415    /// player states, such as whether it's paused.
416    ///
417    /// # Errors
418    ///
419    /// Returns a [`NodeSenderErrorType::Sending`] error type if node is no
420    /// longer connected.
421    pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
422        self.sender().send(event)
423    }
424
425    /// Retrieve a unique sender to send events to the Lavalink server.
426    ///
427    /// Note that sending player events through the node's sender won't update
428    /// player states, such as whether it's paused.
429    pub fn sender(&self) -> NodeSender {
430        NodeSender {
431            inner: self.lavalink_tx.clone(),
432        }
433    }
434
435    /// Retrieve a copy of the node's stats.
436    pub async fn stats(&self) -> Stats {
437        (*self.stats.lock().await).clone()
438    }
439
440    /// Retrieve the calculated penalty score of the node.
441    ///
442    /// This score can be used to calculate how loaded the server is. A higher
443    /// number means it is more heavily loaded.
444    #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
445    pub async fn penalty(&self) -> i32 {
446        let stats = self.stats.lock().await;
447        let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
448
449        let (deficit_frame, null_frame) = (
450            1.03f64
451                .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64))
452                * 300f64
453                - 300f64,
454            (1.03f64
455                .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64))
456                * 300f64
457                - 300f64)
458                * 2f64,
459        );
460
461        stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
462    }
463}
464
465struct Connection {
466    config: NodeConfig,
467    stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
468    node_from: UnboundedReceiver<OutgoingEvent>,
469    node_to: UnboundedSender<IncomingEvent>,
470    players: PlayerManager,
471    stats: BiLock<Stats>,
472}
473
474impl Connection {
475    async fn connect(
476        config: NodeConfig,
477        players: PlayerManager,
478        stats: BiLock<Stats>,
479    ) -> Result<
480        (
481            Self,
482            UnboundedSender<OutgoingEvent>,
483            UnboundedReceiver<IncomingEvent>,
484        ),
485        NodeError,
486    > {
487        let stream = reconnect(&config).await?;
488
489        let (to_node, from_lavalink) = mpsc::unbounded_channel();
490        let (to_lavalink, from_node) = mpsc::unbounded_channel();
491
492        Ok((
493            Self {
494                config,
495                stream,
496                node_from: from_node,
497                node_to: to_node,
498                players,
499                stats,
500            },
501            to_lavalink,
502            from_lavalink,
503        ))
504    }
505
506    async fn run(mut self) -> Result<(), NodeError> {
507        loop {
508            tokio::select! {
509                incoming = self.stream.next() => {
510                    if let Some(Ok(incoming)) = incoming {
511                        self.incoming(incoming).await?;
512                    } else {
513                        tracing::debug!("connection to {} closed, reconnecting", self.config.address);
514                        self.stream = reconnect(&self.config).await?;
515                    }
516                }
517                outgoing = self.node_from.recv() => {
518                    if let Some(outgoing) = outgoing {
519                        tracing::debug!(
520                            "forwarding event to {}: {outgoing:?}",
521                            self.config.address,
522                        );
523
524                        let payload = serde_json::to_string(&outgoing).map_err(|source| NodeError {
525                            kind: NodeErrorType::SerializingMessage { message: outgoing },
526                            source: Some(Box::new(source)),
527                        })?;
528                        let msg = Message::text(payload);
529                        self.stream.send(msg).await.unwrap();
530                    } else {
531                        tracing::debug!("node {} closed, ending connection", self.config.address);
532
533                        break;
534                    }
535                }
536            }
537        }
538
539        Ok(())
540    }
541
542    async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
543        tracing::debug!(
544            "received message from {}: {incoming:?}",
545            self.config.address,
546        );
547
548        let text = if incoming.is_text() {
549            incoming.as_text().expect("message is text")
550        } else if incoming.is_close() {
551            tracing::debug!("got close, closing connection");
552
553            return Ok(false);
554        } else {
555            tracing::debug!("got ping, pong or binary payload: {incoming:?}");
556
557            return Ok(true);
558        };
559
560        let Ok(event) = serde_json::from_str(text) else {
561            tracing::warn!("unknown message from lavalink node: {text}");
562
563            return Ok(true);
564        };
565
566        match &event {
567            IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
568            IncomingEvent::Stats(stats) => self.stats(stats).await?,
569            _ => {}
570        }
571
572        // It's fine if the rx end dropped, often users don't need to care about
573        // these events.
574        if !self.node_to.is_closed() {
575            let _result = self.node_to.send(event);
576        }
577
578        Ok(true)
579    }
580
581    fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
582        let Some(player) = self.players.get(&update.guild_id) else {
583            tracing::warn!(
584                "invalid player update for guild {}: {update:?}",
585                update.guild_id,
586            );
587
588            return Ok(());
589        };
590
591        player.set_position(update.state.position.unwrap_or(0));
592        player.set_time(update.state.time);
593
594        Ok(())
595    }
596
597    async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
598        *self.stats.lock().await = stats.clone();
599
600        Ok(())
601    }
602}
603
604impl Drop for Connection {
605    fn drop(&mut self) {
606        // Cleanup local players associated with the node
607        self.players
608            .players
609            .retain(|_, v| v.node().config().address != self.config.address);
610    }
611}
612
613fn connect_request(state: &NodeConfig) -> Result<ClientBuilder, NodeError> {
614    let mut builder = ClientBuilder::new()
615        .uri(&format!("ws://{}", state.address))
616        .map_err(|source| NodeError {
617            kind: NodeErrorType::BuildingConnectionRequest,
618            source: Some(Box::new(source)),
619        })?
620        .add_header(AUTHORIZATION, state.authorization.parse().unwrap())
621        .expect("allowed header")
622        .add_header(
623            HeaderName::from_static("user-id"),
624            state.user_id.get().into(),
625        )
626        .expect("allowed header");
627
628    if state.resume.is_some() {
629        builder = builder
630            .add_header(
631                HeaderName::from_static("resume-key"),
632                state.address.to_string().parse().unwrap(),
633            )
634            .expect("allowed header");
635    }
636
637    Ok(builder)
638}
639
640async fn reconnect(
641    config: &NodeConfig,
642) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
643    let (mut stream, res) = backoff(config).await?;
644
645    let headers = res.headers();
646
647    if let Some(resume) = config.resume.as_ref() {
648        let header = HeaderName::from_static("session-resumed");
649
650        if let Some(value) = headers.get(header) {
651            if value.as_bytes() == b"false" {
652                tracing::debug!("session to node {} didn't resume", config.address);
653
654                let payload = serde_json::json!({
655                    "op": "configureResuming",
656                    "key": config.address,
657                    "timeout": resume.timeout,
658                });
659                let msg = Message::text(serde_json::to_string(&payload).unwrap());
660
661                stream.send(msg).await.unwrap();
662            } else {
663                tracing::debug!("session to {} resumed", config.address);
664            }
665        }
666    }
667
668    Ok(stream)
669}
670
671async fn backoff(
672    config: &NodeConfig,
673) -> Result<
674    (
675        WebSocketStream<MaybeTlsStream<TcpStream>>,
676        upgrade::Response,
677    ),
678    NodeError,
679> {
680    let mut seconds = 1;
681
682    loop {
683        let request = connect_request(config)?;
684
685        match request.connect().await {
686            Ok((stream, response)) => return Ok((stream, response)),
687            Err(source) => {
688                tracing::warn!("failed to connect to node {source}: {:?}", config.address);
689
690                if matches!(
691                    &source,
692                    WebsocketError::Upgrade(upgrade::Error::DidNotSwitchProtocols(401))
693                ) {
694                    return Err(NodeError {
695                        kind: NodeErrorType::Unauthorized {
696                            address: config.address,
697                            authorization: config.authorization.clone(),
698                        },
699                        source: None,
700                    });
701                }
702
703                if seconds > 64 {
704                    tracing::debug!("no longer trying to connect to node {}", config.address);
705
706                    return Err(NodeError {
707                        kind: NodeErrorType::Connecting,
708                        source: Some(Box::new(source)),
709                    });
710                }
711
712                tracing::debug!(
713                    "waiting {seconds} seconds before attempting to connect to node {} again",
714                    config.address,
715                );
716                tokio_time::sleep(Duration::from_secs(seconds)).await;
717
718                seconds *= 2;
719
720                continue;
721            }
722        }
723    }
724}
725
726#[cfg(test)]
727mod tests {
728    use super::{Node, NodeConfig, NodeError, NodeErrorType, Resume};
729    use static_assertions::{assert_fields, assert_impl_all};
730    use std::{
731        error::Error,
732        fmt::Debug,
733        net::{Ipv4Addr, SocketAddr, SocketAddrV4},
734    };
735    use twilight_model::id::Id;
736
737    assert_fields!(NodeConfig: address, authorization, resume, user_id);
738    assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
739    assert_fields!(NodeErrorType::SerializingMessage: message);
740    assert_fields!(NodeErrorType::Unauthorized: address, authorization);
741    assert_impl_all!(NodeErrorType: Debug, Send, Sync);
742    assert_impl_all!(NodeError: Error, Send, Sync);
743    assert_impl_all!(Node: Debug, Send, Sync);
744    assert_fields!(Resume: timeout);
745    assert_impl_all!(Resume: Clone, Debug, Default, Eq, PartialEq, Send, Sync);
746
747    #[test]
748    fn node_config_debug() {
749        let config = NodeConfig {
750            address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1312)),
751            authorization: "some auth".to_owned(),
752            resume: None,
753            user_id: Id::new(123),
754        };
755
756        assert!(format!("{config:?}").contains("authorization: <redacted>"));
757    }
758}