twilight_gateway/
compression.rs

1//! Efficiently decompress Discord gateway messages.
2
3use std::{error::Error, fmt};
4
5/// An operation relating to compression failed.
6#[derive(Debug)]
7pub struct CompressionError {
8    /// Type of error.
9    pub(crate) kind: CompressionErrorType,
10    /// Source error if available.
11    pub(crate) source: Option<Box<dyn Error + Send + Sync>>,
12}
13
14impl CompressionError {
15    /// Immutable reference to the type of error that occurred.
16    #[must_use = "retrieving the type has no effect if left unused"]
17    pub const fn kind(&self) -> &CompressionErrorType {
18        &self.kind
19    }
20
21    /// Consume the error, returning the source error if there is any.
22    #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
23    pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
24        self.source
25    }
26
27    /// Consume the error, returning the owned error type and the source error.
28    #[must_use = "consuming the error into its parts has no effect if left unused"]
29    pub fn into_parts(self) -> (CompressionErrorType, Option<Box<dyn Error + Send + Sync>>) {
30        (self.kind, None)
31    }
32
33    /// Shortcut to create a new error for a not UTF-8 message.
34    pub(crate) fn from_utf8_error(source: std::string::FromUtf8Error) -> Self {
35        Self {
36            kind: CompressionErrorType::NotUtf8,
37            source: Some(Box::new(source)),
38        }
39    }
40
41    /// Shortcut to create a new error for an erroneous status code.
42    #[cfg(feature = "zstd")]
43    pub(crate) fn from_code(code: usize) -> Self {
44        Self {
45            kind: CompressionErrorType::Decompressing,
46            source: Some(zstd_safe::get_error_name(code).into()),
47        }
48    }
49}
50
51impl fmt::Display for CompressionError {
52    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53        match self.kind {
54            CompressionErrorType::Decompressing => f.write_str("message could not be decompressed"),
55            CompressionErrorType::NotUtf8 => f.write_str("decompressed message is not UTF-8"),
56        }
57    }
58}
59
60impl Error for CompressionError {
61    fn source(&self) -> Option<&(dyn Error + 'static)> {
62        self.source
63            .as_ref()
64            .map(|source| &**source as &(dyn Error + 'static))
65    }
66}
67
68/// Type of [`CompressionError`] that occurred.
69#[derive(Debug)]
70#[non_exhaustive]
71pub enum CompressionErrorType {
72    /// Decompressing a frame failed.
73    Decompressing,
74    /// Decompressed message is not UTF-8.
75    NotUtf8,
76}
77
78/// Gateway event decompressor.
79#[cfg(feature = "zstd")]
80pub struct Decompressor {
81    /// Common decompressed message buffer.
82    buffer: Box<[u8]>,
83    /// Reusable zstd decompression context.
84    ctx: zstd_safe::DCtx<'static>,
85}
86
87#[cfg(feature = "zstd")]
88impl Decompressor {
89    /// [`Self::buffer`]'s size.
90    const BUFFER_SIZE: usize = 32 * 1024;
91
92    /// Create a new decompressor for a shard.
93    pub fn new() -> Self {
94        Self {
95            buffer: vec![0; Decompressor::BUFFER_SIZE].into_boxed_slice(),
96            ctx: zstd_safe::DCtx::create(),
97        }
98    }
99
100    /// Decompress a message.
101    ///
102    /// # Errors
103    ///
104    /// Returns a [`CompressionErrorType::Decompressing`] error type if the
105    /// message could not be decompressed.
106    ///
107    /// Returns a [`CompressionErrorType::NotUtf8`] error type if the
108    /// decompressed message is not UTF-8.
109    pub fn decompress(&mut self, message: &[u8]) -> Result<String, CompressionError> {
110        let mut input = zstd_safe::InBuffer::around(message);
111
112        // Decompressed message. `Vec::extend_from_slice` efficiently allocates
113        // only what's necessary.
114        let mut decompressed = Vec::new();
115
116        loop {
117            let mut output = zstd_safe::OutBuffer::around(self.buffer.as_mut());
118
119            self.ctx
120                .decompress_stream(&mut output, &mut input)
121                .map_err(CompressionError::from_code)?;
122
123            decompressed.extend_from_slice(output.as_slice());
124
125            // Break when message has been fully decompressed.
126            if input.pos == input.src.len() && output.pos() != output.capacity() {
127                break;
128            }
129        }
130
131        String::from_utf8(decompressed).map_err(CompressionError::from_utf8_error)
132    }
133
134    /// Reset the decompressor's internal state.
135    pub fn reset(&mut self) {
136        self.ctx
137            .reset(zstd_safe::ResetDirective::SessionOnly)
138            .expect("resetting session is infallible");
139    }
140}
141
142#[cfg(feature = "zstd")]
143impl fmt::Debug for Decompressor {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        f.debug_struct("Decompressor")
146            .field("buffer", &self.buffer)
147            .field("ctx", &"<decompression context>")
148            .finish()
149    }
150}
151
152#[cfg(all(feature = "zstd", test))]
153mod tests {
154    use super::Decompressor;
155
156    const MESSAGE: [u8; 117] = [
157        40, 181, 47, 253, 0, 64, 100, 3, 0, 66, 7, 25, 28, 112, 137, 115, 116, 40, 208, 203, 85,
158        255, 167, 74, 75, 126, 203, 222, 231, 255, 151, 18, 211, 212, 171, 144, 151, 210, 255, 51,
159        4, 49, 34, 71, 98, 2, 36, 253, 122, 141, 99, 203, 225, 11, 162, 47, 133, 241, 6, 201, 82,
160        245, 91, 206, 247, 164, 226, 156, 92, 108, 130, 123, 11, 95, 199, 15, 61, 179, 117, 157,
161        28, 37, 65, 64, 25, 250, 182, 8, 199, 205, 44, 73, 47, 19, 218, 45, 27, 14, 245, 202, 81,
162        82, 122, 167, 121, 71, 173, 61, 140, 190, 15, 3, 1, 0, 36, 74, 18,
163    ];
164    const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-us-east1-c-7s4x\",{\"micros\":0.0}]"]}}"#;
165
166    #[test]
167    fn decompress_single_segment() {
168        let mut inflator = Decompressor::new();
169        assert_eq!(inflator.decompress(&MESSAGE).unwrap(), OUTPUT);
170    }
171
172    #[test]
173    fn reset() {
174        let mut inflator = Decompressor::new();
175        inflator.decompress(&MESSAGE[..MESSAGE.len() - 2]).unwrap();
176
177        assert!(inflator.decompress(&MESSAGE).is_err());
178        inflator.reset();
179        assert_eq!(inflator.decompress(&MESSAGE).unwrap(), OUTPUT);
180    }
181}