1use crate::{
21 model::{IncomingEvent, Opcode, OutgoingEvent, PlayerUpdate, Stats, StatsCpu, StatsMemory},
22 player::PlayerManager,
23};
24use futures_util::{
25 lock::BiLock,
26 sink::SinkExt,
27 stream::{Stream, StreamExt},
28};
29use http::header::{HeaderName, AUTHORIZATION};
30use std::{
31 error::Error,
32 fmt::{Debug, Display, Formatter, Result as FmtResult},
33 net::SocketAddr,
34 pin::Pin,
35 task::{Context, Poll},
36 time::Duration,
37};
38use tokio::{
39 net::TcpStream,
40 sync::mpsc::{self, UnboundedReceiver, UnboundedSender},
41 time as tokio_time,
42};
43use tokio_websockets::{
44 upgrade, ClientBuilder, Error as WebsocketError, MaybeTlsStream, Message, WebSocketStream,
45};
46use twilight_model::id::{marker::UserMarker, Id};
47
48#[derive(Debug)]
51pub struct NodeError {
52 kind: NodeErrorType,
53 source: Option<Box<dyn Error + Send + Sync>>,
54}
55
56impl NodeError {
57 #[must_use = "retrieving the type has no effect if left unused"]
59 pub const fn kind(&self) -> &NodeErrorType {
60 &self.kind
61 }
62
63 #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
65 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
66 self.source
67 }
68
69 #[must_use = "consuming the error into its parts has no effect if left unused"]
71 pub fn into_parts(self) -> (NodeErrorType, Option<Box<dyn Error + Send + Sync>>) {
72 (self.kind, self.source)
73 }
74}
75
76impl Display for NodeError {
77 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
78 match &self.kind {
79 NodeErrorType::BuildingConnectionRequest { .. } => {
80 f.write_str("failed to build connection request")
81 }
82 NodeErrorType::Connecting { .. } => f.write_str("Failed to connect to the node"),
83 NodeErrorType::SerializingMessage { .. } => {
84 f.write_str("failed to serialize outgoing message as json")
85 }
86 NodeErrorType::Unauthorized { address, .. } => {
87 f.write_str("the authorization used to connect to node ")?;
88 Display::fmt(address, f)?;
89
90 f.write_str(" is invalid")
91 }
92 }
93 }
94}
95
96impl Error for NodeError {
97 fn source(&self) -> Option<&(dyn Error + 'static)> {
98 self.source
99 .as_ref()
100 .map(|source| &**source as &(dyn Error + 'static))
101 }
102}
103
104#[derive(Debug)]
106#[non_exhaustive]
107pub enum NodeErrorType {
108 BuildingConnectionRequest,
110 Connecting,
112 SerializingMessage {
114 message: OutgoingEvent,
116 },
117 Unauthorized {
119 address: SocketAddr,
121 authorization: String,
123 },
124}
125
126#[derive(Debug)]
128pub struct NodeSenderError {
129 kind: NodeSenderErrorType,
130 source: Option<Box<dyn Error + Send + Sync>>,
131}
132
133impl NodeSenderError {
134 pub const fn kind(&self) -> &NodeSenderErrorType {
136 &self.kind
137 }
138
139 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
141 self.source
142 }
143
144 #[must_use = "consuming the error into its parts has no effect if left unused"]
146 pub fn into_parts(self) -> (NodeSenderErrorType, Option<Box<dyn Error + Send + Sync>>) {
147 (self.kind, self.source)
148 }
149}
150
151impl Display for NodeSenderError {
152 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
153 match &self.kind {
154 NodeSenderErrorType::Sending => f.write_str("failed to send over channel"),
155 }
156 }
157}
158
159impl Error for NodeSenderError {
160 fn source(&self) -> Option<&(dyn Error + 'static)> {
161 self.source
162 .as_ref()
163 .map(|source| &**source as &(dyn Error + 'static))
164 }
165}
166
167#[derive(Debug)]
169#[non_exhaustive]
170pub enum NodeSenderErrorType {
171 Sending,
173}
174
175pub struct IncomingEvents {
177 inner: UnboundedReceiver<IncomingEvent>,
178}
179
180impl IncomingEvents {
181 pub fn close(&mut self) {
183 self.inner.close();
184 }
185}
186
187impl Stream for IncomingEvents {
188 type Item = IncomingEvent;
189
190 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
191 self.inner.poll_recv(cx)
192 }
193}
194
195pub struct NodeSender {
197 inner: UnboundedSender<OutgoingEvent>,
198}
199
200impl NodeSender {
201 pub fn is_closed(&self) -> bool {
203 self.inner.is_closed()
204 }
205
206 pub fn send(&self, msg: OutgoingEvent) -> Result<(), NodeSenderError> {
217 self.inner.send(msg).map_err(|source| NodeSenderError {
218 kind: NodeSenderErrorType::Sending,
219 source: Some(Box::new(source)),
220 })
221 }
222}
223
224#[derive(Clone, Eq, PartialEq)]
226#[non_exhaustive]
227pub struct NodeConfig {
229 pub address: SocketAddr,
231 pub authorization: String,
233 pub resume: Option<Resume>,
237 pub user_id: Id<UserMarker>,
239}
240
241impl Debug for NodeConfig {
242 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
243 struct Redacted;
247
248 impl Debug for Redacted {
249 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
250 f.write_str("<redacted>")
251 }
252 }
253
254 f.debug_struct("NodeConfig")
255 .field("address", &self.address)
256 .field("authorization", &Redacted)
257 .field("resume", &self.resume)
258 .field("user_id", &self.user_id)
259 .finish()
260 }
261}
262
263#[derive(Clone, Debug, Eq, PartialEq)]
265#[non_exhaustive]
266pub struct Resume {
267 pub timeout: u64,
272}
273
274impl Resume {
275 pub const fn new(seconds: u64) -> Self {
278 Self { timeout: seconds }
279 }
280}
281
282impl Default for Resume {
283 fn default() -> Self {
284 Self { timeout: 60 }
285 }
286}
287
288impl NodeConfig {
289 pub fn new(
297 user_id: Id<UserMarker>,
298 address: impl Into<SocketAddr>,
299 authorization: impl Into<String>,
300 resume: impl Into<Option<Resume>>,
301 ) -> Self {
302 Self::_new(user_id, address.into(), authorization.into(), resume.into())
303 }
304
305 const fn _new(
306 user_id: Id<UserMarker>,
307 address: SocketAddr,
308 authorization: String,
309 resume: Option<Resume>,
310 ) -> Self {
311 Self {
312 address,
313 authorization,
314 resume,
315 user_id,
316 }
317 }
318}
319
320#[derive(Debug)]
327pub struct Node {
328 config: NodeConfig,
329 lavalink_tx: UnboundedSender<OutgoingEvent>,
330 players: PlayerManager,
331 stats: BiLock<Stats>,
332}
333
334impl Node {
335 pub async fn connect(
360 config: NodeConfig,
361 players: PlayerManager,
362 ) -> Result<(Self, IncomingEvents), NodeError> {
363 let (bilock_left, bilock_right) = BiLock::new(Stats {
364 cpu: StatsCpu {
365 cores: 0,
366 lavalink_load: 0f64,
367 system_load: 0f64,
368 },
369 frames: None,
370 memory: StatsMemory {
371 allocated: 0,
372 free: 0,
373 used: 0,
374 reservable: 0,
375 },
376 players: 0,
377 playing_players: 0,
378 op: Opcode::Stats,
379 uptime: 0,
380 });
381
382 tracing::debug!("starting connection to {}", config.address);
383
384 let (conn_loop, lavalink_tx, lavalink_rx) =
385 Connection::connect(config.clone(), players.clone(), bilock_right).await?;
386
387 tracing::debug!("started connection to {}", config.address);
388
389 tokio::spawn(conn_loop.run());
390
391 Ok((
392 Self {
393 config,
394 lavalink_tx,
395 players,
396 stats: bilock_left,
397 },
398 IncomingEvents { inner: lavalink_rx },
399 ))
400 }
401
402 pub const fn config(&self) -> &NodeConfig {
404 &self.config
405 }
406
407 pub const fn players(&self) -> &PlayerManager {
409 &self.players
410 }
411
412 pub fn send(&self, event: OutgoingEvent) -> Result<(), NodeSenderError> {
422 self.sender().send(event)
423 }
424
425 pub fn sender(&self) -> NodeSender {
430 NodeSender {
431 inner: self.lavalink_tx.clone(),
432 }
433 }
434
435 pub async fn stats(&self) -> Stats {
437 (*self.stats.lock().await).clone()
438 }
439
440 #[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
445 pub async fn penalty(&self) -> i32 {
446 let stats = self.stats.lock().await;
447 let cpu = 1.05f64.powf(100f64 * stats.cpu.system_load) * 10f64 - 10f64;
448
449 let (deficit_frame, null_frame) = (
450 1.03f64
451 .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.deficit) as f64 / 3000f64))
452 * 300f64
453 - 300f64,
454 (1.03f64
455 .powf(500f64 * (stats.frames.as_ref().map_or(0, |f| f.nulled) as f64 / 3000f64))
456 * 300f64
457 - 300f64)
458 * 2f64,
459 );
460
461 stats.playing_players as i32 + cpu as i32 + deficit_frame as i32 + null_frame as i32
462 }
463}
464
465struct Connection {
466 config: NodeConfig,
467 stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
468 node_from: UnboundedReceiver<OutgoingEvent>,
469 node_to: UnboundedSender<IncomingEvent>,
470 players: PlayerManager,
471 stats: BiLock<Stats>,
472}
473
474impl Connection {
475 async fn connect(
476 config: NodeConfig,
477 players: PlayerManager,
478 stats: BiLock<Stats>,
479 ) -> Result<
480 (
481 Self,
482 UnboundedSender<OutgoingEvent>,
483 UnboundedReceiver<IncomingEvent>,
484 ),
485 NodeError,
486 > {
487 let stream = reconnect(&config).await?;
488
489 let (to_node, from_lavalink) = mpsc::unbounded_channel();
490 let (to_lavalink, from_node) = mpsc::unbounded_channel();
491
492 Ok((
493 Self {
494 config,
495 stream,
496 node_from: from_node,
497 node_to: to_node,
498 players,
499 stats,
500 },
501 to_lavalink,
502 from_lavalink,
503 ))
504 }
505
506 async fn run(mut self) -> Result<(), NodeError> {
507 loop {
508 tokio::select! {
509 incoming = self.stream.next() => {
510 if let Some(Ok(incoming)) = incoming {
511 self.incoming(incoming).await?;
512 } else {
513 tracing::debug!("connection to {} closed, reconnecting", self.config.address);
514 self.stream = reconnect(&self.config).await?;
515 }
516 }
517 outgoing = self.node_from.recv() => {
518 if let Some(outgoing) = outgoing {
519 tracing::debug!(
520 "forwarding event to {}: {outgoing:?}",
521 self.config.address,
522 );
523
524 let payload = serde_json::to_string(&outgoing).map_err(|source| NodeError {
525 kind: NodeErrorType::SerializingMessage { message: outgoing },
526 source: Some(Box::new(source)),
527 })?;
528 let msg = Message::text(payload);
529 self.stream.send(msg).await.unwrap();
530 } else {
531 tracing::debug!("node {} closed, ending connection", self.config.address);
532
533 break;
534 }
535 }
536 }
537 }
538
539 Ok(())
540 }
541
542 async fn incoming(&mut self, incoming: Message) -> Result<bool, NodeError> {
543 tracing::debug!(
544 "received message from {}: {incoming:?}",
545 self.config.address,
546 );
547
548 let text = if incoming.is_text() {
549 incoming.as_text().expect("message is text")
550 } else if incoming.is_close() {
551 tracing::debug!("got close, closing connection");
552
553 return Ok(false);
554 } else {
555 tracing::debug!("got ping, pong or binary payload: {incoming:?}");
556
557 return Ok(true);
558 };
559
560 let Ok(event) = serde_json::from_str(text) else {
561 tracing::warn!("unknown message from lavalink node: {text}");
562
563 return Ok(true);
564 };
565
566 match &event {
567 IncomingEvent::PlayerUpdate(update) => self.player_update(update)?,
568 IncomingEvent::Stats(stats) => self.stats(stats).await?,
569 _ => {}
570 }
571
572 if !self.node_to.is_closed() {
575 let _result = self.node_to.send(event);
576 }
577
578 Ok(true)
579 }
580
581 fn player_update(&self, update: &PlayerUpdate) -> Result<(), NodeError> {
582 let Some(player) = self.players.get(&update.guild_id) else {
583 tracing::warn!(
584 "invalid player update for guild {}: {update:?}",
585 update.guild_id,
586 );
587
588 return Ok(());
589 };
590
591 player.set_position(update.state.position.unwrap_or(0));
592 player.set_time(update.state.time);
593
594 Ok(())
595 }
596
597 async fn stats(&self, stats: &Stats) -> Result<(), NodeError> {
598 *self.stats.lock().await = stats.clone();
599
600 Ok(())
601 }
602}
603
604impl Drop for Connection {
605 fn drop(&mut self) {
606 self.players
608 .players
609 .retain(|_, v| v.node().config().address != self.config.address);
610 }
611}
612
613fn connect_request(state: &NodeConfig) -> Result<ClientBuilder, NodeError> {
614 let mut builder = ClientBuilder::new()
615 .uri(&format!("ws://{}", state.address))
616 .map_err(|source| NodeError {
617 kind: NodeErrorType::BuildingConnectionRequest,
618 source: Some(Box::new(source)),
619 })?
620 .add_header(AUTHORIZATION, state.authorization.parse().unwrap())
621 .expect("allowed header")
622 .add_header(
623 HeaderName::from_static("user-id"),
624 state.user_id.get().into(),
625 )
626 .expect("allowed header");
627
628 if state.resume.is_some() {
629 builder = builder
630 .add_header(
631 HeaderName::from_static("resume-key"),
632 state.address.to_string().parse().unwrap(),
633 )
634 .expect("allowed header");
635 }
636
637 Ok(builder)
638}
639
640async fn reconnect(
641 config: &NodeConfig,
642) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, NodeError> {
643 let (mut stream, res) = backoff(config).await?;
644
645 let headers = res.headers();
646
647 if let Some(resume) = config.resume.as_ref() {
648 let header = HeaderName::from_static("session-resumed");
649
650 if let Some(value) = headers.get(header) {
651 if value.as_bytes() == b"false" {
652 tracing::debug!("session to node {} didn't resume", config.address);
653
654 let payload = serde_json::json!({
655 "op": "configureResuming",
656 "key": config.address,
657 "timeout": resume.timeout,
658 });
659 let msg = Message::text(serde_json::to_string(&payload).unwrap());
660
661 stream.send(msg).await.unwrap();
662 } else {
663 tracing::debug!("session to {} resumed", config.address);
664 }
665 }
666 }
667
668 Ok(stream)
669}
670
671async fn backoff(
672 config: &NodeConfig,
673) -> Result<
674 (
675 WebSocketStream<MaybeTlsStream<TcpStream>>,
676 upgrade::Response,
677 ),
678 NodeError,
679> {
680 let mut seconds = 1;
681
682 loop {
683 let request = connect_request(config)?;
684
685 match request.connect().await {
686 Ok((stream, response)) => return Ok((stream, response)),
687 Err(source) => {
688 tracing::warn!("failed to connect to node {source}: {:?}", config.address);
689
690 if matches!(
691 &source,
692 WebsocketError::Upgrade(upgrade::Error::DidNotSwitchProtocols(401))
693 ) {
694 return Err(NodeError {
695 kind: NodeErrorType::Unauthorized {
696 address: config.address,
697 authorization: config.authorization.clone(),
698 },
699 source: None,
700 });
701 }
702
703 if seconds > 64 {
704 tracing::debug!("no longer trying to connect to node {}", config.address);
705
706 return Err(NodeError {
707 kind: NodeErrorType::Connecting,
708 source: Some(Box::new(source)),
709 });
710 }
711
712 tracing::debug!(
713 "waiting {seconds} seconds before attempting to connect to node {} again",
714 config.address,
715 );
716 tokio_time::sleep(Duration::from_secs(seconds)).await;
717
718 seconds *= 2;
719
720 continue;
721 }
722 }
723 }
724}
725
726#[cfg(test)]
727mod tests {
728 use super::{Node, NodeConfig, NodeError, NodeErrorType, Resume};
729 use static_assertions::{assert_fields, assert_impl_all};
730 use std::{
731 error::Error,
732 fmt::Debug,
733 net::{Ipv4Addr, SocketAddr, SocketAddrV4},
734 };
735 use twilight_model::id::Id;
736
737 assert_fields!(NodeConfig: address, authorization, resume, user_id);
738 assert_impl_all!(NodeConfig: Clone, Debug, Send, Sync);
739 assert_fields!(NodeErrorType::SerializingMessage: message);
740 assert_fields!(NodeErrorType::Unauthorized: address, authorization);
741 assert_impl_all!(NodeErrorType: Debug, Send, Sync);
742 assert_impl_all!(NodeError: Error, Send, Sync);
743 assert_impl_all!(Node: Debug, Send, Sync);
744 assert_fields!(Resume: timeout);
745 assert_impl_all!(Resume: Clone, Debug, Default, Eq, PartialEq, Send, Sync);
746
747 #[test]
748 fn node_config_debug() {
749 let config = NodeConfig {
750 address: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 1312)),
751 authorization: "some auth".to_owned(),
752 resume: None,
753 user_id: Id::new(123),
754 };
755
756 assert!(format!("{config:?}").contains("authorization: <redacted>"));
757 }
758}