use std::{
error::Error,
fmt::{Debug, Display, Formatter, Result as FmtResult},
str::{self, FromStr, Utf8Error},
};
#[derive(Debug)]
pub struct HeaderParsingError {
pub(super) kind: HeaderParsingErrorType,
pub(super) source: Option<Box<dyn Error + Send + Sync>>,
}
impl HeaderParsingError {
#[must_use = "retrieving the type has no effect if left unused"]
pub const fn kind(&self) -> &HeaderParsingErrorType {
&self.kind
}
#[must_use = "consuming the error and retrieving the source has no effect if left unused"]
pub fn into_source(self) -> Option<Box<dyn Error + Send + Sync>> {
self.source
}
#[must_use = "consuming the error into its parts has no effect if left unused"]
pub fn into_parts(self) -> (HeaderParsingErrorType, Option<Box<dyn Error + Send + Sync>>) {
(self.kind, self.source)
}
pub(super) fn missing(name: HeaderName) -> Self {
Self {
kind: HeaderParsingErrorType::Missing { name },
source: None,
}
}
pub(super) fn not_utf8(name: HeaderName, value: Vec<u8>, source: Utf8Error) -> Self {
Self {
kind: HeaderParsingErrorType::NotUtf8 { name, value },
source: Some(Box::new(source)),
}
}
}
impl Display for HeaderParsingError {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match &self.kind {
HeaderParsingErrorType::Missing { name } => {
f.write_str("at least one header, '")?;
f.write_str(name.name())?;
f.write_str("', is missing")
}
HeaderParsingErrorType::NotUtf8 { name, value } => {
f.write_str("header '")?;
f.write_str(name.name())?;
f.write_str("' contains invalid UTF-16: ")?;
Debug::fmt(value, f)
}
HeaderParsingErrorType::Parsing { kind, name, value } => {
f.write_str("header '")?;
f.write_str(name.name())?;
f.write_str("' can not be parsed as a ")?;
f.write_str(kind.name())?;
f.write_str(": '")?;
f.write_str(value)?;
f.write_str("'")
}
}
}
}
impl Error for HeaderParsingError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.source
.as_ref()
.map(|source| &**source as &(dyn Error + 'static))
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum HeaderParsingErrorType {
Missing {
name: HeaderName,
},
NotUtf8 {
name: HeaderName,
value: Vec<u8>,
},
Parsing {
kind: HeaderType,
name: HeaderName,
value: String,
},
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum HeaderName {
Bucket,
Global,
Limit,
Remaining,
ResetAfter,
Reset,
RetryAfter,
Scope,
}
impl HeaderName {
pub const BUCKET: &'static str = "x-ratelimit-bucket";
pub const GLOBAL: &'static str = "x-ratelimit-global";
pub const LIMIT: &'static str = "x-ratelimit-limit";
pub const REMAINING: &'static str = "x-ratelimit-remaining";
pub const RESET_AFTER: &'static str = "x-ratelimit-reset-after";
pub const RESET: &'static str = "x-ratelimit-reset";
pub const RETRY_AFTER: &'static str = "retry-after";
pub const SCOPE: &'static str = "x-ratelimit-scope";
#[must_use]
pub const fn name(self) -> &'static str {
match self {
Self::Bucket => Self::BUCKET,
Self::Global => Self::GLOBAL,
Self::Limit => Self::LIMIT,
Self::Remaining => Self::REMAINING,
Self::ResetAfter => Self::RESET_AFTER,
Self::Reset => Self::RESET,
Self::RetryAfter => Self::RETRY_AFTER,
Self::Scope => Self::SCOPE,
}
}
}
impl Display for HeaderName {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str(self.name())
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[non_exhaustive]
pub enum HeaderType {
Bool,
Float,
Integer,
String,
}
impl HeaderType {
const fn name(self) -> &'static str {
match self {
Self::Bool => "bool",
Self::Float => "float",
Self::Integer => "integer",
Self::String => "string",
}
}
}
impl Display for HeaderType {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str(self.name())
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Global {
retry_after: u64,
scope: Option<RatelimitScope>,
}
impl Global {
#[must_use]
pub const fn retry_after(&self) -> u64 {
self.retry_after
}
#[must_use]
pub const fn scope(&self) -> Option<RatelimitScope> {
self.scope
}
}
#[derive(Clone, Debug, Eq, Hash, PartialEq)]
pub struct Present {
bucket: Option<String>,
limit: u64,
remaining: u64,
reset_after: u64,
reset: u64,
scope: Option<RatelimitScope>,
}
impl Present {
#[must_use]
pub fn bucket(&self) -> Option<&str> {
self.bucket.as_deref()
}
#[allow(clippy::missing_const_for_fn)]
#[must_use]
pub fn into_bucket(self) -> Option<String> {
self.bucket
}
#[must_use]
pub const fn limit(&self) -> u64 {
self.limit
}
#[must_use]
pub const fn remaining(&self) -> u64 {
self.remaining
}
#[must_use]
pub const fn reset_after(&self) -> u64 {
self.reset_after
}
#[must_use]
pub const fn reset(&self) -> u64 {
self.reset
}
#[must_use]
pub const fn scope(&self) -> Option<RatelimitScope> {
self.scope
}
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum RatelimitScope {
Global,
Shared,
User,
}
impl Display for RatelimitScope {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
f.write_str(match self {
Self::Global => "global",
Self::Shared => "shared",
Self::User => "user",
})
}
}
impl FromStr for RatelimitScope {
type Err = HeaderParsingError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"global" => Self::Global,
"shared" => Self::Shared,
"user" => Self::User,
_ => {
return Err(HeaderParsingError {
kind: HeaderParsingErrorType::Parsing {
kind: HeaderType::String,
name: HeaderName::Scope,
value: s.to_owned(),
},
source: None,
})
}
})
}
}
impl TryFrom<&'_ str> for RatelimitScope {
type Error = HeaderParsingError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Self::from_str(value)
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum RatelimitHeaders {
Global(Global),
None,
Present(Present),
}
impl RatelimitHeaders {
#[must_use]
pub const fn is_global(&self) -> bool {
matches!(self, Self::Global(_))
}
#[must_use]
pub const fn is_none(&self) -> bool {
matches!(self, Self::None)
}
#[must_use]
pub const fn is_present(&self) -> bool {
matches!(self, Self::Present(_))
}
pub fn from_pairs<'a>(
headers: impl Iterator<Item = (&'a str, &'a [u8])>,
) -> Result<Self, HeaderParsingError> {
let mut bucket = None;
let mut global = false;
let mut limit = None;
let mut remaining = None;
let mut reset = None;
let mut reset_after = None;
let mut retry_after = None;
let mut scope = None;
for (name, value) in headers {
match name {
HeaderName::BUCKET => {
bucket.replace(header_str(HeaderName::Bucket, value)?);
}
HeaderName::GLOBAL => {
global = header_bool(HeaderName::Global, value)?;
}
HeaderName::LIMIT => {
limit.replace(header_int(HeaderName::Limit, value)?);
}
HeaderName::REMAINING => {
remaining.replace(header_int(HeaderName::Remaining, value)?);
}
HeaderName::RESET => {
let reset_value = header_float(HeaderName::Reset, value)?;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
reset.replace((reset_value * 1000.).ceil() as u64);
}
HeaderName::RESET_AFTER => {
let reset_after_value = header_float(HeaderName::ResetAfter, value)?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
reset_after.replace((reset_after_value * 1000.).ceil() as u64);
}
HeaderName::RETRY_AFTER => {
let retry_after_value = header_int(HeaderName::RetryAfter, value)?;
retry_after.replace(retry_after_value);
}
HeaderName::SCOPE => {
let scope_value = header_str(HeaderName::Scope, value)?;
let scope_parsed = RatelimitScope::try_from(scope_value)?;
scope.replace(scope_parsed);
}
_ => continue,
}
}
if global {
let retry_after =
retry_after.ok_or_else(|| HeaderParsingError::missing(HeaderName::RetryAfter))?;
return Ok(RatelimitHeaders::Global(Global { retry_after, scope }));
}
if bucket.is_none()
&& limit.is_none()
&& remaining.is_none()
&& reset.is_none()
&& reset_after.is_none()
{
return Ok(RatelimitHeaders::None);
}
Ok(RatelimitHeaders::Present(Present {
bucket: bucket.map(Into::into),
limit: limit.ok_or_else(|| HeaderParsingError::missing(HeaderName::Limit))?,
remaining: remaining
.ok_or_else(|| HeaderParsingError::missing(HeaderName::Remaining))?,
reset: reset.ok_or_else(|| HeaderParsingError::missing(HeaderName::Reset))?,
reset_after: reset_after
.ok_or_else(|| HeaderParsingError::missing(HeaderName::ResetAfter))?,
scope,
}))
}
}
fn header_bool(name: HeaderName, value: &[u8]) -> Result<bool, HeaderParsingError> {
let text = header_str(name, value)?;
let end = text.parse().map_err(|source| HeaderParsingError {
kind: HeaderParsingErrorType::Parsing {
kind: HeaderType::Bool,
name,
value: text.to_owned(),
},
source: Some(Box::new(source)),
})?;
Ok(end)
}
fn header_float(name: HeaderName, value: &[u8]) -> Result<f64, HeaderParsingError> {
let text = header_str(name, value)?;
let end = text.parse().map_err(|source| HeaderParsingError {
kind: HeaderParsingErrorType::Parsing {
kind: HeaderType::Float,
name,
value: text.to_owned(),
},
source: Some(Box::new(source)),
})?;
Ok(end)
}
fn header_int(name: HeaderName, value: &[u8]) -> Result<u64, HeaderParsingError> {
let text = header_str(name, value)?;
let end = text.parse().map_err(|source| HeaderParsingError {
kind: HeaderParsingErrorType::Parsing {
kind: HeaderType::Integer,
name,
value: text.to_owned(),
},
source: Some(Box::new(source)),
})?;
Ok(end)
}
fn header_str(name: HeaderName, value: &[u8]) -> Result<&str, HeaderParsingError> {
let text = str::from_utf8(value)
.map_err(|source| HeaderParsingError::not_utf8(name, value.to_owned(), source))?;
Ok(text)
}
#[cfg(test)]
mod tests {
use super::{
Global, HeaderName, HeaderParsingError, HeaderParsingErrorType, HeaderType, Present,
RatelimitHeaders,
};
use crate::headers::RatelimitScope;
use http::header::{HeaderMap, HeaderName as HttpHeaderName, HeaderValue};
use static_assertions::{assert_fields, assert_impl_all};
use std::{
error::Error,
fmt::{Debug, Display},
};
assert_fields!(HeaderParsingErrorType::Missing: name);
assert_fields!(HeaderParsingErrorType::NotUtf8: name, value);
assert_fields!(HeaderParsingErrorType::Parsing: kind, name, value);
assert_impl_all!(
HeaderName: Clone,
Copy,
Debug,
Display,
Eq,
PartialEq,
Send,
Sync
);
assert_impl_all!(HeaderParsingErrorType: Debug, Send, Sync);
assert_impl_all!(HeaderParsingError: Error, Send, Sync);
assert_impl_all!(
HeaderType: Clone,
Copy,
Debug,
Display,
Eq,
PartialEq,
Send,
Sync
);
assert_impl_all!(Global: Clone, Debug, Eq, PartialEq, Send, Sync);
assert_impl_all!(Present: Clone, Debug, Eq, PartialEq, Send, Sync);
assert_impl_all!(RatelimitHeaders: Clone, Debug, Send, Sync);
#[test]
fn global() -> Result<(), Box<dyn Error>> {
let map = {
let mut map = HeaderMap::new();
map.insert(
HttpHeaderName::from_static("x-ratelimit-global"),
HeaderValue::from_static("true"),
);
map.insert(
HttpHeaderName::from_static("retry-after"),
HeaderValue::from_static("65"),
);
map
};
let iter = map.iter().map(|(k, v)| (k.as_str(), v.as_bytes()));
let headers = RatelimitHeaders::from_pairs(iter)?;
assert!(matches!(headers, RatelimitHeaders::Global(g) if g.retry_after() == 65));
Ok(())
}
#[test]
fn global_with_scope() -> Result<(), Box<dyn Error>> {
let map = {
let mut map = HeaderMap::new();
map.insert(
HttpHeaderName::from_static("x-ratelimit-global"),
HeaderValue::from_static("true"),
);
map.insert(
HttpHeaderName::from_static("retry-after"),
HeaderValue::from_static("65"),
);
map.insert(
HttpHeaderName::from_static("x-ratelimit-scope"),
HeaderValue::from_static("global"),
);
map
};
let iter = map.iter().map(|(k, v)| (k.as_str(), v.as_bytes()));
let headers = RatelimitHeaders::from_pairs(iter)?;
assert!(matches!(
&headers,
RatelimitHeaders::Global(global)
if global.retry_after() == 65
));
assert!(matches!(
headers,
RatelimitHeaders::Global(global)
if global.scope() == Some(RatelimitScope::Global)
));
Ok(())
}
#[test]
fn present() -> Result<(), Box<dyn Error>> {
let map = {
let mut map = HeaderMap::new();
map.insert(
HttpHeaderName::from_static("x-ratelimit-limit"),
HeaderValue::from_static("10"),
);
map.insert(
HttpHeaderName::from_static("x-ratelimit-remaining"),
HeaderValue::from_static("9"),
);
map.insert(
HttpHeaderName::from_static("x-ratelimit-reset"),
HeaderValue::from_static("1470173023.123"),
);
map.insert(
HttpHeaderName::from_static("x-ratelimit-reset-after"),
HeaderValue::from_static("64.57"),
);
map.insert(
HttpHeaderName::from_static("x-ratelimit-bucket"),
HeaderValue::from_static("abcd1234"),
);
map.insert(
HttpHeaderName::from_static("x-ratelimit-scope"),
HeaderValue::from_static("shared"),
);
map
};
let iter = map.iter().map(|(k, v)| (k.as_str(), v.as_bytes()));
let headers = RatelimitHeaders::from_pairs(iter)?;
assert!(matches!(
&headers,
RatelimitHeaders::Present(present)
if present.bucket.as_deref() == Some("abcd1234")
));
assert!(matches!(
&headers,
RatelimitHeaders::Present(present)
if present.limit == 10
));
assert!(matches!(
&headers,
RatelimitHeaders::Present(present)
if present.remaining == 9
));
assert!(matches!(
&headers,
RatelimitHeaders::Present(present)
if present.reset_after == 64_570
));
assert!(matches!(
&headers,
RatelimitHeaders::Present(present)
if present.reset == 1_470_173_023_123
));
assert!(matches!(
headers,
RatelimitHeaders::Present(present)
if present.scope() == Some(RatelimitScope::Shared)
));
Ok(())
}
#[test]
fn name() {
assert_eq!("x-ratelimit-bucket", HeaderName::BUCKET);
assert_eq!("x-ratelimit-global", HeaderName::GLOBAL);
assert_eq!("x-ratelimit-limit", HeaderName::LIMIT);
assert_eq!("x-ratelimit-remaining", HeaderName::REMAINING);
assert_eq!("x-ratelimit-reset-after", HeaderName::RESET_AFTER);
assert_eq!("x-ratelimit-reset", HeaderName::RESET);
assert_eq!("retry-after", HeaderName::RETRY_AFTER);
assert_eq!("x-ratelimit-scope", HeaderName::SCOPE);
assert_eq!(HeaderName::BUCKET, HeaderName::Bucket.name());
assert_eq!(HeaderName::GLOBAL, HeaderName::Global.name());
assert_eq!(HeaderName::LIMIT, HeaderName::Limit.name());
assert_eq!(HeaderName::REMAINING, HeaderName::Remaining.name());
assert_eq!(HeaderName::RESET_AFTER, HeaderName::ResetAfter.name());
assert_eq!(HeaderName::RESET, HeaderName::Reset.name());
assert_eq!(HeaderName::RETRY_AFTER, HeaderName::RetryAfter.name());
assert_eq!(HeaderName::SCOPE, HeaderName::Scope.name());
}
#[test]
fn type_name() {
assert_eq!("bool", HeaderType::Bool.name());
assert_eq!("float", HeaderType::Float.name());
assert_eq!("integer", HeaderType::Integer.name());
assert_eq!("string", HeaderType::String.name());
}
}