twilight_lavalink/client.rs
1//! Client to manage nodes and players.
2
3use crate::{
4 model::VoiceUpdate,
5 node::{IncomingEvents, Node, NodeConfig, NodeError, Resume},
6 player::{Player, PlayerManager},
7};
8use dashmap::DashMap;
9use std::{
10 error::Error,
11 fmt::{Display, Formatter, Result as FmtResult},
12 net::SocketAddr,
13 sync::Arc,
14};
15use twilight_model::{
16 gateway::{event::Event, payload::incoming::VoiceServerUpdate, ShardId},
17 id::{
18 marker::{GuildMarker, UserMarker},
19 Id,
20 },
21};
22
23/// An error that can occur while interacting with the client.
24#[derive(Debug)]
25pub struct ClientError {
26 kind: ClientErrorType,
27 source: Option<Box<dyn Error + Send + Sync>>,
28}
29
30impl ClientError {
31 /// Immutable reference to the type of error that occurred.
32 pub const fn kind(&self) -> &ClientErrorType {
33 &self.kind
34 }
35
36 /// Consume the error, returning the source error if there is any.
37 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
38 self.source
39 }
40
41 /// Consume the error, returning the owned error type and the source error.
42 #[must_use = "consuming the error into its parts has no effect if left unused"]
43 pub fn into_parts(self) -> (ClientErrorType, Option<Box<dyn Error + Send + Sync>>) {
44 (self.kind, self.source)
45 }
46}
47
48impl Display for ClientError {
49 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
50 match &self.kind {
51 ClientErrorType::NodesUnconfigured => f.write_str("no node has been configured"),
52 ClientErrorType::SendingVoiceUpdate => {
53 f.write_str("couldn't send voice update to node")
54 }
55 }
56 }
57}
58
59impl Error for ClientError {
60 fn source(&self) -> Option<&(dyn Error + 'static)> {
61 self.source
62 .as_ref()
63 .map(|source| &**source as &(dyn Error + 'static))
64 }
65}
66
67/// Type of [`ClientError`] that occurred.
68#[derive(Debug)]
69#[non_exhaustive]
70pub enum ClientErrorType {
71 /// A node isn't configured, so the operation isn't possible to fulfill.
72 NodesUnconfigured,
73 /// Sending a voice update event to the node failed because the node's
74 /// connection was shutdown.
75 SendingVoiceUpdate,
76}
77
78/// The lavalink client that manages nodes, players, and processes events from
79/// Discord to tie it all together.
80///
81/// **Note**: You must call the [`process`] method with every Voice State Update
82/// and Voice Server Update event you receive from Discord. It will
83/// automatically forward these events to Lavalink. See its documentation for
84/// more information.
85///
86/// You can retrieve players using the [`player`] method. Players contain
87/// information about the active playing information of a guild and allows you to send events to the
88/// connected node, such as [`Play`] events.
89///
90/// # Using a Lavalink client in multiple tasks
91///
92/// To use a Lavalink client instance in multiple tasks, consider wrapping it in
93/// an [`std::sync::Arc`] or [`std::rc::Rc`].
94///
95/// [`Play`]: crate::model::outgoing::Play
96/// [`player`]: Self::player
97/// [`process`]: Self::process
98#[derive(Debug)]
99pub struct Lavalink {
100 nodes: DashMap<SocketAddr, Arc<Node>>,
101 players: PlayerManager,
102 resume: Option<Resume>,
103 shard_count: u32,
104 user_id: Id<UserMarker>,
105 server_updates: DashMap<Id<GuildMarker>, VoiceServerUpdate>,
106 sessions: DashMap<Id<GuildMarker>, Box<str>>,
107}
108
109impl Lavalink {
110 /// Create a new Lavalink client instance.
111 ///
112 /// The user ID and number of shards provided may not be modified during
113 /// runtime, and the client must be re-created. These parameters are
114 /// automatically passed to new nodes created via [`add`].
115 ///
116 /// See also [`new_with_resume`], which allows you to specify session resume
117 /// capability.
118 ///
119 /// [`add`]: Self::add
120 /// [`new_with_resume`]: Self::new_with_resume
121 pub fn new(user_id: Id<UserMarker>, shard_count: u32) -> Self {
122 Self::_new_with_resume(user_id, shard_count, None)
123 }
124
125 /// Like [`new`], but allows you to specify resume capability (if any).
126 ///
127 /// Provide `None` for the `resume` parameter to disable session resume
128 /// capability. See the [`Resume`] documentation for defaults.
129 ///
130 /// [`Resume`]: crate::node::Resume
131 /// [`new`]: Self::new
132 pub fn new_with_resume(
133 user_id: Id<UserMarker>,
134 shard_count: u32,
135 resume: impl Into<Option<Resume>>,
136 ) -> Self {
137 Self::_new_with_resume(user_id, shard_count, resume.into())
138 }
139
140 fn _new_with_resume(user_id: Id<UserMarker>, shard_count: u32, resume: Option<Resume>) -> Self {
141 Self {
142 nodes: DashMap::new(),
143 players: PlayerManager::new(),
144 resume,
145 shard_count,
146 user_id,
147 server_updates: DashMap::new(),
148 sessions: DashMap::new(),
149 }
150 }
151
152 /// Process an event into the Lavalink client.
153 ///
154 /// **Note**: calling this method in your event loop is required. See the
155 /// [crate documentation] for an example.
156 ///
157 /// This requires the `VoiceServerUpdate` and `VoiceStateUpdate` events that
158 /// you receive from Discord over the gateway to send voice updates to
159 /// nodes. For simplicity in some applications' event loops, any event can
160 /// be provided, but they will just be ignored.
161 ///
162 /// The Ready event can optionally be provided to do some cleaning of
163 /// stalled voice states that never received their voice server update half
164 /// or vice versa. It is recommended that you process Ready events.
165 ///
166 /// # Errors
167 ///
168 /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if no nodes
169 /// have been added to the client when attempting to retrieve a guild's
170 /// player.
171 ///
172 /// [crate documentation]: crate#examples
173 pub async fn process(&self, event: &Event) -> Result<(), ClientError> {
174 tracing::trace!("processing event: {event:?}");
175
176 let guild_id = match event {
177 Event::Ready(e) => {
178 let shard_id = e.shard.map_or(0, ShardId::number);
179
180 self.clear_shard_states(shard_id);
181
182 return Ok(());
183 }
184 Event::VoiceServerUpdate(e) => {
185 self.server_updates.insert(e.guild_id, e.clone());
186 e.guild_id
187 }
188 Event::VoiceStateUpdate(e) => {
189 if e.user_id != self.user_id {
190 tracing::trace!("got voice state update from another user");
191
192 return Ok(());
193 }
194
195 if let Some(guild_id) = e.guild_id {
196 // Update player if it exists and update the connected channel ID.
197 if let Some(player) = self.players.get(&guild_id) {
198 player.set_channel_id(e.channel_id);
199 }
200
201 if e.channel_id.is_none() {
202 self.sessions.remove(&guild_id);
203 self.server_updates.remove(&guild_id);
204 } else {
205 self.sessions
206 .insert(guild_id, e.session_id.clone().into_boxed_str());
207 }
208 guild_id
209 } else {
210 tracing::trace!("event has no guild ID: {e:?}");
211 return Ok(());
212 }
213 }
214 _ => return Ok(()),
215 };
216
217 tracing::debug!("got voice server/state update for {guild_id:?}: {event:?}");
218
219 let update = {
220 let server = self.server_updates.get(&guild_id);
221 let session = self.sessions.get(&guild_id);
222 match (server, session) {
223 (Some(server), Some(session)) => {
224 let server = server.value();
225 let session = session.value();
226 tracing::debug!(
227 "got both halves for {guild_id}: {server:?}; Session ID: {session:?}",
228 );
229 VoiceUpdate::new(guild_id, session.as_ref(), server.clone())
230 }
231 (Some(server), None) => {
232 tracing::debug!(
233 "guild {guild_id} is now waiting for other half; got: {:?}",
234 server.value()
235 );
236 return Ok(());
237 }
238 (None, Some(session)) => {
239 tracing::debug!(
240 "guild {guild_id} is now waiting for other half; got session ID: {:?}",
241 session.value()
242 );
243 return Ok(());
244 }
245 _ => return Ok(()),
246 }
247 };
248
249 tracing::debug!("getting player for guild {guild_id}");
250
251 let player = self.player(guild_id).await?;
252
253 tracing::debug!("sending voice update for guild {guild_id}: {update:?}");
254
255 player.send(update).map_err(|source| ClientError {
256 kind: ClientErrorType::SendingVoiceUpdate,
257 source: Some(Box::new(source)),
258 })?;
259
260 tracing::debug!("sent voice update for guild {guild_id}");
261
262 Ok(())
263 }
264
265 /// Add a new node to be managed by the Lavalink client.
266 ///
267 /// If a node already exists with the provided address, then it will be
268 /// replaced.
269 ///
270 /// # Errors
271 ///
272 /// See the errors section of [`Node::connect`].
273 pub async fn add(
274 &self,
275 address: SocketAddr,
276 authorization: impl Into<String>,
277 ) -> Result<(Arc<Node>, IncomingEvents), NodeError> {
278 let config = NodeConfig {
279 address,
280 authorization: authorization.into(),
281 resume: self.resume.clone(),
282 user_id: self.user_id,
283 };
284
285 let (node, rx) = Node::connect(config, self.players.clone()).await?;
286 let node = Arc::new(node);
287 self.nodes.insert(address, Arc::clone(&node));
288
289 Ok((node, rx))
290 }
291
292 /// Remove a node from the list of nodes being managed by the Lavalink
293 /// client.
294 ///
295 /// This does not disconnect the node. Use [`Lavalink::disconnect`] instead.
296 /// or drop all [`Node`]s.
297 ///
298 /// The node is returned if it existed.
299 pub fn remove(&self, address: SocketAddr) -> Option<(SocketAddr, Arc<Node>)> {
300 self.nodes.remove(&address)
301 }
302
303 /// Remove a node from the list of nodes being managed by the Lavalink
304 /// client and terminates the connection.
305 ///
306 /// Use [`Lavalink::remove`] if detaching a node from a Lavalink instance
307 /// is required without closing the underlying connection.
308 ///
309 /// Returns whether the node has been removed and disconnected.
310 pub fn disconnect(&self, address: SocketAddr) -> bool {
311 self.nodes.remove(&address).is_some()
312 }
313
314 /// Determine the "best" node for new players according to available nodes'
315 /// penalty scores. Disconnected nodes will not be considered.
316 ///
317 /// Refer to [`Node::penalty`] for how this is calculated.
318 ///
319 /// # Errors
320 ///
321 /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if there are
322 /// no connected nodes available in the client.
323 ///
324 /// [`Node::penalty`]: crate::node::Node::penalty
325 pub async fn best(&self) -> Result<Arc<Node>, ClientError> {
326 let mut lowest = i32::MAX;
327 let mut best = None;
328
329 for node in &self.nodes {
330 if node.sender().is_closed() {
331 continue;
332 }
333
334 let penalty = node.value().penalty().await;
335
336 if penalty < lowest {
337 lowest = penalty;
338 best.replace(node.clone());
339 }
340 }
341
342 best.ok_or(ClientError {
343 kind: ClientErrorType::NodesUnconfigured,
344 source: None,
345 })
346 }
347
348 /// Retrieve an immutable reference to the player manager.
349 pub const fn players(&self) -> &PlayerManager {
350 &self.players
351 }
352
353 /// Retrieve a player for the guild.
354 ///
355 /// Creates a player configured to use the best available node if a player
356 /// for the guild doesn't already exist. Use [`PlayerManager::get`] to only
357 /// retrieve and not create.
358 ///
359 /// # Errors
360 ///
361 /// Returns a [`ClientError`] with a [`ClientErrorType::NodesUnconfigured`]
362 /// type if no node has been configured via [`add`].
363 ///
364 /// [`PlayerManager::get`]: crate::player::PlayerManager::get
365 /// [`add`]: Self::add
366 pub async fn player(&self, guild_id: Id<GuildMarker>) -> Result<Arc<Player>, ClientError> {
367 if let Some(player) = self.players().get(&guild_id) {
368 return Ok(player);
369 }
370
371 let node = self.best().await?;
372
373 Ok(self.players().get_or_insert(guild_id, node))
374 }
375
376 /// Clear out the map of guild states/updates for a shard that are waiting
377 /// for their other half.
378 ///
379 /// We can do this by iterating over the map and removing the ones that we
380 /// can calculate came from a shard.
381 ///
382 /// This map should be small or empty, and if it isn't, then it needs to be
383 /// cleared out anyway.
384 fn clear_shard_states(&self, shard_id: u32) {
385 let shard_count = u64::from(self.shard_count);
386
387 self.server_updates
388 .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
389 self.sessions
390 .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::{ClientError, ClientErrorType, Lavalink};
397 use static_assertions::assert_impl_all;
398 use std::{error::Error, fmt::Debug};
399
400 assert_impl_all!(ClientErrorType: Debug, Send, Sync);
401 assert_impl_all!(ClientError: Error, Send, Sync);
402 assert_impl_all!(Lavalink: Debug, Send, Sync);
403}