Skip to main content

twilight_model/gateway/
id.rs

1use serde::{Deserialize, Serialize};
2use std::{
3    error::Error,
4    fmt::{Display, Formatter, Result as FmtResult},
5    num::NonZero,
6};
7
8pub struct ShardIdParseError {
9    kind: ShardIdParseErrorType,
10}
11
12impl ShardIdParseError {
13    /// Immutable reference to the type of error that occurred.
14    #[must_use = "retrieving the type has no effect if left unused"]
15    pub const fn kind(&self) -> &ShardIdParseErrorType {
16        &self.kind
17    }
18
19    /// Consume the error, returning the source error if there is any.
20    #[allow(clippy::unused_self)]
21    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
22    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
23        None
24    }
25
26    /// Consume the error, returning the owned error type and the source error.
27    #[must_use = "consuming the error into its parts has no effect if left unused"]
28    pub fn into_parts(self) -> (ShardIdParseErrorType, Option<Box<dyn Error + Send + Sync>>) {
29        (self.kind, None)
30    }
31}
32
33impl Display for ShardIdParseError {
34    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
35        match self.kind {
36            ShardIdParseErrorType::NumberGreaterOrEqualTotal { number, total } => {
37                f.write_str("ShardId's number (")?;
38                Display::fmt(&number, f)?;
39                f.write_str(") was greater or equal to its total (")?;
40                Display::fmt(&total, f)?;
41
42                f.write_str(")")
43            }
44        }
45    }
46}
47
48/// Type of [`ShardIdParseError`] that occurred.
49#[derive(Debug)]
50pub enum ShardIdParseErrorType {
51    /// `ShardId`'s number was greater or equal to its total.
52    NumberGreaterOrEqualTotal {
53        /// Value of number.
54        number: u32,
55        /// Value of total.
56        total: u32,
57    },
58}
59
60/// Shard identifier to calculate if it receivies a given event.
61///
62/// A shard ID consist of two fields: `number` and `total`. These values do not
63/// need to be unique, and are used by Discord for calculating which events to
64/// send to which shard. Shards should in general share the same `total` value
65/// and have an unique `number` value, but users may deviate from this when
66/// resharding/migrating to a new set of shards.
67///
68/// # Advanced use
69///
70/// Incoming events are split by their originating guild and are received by the
71/// shard with the id calculated from the following formula:
72///
73/// > `number = (guild_id >> 22) % total`.
74///
75/// `total` is in other words unrelated to the total number of shards and is
76/// only used to specify the share of events a shard will receive. The formula
77/// is independently calculated for all shards, which means that events may be
78/// duplicated or lost if it's determined that an event should be sent to
79/// multiple or no shard.
80///
81/// It may be helpful to visualize the logic in code:
82///
83/// ```ignore
84/// for shard in shards {
85///     if shard.id().number() == (guild_id >> 22) % shard.id().total() {
86///         unimplemented!("send event to shard");
87///     }
88/// }
89/// ```
90///
91/// See [Discord Docs/Sharding].
92///
93/// [Discord Docs/Sharding]: https://discord.com/developers/docs/topics/gateway#sharding
94#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
95#[serde(try_from = "[u32; 2]", into = "[u32; 2]")]
96pub struct ShardId {
97    /// Number of the shard, 0-indexed.
98    number: u32,
99    /// Total number of shards used by the bot, 1-indexed.
100    total: NonZero<u32>,
101}
102
103impl ShardId {
104    /// ID of a bot that has only one shard.
105    ///
106    /// Should *only* be used by small bots in under one or two thousand guilds.
107    pub const ONE: ShardId = ShardId::new(0, 1);
108
109    /// Create a new shard identifier.
110    ///
111    /// The shard number is 0-indexed while the total number of shards is
112    /// 1-indexed. A shard number of 7 with a total of 8 is therefore valid,
113    /// whilst a shard number of 8 out of 8 total shards is invalid.
114    ///
115    /// # Examples
116    ///
117    /// Create a new shard with a shard number of 13 out of a total of 24
118    /// shards:
119    ///
120    /// ```
121    /// use twilight_model::gateway::ShardId;
122    ///
123    /// let id = ShardId::new(13, 24);
124    /// ```
125    ///
126    /// # Panics
127    ///
128    /// Panics if the shard number is greater than or equal to the total number
129    /// of shards.
130    pub const fn new(number: u32, total: u32) -> Self {
131        assert!(number < total, "number must be less than total");
132        Self {
133            number,
134            total: NonZero::new(total).expect("total is at least 1"),
135        }
136    }
137
138    /// Create a new shard identifier if the shard indexes are valid.
139    #[allow(clippy::missing_panics_doc)]
140    pub const fn new_checked(number: u32, total: u32) -> Option<Self> {
141        if number < total {
142            Some(Self {
143                number,
144                total: NonZero::new(total).expect("total is at least 1"),
145            })
146        } else {
147            None
148        }
149    }
150
151    /// Identifying number of the shard, 0-indexed.
152    pub const fn number(self) -> u32 {
153        self.number
154    }
155
156    /// Total number of shards, 1-indexed.
157    pub const fn total(self) -> u32 {
158        self.total.get()
159    }
160}
161
162/// Display the shard ID.
163///
164/// Formats as `[{number}, {total}]`.
165impl Display for ShardId {
166    fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
167        f.debug_list().entries(<[u32; 2]>::from(*self)).finish()
168    }
169}
170
171impl PartialOrd for ShardId {
172    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
173        (self.total == other.total).then(|| self.number.cmp(&other.number))
174    }
175}
176
177impl TryFrom<[u32; 2]> for ShardId {
178    type Error = ShardIdParseError;
179
180    fn try_from([number, total]: [u32; 2]) -> Result<Self, Self::Error> {
181        Self::new_checked(number, total).ok_or(ShardIdParseError {
182            kind: ShardIdParseErrorType::NumberGreaterOrEqualTotal { number, total },
183        })
184    }
185}
186
187impl From<ShardId> for [u32; 2] {
188    fn from(id: ShardId) -> Self {
189        [id.number(), id.total()]
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::ShardId;
196    use serde::{Serialize, de::DeserializeOwned};
197    use serde_test::Token;
198    use static_assertions::{assert_impl_all, const_assert_eq};
199    use std::{fmt::Debug, hash::Hash};
200
201    const_assert_eq!(ShardId::ONE.number(), 0);
202    const_assert_eq!(ShardId::ONE.total(), 1);
203    assert_impl_all!(
204        ShardId: Clone,
205        Copy,
206        Debug,
207        DeserializeOwned,
208        Eq,
209        Hash,
210        PartialEq,
211        PartialOrd,
212        Send,
213        Serialize,
214        Sync
215    );
216
217    #[test]
218    const fn checked_invalid() {
219        assert!(ShardId::new_checked(0, 1).is_some());
220        assert!(ShardId::new_checked(1, 1).is_none());
221        assert!(ShardId::new_checked(2, 1).is_none());
222        assert!(ShardId::new_checked(0, 0).is_none());
223    }
224
225    #[test]
226    const fn getters() {
227        let id = ShardId::new(2, 4);
228
229        assert!(id.number() == 2);
230        assert!(id.total() == 4);
231    }
232
233    #[test]
234    fn serde() {
235        let value = ShardId::new(0, 1);
236
237        serde_test::assert_tokens(
238            &value,
239            &[
240                Token::Tuple { len: 2 },
241                Token::U32(0),
242                Token::U32(1),
243                Token::TupleEnd,
244            ],
245        )
246    }
247
248    #[should_panic(expected = "number must be less than total")]
249    #[test]
250    const fn number_equal() {
251        ShardId::new(1, 1);
252    }
253
254    #[should_panic(expected = "number must be less than total")]
255    #[test]
256    const fn number_greater() {
257        ShardId::new(2, 1);
258    }
259}