twilight_lavalink/
client.rs

1//! Client to manage nodes and players.
2
3use crate::{
4    model::VoiceUpdate,
5    node::{IncomingEvents, Node, NodeConfig, NodeError, Resume},
6    player::{Player, PlayerManager},
7};
8use dashmap::DashMap;
9use std::{
10    error::Error,
11    fmt::{Display, Formatter, Result as FmtResult},
12    net::SocketAddr,
13    sync::Arc,
14};
15use twilight_model::{
16    gateway::{event::Event, payload::incoming::VoiceServerUpdate, ShardId},
17    id::{
18        marker::{GuildMarker, UserMarker},
19        Id,
20    },
21};
22
23/// An error that can occur while interacting with the client.
24#[derive(Debug)]
25pub struct ClientError {
26    kind: ClientErrorType,
27    source: Option<Box<dyn Error + Send + Sync>>,
28}
29
30impl ClientError {
31    /// Immutable reference to the type of error that occurred.
32    pub const fn kind(&self) -> &ClientErrorType {
33        &self.kind
34    }
35
36    /// Consume the error, returning the source error if there is any.
37    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
38        self.source
39    }
40
41    /// Consume the error, returning the owned error type and the source error.
42    #[must_use = "consuming the error into its parts has no effect if left unused"]
43    pub fn into_parts(self) -> (ClientErrorType, Option<Box<dyn Error + Send + Sync>>) {
44        (self.kind, self.source)
45    }
46}
47
48impl Display for ClientError {
49    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
50        match &self.kind {
51            ClientErrorType::NodesUnconfigured => f.write_str("no node has been configured"),
52            ClientErrorType::SendingVoiceUpdate => {
53                f.write_str("couldn't send voice update to node")
54            }
55        }
56    }
57}
58
59impl Error for ClientError {
60    fn source(&self) -> Option<&(dyn Error + 'static)> {
61        self.source
62            .as_ref()
63            .map(|source| &**source as &(dyn Error + 'static))
64    }
65}
66
67/// Type of [`ClientError`] that occurred.
68#[derive(Debug)]
69#[non_exhaustive]
70pub enum ClientErrorType {
71    /// A node isn't configured, so the operation isn't possible to fulfill.
72    NodesUnconfigured,
73    /// Sending a voice update event to the node failed because the node's
74    /// connection was shutdown.
75    SendingVoiceUpdate,
76}
77
78/// The lavalink client that manages nodes, players, and processes events from
79/// Discord to tie it all together.
80///
81/// **Note**: You must call the [`process`] method with every Voice State Update
82/// and Voice Server Update event you receive from Discord. It will
83/// automatically forward these events to Lavalink. See its documentation for
84/// more information.
85///
86/// You can retrieve players using the [`player`] method. Players contain
87/// information about the active playing information of a guild and allows you to send events to the
88/// connected node, such as [`Play`] events.
89///
90/// # Using a Lavalink client in multiple tasks
91///
92/// To use a Lavalink client instance in multiple tasks, consider wrapping it in
93/// an [`std::sync::Arc`] or [`std::rc::Rc`].
94///
95/// [`Play`]: crate::model::outgoing::Play
96/// [`player`]: Self::player
97/// [`process`]: Self::process
98#[derive(Debug)]
99pub struct Lavalink {
100    nodes: DashMap<SocketAddr, Arc<Node>>,
101    players: PlayerManager,
102    resume: Option<Resume>,
103    shard_count: u32,
104    user_id: Id<UserMarker>,
105    server_updates: DashMap<Id<GuildMarker>, VoiceServerUpdate>,
106    sessions: DashMap<Id<GuildMarker>, Box<str>>,
107}
108
109impl Lavalink {
110    /// Create a new Lavalink client instance.
111    ///
112    /// The user ID and number of shards provided may not be modified during
113    /// runtime, and the client must be re-created. These parameters are
114    /// automatically passed to new nodes created via [`add`].
115    ///
116    /// See also [`new_with_resume`], which allows you to specify session resume
117    /// capability.
118    ///
119    /// [`add`]: Self::add
120    /// [`new_with_resume`]: Self::new_with_resume
121    pub fn new(user_id: Id<UserMarker>, shard_count: u32) -> Self {
122        Self::_new_with_resume(user_id, shard_count, None)
123    }
124
125    /// Like [`new`], but allows you to specify resume capability (if any).
126    ///
127    /// Provide `None` for the `resume` parameter to disable session resume
128    /// capability. See the [`Resume`] documentation for defaults.
129    ///
130    /// [`Resume`]: crate::node::Resume
131    /// [`new`]: Self::new
132    pub fn new_with_resume(
133        user_id: Id<UserMarker>,
134        shard_count: u32,
135        resume: impl Into<Option<Resume>>,
136    ) -> Self {
137        Self::_new_with_resume(user_id, shard_count, resume.into())
138    }
139
140    fn _new_with_resume(user_id: Id<UserMarker>, shard_count: u32, resume: Option<Resume>) -> Self {
141        Self {
142            nodes: DashMap::new(),
143            players: PlayerManager::new(),
144            resume,
145            shard_count,
146            user_id,
147            server_updates: DashMap::new(),
148            sessions: DashMap::new(),
149        }
150    }
151
152    /// Process an event into the Lavalink client.
153    ///
154    /// **Note**: calling this method in your event loop is required. See the
155    /// [crate documentation] for an example.
156    ///
157    /// This requires the `VoiceServerUpdate` and `VoiceStateUpdate` events that
158    /// you receive from Discord over the gateway to send voice updates to
159    /// nodes. For simplicity in some applications' event loops, any event can
160    /// be provided, but they will just be ignored.
161    ///
162    /// The Ready event can optionally be provided to do some cleaning of
163    /// stalled voice states that never received their voice server update half
164    /// or vice versa. It is recommended that you process Ready events.
165    ///
166    /// # Errors
167    ///
168    /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if no nodes
169    /// have been added to the client when attempting to retrieve a guild's
170    /// player.
171    ///
172    /// [crate documentation]: crate#examples
173    pub async fn process(&self, event: &Event) -> Result<(), ClientError> {
174        tracing::trace!("processing event: {event:?}");
175
176        let guild_id = match event {
177            Event::Ready(e) => {
178                let shard_id = e.shard.map_or(0, ShardId::number);
179
180                self.clear_shard_states(shard_id);
181
182                return Ok(());
183            }
184            Event::VoiceServerUpdate(e) => {
185                self.server_updates.insert(e.guild_id, e.clone());
186                e.guild_id
187            }
188            Event::VoiceStateUpdate(e) => {
189                if e.user_id != self.user_id {
190                    tracing::trace!("got voice state update from another user");
191
192                    return Ok(());
193                }
194
195                if let Some(guild_id) = e.guild_id {
196                    // Update player if it exists and update the connected channel ID.
197                    if let Some(player) = self.players.get(&guild_id) {
198                        player.set_channel_id(e.channel_id);
199                    }
200
201                    if e.channel_id.is_none() {
202                        self.sessions.remove(&guild_id);
203                        self.server_updates.remove(&guild_id);
204                    } else {
205                        self.sessions
206                            .insert(guild_id, e.session_id.clone().into_boxed_str());
207                    }
208                    guild_id
209                } else {
210                    tracing::trace!("event has no guild ID: {e:?}");
211                    return Ok(());
212                }
213            }
214            _ => return Ok(()),
215        };
216
217        tracing::debug!("got voice server/state update for {guild_id:?}: {event:?}");
218
219        let update = {
220            let server = self.server_updates.get(&guild_id);
221            let session = self.sessions.get(&guild_id);
222            match (server, session) {
223                (Some(server), Some(session)) => {
224                    let server = server.value();
225                    let session = session.value();
226                    tracing::debug!(
227                        "got both halves for {guild_id}: {server:?}; Session ID: {session:?}",
228                    );
229                    VoiceUpdate::new(guild_id, session.as_ref(), server.clone())
230                }
231                (Some(server), None) => {
232                    tracing::debug!(
233                        "guild {guild_id} is now waiting for other half; got: {:?}",
234                        server.value()
235                    );
236                    return Ok(());
237                }
238                (None, Some(session)) => {
239                    tracing::debug!(
240                        "guild {guild_id} is now waiting for other half; got session ID: {:?}",
241                        session.value()
242                    );
243                    return Ok(());
244                }
245                _ => return Ok(()),
246            }
247        };
248
249        tracing::debug!("getting player for guild {guild_id}");
250
251        let player = self.player(guild_id).await?;
252
253        tracing::debug!("sending voice update for guild {guild_id}: {update:?}");
254
255        player.send(update).map_err(|source| ClientError {
256            kind: ClientErrorType::SendingVoiceUpdate,
257            source: Some(Box::new(source)),
258        })?;
259
260        tracing::debug!("sent voice update for guild {guild_id}");
261
262        Ok(())
263    }
264
265    /// Add a new node to be managed by the Lavalink client.
266    ///
267    /// If a node already exists with the provided address, then it will be
268    /// replaced.
269    ///
270    /// # Errors
271    ///
272    /// See the errors section of [`Node::connect`].
273    pub async fn add(
274        &self,
275        address: SocketAddr,
276        authorization: impl Into<String>,
277    ) -> Result<(Arc<Node>, IncomingEvents), NodeError> {
278        let config = NodeConfig {
279            address,
280            authorization: authorization.into(),
281            resume: self.resume.clone(),
282            user_id: self.user_id,
283        };
284
285        let (node, rx) = Node::connect(config, self.players.clone()).await?;
286        let node = Arc::new(node);
287        self.nodes.insert(address, Arc::clone(&node));
288
289        Ok((node, rx))
290    }
291
292    /// Remove a node from the list of nodes being managed by the Lavalink
293    /// client.
294    ///
295    /// This does not disconnect the node. Use [`Lavalink::disconnect`] instead.
296    /// or drop all [`Node`]s.
297    ///
298    /// The node is returned if it existed.
299    pub fn remove(&self, address: SocketAddr) -> Option<(SocketAddr, Arc<Node>)> {
300        self.nodes.remove(&address)
301    }
302
303    /// Remove a node from the list of nodes being managed by the Lavalink
304    /// client and terminates the connection.
305    ///
306    /// Use [`Lavalink::remove`] if detaching a node from a Lavalink instance
307    /// is required without closing the underlying connection.
308    ///
309    /// Returns whether the node has been removed and disconnected.
310    pub fn disconnect(&self, address: SocketAddr) -> bool {
311        self.nodes.remove(&address).is_some()
312    }
313
314    /// Determine the "best" node for new players according to available nodes'
315    /// penalty scores. Disconnected nodes will not be considered.
316    ///
317    /// Refer to [`Node::penalty`] for how this is calculated.
318    ///
319    /// # Errors
320    ///
321    /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if there are
322    /// no connected nodes available in the client.
323    ///
324    /// [`Node::penalty`]: crate::node::Node::penalty
325    pub async fn best(&self) -> Result<Arc<Node>, ClientError> {
326        let mut lowest = i32::MAX;
327        let mut best = None;
328
329        for node in &self.nodes {
330            if node.sender().is_closed() {
331                continue;
332            }
333
334            let penalty = node.value().penalty().await;
335
336            if penalty < lowest {
337                lowest = penalty;
338                best.replace(node.clone());
339            }
340        }
341
342        best.ok_or(ClientError {
343            kind: ClientErrorType::NodesUnconfigured,
344            source: None,
345        })
346    }
347
348    /// Retrieve an immutable reference to the player manager.
349    pub const fn players(&self) -> &PlayerManager {
350        &self.players
351    }
352
353    /// Retrieve a player for the guild.
354    ///
355    /// Creates a player configured to use the best available node if a player
356    /// for the guild doesn't already exist. Use [`PlayerManager::get`] to only
357    /// retrieve and not create.
358    ///
359    /// # Errors
360    ///
361    /// Returns a [`ClientError`] with a [`ClientErrorType::NodesUnconfigured`]
362    /// type if no node has been configured via [`add`].
363    ///
364    /// [`PlayerManager::get`]: crate::player::PlayerManager::get
365    /// [`add`]: Self::add
366    pub async fn player(&self, guild_id: Id<GuildMarker>) -> Result<Arc<Player>, ClientError> {
367        if let Some(player) = self.players().get(&guild_id) {
368            return Ok(player);
369        }
370
371        let node = self.best().await?;
372
373        Ok(self.players().get_or_insert(guild_id, node))
374    }
375
376    /// Clear out the map of guild states/updates for a shard that are waiting
377    /// for their other half.
378    ///
379    /// We can do this by iterating over the map and removing the ones that we
380    /// can calculate came from a shard.
381    ///
382    /// This map should be small or empty, and if it isn't, then it needs to be
383    /// cleared out anyway.
384    fn clear_shard_states(&self, shard_id: u32) {
385        let shard_count = u64::from(self.shard_count);
386
387        self.server_updates
388            .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
389        self.sessions
390            .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::{ClientError, ClientErrorType, Lavalink};
397    use static_assertions::assert_impl_all;
398    use std::{error::Error, fmt::Debug};
399
400    assert_impl_all!(ClientErrorType: Debug, Send, Sync);
401    assert_impl_all!(ClientError: Error, Send, Sync);
402    assert_impl_all!(Lavalink: Debug, Send, Sync);
403}