1use crate::{
21 model::{IncomingEvent, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22 player::PlayerManager,
23};
24use futures_util::{
25 lock::BiLock,
26 stream::{Stream, StreamExt},
27};
28use http::header::{AUTHORIZATION, HeaderName, HeaderValue};
29use http_body_util::Full;
30use hyper::{Method, Request, Uri, body::Bytes, header};
31use hyper_util::{
32 client::legacy::{Client as HyperClient, connect::HttpConnector},
33 rt::TokioExecutor,
34};
35use std::{
36 borrow::Borrow,
37 error::Error,
38 fmt::{Debug, Display, Formatter, Result as FmtResult, Write as _},
39 net::SocketAddr,
40 pin::Pin,
41 task::{Context, Poll},
42 time::Duration,
43};
44use tokio::{
45 net::TcpStream,
46 sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
47 time as tokio_time,
48};
49use tokio_websockets::{
50 ClientBuilder, Error as WebsocketError, MaybeTlsStream, Message, WebSocketStream, upgrade,
51};
52use twilight_model::id::{Id, marker::UserMarker};
53
54#[derive(Debug)]
57pub struct NodeError {
58 kind: NodeErrorType,
59 source: Option<Box<dyn Error + Send + Sync>>,
60}
61
62impl NodeError {
63 #[must_use = "retrieving the type has no effect if left unused"]
65 pub const fn kind(&self) -> &NodeErrorType {
66 &self.kind
67 }
68
69 #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
71 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
72 self.source
73 }
74
75 #[must_use = "consuming the error into its parts has no effect if left unused"]
77 pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
78 (self.kind, self.source)
79 }
80}
81
82impl Display for NodeError {
83 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
84 match &self.kind {
85 NodeErrorType::BuildingConnectionRequest => {
86 f.write_str("failed to build connection request")
87 }
88 NodeErrorType::HttpRequestFailed => {
89 f.write_str("failed to send http request to lavalink server")
90 }
91 NodeErrorType::Connecting => f.write_str("Failed to connect to the node"),
92 NodeErrorType::OutgoingEventHasNoSession => {
93 f.write_str("no session id found for connection to lavalink api")
94 }
95 NodeErrorType::SerializingMessage { message: _ } => {
96 f.write_str("failed to serialize outgoing message as json")
97 }
98 NodeErrorType::Unauthorized { address, .. } => {
99 f.write_str("the authorization used to connect to node ")?;
100 Display::fmt(address, f)?;
101
102 f.write_str(" is invalid")
103 }
104 }
105 }
106}
107
108impl Error for NodeError {
109 fn source(&self) -> Option<&(dyn Error + 'static)> {
110 self.source
111 .as_ref()
112 .map(|source| &**source as &(dyn Error + 'static))
113 }
114}
115
116#[derive(Debug)]
118#[non_exhaustive]
119pub enum NodeErrorType {
120 BuildingConnectionRequest,
122 HttpRequestFailed,
124 Connecting,
126 OutgoingEventHasNoSession,
130 SerializingMessage {
132 message: OutgoingEvent,
134 },
135 Unauthorized {
137 address: SocketAddr,
139 authorization: String,
141 },
142}
143
144#[derive(Debug)]
146pub struct NodeSenderError {
147 kind: NodeSenderErrorType,
148 source: Option<Box<dyn Error + Send + Sync>>,
149}
150
151impl NodeSenderError {
152 pub const fn kind(&self) -> &NodeSenderErrorType {
154 &self.kind
155 }
156
157 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
159 self.source
160 }
161
162 #[must_use = "consuming the error into its parts has no effect if left unused"]
164 pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
165 (self.kind, self.source)
166 }
167}
168
169impl Display for NodeSenderError {
170 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
171 match &self.kind {
172 NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
173 }
174 }
175}
176
177impl Error for NodeSenderError {
178 fn source(&self) -> Option<&(dyn Error + 'static)> {
179 self.source
180 .as_ref()
181 .map(|source| &**source as &(dyn Error + 'static))
182 }
183}
184
185#[derive(Debug)]
187#[non_exhaustive]
188pub enum NodeSenderErrorType {
189 Sending,
191}
192
193pub struct IncomingEvents {
195 inner: UnboundedReceiver<IncomingEvent>,
196}
197
198impl IncomingEvents {
199 pub fn close(&mut self) {
201 self.inner.close();
202 }
203}
204
205impl Stream for IncomingEvents {
206 type Item = IncomingEvent;
207
208 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
209 self.inner.poll_recv(cx)
210 }
211}
212
213pub struct NodeSender {
215 inner: UnboundedSender<OutgoingEvent>,
216}
217
218impl NodeSender {
219 pub fn is_closed(&self) -> bool {
221 self.inner.is_closed()
222 }
223
224 pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
235 self.inner.send(msg).map_err(|source| NodeSenderError {
236 kind: NodeSenderErrorType::Sending,
237 source: Some(Box::new(source)),
238 })
239 }
240}
241
242#[derive(Clone, Eq, PartialEq)]
244#[non_exhaustive]
245pub struct NodeConfig {
247 pub address: SocketAddr,
249 pub authorization: String,
251 pub user_id: Id<UserMarker>,
253 pub enable_tls: bool,
255 pub session_id: Option<String>,
260}
261
262impl Debug for NodeConfig {
263 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
264 struct Redacted;
268
269 impl Debug for Redacted {
270 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
271 f.write_str("<redacted>")
272 }
273 }
274
275 f.debug_struct("NodeConfig")
276 .field("address", &self.address)
277 .field("authorization", &Redacted)
278 .field("user_id", &self.user_id)
279 .field("enable_tls", &self.enable_tls)
280 .field("session_id", &self.session_id)
281 .finish()
282 }
283}
284
285impl NodeConfig {
286 pub fn new(
294 user_id: Id<UserMarker>,
295 address: impl Into<SocketAddr>,
296 authorization: impl Into<String>,
297 enable_tls: bool,
298 ) -> Self {
299 Self::_new(user_id, address.into(), authorization.into(), enable_tls)
300 }
301
302 const fn _new(
303 user_id: Id<UserMarker>,
304 address: SocketAddr,
305 authorization: String,
306 enable_tls: bool,
307 ) -> Self {
308 Self {
309 address,
310 authorization,
311 user_id,
312 enable_tls,
313 session_id: None,
314 }
315 }
316}
317
318#[derive(Debug)]
325pub struct Node {
326 config: NodeConfig,
327 lavalink_tx: UnboundedSender<OutgoingEvent>,
328 players: PlayerManager,
329 stats: BiLock<Stats>,
330}
331
332impl Node {
333 pub async fn connect(
358 config: NodeConfig,
359 players: PlayerManager,
360 ) -> Result<(Self, IncomingEvents), NodeError> {
361 let (bilock_left, bilock_right) = BiLock::new(Stats {
362 cpu: StatsCpu {
363 cores: 0,
364 lavalink_load: 0f64,
365 system_load: 0f64,
366 },
367 frame_stats: None,
368 memory: StatsMemory {
369 allocated: 0,
370 free: 0,
371 used: 0,
372 reservable: 0,
373 },
374 players: 0,
375 playing_players: 0,
376 uptime: 0,
377 });
378
379 tracing::debug!("starting connection to {}", config.address);
380
381 let (conn_loop, lavalink_tx, lavalink_rx) =
382 Connection::connect(config.clone(), players.clone(), bilock_right).await?;
383
384 tracing::debug!("started connection to {}", config.address);
385
386 tokio::spawn(conn_loop.run());
387
388 Ok((
389 Self {
390 config,
391 lavalink_tx,
392 players,
393 stats: bilock_left,
394 },
395 IncomingEvents { inner: lavalink_rx },
396 ))
397 }
398
399 pub const fn config(&self) -> &NodeConfig {
401 &self.config
402 }
403
404 pub const fn players(&self) -> &PlayerManager {
406 &self.players
407 }
408
409 pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
419 self.sender().send(event)
420 }
421
422 pub fn sender(&self) -> NodeSender {
427 NodeSender {
428 inner: self.lavalink_tx.clone(),
429 }
430 }
431
432 pub async fn stats(&self) -> Stats {
434 (*self.stats.lock().await).clone()
435 }
436
437 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
442 pub async fn penalty(&self) -> i32 {
443 let stats = self.stats.lock().await;
444 let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
445
446 let (deficit_frame, null_frame) = (
447 1.03f64.powf(
448 500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64),
449 ) * 300f64
450 - 300f64,
451 (1.03f64.powf(
452 500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64),
453 ) * 300f64
454 - 300f64)
455 * 2f64,
456 );
457
458 stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
459 }
460}
461
462struct Connection {
463 config: NodeConfig,
464 stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
465 lavalink_http: HyperClient<HttpConnector, Full<Bytes>>,
466 node_from: UnboundedReceiver<OutgoingEvent>,
467 node_to: UnboundedSender<IncomingEvent>,
468 players: PlayerManager,
469 stats: BiLock<Stats>,
470 lavalink_session_id: Option<Box<str>>,
471}
472
473impl Connection {
474 async fn connect(
475 config: NodeConfig,
476 players: PlayerManager,
477 stats: BiLock<Stats>,
478 ) -> Result<
479 (
480 Self,
481 UnboundedSender<OutgoingEvent>,
482 UnboundedReceiver<IncomingEvent>,
483 ),
484 NodeError,
485 > {
486 let stream = reconnect(&config).await?;
487
488 let (to_node, from_lavalink) = mpsc::unbounded_channel();
489 let (to_lavalink, from_node) = mpsc::unbounded_channel();
490
491 let mut client_builder = HyperClient::builder(TokioExecutor::new());
492
493 if config.enable_tls {
494 client_builder.http2_only(config.enable_tls);
495 }
496
497 let lavalink_http = client_builder.build_http();
498
499 Ok((
500 Self {
501 config,
502 stream,
503 lavalink_http,
504 node_from: from_node,
505 node_to: to_node,
506 players,
507 stats,
508 lavalink_session_id: None,
509 },
510 to_lavalink,
511 from_lavalink,
512 ))
513 }
514
515 async fn run(mut self) -> Result<(), NodeError> {
516 loop {
517 tokio::select! {
518 incoming = self.stream.next() => {
519 if let Some(Ok(incoming)) = incoming {
520 self.incoming(incoming).await?;
521 } else {
522 tracing::debug!("connection to {} closed, reconnecting", self.config.address);
523 self.stream = reconnect(&self.config).await?;
524 }
525 }
526 outgoing = self.node_from.recv() => {
527 if let Some(outgoing) = outgoing {
528 self.outgoing(outgoing).await?;
529 } else {
530 tracing::debug!("node {} closed, ending connection", self.config.address);
531 break;
532 }
533 }
534 }
535 }
536
537 Ok(())
538 }
539
540 fn get_outgoing_endpoint_based_on_event(
541 &mut self,
542 outgoing: &OutgoingEvent,
543 ) -> Result<(Method, hyper::Uri), NodeError> {
544 let address = self.config.address;
545 tracing::debug!("forwarding event to {address}: {outgoing:?}");
546
547 let guild_id = outgoing.guild_id();
548 let no_replace = outgoing.no_replace();
549
550 if let Some(session) = &self.lavalink_session_id {
551 let mut path = format!("/v4/sessions/{session}/players/{guild_id}");
552 if !matches!(outgoing, OutgoingEvent::Destroy(_)) {
553 let _ = write!(path, "?noReplace={no_replace}");
554 }
555 let uri = Uri::builder()
556 .scheme("http")
557 .authority(address.to_string())
558 .path_and_query(path)
559 .build()
560 .expect("uri is valid");
561 return if matches!(outgoing, OutgoingEvent::Destroy(_)) {
562 Ok((Method::DELETE, uri))
563 } else {
564 Ok((Method::PATCH, uri))
565 };
566 }
567
568 tracing::error!("no session id is found");
569
570 Err(NodeError {
571 kind: NodeErrorType::OutgoingEventHasNoSession,
572 source: None,
573 })
574 }
575
576 async fn outgoing(&mut self, outgoing: OutgoingEvent) -> Result<(), NodeError> {
577 let (method, url) = self.get_outgoing_endpoint_based_on_event(&outgoing)?;
578 let payload = serde_json::to_string(&outgoing).expect("serialization cannot fail");
579
580 let authority = url.authority().expect("authority comes from endpoint");
581
582 let req = Request::builder()
583 .uri(url.borrow())
584 .method(method)
585 .header(header::HOST, authority.as_str())
586 .header(header::AUTHORIZATION, self.config.authorization.as_str())
587 .header(header::CONTENT_TYPE, "application/json")
588 .body(Full::from(payload))
589 .map_err(|source| NodeError {
590 kind: NodeErrorType::BuildingConnectionRequest,
591 source: Some(Box::new(source)),
592 })?;
593
594 self.lavalink_http
595 .request(req)
596 .await
597 .map_err(|source| NodeError {
598 kind: NodeErrorType::HttpRequestFailed,
599 source: Some(Box::new(source)),
600 })?;
601
602 Ok(())
603 }
604
605 async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
606 tracing::debug!(
607 "received message from {}: {incoming:?}",
608 self.config.address,
609 );
610
611 let text = if incoming.is_text() {
612 incoming.as_text().expect("message is text")
613 } else if incoming.is_close() {
614 tracing::debug!("got close, closing connection");
615
616 return Ok(false);
617 } else {
618 tracing::debug!("got ping, pong or binary payload: {incoming:?}");
619
620 return Ok(true);
621 };
622
623 let Ok(event) = serde_json::from_str(text) else {
624 tracing::warn!("unknown message from lavalink node: {text}");
625
626 return Ok(true);
627 };
628
629 match &event {
630 IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
631 IncomingEvent::Ready(ready) => {
632 self.lavalink_session_id = Some(ready.session_id.clone().into_boxed_str());
633 }
634 IncomingEvent::Stats(stats) => self.stats(stats).await?,
635 IncomingEvent::Event(_) => {}
636 }
637
638 if !self.node_to.is_closed() {
641 let _result = self.node_to.send(event);
642 }
643
644 Ok(true)
645 }
646
647 fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
648 let Some(player) = self.players.get(&update.guild_id) else {
649 tracing::warn!(
650 "invalid player update for guild {}: {update:?}",
651 update.guild_id,
652 );
653
654 return Ok(());
655 };
656
657 player.set_position(update.state.position);
658 player.set_time(update.state.time);
659
660 Ok(())
661 }
662
663 async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
664 *self.stats.lock().await = stats.clone();
665
666 Ok(())
667 }
668}
669
670impl Drop for Connection {
671 fn drop(&mut self) {
672 self.players
674 .players
675 .retain(|_, v| v.node().config().address != self.config.address);
676 }
677}
678
679const TWILIGHT_CLIENT_NAME: &str = concat!("twilight-lavalink/", env!("CARGO_PKG_VERSION"));
680
681fn connect_request(state: &NodeConfig) -> Result<ClientBuilder<'_>, NodeError> {
682 let websocket_protocol = if state.enable_tls { "wss" } else { "ws" };
683
684 let mut builder = ClientBuilder::new()
685 .uri(&format!(
686 "{}://{}/v4/websocket",
687 websocket_protocol, state.address
688 ))
689 .map_err(|source| NodeError {
690 kind: NodeErrorType::BuildingConnectionRequest,
691 source: Some(Box::new(source)),
692 })?
693 .add_header(
694 AUTHORIZATION,
695 state.authorization.parse().map_err(|source| NodeError {
696 kind: NodeErrorType::BuildingConnectionRequest,
697 source: Some(Box::new(source)),
698 })?,
699 )
700 .expect("Unable to create authorization header")
701 .add_header(
702 HeaderName::from_static("user-id"),
703 state.user_id.get().into(),
704 )
705 .expect("Unable to add user-id")
706 .add_header(
707 HeaderName::from_static("client-name"),
708 HeaderValue::from_static(TWILIGHT_CLIENT_NAME),
709 )
710 .expect("Unable to create builder");
711
712 if let Some(session_id) = &state.session_id {
714 builder = builder
715 .add_header(
716 HeaderName::from_static("session-id"),
717 session_id.parse().map_err(|source| NodeError {
718 kind: NodeErrorType::BuildingConnectionRequest,
719 source: Some(Box::new(source)),
720 })?,
721 )
722 .expect("Unable to add Session-Id header");
723 }
724
725 Ok(builder)
726}
727
728async fn reconnect(
743 config: &NodeConfig,
744) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
745 let (stream, res) = backoff(config).await?;
746
747 let headers = res.headers();
748
749 let header = HeaderName::from_static("session-resumed");
751 if let Some(value) = headers.get(header) {
752 if value.as_bytes() == b"true" {
753 tracing::info!(
754 "Successfully resumed Lavalink session for node {}",
755 config.address
756 );
757 } else if config.session_id.is_some() {
758 tracing::warn!(
759 "Failed to resume Lavalink session for node {} (session not resumed)",
760 config.address,
761 );
762 return Err(NodeError {
763 kind: NodeErrorType::Connecting,
764 source: None,
765 });
766 } else {
767 tracing::debug!("New Lavalink session created for node {}", config.address);
768 }
769 } else if config.session_id.is_some() {
770 tracing::warn!(
771 "Session-Resumed header not present for node {}; resume may have failed",
772 config.address,
773 );
774 return Err(NodeError {
775 kind: NodeErrorType::Connecting,
776 source: None,
777 });
778 }
779
780 Ok(stream)
781}
782
783async fn backoff(
784 config: &NodeConfig,
785) -> Result<
786 (
787 WebSocketStream<MaybeTlsStream<TcpStream>>,
788 upgrade::Response,
789 ),
790 NodeError,
791> {
792 let mut seconds = 1;
793
794 loop {
795 let request = connect_request(config)?;
796
797 match request.connect().await {
798 Ok((stream, response)) => return Ok((stream, response)),
799 Err(source) => {
800 tracing::warn!("failed to connect to node {source}: {:?}", config.address);
801
802 if matches!(
803 &source,
804 WebsocketError::Upgrade(upgrade::Error::DidNotSwitchProtocols(401))
805 ) {
806 return Err(NodeError {
807 kind: NodeErrorType::Unauthorized {
808 address: config.address,
809 authorization: config.authorization.clone(),
810 },
811 source: None,
812 });
813 }
814
815 if seconds > 64 {
816 tracing::debug!("no longer trying to connect to node {}", config.address);
817
818 return Err(NodeError {
819 kind: NodeErrorType::Connecting,
820 source: Some(Box::new(source)),
821 });
822 }
823
824 tracing::debug!(
825 "waiting {seconds} seconds before attempting to connect to node {} again",
826 config.address,
827 );
828 tokio_time::sleep(Duration::from_secs(seconds)).await;
829
830 seconds *= 2;
831 }
832 }
833 }
834}
835
836#[cfg(test)]
837mod tests {
838 use super::{Node, NodeConfig, NodeError, NodeErrorType};
839 use static_assertions::{assert_fields, assert_impl_all};
840 use std::{
841 error::Error,
842 fmt::Debug,
843 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
844 };
845 use twilight_model::id::Id;
846
847 assert_fields!(NodeConfig: address, authorization, user_id, enable_tls, session_id);
848 assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
849 assert_fields!(NodeErrorType::SerializingMessage: message);
850 assert_fields!(NodeErrorType::Unauthorized: address, authorization);
851 assert_impl_all!(NodeErrorType: Debug, Send, Sync);
852 assert_impl_all!(NodeError: Error, Send, Sync);
853 assert_impl_all!(Node: Debug, Send, Sync);
854
855 #[test]
856 fn node_config_debug() {
857 let config = NodeConfig {
858 address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1312)),
859 authorization: "some auth".to_owned(),
860 user_id: Id::new(123),
861 enable_tls: false,
862 session_id: None,
863 };
864
865 assert!(format!("{config:?}").contains("authorization: <redacted>"));
866 }
867}