1use std::{
9 error::Error,
10 fmt::{Debug, Display, Formatter, Result as FmtResult},
11 str::{self, FromStr, Utf8Error},
12};
13
14#[derive(Debug)]
16pub struct HeaderParsingError {
17 pub(super) kind: HeaderParsingErrorType,
19 pub(super) source: Option<Box<dyn Error + Send + Sync>>,
21}
22
23impl HeaderParsingError {
24 #[must_use = "retrieving the type has no effect if left unused"]
26 pub const fn kind(&self) -> &HeaderParsingErrorType {
27 &self.kind
28 }
29
30 #[must_use = "consuming the error and retrieving the source has no effect if left unused"]
32 pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
33 self.source
34 }
35
36 #[must_use = "consuming the error into its parts has no effect if left unused"]
38 pub fn into_parts(self) -> (HeaderParsingErrorType, Option<Box<dyn Error + Send + Sync>>) {
39 (self.kind, self.source)
40 }
41
42 pub(super) fn missing(name: HeaderName) -> Self {
44 Self {
45 kind: HeaderParsingErrorType::Missing { name },
46 source: None,
47 }
48 }
49
50 pub(super) fn not_utf8(name: HeaderName, value: Vec<u8>, source: Utf8Error) -> Self {
52 Self {
53 kind: HeaderParsingErrorType::NotUtf8 { name, value },
54 source: Some(Box::new(source)),
55 }
56 }
57}
58
59impl Display for HeaderParsingError {
60 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
61 match &self.kind {
62 HeaderParsingErrorType::Missing { name } => {
63 f.write_str("at least one header, '")?;
64 f.write_str(name.name())?;
65
66 f.write_str("', is missing")
67 }
68 HeaderParsingErrorType::NotUtf8 { name, value } => {
69 f.write_str("header '")?;
70 f.write_str(name.name())?;
71 f.write_str("' contains invalid UTF-16: ")?;
72
73 Debug::fmt(value, f)
74 }
75 HeaderParsingErrorType::Parsing { kind, name, value } => {
76 f.write_str("header '")?;
77 f.write_str(name.name())?;
78 f.write_str("' can not be parsed as a ")?;
79 f.write_str(kind.name())?;
80 f.write_str(": '")?;
81 f.write_str(value)?;
82
83 f.write_str("'")
84 }
85 }
86 }
87}
88
89impl Error for HeaderParsingError {
90 fn source(&self) -> Option<&(dyn Error + 'static)> {
91 self.source
92 .as_ref()
93 .map(|source| &**source as &(dyn Error + 'static))
94 }
95}
96
97#[derive(Debug)]
99#[non_exhaustive]
100pub enum HeaderParsingErrorType {
101 Missing {
103 name: HeaderName,
105 },
106 NotUtf8 {
108 name: HeaderName,
110 value: Vec<u8>,
112 },
113 Parsing {
115 kind: HeaderType,
117 name: HeaderName,
119 value: String,
121 },
122}
123
124#[derive(Clone, Copy, Debug, Eq, PartialEq)]
126#[non_exhaustive]
127pub enum HeaderName {
128 Bucket,
130 Global,
132 Limit,
134 Remaining,
136 ResetAfter,
138 Reset,
140 RetryAfter,
142 Scope,
144}
145
146impl HeaderName {
147 pub const BUCKET: &'static str = "x-ratelimit-bucket";
149
150 pub const GLOBAL: &'static str = "x-ratelimit-global";
152
153 pub const LIMIT: &'static str = "x-ratelimit-limit";
155
156 pub const REMAINING: &'static str = "x-ratelimit-remaining";
158
159 pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
161
162 pub const RESET: &'static str = "x-ratelimit-reset";
164
165 pub const RETRY_AFTER: &'static str = "retry-after";
168
169 pub const SCOPE: &'static str = "x-ratelimit-scope";
171
172 #[must_use]
174 pub const fn name(self) -> &'static str {
175 match self {
176 Self::Bucket => Self::BUCKET,
177 Self::Global => Self::GLOBAL,
178 Self::Limit => Self::LIMIT,
179 Self::Remaining => Self::REMAINING,
180 Self::ResetAfter => Self::RESET_AFTER,
181 Self::Reset => Self::RESET,
182 Self::RetryAfter => Self::RETRY_AFTER,
183 Self::Scope => Self::SCOPE,
184 }
185 }
186}
187
188impl Display for HeaderName {
189 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
190 f.write_str(self.name())
191 }
192}
193
194#[derive(Clone, Copy, Debug, Eq, PartialEq)]
196#[non_exhaustive]
197pub enum HeaderType {
198 Bool,
200 Float,
202 Integer,
204 String,
206}
207
208impl HeaderType {
209 const fn name(self) -> &'static str {
211 match self {
212 Self::Bool => "bool",
213 Self::Float => "float",
214 Self::Integer => "integer",
215 Self::String => "string",
216 }
217 }
218}
219
220impl Display for HeaderType {
221 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
222 f.write_str(self.name())
223 }
224}
225
226#[derive(Clone, Debug, Eq, Hash, PartialEq)]
228pub struct Global {
229 retry_after: u64,
231 scope: Option<RatelimitScope>,
236}
237
238impl Global {
239 #[must_use]
241 pub const fn retry_after(&self) -> u64 {
242 self.retry_after
243 }
244
245 #[must_use]
250 pub const fn scope(&self) -> Option<RatelimitScope> {
251 self.scope
252 }
253}
254
255#[derive(Clone, Debug, Eq, Hash, PartialEq)]
257pub struct Present {
258 bucket: Option<String>,
260 limit: u64,
262 remaining: u64,
264 reset_after: u64,
266 reset: u64,
268 scope: Option<RatelimitScope>,
270}
271
272impl Present {
273 #[must_use]
275 pub fn bucket(&self) -> Option<&str> {
276 self.bucket.as_deref()
277 }
278
279 #[allow(clippy::missing_const_for_fn)]
281 #[must_use]
282 pub fn into_bucket(self) -> Option<String> {
283 self.bucket
284 }
285
286 #[must_use]
288 pub const fn limit(&self) -> u64 {
289 self.limit
290 }
291
292 #[must_use]
294 pub const fn remaining(&self) -> u64 {
295 self.remaining
296 }
297
298 #[must_use]
300 pub const fn reset_after(&self) -> u64 {
301 self.reset_after
302 }
303
304 #[must_use]
306 pub const fn reset(&self) -> u64 {
307 self.reset
308 }
309
310 #[must_use]
312 pub const fn scope(&self) -> Option<RatelimitScope> {
313 self.scope
314 }
315}
316
317#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
319pub enum RatelimitScope {
320 Global,
322 Shared,
326 User,
328}
329
330impl Display for RatelimitScope {
331 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
332 f.write_str(match self {
333 Self::Global => "global",
334 Self::Shared => "shared",
335 Self::User => "user",
336 })
337 }
338}
339
340impl FromStr for RatelimitScope {
341 type Err = HeaderParsingError;
342
343 fn from_str(s: &str) -> Result<Self, Self::Err> {
344 Ok(match s {
345 "global" => Self::Global,
346 "shared" => Self::Shared,
347 "user" => Self::User,
348 _ => {
349 return Err(HeaderParsingError {
350 kind: HeaderParsingErrorType::Parsing {
351 kind: HeaderType::String,
352 name: HeaderName::Scope,
353 value: s.to_owned(),
354 },
355 source: None,
356 })
357 }
358 })
359 }
360}
361
362impl TryFrom<&'_ str> for RatelimitScope {
363 type Error = HeaderParsingError;
364
365 fn try_from(value: &str) -> Result<Self, Self::Error> {
366 Self::from_str(value)
367 }
368}
369
370#[derive(Clone, Debug)]
375#[non_exhaustive]
376pub enum RatelimitHeaders {
377 Global(Global),
379 None,
381 Present(Present),
383}
384
385impl RatelimitHeaders {
386 #[must_use]
388 pub const fn is_global(&self) -> bool {
389 matches!(self, Self::Global(_))
390 }
391
392 #[must_use]
394 pub const fn is_none(&self) -> bool {
395 matches!(self, Self::None)
396 }
397
398 #[must_use]
400 pub const fn is_present(&self) -> bool {
401 matches!(self, Self::Present(_))
402 }
403
404 pub fn from_pairs<'a>(
461 headers: impl Iterator<Item = (&'a str, &'a [u8])>,
462 ) -> Result<Self, HeaderParsingError> {
463 let mut bucket = None;
464 let mut global = false;
465 let mut limit = None;
466 let mut remaining = None;
467 let mut reset = None;
468 let mut reset_after = None;
469 let mut retry_after = None;
470 let mut scope = None;
471
472 for (name, value) in headers {
473 match name {
474 HeaderName::BUCKET => {
475 bucket.replace(header_str(HeaderName::Bucket, value)?);
476 }
477 HeaderName::GLOBAL => {
478 global = header_bool(HeaderName::Global, value)?;
479 }
480 HeaderName::LIMIT => {
481 limit.replace(header_int(HeaderName::Limit, value)?);
482 }
483 HeaderName::REMAINING => {
484 remaining.replace(header_int(HeaderName::Remaining, value)?);
485 }
486 HeaderName::RESET => {
487 let reset_value = header_float(HeaderName::Reset, value)?;
488
489 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
490 reset.replace((reset_value * 1000.).ceil() as u64);
491 }
492 HeaderName::RESET_AFTER => {
493 let reset_after_value = header_float(HeaderName::ResetAfter, value)?;
494
495 #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
496 reset_after.replace((reset_after_value * 1000.).ceil() as u64);
497 }
498 HeaderName::RETRY_AFTER => {
499 let retry_after_value = header_int(HeaderName::RetryAfter, value)?;
500
501 retry_after.replace(retry_after_value);
502 }
503 HeaderName::SCOPE => {
504 let scope_value = header_str(HeaderName::Scope, value)?;
505 let scope_parsed = RatelimitScope::try_from(scope_value)?;
506
507 scope.replace(scope_parsed);
508 }
509 _ => continue,
510 }
511 }
512
513 if global {
514 let retry_after =
515 retry_after.ok_or_else(|| HeaderParsingError::missing(HeaderName::RetryAfter))?;
516
517 return Ok(RatelimitHeaders::Global(Global { retry_after, scope }));
518 }
519
520 if bucket.is_none()
523 && limit.is_none()
524 && remaining.is_none()
525 && reset.is_none()
526 && reset_after.is_none()
527 {
528 return Ok(RatelimitHeaders::None);
529 }
530
531 Ok(RatelimitHeaders::Present(Present {
532 bucket: bucket.map(Into::into),
533 limit: limit.ok_or_else(|| HeaderParsingError::missing(HeaderName::Limit))?,
534 remaining: remaining
535 .ok_or_else(|| HeaderParsingError::missing(HeaderName::Remaining))?,
536 reset: reset.ok_or_else(|| HeaderParsingError::missing(HeaderName::Reset))?,
537 reset_after: reset_after
538 .ok_or_else(|| HeaderParsingError::missing(HeaderName::ResetAfter))?,
539 scope,
540 }))
541 }
542}
543
544fn header_bool(name: HeaderName, value: &[u8]) -> Result<bool, HeaderParsingError> {
546 let text = header_str(name, value)?;
547
548 let end = text.parse().map_err(|source| HeaderParsingError {
549 kind: HeaderParsingErrorType::Parsing {
550 kind: HeaderType::Bool,
551 name,
552 value: text.to_owned(),
553 },
554 source: Some(Box::new(source)),
555 })?;
556
557 Ok(end)
558}
559
560fn header_float(name: HeaderName, value: &[u8]) -> Result<f64, HeaderParsingError> {
562 let text = header_str(name, value)?;
563
564 let end = text.parse().map_err(|source| HeaderParsingError {
565 kind: HeaderParsingErrorType::Parsing {
566 kind: HeaderType::Float,
567 name,
568 value: text.to_owned(),
569 },
570 source: Some(Box::new(source)),
571 })?;
572
573 Ok(end)
574}
575
576fn header_int(name: HeaderName, value: &[u8]) -> Result<u64, HeaderParsingError> {
578 let text = header_str(name, value)?;
579
580 let end = text.parse().map_err(|source| HeaderParsingError {
581 kind: HeaderParsingErrorType::Parsing {
582 kind: HeaderType::Integer,
583 name,
584 value: text.to_owned(),
585 },
586 source: Some(Box::new(source)),
587 })?;
588
589 Ok(end)
590}
591
592fn header_str(name: HeaderName, value: &[u8]) -> Result<&str, HeaderParsingError> {
594 let text = str::from_utf8(value)
595 .map_err(|source| HeaderParsingError::not_utf8(name, value.to_owned(), source))?;
596
597 Ok(text)
598}
599
600#[cfg(test)]
601mod tests {
602 use super::{
603 Global, HeaderName, HeaderParsingError, HeaderParsingErrorType, HeaderType, Present,
604 RatelimitHeaders,
605 };
606 use crate::headers::RatelimitScope;
607 use http::header::{HeaderMap, HeaderName as HttpHeaderName, HeaderValue};
608 use static_assertions::{assert_fields, assert_impl_all};
609 use std::{
610 error::Error,
611 fmt::{Debug, Display},
612 };
613
614 assert_fields!(HeaderParsingErrorType::Missing: name);
615 assert_fields!(HeaderParsingErrorType::NotUtf8: name, value);
616 assert_fields!(HeaderParsingErrorType::Parsing: kind, name, value);
617 assert_impl_all!(
618 HeaderName: Clone,
619 Copy,
620 Debug,
621 Display,
622 Eq,
623 PartialEq,
624 Send,
625 Sync
626 );
627 assert_impl_all!(HeaderParsingErrorType: Debug, Send, Sync);
628 assert_impl_all!(HeaderParsingError: Error, Send, Sync);
629 assert_impl_all!(
630 HeaderType: Clone,
631 Copy,
632 Debug,
633 Display,
634 Eq,
635 PartialEq,
636 Send,
637 Sync
638 );
639 assert_impl_all!(Global: Clone, Debug, Eq, PartialEq, Send, Sync);
640 assert_impl_all!(Present: Clone, Debug, Eq, PartialEq, Send, Sync);
641 assert_impl_all!(RatelimitHeaders: Clone, Debug, Send, Sync);
642
643 #[test]
644 fn global() -> Result<(), Box<dyn Error>> {
645 let map = {
646 let mut map = HeaderMap::new();
647 map.insert(
648 HttpHeaderName::from_static("x-ratelimit-global"),
649 HeaderValue::from_static("true"),
650 );
651 map.insert(
652 HttpHeaderName::from_static("retry-after"),
653 HeaderValue::from_static("65"),
654 );
655
656 map
657 };
658
659 let iter = map.iter().map(|(k, v)| (k.as_str(), v.as_bytes()));
660 let headers = RatelimitHeaders::from_pairs(iter)?;
661 assert!(matches!(headers, RatelimitHeaders::Global(g) if g.retry_after() == 65));
662
663 Ok(())
664 }
665
666 #[test]
667 fn global_with_scope() -> Result<(), Box<dyn Error>> {
668 let map = {
669 let mut map = HeaderMap::new();
670 map.insert(
671 HttpHeaderName::from_static("x-ratelimit-global"),
672 HeaderValue::from_static("true"),
673 );
674 map.insert(
675 HttpHeaderName::from_static("retry-after"),
676 HeaderValue::from_static("65"),
677 );
678 map.insert(
679 HttpHeaderName::from_static("x-ratelimit-scope"),
680 HeaderValue::from_static("global"),
681 );
682
683 map
684 };
685
686 let iter = map.iter().map(|(k, v)| (k.as_str(), v.as_bytes()));
687 let headers = RatelimitHeaders::from_pairs(iter)?;
688 assert!(matches!(
689 &headers,
690 RatelimitHeaders::Global(global)
691 if global.retry_after() == 65
692 ));
693 assert!(matches!(
694 headers,
695 RatelimitHeaders::Global(global)
696 if global.scope() == Some(RatelimitScope::Global)
697 ));
698
699 Ok(())
700 }
701
702 #[test]
703 fn present() -> Result<(), Box<dyn Error>> {
704 let map = {
705 let mut map = HeaderMap::new();
706 map.insert(
707 HttpHeaderName::from_static("x-ratelimit-limit"),
708 HeaderValue::from_static("10"),
709 );
710 map.insert(
711 HttpHeaderName::from_static("x-ratelimit-remaining"),
712 HeaderValue::from_static("9"),
713 );
714 map.insert(
715 HttpHeaderName::from_static("x-ratelimit-reset"),
716 HeaderValue::from_static("1470173023.123"),
717 );
718 map.insert(
719 HttpHeaderName::from_static("x-ratelimit-reset-after"),
720 HeaderValue::from_static("64.57"),
721 );
722 map.insert(
723 HttpHeaderName::from_static("x-ratelimit-bucket"),
724 HeaderValue::from_static("abcd1234"),
725 );
726 map.insert(
727 HttpHeaderName::from_static("x-ratelimit-scope"),
728 HeaderValue::from_static("shared"),
729 );
730
731 map
732 };
733
734 let iter = map.iter().map(|(k, v)| (k.as_str(), v.as_bytes()));
735 let headers = RatelimitHeaders::from_pairs(iter)?;
736 assert!(matches!(
737 &headers,
738 RatelimitHeaders::Present(present)
739 if present.bucket.as_deref() == Some("abcd1234")
740 ));
741 assert!(matches!(
742 &headers,
743 RatelimitHeaders::Present(present)
744 if present.limit == 10
745 ));
746 assert!(matches!(
747 &headers,
748 RatelimitHeaders::Present(present)
749 if present.remaining == 9
750 ));
751 assert!(matches!(
752 &headers,
753 RatelimitHeaders::Present(present)
754 if present.reset_after == 64_570
755 ));
756 assert!(matches!(
757 &headers,
758 RatelimitHeaders::Present(present)
759 if present.reset == 1_470_173_023_123
760 ));
761 assert!(matches!(
762 headers,
763 RatelimitHeaders::Present(present)
764 if present.scope() == Some(RatelimitScope::Shared)
765 ));
766
767 Ok(())
768 }
769
770 #[test]
771 fn name() {
772 assert_eq!("x-ratelimit-bucket", HeaderName::BUCKET);
773 assert_eq!("x-ratelimit-global", HeaderName::GLOBAL);
774 assert_eq!("x-ratelimit-limit", HeaderName::LIMIT);
775 assert_eq!("x-ratelimit-remaining", HeaderName::REMAINING);
776 assert_eq!("x-ratelimit-reset-after", HeaderName::RESET_AFTER);
777 assert_eq!("x-ratelimit-reset", HeaderName::RESET);
778 assert_eq!("retry-after", HeaderName::RETRY_AFTER);
779 assert_eq!("x-ratelimit-scope", HeaderName::SCOPE);
780 assert_eq!(HeaderName::BUCKET, HeaderName::Bucket.name());
781 assert_eq!(HeaderName::GLOBAL, HeaderName::Global.name());
782 assert_eq!(HeaderName::LIMIT, HeaderName::Limit.name());
783 assert_eq!(HeaderName::REMAINING, HeaderName::Remaining.name());
784 assert_eq!(HeaderName::RESET_AFTER, HeaderName::ResetAfter.name());
785 assert_eq!(HeaderName::RESET, HeaderName::Reset.name());
786 assert_eq!(HeaderName::RETRY_AFTER, HeaderName::RetryAfter.name());
787 assert_eq!(HeaderName::SCOPE, HeaderName::Scope.name());
788 }
789
790 #[test]
791 fn type_name() {
792 assert_eq!("bool", HeaderType::Bool.name());
793 assert_eq!("float", HeaderType::Float.name());
794 assert_eq!("integer", HeaderType::Integer.name());
795 assert_eq!("string", HeaderType::String.name());
796 }
797}