1use crate::{
21 model::{IncomingEvent, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22 player::PlayerManager,
23};
24use futures_util::{
25 lock::BiLock,
26 sink::SinkExt,
27 stream::{Stream, StreamExt},
28};
29use http::header::{AUTHORIZATION, HeaderName, HeaderValue};
30use http_body_util::Full;
31use hyper::{Method, Request, Uri, body::Bytes, header};
32use hyper_util::{
33 client::legacy::{Client as HyperClient, connect::HttpConnector},
34 rt::TokioExecutor,
35};
36use std::{
37 borrow::Borrow,
38 error::Error,
39 fmt::{Debug, Display, Formatter, Result as FmtResult, Write as _},
40 net::SocketAddr,
41 pin::Pin,
42 task::{Context, Poll},
43 time::Duration,
44};
45use tokio::{
46 net::TcpStream,
47 sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
48 time as tokio_time,
49};
50use tokio_websockets::{
51 ClientBuilder, Error as WebsocketError, MaybeTlsStream, Message, WebSocketStream, upgrade,
52};
53use twilight_model::id::{Id, marker::UserMarker};
54
55#[derive(Debug)]
58pub struct NodeError {
59 kind: NodeErrorType,
60 source: Option<Box<dyn Error + Send + Sync>>,
61}
62
63impl NodeError {
64 #[must_use = "retrieving the type has no effect if left unused"]
66 pub const fn kind(&self) -> &NodeErrorType {
67 &self.kind
68 }
69
70 #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
72 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
73 self.source
74 }
75
76 #[must_use = "consuming the error into its parts has no effect if left unused"]
78 pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
79 (self.kind, self.source)
80 }
81}
82
83impl Display for NodeError {
84 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
85 match &self.kind {
86 NodeErrorType::BuildingConnectionRequest => {
87 f.write_str("failed to build connection request")
88 }
89 NodeErrorType::HttpRequestFailed => {
90 f.write_str("failed to send http request to lavalink server")
91 }
92 NodeErrorType::Connecting => f.write_str("Failed to connect to the node"),
93 NodeErrorType::OutgoingEventHasNoSession => {
94 f.write_str("no session id found for connection to lavalink api")
95 }
96 NodeErrorType::SerializingMessage { message: _ } => {
97 f.write_str("failed to serialize outgoing message as json")
98 }
99 NodeErrorType::Unauthorized { address, .. } => {
100 f.write_str("the authorization used to connect to node ")?;
101 Display::fmt(address, f)?;
102
103 f.write_str(" is invalid")
104 }
105 }
106 }
107}
108
109impl Error for NodeError {
110 fn source(&self) -> Option<&(dyn Error + 'static)> {
111 self.source
112 .as_ref()
113 .map(|source| &**source as &(dyn Error + 'static))
114 }
115}
116
117#[derive(Debug)]
119#[non_exhaustive]
120pub enum NodeErrorType {
121 BuildingConnectionRequest,
123 HttpRequestFailed,
125 Connecting,
127 OutgoingEventHasNoSession,
131 SerializingMessage {
133 message: OutgoingEvent,
135 },
136 Unauthorized {
138 address: SocketAddr,
140 authorization: String,
142 },
143}
144
145#[derive(Debug)]
147pub struct NodeSenderError {
148 kind: NodeSenderErrorType,
149 source: Option<Box<dyn Error + Send + Sync>>,
150}
151
152impl NodeSenderError {
153 pub const fn kind(&self) -> &NodeSenderErrorType {
155 &self.kind
156 }
157
158 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
160 self.source
161 }
162
163 #[must_use = "consuming the error into its parts has no effect if left unused"]
165 pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
166 (self.kind, self.source)
167 }
168}
169
170impl Display for NodeSenderError {
171 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
172 match &self.kind {
173 NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
174 }
175 }
176}
177
178impl Error for NodeSenderError {
179 fn source(&self) -> Option<&(dyn Error + 'static)> {
180 self.source
181 .as_ref()
182 .map(|source| &**source as &(dyn Error + 'static))
183 }
184}
185
186#[derive(Debug)]
188#[non_exhaustive]
189pub enum NodeSenderErrorType {
190 Sending,
192}
193
194pub struct IncomingEvents {
196 inner: UnboundedReceiver<IncomingEvent>,
197}
198
199impl IncomingEvents {
200 pub fn close(&mut self) {
202 self.inner.close();
203 }
204}
205
206impl Stream for IncomingEvents {
207 type Item = IncomingEvent;
208
209 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
210 self.inner.poll_recv(cx)
211 }
212}
213
214pub struct NodeSender {
216 inner: UnboundedSender<OutgoingEvent>,
217}
218
219impl NodeSender {
220 pub fn is_closed(&self) -> bool {
222 self.inner.is_closed()
223 }
224
225 pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
236 self.inner.send(msg).map_err(|source| NodeSenderError {
237 kind: NodeSenderErrorType::Sending,
238 source: Some(Box::new(source)),
239 })
240 }
241}
242
243#[derive(Clone, Eq, PartialEq)]
245#[non_exhaustive]
246pub struct NodeConfig {
248 pub address: SocketAddr,
250 pub authorization: String,
252 pub resume: Option<Resume>,
256 pub user_id: Id<UserMarker>,
258 pub enable_tls: bool,
260}
261
262impl Debug for NodeConfig {
263 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
264 struct Redacted;
268
269 impl Debug for Redacted {
270 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
271 f.write_str("<redacted>")
272 }
273 }
274
275 f.debug_struct("NodeConfig")
276 .field("address", &self.address)
277 .field("authorization", &Redacted)
278 .field("resume", &self.resume)
279 .field("user_id", &self.user_id)
280 .field("enable_tls", &self.enable_tls)
281 .finish()
282 }
283}
284
285#[derive(Clone, Debug, Eq, PartialEq)]
287#[non_exhaustive]
288pub struct Resume {
289 pub timeout: u64,
294}
295
296impl Resume {
297 pub const fn new(seconds: u64) -> Self {
300 Self { timeout: seconds }
301 }
302}
303
304impl Default for Resume {
305 fn default() -> Self {
306 Self { timeout: 60 }
307 }
308}
309
310impl NodeConfig {
311 pub fn new(
319 user_id: Id<UserMarker>,
320 address: impl Into<SocketAddr>,
321 authorization: impl Into<String>,
322 resume: impl Into<Option<Resume>>,
323 enable_tls: bool,
324 ) -> Self {
325 Self::_new(
326 user_id,
327 address.into(),
328 authorization.into(),
329 resume.into(),
330 enable_tls,
331 )
332 }
333
334 const fn _new(
335 user_id: Id<UserMarker>,
336 address: SocketAddr,
337 authorization: String,
338 resume: Option<Resume>,
339 enable_tls: bool,
340 ) -> Self {
341 Self {
342 address,
343 authorization,
344 resume,
345 user_id,
346 enable_tls,
347 }
348 }
349}
350
351#[derive(Debug)]
358pub struct Node {
359 config: NodeConfig,
360 lavalink_tx: UnboundedSender<OutgoingEvent>,
361 players: PlayerManager,
362 stats: BiLock<Stats>,
363}
364
365impl Node {
366 pub async fn connect(
391 config: NodeConfig,
392 players: PlayerManager,
393 ) -> Result<(Self, IncomingEvents), NodeError> {
394 let (bilock_left, bilock_right) = BiLock::new(Stats {
395 cpu: StatsCpu {
396 cores: 0,
397 lavalink_load: 0f64,
398 system_load: 0f64,
399 },
400 frame_stats: None,
401 memory: StatsMemory {
402 allocated: 0,
403 free: 0,
404 used: 0,
405 reservable: 0,
406 },
407 players: 0,
408 playing_players: 0,
409 uptime: 0,
410 });
411
412 tracing::debug!("starting connection to {}", config.address);
413
414 let (conn_loop, lavalink_tx, lavalink_rx) =
415 Connection::connect(config.clone(), players.clone(), bilock_right).await?;
416
417 tracing::debug!("started connection to {}", config.address);
418
419 tokio::spawn(conn_loop.run());
420
421 Ok((
422 Self {
423 config,
424 lavalink_tx,
425 players,
426 stats: bilock_left,
427 },
428 IncomingEvents { inner: lavalink_rx },
429 ))
430 }
431
432 pub const fn config(&self) -> &NodeConfig {
434 &self.config
435 }
436
437 pub const fn players(&self) -> &PlayerManager {
439 &self.players
440 }
441
442 pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
452 self.sender().send(event)
453 }
454
455 pub fn sender(&self) -> NodeSender {
460 NodeSender {
461 inner: self.lavalink_tx.clone(),
462 }
463 }
464
465 pub async fn stats(&self) -> Stats {
467 (*self.stats.lock().await).clone()
468 }
469
470 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
475 pub async fn penalty(&self) -> i32 {
476 let stats = self.stats.lock().await;
477 let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
478
479 let (deficit_frame, null_frame) = (
480 1.03f64.powf(
481 500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64),
482 ) * 300f64
483 - 300f64,
484 (1.03f64.powf(
485 500f64 * (stats.frame_stats.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64),
486 ) * 300f64
487 - 300f64)
488 * 2f64,
489 );
490
491 stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
492 }
493}
494
495struct Connection {
496 config: NodeConfig,
497 stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
498 lavalink_http: HyperClient<HttpConnector, Full<Bytes>>,
499 node_from: UnboundedReceiver<OutgoingEvent>,
500 node_to: UnboundedSender<IncomingEvent>,
501 players: PlayerManager,
502 stats: BiLock<Stats>,
503 lavalink_session_id: Option<Box<str>>,
504}
505
506impl Connection {
507 async fn connect(
508 config: NodeConfig,
509 players: PlayerManager,
510 stats: BiLock<Stats>,
511 ) -> Result<
512 (
513 Self,
514 UnboundedSender<OutgoingEvent>,
515 UnboundedReceiver<IncomingEvent>,
516 ),
517 NodeError,
518 > {
519 let stream = reconnect(&config).await?;
520
521 let (to_node, from_lavalink) = mpsc::unbounded_channel();
522 let (to_lavalink, from_node) = mpsc::unbounded_channel();
523
524 let mut client_builder = HyperClient::builder(TokioExecutor::new());
525
526 if config.enable_tls {
527 client_builder.http2_only(config.enable_tls);
528 }
529
530 let lavalink_http = client_builder.build_http();
531
532 Ok((
533 Self {
534 config,
535 stream,
536 lavalink_http,
537 node_from: from_node,
538 node_to: to_node,
539 players,
540 stats,
541 lavalink_session_id: None,
542 },
543 to_lavalink,
544 from_lavalink,
545 ))
546 }
547
548 async fn run(mut self) -> Result<(), NodeError> {
549 loop {
550 tokio::select! {
551 incoming = self.stream.next() => {
552 if let Some(Ok(incoming)) = incoming {
553 self.incoming(incoming).await?;
554 } else {
555 tracing::debug!("connection to {} closed, reconnecting", self.config.address);
556 self.stream = reconnect(&self.config).await?;
557 }
558 }
559 outgoing = self.node_from.recv() => {
560 if let Some(outgoing) = outgoing {
561 self.outgoing(outgoing).await?;
562 } else {
563 tracing::debug!("node {} closed, ending connection", self.config.address);
564 break;
565 }
566 }
567 }
568 }
569
570 Ok(())
571 }
572
573 fn get_outgoing_endpoint_based_on_event(
574 &mut self,
575 outgoing: &OutgoingEvent,
576 ) -> Result<(Method, hyper::Uri), NodeError> {
577 let address = self.config.address;
578 tracing::debug!("forwarding event to {address}: {outgoing:?}");
579
580 let guild_id = outgoing.guild_id();
581 let no_replace = outgoing.no_replace();
582
583 if let Some(session) = &self.lavalink_session_id {
584 let mut path = format!("/v4/sessions/{session}/players/{guild_id}");
585 if !matches!(outgoing, OutgoingEvent::Destroy(_)) {
586 let _ = write!(path, "?noReplace={no_replace}");
587 }
588 let uri = Uri::builder()
589 .scheme("http")
590 .authority(address.to_string())
591 .path_and_query(path)
592 .build()
593 .expect("uri is valid");
594 return if matches!(outgoing, OutgoingEvent::Destroy(_)) {
595 Ok((Method::DELETE, uri))
596 } else {
597 Ok((Method::PATCH, uri))
598 };
599 }
600
601 tracing::error!("no session id is found");
602
603 Err(NodeError {
604 kind: NodeErrorType::OutgoingEventHasNoSession,
605 source: None,
606 })
607 }
608
609 async fn outgoing(&mut self, outgoing: OutgoingEvent) -> Result<(), NodeError> {
610 let (method, url) = self.get_outgoing_endpoint_based_on_event(&outgoing)?;
611 let payload = serde_json::to_string(&outgoing).expect("serialization cannot fail");
612
613 let authority = url.authority().expect("authority comes from endpoint");
614
615 let req = Request::builder()
616 .uri(url.borrow())
617 .method(method)
618 .header(header::HOST, authority.as_str())
619 .header(header::AUTHORIZATION, self.config.authorization.as_str())
620 .header(header::CONTENT_TYPE, "application/json")
621 .body(Full::from(payload))
622 .map_err(|source| NodeError {
623 kind: NodeErrorType::BuildingConnectionRequest,
624 source: Some(Box::new(source)),
625 })?;
626
627 self.lavalink_http
628 .request(req)
629 .await
630 .map_err(|source| NodeError {
631 kind: NodeErrorType::HttpRequestFailed,
632 source: Some(Box::new(source)),
633 })?;
634
635 Ok(())
636 }
637
638 async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
639 tracing::debug!(
640 "received message from {}: {incoming:?}",
641 self.config.address,
642 );
643
644 let text = if incoming.is_text() {
645 incoming.as_text().expect("message is text")
646 } else if incoming.is_close() {
647 tracing::debug!("got close, closing connection");
648
649 return Ok(false);
650 } else {
651 tracing::debug!("got ping, pong or binary payload: {incoming:?}");
652
653 return Ok(true);
654 };
655
656 let Ok(event) = serde_json::from_str(text) else {
657 tracing::warn!("unknown message from lavalink node: {text}");
658
659 return Ok(true);
660 };
661
662 match &event {
663 IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
664 IncomingEvent::Ready(ready) => {
665 self.lavalink_session_id = Some(ready.session_id.clone().into_boxed_str());
666 }
667 IncomingEvent::Stats(stats) => self.stats(stats).await?,
668 IncomingEvent::Event(_) => {}
669 }
670
671 if !self.node_to.is_closed() {
674 let _result = self.node_to.send(event);
675 }
676
677 Ok(true)
678 }
679
680 fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
681 let Some(player) = self.players.get(&update.guild_id) else {
682 tracing::warn!(
683 "invalid player update for guild {}: {update:?}",
684 update.guild_id,
685 );
686
687 return Ok(());
688 };
689
690 player.set_position(update.state.position);
691 player.set_time(update.state.time);
692
693 Ok(())
694 }
695
696 async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
697 *self.stats.lock().await = stats.clone();
698
699 Ok(())
700 }
701}
702
703impl Drop for Connection {
704 fn drop(&mut self) {
705 self.players
707 .players
708 .retain(|_, v| v.node().config().address != self.config.address);
709 }
710}
711
712const TWILIGHT_CLIENT_NAME: &str = concat!("twilight-lavalink/", env!("CARGO_PKG_VERSION"));
713
714fn connect_request(state: &NodeConfig) -> Result<ClientBuilder<'_>, NodeError> {
715 let websocket_protocol = if state.enable_tls { "wss" } else { "ws" };
716
717 let mut builder = ClientBuilder::new()
718 .uri(&format!(
719 "{}://{}/v4/websocket",
720 websocket_protocol, state.address
721 ))
722 .map_err(|source| NodeError {
723 kind: NodeErrorType::BuildingConnectionRequest,
724 source: Some(Box::new(source)),
725 })?
726 .add_header(
727 AUTHORIZATION,
728 state.authorization.parse().map_err(|source| NodeError {
729 kind: NodeErrorType::BuildingConnectionRequest,
730 source: Some(Box::new(source)),
731 })?,
732 )
733 .expect("Unable to crate authorization header")
734 .add_header(
735 HeaderName::from_static("user-id"),
736 state.user_id.get().into(),
737 )
738 .expect("Unable to add user-id")
739 .add_header(
740 HeaderName::from_static("client-name"),
741 HeaderValue::from_static(TWILIGHT_CLIENT_NAME),
742 )
743 .expect("Unable to crate builder");
744
745 if state.resume.is_some() {
746 builder = builder
747 .add_header(
748 HeaderName::from_static("resume-key"),
749 state
750 .address
751 .to_string()
752 .parse()
753 .map_err(|source| NodeError {
754 kind: NodeErrorType::BuildingConnectionRequest,
755 source: Some(Box::new(source)),
756 })?,
757 )
758 .expect("Unable to build state header");
759 }
760
761 Ok(builder)
762}
763
764async fn reconnect(
765 config: &NodeConfig,
766) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
767 let (mut stream, res) = backoff(config).await?;
768
769 let headers = res.headers();
770
771 if let Some(resume) = config.resume.as_ref() {
772 let header = HeaderName::from_static("session-resumed");
773
774 if let Some(value) = headers.get(header) {
775 if value.as_bytes() == b"false" {
776 tracing::debug!("session to node {} didn't resume", config.address);
777
778 let payload = serde_json::json!({
779 "op": "configureResuming",
780 "key": config.address,
781 "timeout": resume.timeout,
782 });
783 let msg = Message::text(
784 serde_json::to_string(&payload).expect("serialize can't panic here"),
785 );
786
787 stream.send(msg).await.map_err(|source| NodeError {
788 kind: NodeErrorType::Connecting,
789 source: Some(Box::new(source)),
790 })?;
791 } else {
792 tracing::debug!("session to {} resumed", config.address);
793 }
794 }
795 }
796
797 Ok(stream)
798}
799
800async fn backoff(
801 config: &NodeConfig,
802) -> Result<
803 (
804 WebSocketStream<MaybeTlsStream<TcpStream>>,
805 upgrade::Response,
806 ),
807 NodeError,
808> {
809 let mut seconds = 1;
810
811 loop {
812 let request = connect_request(config)?;
813
814 match request.connect().await {
815 Ok((stream, response)) => return Ok((stream, response)),
816 Err(source) => {
817 tracing::warn!("failed to connect to node {source}: {:?}", config.address);
818
819 if matches!(
820 &source,
821 WebsocketError::Upgrade(upgrade::Error::DidNotSwitchProtocols(401))
822 ) {
823 return Err(NodeError {
824 kind: NodeErrorType::Unauthorized {
825 address: config.address,
826 authorization: config.authorization.clone(),
827 },
828 source: None,
829 });
830 }
831
832 if seconds > 64 {
833 tracing::debug!("no longer trying to connect to node {}", config.address);
834
835 return Err(NodeError {
836 kind: NodeErrorType::Connecting,
837 source: Some(Box::new(source)),
838 });
839 }
840
841 tracing::debug!(
842 "waiting {seconds} seconds before attempting to connect to node {} again",
843 config.address,
844 );
845 tokio_time::sleep(Duration::from_secs(seconds)).await;
846
847 seconds *= 2;
848 }
849 }
850 }
851}
852
853#[cfg(test)]
854mod tests {
855 use super::{Node, NodeConfig, NodeError, NodeErrorType, Resume};
856 use static_assertions::{assert_fields, assert_impl_all};
857 use std::{
858 error::Error,
859 fmt::Debug,
860 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
861 };
862 use twilight_model::id::Id;
863
864 assert_fields!(NodeConfig: address, authorization, resume, user_id, enable_tls);
865 assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
866 assert_fields!(NodeErrorType::SerializingMessage: message);
867 assert_fields!(NodeErrorType::Unauthorized: address, authorization);
868 assert_impl_all!(NodeErrorType: Debug, Send, Sync);
869 assert_impl_all!(NodeError: Error, Send, Sync);
870 assert_impl_all!(Node: Debug, Send, Sync);
871 assert_fields!(Resume: timeout);
872 assert_impl_all!(Resume: Clone, Debug, Default, Eq, PartialEq, Send, Sync);
873
874 #[test]
875 fn node_config_debug() {
876 let config = NodeConfig {
877 address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 1312)),
878 authorization: "some auth".to_owned(),
879 resume: None,
880 user_id: Id::new(123),
881 enable_tls: false,
882 };
883
884 assert!(format!("{config:?}").contains("authorization: <redacted>"));
885 }
886}