twilight_lavalink/client.rs
1//! Client to manage nodes and players.
2
3use crate::{
4 model::VoiceUpdate,
5 node::{IncomingEvents, Node, NodeConfig, NodeError},
6 player::{Player, PlayerManager},
7};
8use dashmap::DashMap;
9use std::{
10 error::Error,
11 fmt::{Display, Formatter, Result as FmtResult},
12 net::SocketAddr,
13 sync::Arc,
14};
15use twilight_model::{
16 gateway::{ShardId, event::Event, payload::incoming::VoiceServerUpdate},
17 id::{
18 Id,
19 marker::{ChannelMarker, GuildMarker, UserMarker},
20 },
21};
22
23/// An error that can occur while interacting with the client.
24#[derive(Debug)]
25pub struct ClientError {
26 kind: ClientErrorType,
27 source: Option<Box<dyn Error + Send + Sync>>,
28}
29
30impl ClientError {
31 /// Immutable reference to the type of error that occurred.
32 pub const fn kind(&self) -> &ClientErrorType {
33 &self.kind
34 }
35
36 /// Consume the error, returning the source error if there is any.
37 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
38 self.source
39 }
40
41 /// Consume the error, returning the owned error type and the source error.
42 #[must_use = "consuming the error into its parts has no effect if left unused"]
43 pub fn into_parts(self) -> (ClientErrorType, Option<Box<dyn Error + Send + Sync>>) {
44 (self.kind, self.source)
45 }
46}
47
48impl Display for ClientError {
49 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
50 match &self.kind {
51 ClientErrorType::NodesUnconfigured => f.write_str("no node has been configured"),
52 ClientErrorType::SendingVoiceUpdate => {
53 f.write_str("couldn't send voice update to node")
54 }
55 }
56 }
57}
58
59impl Error for ClientError {
60 fn source(&self) -> Option<&(dyn Error + 'static)> {
61 self.source
62 .as_ref()
63 .map(|source| &**source as &(dyn Error + 'static))
64 }
65}
66
67/// Type of [`ClientError`] that occurred.
68#[derive(Debug)]
69#[non_exhaustive]
70pub enum ClientErrorType {
71 /// A node isn't configured, so the operation isn't possible to fulfill.
72 NodesUnconfigured,
73 /// Sending a voice update event to the node failed because the node's
74 /// connection was shutdown.
75 SendingVoiceUpdate,
76}
77
78#[derive(Debug)]
79struct DiscordSession {
80 channel_id: Option<Id<ChannelMarker>>,
81 id: Box<str>,
82}
83
84/// The lavalink client that manages nodes, players, and processes events from
85/// Discord to tie it all together.
86///
87/// **Note**: You must call the [`process`] method with every Voice State Update
88/// and Voice Server Update event you receive from Discord. It will
89/// automatically forward these events to Lavalink. See its documentation for
90/// more information.
91///
92/// You can retrieve players using the [`player`] method. Players contain
93/// information about the active playing information of a guild and allows you to send events to the
94/// connected node, such as [`Play`] events.
95///
96/// # Using a Lavalink client in multiple tasks
97///
98/// To use a Lavalink client instance in multiple tasks, consider wrapping it in
99/// an [`std::sync::Arc`] or [`std::rc::Rc`].
100///
101/// [`Play`]: crate::model::outgoing::Play
102/// [`player`]: Self::player
103/// [`process`]: Self::process
104#[derive(Debug)]
105pub struct Lavalink {
106 nodes: DashMap<SocketAddr, Arc<Node>>,
107 players: PlayerManager,
108 shard_count: u32,
109 user_id: Id<UserMarker>,
110 server_updates: DashMap<Id<GuildMarker>, VoiceServerUpdate>,
111 discord_sessions: DashMap<Id<GuildMarker>, DiscordSession>,
112}
113
114impl Lavalink {
115 /// Create a new Lavalink client instance.
116 ///
117 /// The user ID and number of shards provided may not be modified during
118 /// runtime, and the client must be re-created. These parameters are
119 /// automatically passed to new nodes created via [`add`].
120 ///
121 /// [`add`]: Self::add
122 pub fn new(user_id: Id<UserMarker>, shard_count: u32) -> Self {
123 Self {
124 nodes: DashMap::new(),
125 players: PlayerManager::new(),
126 shard_count,
127 user_id,
128 server_updates: DashMap::new(),
129 discord_sessions: DashMap::new(),
130 }
131 }
132
133 /// Process an event into the Lavalink client.
134 ///
135 /// **Note**: calling this method in your event loop is required. See the
136 /// [crate documentation] for an example.
137 ///
138 /// This requires the `VoiceServerUpdate` and `VoiceStateUpdate` events that
139 /// you receive from Discord over the gateway to send voice updates to
140 /// nodes. For simplicity in some applications' event loops, any event can
141 /// be provided, but they will just be ignored.
142 ///
143 /// The Ready event can optionally be provided to do some cleaning of
144 /// stalled voice states that never received their voice server update half
145 /// or vice versa. It is recommended that you process Ready events.
146 ///
147 /// # Errors
148 ///
149 /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if no nodes
150 /// have been added to the client when attempting to retrieve a guild's
151 /// player.
152 ///
153 /// [crate documentation]: crate#examples
154 pub async fn process(&self, event: &Event) -> Result<(), ClientError> {
155 tracing::trace!("processing event: {event:?}");
156
157 let guild_id = match event {
158 Event::Ready(e) => {
159 let shard_id = e.shard.map_or(0, ShardId::number);
160
161 self.clear_shard_states(shard_id);
162
163 return Ok(());
164 }
165 Event::VoiceServerUpdate(e) => {
166 self.server_updates.insert(e.guild_id, e.clone());
167 e.guild_id
168 }
169 Event::VoiceStateUpdate(e) => {
170 if e.user_id != self.user_id {
171 tracing::trace!("got voice state update from another user");
172
173 return Ok(());
174 }
175
176 if let Some(guild_id) = e.guild_id {
177 // Update player if it exists and update the connected channel ID.
178 if let Some(player) = self.players.get(&guild_id) {
179 player.set_channel_id(e.channel_id);
180 }
181
182 if e.channel_id.is_none() {
183 self.discord_sessions.remove(&guild_id);
184 self.server_updates.remove(&guild_id);
185 } else {
186 let session = DiscordSession {
187 channel_id: e.channel_id,
188 id: e.session_id.clone().into_boxed_str(),
189 };
190 self.discord_sessions.insert(guild_id, session);
191 }
192 guild_id
193 } else {
194 tracing::trace!("event has no guild ID: {e:?}");
195 return Ok(());
196 }
197 }
198 _ => return Ok(()),
199 };
200
201 tracing::debug!("got voice server/state update for {guild_id:?}: {event:?}");
202
203 let update = {
204 let server = self.server_updates.get(&guild_id);
205 let session = self.discord_sessions.get(&guild_id);
206 match (server, session) {
207 (Some(server), Some(session)) => {
208 let server = server.value();
209 let DiscordSession {
210 channel_id,
211 id: session_id,
212 } = session.value();
213 tracing::debug!(
214 "got both halves for {guild_id}: {server:?}; Session ID: {session_id:?}; Channel ID: {channel_id:?}",
215 );
216 VoiceUpdate::new(guild_id, session_id.as_ref(), *channel_id, server.clone())
217 }
218 (Some(server), None) => {
219 tracing::debug!(
220 "guild {guild_id} is now waiting for other half; got: {:?}",
221 server.value()
222 );
223 return Ok(());
224 }
225 (None, Some(session)) => {
226 tracing::debug!(
227 "guild {guild_id} is now waiting for other half; got session ID: {:?}",
228 session.value()
229 );
230 return Ok(());
231 }
232 _ => return Ok(()),
233 }
234 };
235
236 tracing::debug!("getting player for guild {guild_id}");
237
238 let player = self.player(guild_id).await?;
239
240 tracing::debug!("sending voice update for guild {guild_id}: {update:?}");
241
242 player.send(update).map_err(|source| ClientError {
243 kind: ClientErrorType::SendingVoiceUpdate,
244 source: Some(Box::new(source)),
245 })?;
246
247 tracing::debug!("sent voice update for guild {guild_id}");
248
249 Ok(())
250 }
251
252 /// Add a new node to be managed by the Lavalink client.
253 ///
254 /// If a node already exists with the provided address, then it will be
255 /// replaced.
256 ///
257 /// Pass `None` for `session_id` to create a fresh session. Pass
258 /// `Some(id)` to attempt resuming an existing Lavalink session. Resume
259 /// success is indicated by the `resumed` field on the
260 /// [`Ready`] event received via the returned [`IncomingEvents`] stream.
261 ///
262 /// If a `session_id` is provided and the session cannot be resumed,
263 /// the connection will fail with a [`Connecting`] error. The caller
264 /// can then retry with `session_id` set to `None` for a fresh session.
265 ///
266 /// [`Ready`]: crate::model::incoming::Ready
267 /// [`Connecting`]: crate::node::NodeErrorType::Connecting
268 ///
269 /// # Example
270 ///
271 /// ```no_run
272 /// # use twilight_lavalink::Lavalink;
273 /// # use twilight_model::id::Id;
274 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
275 /// # let lavalink = Lavalink::new(Id::new(1), 1);
276 /// # let address = "127.0.0.1:2333".parse()?;
277 /// # let auth = "youshallnotpass";
278 /// // Fresh session (no resume)
279 /// let (node, events) = lavalink.add(address, auth, None).await?;
280 ///
281 /// // Resume an existing session
282 /// let old_session_id = Some("existing-session-id".to_string());
283 /// let (node, events) = lavalink.add(address, auth, old_session_id).await?;
284 /// # Ok(())
285 /// # }
286 /// ```
287 ///
288 /// # Errors
289 ///
290 /// See the errors section of [`Node::connect`].
291 pub async fn add(
292 &self,
293 address: SocketAddr,
294 authorization: impl Into<String>,
295 session_id: Option<String>,
296 ) -> Result<(Arc<Node>, IncomingEvents), NodeError> {
297 let config = NodeConfig {
298 address,
299 authorization: authorization.into(),
300 user_id: self.user_id,
301 enable_tls: cfg!(feature = "tls"),
302 session_id,
303 };
304
305 let (node, rx) = Node::connect(config, self.players.clone()).await?;
306 let node = Arc::new(node);
307 self.nodes.insert(address, Arc::clone(&node));
308
309 Ok((node, rx))
310 }
311
312 /// Remove a node from the list of nodes being managed by the Lavalink
313 /// client.
314 ///
315 /// This does not disconnect the node. Use [`Lavalink::disconnect`] instead.
316 /// or drop all [`Node`]s.
317 ///
318 /// The node is returned if it existed.
319 pub fn remove(&self, address: SocketAddr) -> Option<(SocketAddr, Arc<Node>)> {
320 self.nodes.remove(&address)
321 }
322
323 /// Remove a node from the list of nodes being managed by the Lavalink
324 /// client and terminates the connection.
325 ///
326 /// Use [`Lavalink::remove`] if detaching a node from a Lavalink instance
327 /// is required without closing the underlying connection.
328 ///
329 /// Returns whether the node has been removed and disconnected.
330 pub fn disconnect(&self, address: SocketAddr) -> bool {
331 self.nodes.remove(&address).is_some()
332 }
333
334 /// Determine the "best" node for new players according to available nodes'
335 /// penalty scores. Disconnected nodes will not be considered.
336 ///
337 /// Refer to [`Node::penalty`] for how this is calculated.
338 ///
339 /// # Errors
340 ///
341 /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if there are
342 /// no connected nodes available in the client.
343 ///
344 /// [`Node::penalty`]: crate::node::Node::penalty
345 pub async fn best(&self) -> Result<Arc<Node>, ClientError> {
346 let mut lowest = i32::MAX;
347 let mut best = None;
348
349 for node in &self.nodes {
350 if node.sender().is_closed() {
351 continue;
352 }
353
354 let penalty = node.value().penalty().await;
355
356 if penalty < lowest {
357 lowest = penalty;
358 best.replace(node.clone());
359 }
360 }
361
362 best.ok_or(ClientError {
363 kind: ClientErrorType::NodesUnconfigured,
364 source: None,
365 })
366 }
367
368 /// Retrieve an immutable reference to the player manager.
369 pub const fn players(&self) -> &PlayerManager {
370 &self.players
371 }
372
373 /// Retrieve a player for the guild.
374 ///
375 /// Creates a player configured to use the best available node if a player
376 /// for the guild doesn't already exist. Use [`PlayerManager::get`] to only
377 /// retrieve and not create.
378 ///
379 /// # Errors
380 ///
381 /// Returns a [`ClientError`] with a [`ClientErrorType::NodesUnconfigured`]
382 /// type if no node has been configured via [`add`].
383 ///
384 /// [`PlayerManager::get`]: crate::player::PlayerManager::get
385 /// [`add`]: Self::add
386 pub async fn player(&self, guild_id: Id<GuildMarker>) -> Result<Arc<Player>, ClientError> {
387 if let Some(player) = self.players().get(&guild_id) {
388 return Ok(player);
389 }
390
391 let node = self.best().await?;
392
393 Ok(self.players().get_or_insert(guild_id, node))
394 }
395
396 /// Clear out the map of guild states/updates for a shard that are waiting
397 /// for their other half.
398 ///
399 /// We can do this by iterating over the map and removing the ones that we
400 /// can calculate came from a shard.
401 ///
402 /// This map should be small or empty, and if it isn't, then it needs to be
403 /// cleared out anyway.
404 fn clear_shard_states(&self, shard_id: u32) {
405 let shard_count = u64::from(self.shard_count);
406
407 self.server_updates
408 .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
409 self.discord_sessions
410 .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::{ClientError, ClientErrorType, Lavalink};
417 use static_assertions::assert_impl_all;
418 use std::{error::Error, fmt::Debug};
419
420 assert_impl_all!(ClientErrorType: Debug, Send, Sync);
421 assert_impl_all!(ClientError: Error, Send, Sync);
422 assert_impl_all!(Lavalink: Debug, Send, Sync);
423}