twilight_model/gateway/
id.rs

1use serde::{Deserialize, Serialize};
2use std::{
3    error::Error,
4    fmt::{Display, Formatter, Result as FmtResult},
5    num::NonZeroU32,
6};
7
8pub struct ShardIdParseError {
9    kind: ShardIdParseErrorType,
10}
11
12impl ShardIdParseError {
13    /// Immutable reference to the type of error that occurred.
14    #[must_use = "retrieving the type has no effect if left unused"]
15    pub const fn kind(&self) -> &ShardIdParseErrorType {
16        &self.kind
17    }
18
19    /// Consume the error, returning the source error if there is any.
20    #[allow(clippy::unused_self)]
21    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
22    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
23        None
24    }
25
26    /// Consume the error, returning the owned error type and the source error.
27    #[must_use = "consuming the error into its parts has no effect if left unused"]
28    pub fn into_parts(self) -> (ShardIdParseErrorType, Option<Box<dyn Error + Send + Sync>>) {
29        (self.kind, None)
30    }
31}
32
33impl Display for ShardIdParseError {
34    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
35        match self.kind {
36            ShardIdParseErrorType::NumberGreaterOrEqualTotal { number, total } => {
37                f.write_str("ShardId's number (")?;
38                Display::fmt(&number, f)?;
39                f.write_str(") was greater or equal to its total (")?;
40                Display::fmt(&total, f)?;
41
42                f.write_str(")")
43            }
44        }
45    }
46}
47
48/// Type of [`ShardIdParseError`] that occurred.
49#[derive(Debug)]
50pub enum ShardIdParseErrorType {
51    /// `ShardId`'s number was greater or equal to its total.
52    NumberGreaterOrEqualTotal {
53        /// Value of number.
54        number: u32,
55        /// Value of total.
56        total: u32,
57    },
58}
59
60/// Shard identifier to calculate if it receivies a given event.
61///
62/// A shard ID consist of two fields: `number` and `total`. These values do not
63/// need to be unique, and are used by Discord for calculating which events to
64/// send to which shard. Shards should in general share the same `total` value
65/// and have an unique `number` value, but users may deviate from this when
66/// resharding/migrating to a new set of shards.
67///
68/// # Advanced use
69///
70/// Incoming events are split by their originating guild and are received by the
71/// shard with the id calculated from the following formula:
72///
73/// > `number = (guild_id >> 22) % total`.
74///
75/// `total` is in other words unrelated to the total number of shards and is
76/// only used to specify the share of events a shard will receive. The formula
77/// is independently calculated for all shards, which means that events may be
78/// duplicated or lost if it's determined that an event should be sent to
79/// multiple or no shard.
80///
81/// It may be helpful to visualize the logic in code:
82///
83/// ```ignore
84/// for shard in shards {
85///     if shard.id().number() == (guild_id >> 22) % shard.id().total() {
86///         unimplemented!("send event to shard");
87///     }
88/// }
89/// ```
90///
91/// See [Discord Docs/Sharding].
92///
93/// [Discord Docs/Sharding]: https://discord.com/developers/docs/topics/gateway#sharding
94#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
95#[serde(try_from = "[u32; 2]", into = "[u32; 2]")]
96pub struct ShardId {
97    /// Number of the shard, 0-indexed.
98    number: u32,
99    /// Total number of shards used by the bot, 1-indexed.
100    total: NonZeroU32,
101}
102
103impl ShardId {
104    /// ID of a bot that has only one shard.
105    ///
106    /// Should *only* be used by small bots in under one or two thousand guilds.
107    pub const ONE: ShardId = ShardId::new(0, 1);
108
109    /// Create a new shard identifier.
110    ///
111    /// The shard number is 0-indexed while the total number of shards is
112    /// 1-indexed. A shard number of 7 with a total of 8 is therefore valid,
113    /// whilst a shard number of 8 out of 8 total shards is invalid.
114    ///
115    /// # Examples
116    ///
117    /// Create a new shard with a shard number of 13 out of a total of 24
118    /// shards:
119    ///
120    /// ```
121    /// use twilight_model::gateway::ShardId;
122    ///
123    /// let id = ShardId::new(13, 24);
124    /// ```
125    ///
126    /// # Panics
127    ///
128    /// Panics if the shard number is greater than or equal to the total number
129    /// of shards.
130    pub const fn new(number: u32, total: u32) -> Self {
131        assert!(number < total, "number must be less than total");
132        if let Some(total) = NonZeroU32::new(total) {
133            Self { number, total }
134        } else {
135            panic!("unreachable: total is at least 1")
136        }
137    }
138
139    /// Create a new shard identifier if the shard indexes are valid.
140    #[allow(clippy::missing_panics_doc)]
141    pub const fn new_checked(number: u32, total: u32) -> Option<Self> {
142        if number >= total {
143            return None;
144        }
145
146        if let Some(total) = NonZeroU32::new(total) {
147            Some(Self { number, total })
148        } else {
149            panic!("unreachable: total is at least 1")
150        }
151    }
152
153    /// Identifying number of the shard, 0-indexed.
154    pub const fn number(self) -> u32 {
155        self.number
156    }
157
158    /// Total number of shards, 1-indexed.
159    pub const fn total(self) -> u32 {
160        self.total.get()
161    }
162}
163
164/// Display the shard ID.
165///
166/// Formats as `[{number}, {total}]`.
167impl Display for ShardId {
168    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
169        f.debug_list()
170            .entries(Into::<[u32; 2]>::into(*self))
171            .finish()
172    }
173}
174
175impl TryFrom<[u32; 2]> for ShardId {
176    type Error = ShardIdParseError;
177
178    fn try_from([number, total]: [u32; 2]) -> Result<Self, Self::Error> {
179        Self::new_checked(number, total).ok_or(ShardIdParseError {
180            kind: ShardIdParseErrorType::NumberGreaterOrEqualTotal { number, total },
181        })
182    }
183}
184
185impl From<ShardId> for [u32; 2] {
186    fn from(id: ShardId) -> Self {
187        [id.number(), id.total()]
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::ShardId;
194    use serde::{de::DeserializeOwned, Serialize};
195    use serde_test::Token;
196    use static_assertions::{assert_impl_all, const_assert_eq};
197    use std::{fmt::Debug, hash::Hash};
198
199    const_assert_eq!(ShardId::ONE.number(), 0);
200    const_assert_eq!(ShardId::ONE.total(), 1);
201    assert_impl_all!(
202        ShardId: Clone,
203        Copy,
204        Debug,
205        DeserializeOwned,
206        Eq,
207        Hash,
208        PartialEq,
209        Send,
210        Serialize,
211        Sync
212    );
213
214    #[test]
215    const fn checked_invalid() {
216        assert!(ShardId::new_checked(0, 1).is_some());
217        assert!(ShardId::new_checked(1, 1).is_none());
218        assert!(ShardId::new_checked(2, 1).is_none());
219        assert!(ShardId::new_checked(0, 0).is_none());
220    }
221
222    #[test]
223    const fn getters() {
224        let id = ShardId::new(2, 4);
225
226        assert!(id.number() == 2);
227        assert!(id.total() == 4);
228    }
229
230    #[test]
231    fn serde() {
232        let value = ShardId::new(0, 1);
233
234        serde_test::assert_tokens(
235            &value,
236            &[
237                Token::Tuple { len: 2 },
238                Token::U32(0),
239                Token::U32(1),
240                Token::TupleEnd,
241            ],
242        )
243    }
244
245    #[should_panic(expected = "number must be less than total")]
246    #[test]
247    const fn number_equal() {
248        ShardId::new(1, 1);
249    }
250
251    #[should_panic(expected = "number must be less than total")]
252    #[test]
253    const fn number_greater() {
254        ShardId::new(2, 1);
255    }
256}