1use std::{error::Error, fmt};
4
5#[derive(Debug)]
7pub struct CompressionError {
8 pub(crate) kind: CompressionErrorType,
10 pub(crate) source: Option<Box<dyn Error + Send + Sync>>,
12}
13
14impl CompressionError {
15 #[must_use = "retrieving the type has no effect if left unused"]
17 pub const fn kind(&self) -> &CompressionErrorType {
18 &self.kind
19 }
20
21 #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
23 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
24 self.source
25 }
26
27 #[must_use = "consuming the error into its parts has no effect if left unused"]
29 pub fn into_parts(self) -> (CompressionErrorType, Option<Box<dyn Error + Send + Sync>>) {
30 (self.kind, None)
31 }
32
33 pub(crate) fn from_utf8_error(source: std::string::FromUtf8Error) -> Self {
35 Self {
36 kind: CompressionErrorType::NotUtf8,
37 source: Some(Box::new(source)),
38 }
39 }
40
41 #[cfg(feature = "zstd")]
43 pub(crate) fn from_code(code: usize) -> Self {
44 Self {
45 kind: CompressionErrorType::Decompressing,
46 source: Some(zstd_safe::get_error_name(code).into()),
47 }
48 }
49}
50
51impl fmt::Display for CompressionError {
52 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
53 match self.kind {
54 CompressionErrorType::Decompressing => f.write_str("message could not be decompressed"),
55 CompressionErrorType::NotUtf8 => f.write_str("decompressed message is not UTF-8"),
56 }
57 }
58}
59
60impl Error for CompressionError {
61 fn source(&self) -> Option<&(dyn Error + 'static)> {
62 self.source
63 .as_ref()
64 .map(|source| &**source as &(dyn Error + 'static))
65 }
66}
67
68#[derive(Debug)]
70#[non_exhaustive]
71pub enum CompressionErrorType {
72 Decompressing,
74 NotUtf8,
76}
77
78#[cfg(feature = "zstd")]
80pub struct Decompressor {
81 buffer: Box<[u8]>,
83 ctx: zstd_safe::DCtx<'static>,
85}
86
87#[cfg(feature = "zstd")]
88impl Decompressor {
89 const BUFFER_SIZE: usize = 32 * 1024;
91
92 pub fn new() -> Self {
94 Self {
95 buffer: vec![0; Decompressor::BUFFER_SIZE].into_boxed_slice(),
96 ctx: zstd_safe::DCtx::create(),
97 }
98 }
99
100 pub fn decompress(&mut self, message: &[u8]) -> Result<String, CompressionError> {
110 let mut input = zstd_safe::InBuffer::around(message);
111
112 let mut decompressed = Vec::new();
115
116 loop {
117 let mut output = zstd_safe::OutBuffer::around(self.buffer.as_mut());
118
119 self.ctx
120 .decompress_stream(&mut output, &mut input)
121 .map_err(CompressionError::from_code)?;
122
123 decompressed.extend_from_slice(output.as_slice());
124
125 if input.pos == input.src.len() && output.pos() != output.capacity() {
127 break;
128 }
129 }
130
131 String::from_utf8(decompressed).map_err(CompressionError::from_utf8_error)
132 }
133
134 pub fn reset(&mut self) {
136 self.ctx
137 .reset(zstd_safe::ResetDirective::SessionOnly)
138 .expect("resetting session is infallible");
139 }
140}
141
142#[cfg(feature = "zstd")]
143impl fmt::Debug for Decompressor {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 f.debug_struct("Decompressor")
146 .field("buffer", &self.buffer)
147 .field("ctx", &"<decompression context>")
148 .finish()
149 }
150}
151
152#[cfg(all(feature = "zstd", test))]
153mod tests {
154 use super::Decompressor;
155
156 const MESSAGE: [u8; 117] = [
157 40, 181, 47, 253, 0, 64, 100, 3, 0, 66, 7, 25, 28, 112, 137, 115, 116, 40, 208, 203, 85,
158 255, 167, 74, 75, 126, 203, 222, 231, 255, 151, 18, 211, 212, 171, 144, 151, 210, 255, 51,
159 4, 49, 34, 71, 98, 2, 36, 253, 122, 141, 99, 203, 225, 11, 162, 47, 133, 241, 6, 201, 82,
160 245, 91, 206, 247, 164, 226, 156, 92, 108, 130, 123, 11, 95, 199, 15, 61, 179, 117, 157,
161 28, 37, 65, 64, 25, 250, 182, 8, 199, 205, 44, 73, 47, 19, 218, 45, 27, 14, 245, 202, 81,
162 82, 122, 167, 121, 71, 173, 61, 140, 190, 15, 3, 1, 0, 36, 74, 18,
163 ];
164 const OUTPUT: &str = r#"{"t":null,"s":null,"op":10,"d":{"heartbeat_interval":41250,"_trace":["[\"gateway-prd-us-east1-c-7s4x\",{\"micros\":0.0}]"]}}"#;
165
166 #[test]
167 fn decompress_single_segment() {
168 let mut inflator = Decompressor::new();
169 assert_eq!(inflator.decompress(&MESSAGE).unwrap(), OUTPUT);
170 }
171
172 #[test]
173 fn reset() {
174 let mut inflator = Decompressor::new();
175 inflator.decompress(&MESSAGE[..MESSAGE.len() - 2]).unwrap();
176
177 assert!(inflator.decompress(&MESSAGE).is_err());
178 inflator.reset();
179 assert_eq!(inflator.decompress(&MESSAGE).unwrap(), OUTPUT);
180 }
181}