Skip to main content

twilight_lavalink/
client.rs

1//! Client to manage nodes and players.
2
3use crate::{
4    model::VoiceUpdate,
5    node::{IncomingEvents, Node, NodeConfig, NodeError},
6    player::{Player, PlayerManager},
7};
8use dashmap::DashMap;
9use std::{
10    error::Error,
11    fmt::{Display, Formatter, Result as FmtResult},
12    net::SocketAddr,
13    sync::Arc,
14};
15use twilight_model::{
16    gateway::{ShardId, event::Event, payload::incoming::VoiceServerUpdate},
17    id::{
18        Id,
19        marker::{ChannelMarker, GuildMarker, UserMarker},
20    },
21};
22
23/// An error that can occur while interacting with the client.
24#[derive(Debug)]
25pub struct ClientError {
26    kind: ClientErrorType,
27    source: Option<Box<dyn Error + Send + Sync>>,
28}
29
30impl ClientError {
31    /// Immutable reference to the type of error that occurred.
32    pub const fn kind(&self) -> &ClientErrorType {
33        &self.kind
34    }
35
36    /// Consume the error, returning the source error if there is any.
37    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
38        self.source
39    }
40
41    /// Consume the error, returning the owned error type and the source error.
42    #[must_use = "consuming the error into its parts has no effect if left unused"]
43    pub fn into_parts(self) -> (ClientErrorType, Option<Box<dyn Error + Send + Sync>>) {
44        (self.kind, self.source)
45    }
46}
47
48impl Display for ClientError {
49    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
50        match &self.kind {
51            ClientErrorType::NodesUnconfigured => f.write_str("no node has been configured"),
52            ClientErrorType::SendingVoiceUpdate => {
53                f.write_str("couldn't send voice update to node")
54            }
55        }
56    }
57}
58
59impl Error for ClientError {
60    fn source(&self) -> Option<&(dyn Error + 'static)> {
61        self.source
62            .as_ref()
63            .map(|source| &**source as &(dyn Error + 'static))
64    }
65}
66
67/// Type of [`ClientError`] that occurred.
68#[derive(Debug)]
69#[non_exhaustive]
70pub enum ClientErrorType {
71    /// A node isn't configured, so the operation isn't possible to fulfill.
72    NodesUnconfigured,
73    /// Sending a voice update event to the node failed because the node's
74    /// connection was shutdown.
75    SendingVoiceUpdate,
76}
77
78#[derive(Debug)]
79struct DiscordSession {
80    channel_id: Option<Id<ChannelMarker>>,
81    id: Box<str>,
82}
83
84/// The lavalink client that manages nodes, players, and processes events from
85/// Discord to tie it all together.
86///
87/// **Note**: You must call the [`process`] method with every Voice State Update
88/// and Voice Server Update event you receive from Discord. It will
89/// automatically forward these events to Lavalink. See its documentation for
90/// more information.
91///
92/// You can retrieve players using the [`player`] method. Players contain
93/// information about the active playing information of a guild and allows you to send events to the
94/// connected node, such as [`Play`] events.
95///
96/// # Using a Lavalink client in multiple tasks
97///
98/// To use a Lavalink client instance in multiple tasks, consider wrapping it in
99/// an [`std::sync::Arc`] or [`std::rc::Rc`].
100///
101/// [`Play`]: crate::model::outgoing::Play
102/// [`player`]: Self::player
103/// [`process`]: Self::process
104#[derive(Debug)]
105pub struct Lavalink {
106    nodes: DashMap<SocketAddr, Arc<Node>>,
107    players: PlayerManager,
108    shard_count: u32,
109    user_id: Id<UserMarker>,
110    server_updates: DashMap<Id<GuildMarker>, VoiceServerUpdate>,
111    discord_sessions: DashMap<Id<GuildMarker>, DiscordSession>,
112}
113
114impl Lavalink {
115    /// Create a new Lavalink client instance.
116    ///
117    /// The user ID and number of shards provided may not be modified during
118    /// runtime, and the client must be re-created. These parameters are
119    /// automatically passed to new nodes created via [`add`].
120    ///
121    /// [`add`]: Self::add
122    pub fn new(user_id: Id<UserMarker>, shard_count: u32) -> Self {
123        Self {
124            nodes: DashMap::new(),
125            players: PlayerManager::new(),
126            shard_count,
127            user_id,
128            server_updates: DashMap::new(),
129            discord_sessions: DashMap::new(),
130        }
131    }
132
133    /// Process an event into the Lavalink client.
134    ///
135    /// **Note**: calling this method in your event loop is required. See the
136    /// [crate documentation] for an example.
137    ///
138    /// This requires the `VoiceServerUpdate` and `VoiceStateUpdate` events that
139    /// you receive from Discord over the gateway to send voice updates to
140    /// nodes. For simplicity in some applications' event loops, any event can
141    /// be provided, but they will just be ignored.
142    ///
143    /// The Ready event can optionally be provided to do some cleaning of
144    /// stalled voice states that never received their voice server update half
145    /// or vice versa. It is recommended that you process Ready events.
146    ///
147    /// # Errors
148    ///
149    /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if no nodes
150    /// have been added to the client when attempting to retrieve a guild's
151    /// player.
152    ///
153    /// [crate documentation]: crate#examples
154    pub async fn process(&self, event: &Event) -> Result<(), ClientError> {
155        tracing::trace!("processing event: {event:?}");
156
157        let guild_id = match event {
158            Event::Ready(e) => {
159                let shard_id = e.shard.map_or(0, ShardId::number);
160
161                self.clear_shard_states(shard_id);
162
163                return Ok(());
164            }
165            Event::VoiceServerUpdate(e) => {
166                self.server_updates.insert(e.guild_id, e.clone());
167                e.guild_id
168            }
169            Event::VoiceStateUpdate(e) => {
170                if e.user_id != self.user_id {
171                    tracing::trace!("got voice state update from another user");
172
173                    return Ok(());
174                }
175
176                if let Some(guild_id) = e.guild_id {
177                    // Update player if it exists and update the connected channel ID.
178                    if let Some(player) = self.players.get(&guild_id) {
179                        player.set_channel_id(e.channel_id);
180                    }
181
182                    if e.channel_id.is_none() {
183                        self.discord_sessions.remove(&guild_id);
184                        self.server_updates.remove(&guild_id);
185                    } else {
186                        let session = DiscordSession {
187                            channel_id: e.channel_id,
188                            id: e.session_id.clone().into_boxed_str(),
189                        };
190                        self.discord_sessions.insert(guild_id, session);
191                    }
192                    guild_id
193                } else {
194                    tracing::trace!("event has no guild ID: {e:?}");
195                    return Ok(());
196                }
197            }
198            _ => return Ok(()),
199        };
200
201        tracing::debug!("got voice server/state update for {guild_id:?}: {event:?}");
202
203        let update = {
204            let server = self.server_updates.get(&guild_id);
205            let session = self.discord_sessions.get(&guild_id);
206            match (server, session) {
207                (Some(server), Some(session)) => {
208                    let server = server.value();
209                    let DiscordSession {
210                        channel_id,
211                        id: session_id,
212                    } = session.value();
213                    tracing::debug!(
214                        "got both halves for {guild_id}: {server:?}; Session ID: {session_id:?}; Channel ID: {channel_id:?}",
215                    );
216                    VoiceUpdate::new(guild_id, session_id.as_ref(), *channel_id, server.clone())
217                }
218                (Some(server), None) => {
219                    tracing::debug!(
220                        "guild {guild_id} is now waiting for other half; got: {:?}",
221                        server.value()
222                    );
223                    return Ok(());
224                }
225                (None, Some(session)) => {
226                    tracing::debug!(
227                        "guild {guild_id} is now waiting for other half; got session ID: {:?}",
228                        session.value()
229                    );
230                    return Ok(());
231                }
232                _ => return Ok(()),
233            }
234        };
235
236        tracing::debug!("getting player for guild {guild_id}");
237
238        let player = self.player(guild_id).await?;
239
240        tracing::debug!("sending voice update for guild {guild_id}: {update:?}");
241
242        player.send(update).map_err(|source| ClientError {
243            kind: ClientErrorType::SendingVoiceUpdate,
244            source: Some(Box::new(source)),
245        })?;
246
247        tracing::debug!("sent voice update for guild {guild_id}");
248
249        Ok(())
250    }
251
252    /// Add a new node to be managed by the Lavalink client.
253    ///
254    /// If a node already exists with the provided address, then it will be
255    /// replaced.
256    ///
257    /// Pass `None` for `session_id` to create a fresh session. Pass
258    /// `Some(id)` to attempt resuming an existing Lavalink session. Resume
259    /// success is indicated by the `resumed` field on the
260    /// [`Ready`] event received via the returned [`IncomingEvents`] stream.
261    ///
262    /// If a `session_id` is provided and the session cannot be resumed,
263    /// the connection will fail with a [`Connecting`] error. The caller
264    /// can then retry with `session_id` set to `None` for a fresh session.
265    ///
266    /// [`Ready`]: crate::model::incoming::Ready
267    /// [`Connecting`]: crate::node::NodeErrorType::Connecting
268    ///
269    /// # Example
270    ///
271    /// ```no_run
272    /// # use twilight_lavalink::Lavalink;
273    /// # use twilight_model::id::Id;
274    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
275    /// # let lavalink = Lavalink::new(Id::new(1), 1);
276    /// # let address = "127.0.0.1:2333".parse()?;
277    /// # let auth = "youshallnotpass";
278    /// // Fresh session (no resume)
279    /// let (node, events) = lavalink.add(address, auth, None).await?;
280    ///
281    /// // Resume an existing session
282    /// let old_session_id = Some("existing-session-id".to_string());
283    /// let (node, events) = lavalink.add(address, auth, old_session_id).await?;
284    /// # Ok(())
285    /// # }
286    /// ```
287    ///
288    /// # Errors
289    ///
290    /// See the errors section of [`Node::connect`].
291    pub async fn add(
292        &self,
293        address: SocketAddr,
294        authorization: impl Into<String>,
295        session_id: Option<String>,
296    ) -> Result<(Arc<Node>, IncomingEvents), NodeError> {
297        let config = NodeConfig {
298            address,
299            authorization: authorization.into(),
300            user_id: self.user_id,
301            enable_tls: cfg!(feature = "tls"),
302            session_id,
303        };
304
305        let (node, rx) = Node::connect(config, self.players.clone()).await?;
306        let node = Arc::new(node);
307        self.nodes.insert(address, Arc::clone(&node));
308
309        Ok((node, rx))
310    }
311
312    /// Remove a node from the list of nodes being managed by the Lavalink
313    /// client.
314    ///
315    /// This does not disconnect the node. Use [`Lavalink::disconnect`] instead.
316    /// or drop all [`Node`]s.
317    ///
318    /// The node is returned if it existed.
319    pub fn remove(&self, address: SocketAddr) -> Option<(SocketAddr, Arc<Node>)> {
320        self.nodes.remove(&address)
321    }
322
323    /// Remove a node from the list of nodes being managed by the Lavalink
324    /// client and terminates the connection.
325    ///
326    /// Use [`Lavalink::remove`] if detaching a node from a Lavalink instance
327    /// is required without closing the underlying connection.
328    ///
329    /// Returns whether the node has been removed and disconnected.
330    pub fn disconnect(&self, address: SocketAddr) -> bool {
331        self.nodes.remove(&address).is_some()
332    }
333
334    /// Determine the "best" node for new players according to available nodes'
335    /// penalty scores. Disconnected nodes will not be considered.
336    ///
337    /// Refer to [`Node::penalty`] for how this is calculated.
338    ///
339    /// # Errors
340    ///
341    /// Returns a [`ClientErrorType::NodesUnconfigured`] error type if there are
342    /// no connected nodes available in the client.
343    ///
344    /// [`Node::penalty`]: crate::node::Node::penalty
345    pub async fn best(&self) -> Result<Arc<Node>, ClientError> {
346        let mut lowest = i32::MAX;
347        let mut best = None;
348
349        for node in &self.nodes {
350            if node.sender().is_closed() {
351                continue;
352            }
353
354            let penalty = node.value().penalty().await;
355
356            if penalty < lowest {
357                lowest = penalty;
358                best.replace(node.clone());
359            }
360        }
361
362        best.ok_or(ClientError {
363            kind: ClientErrorType::NodesUnconfigured,
364            source: None,
365        })
366    }
367
368    /// Retrieve an immutable reference to the player manager.
369    pub const fn players(&self) -> &PlayerManager {
370        &self.players
371    }
372
373    /// Retrieve a player for the guild.
374    ///
375    /// Creates a player configured to use the best available node if a player
376    /// for the guild doesn't already exist. Use [`PlayerManager::get`] to only
377    /// retrieve and not create.
378    ///
379    /// # Errors
380    ///
381    /// Returns a [`ClientError`] with a [`ClientErrorType::NodesUnconfigured`]
382    /// type if no node has been configured via [`add`].
383    ///
384    /// [`PlayerManager::get`]: crate::player::PlayerManager::get
385    /// [`add`]: Self::add
386    pub async fn player(&self, guild_id: Id<GuildMarker>) -> Result<Arc<Player>, ClientError> {
387        if let Some(player) = self.players().get(&guild_id) {
388            return Ok(player);
389        }
390
391        let node = self.best().await?;
392
393        Ok(self.players().get_or_insert(guild_id, node))
394    }
395
396    /// Clear out the map of guild states/updates for a shard that are waiting
397    /// for their other half.
398    ///
399    /// We can do this by iterating over the map and removing the ones that we
400    /// can calculate came from a shard.
401    ///
402    /// This map should be small or empty, and if it isn't, then it needs to be
403    /// cleared out anyway.
404    fn clear_shard_states(&self, shard_id: u32) {
405        let shard_count = u64::from(self.shard_count);
406
407        self.server_updates
408            .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
409        self.discord_sessions
410            .retain(|k, _| (k.get() >> 22) % shard_count != u64::from(shard_id));
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::{ClientError, ClientErrorType, Lavalink};
417    use static_assertions::assert_impl_all;
418    use std::{error::Error, fmt::Debug};
419
420    assert_impl_all!(ClientErrorType: Debug, Send, Sync);
421    assert_impl_all!(ClientError: Error, Send, Sync);
422    assert_impl_all!(Lavalink: Debug, Send, Sync);
423}