twilight_model/util/
mustbe.rs

1//! A struct that only deserializes from one specific boolean value.
2//!
3//! This module is heavily based upon
4//! <https://github.com/dtolnay/monostate>.
5
6use std::fmt;
7
8use serde::{
9    de::{Error, Unexpected, Visitor},
10    Deserialize,
11};
12
13/// Struct that will only serialize from the bool specified as `T`.
14pub struct MustBeBool<const T: bool>;
15
16impl<'de, const T: bool> Deserialize<'de> for MustBeBool<T> {
17    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
18    where
19        D: serde::Deserializer<'de>,
20    {
21        struct MustBeBoolVisitor(bool);
22
23        impl Visitor<'_> for MustBeBoolVisitor {
24            type Value = ();
25
26            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
27                write!(formatter, "boolean `{}`", self.0)
28            }
29
30            fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
31            where
32                E: Error,
33            {
34                if v == self.0 {
35                    Ok(())
36                } else {
37                    Err(E::invalid_value(Unexpected::Bool(v), &self))
38                }
39            }
40        }
41
42        deserializer
43            .deserialize_any(MustBeBoolVisitor(T))
44            .map(|()| MustBeBool)
45    }
46}
47
48#[cfg(test)]
49mod tests {
50    use super::MustBeBool;
51
52    use serde::Deserialize;
53
54    #[derive(Deserialize)]
55    struct MTrue {
56        #[allow(unused)]
57        m: MustBeBool<true>,
58    }
59
60    #[derive(Deserialize)]
61    struct MFalse {
62        #[allow(unused)]
63        m: MustBeBool<false>,
64    }
65
66    #[derive(Deserialize)]
67    #[serde(untagged)]
68    enum TestEnum {
69        VariantTrue(MTrue),
70        VariantFalse(MFalse),
71    }
72
73    #[test]
74    #[allow(unused)]
75    fn true_false_enum() {
76        let json_1 = r#"{ "m": false }"#;
77        let result_1 = serde_json::from_str::<TestEnum>(json_1).unwrap();
78        assert!(matches!(result_1, TestEnum::VariantFalse(_)));
79
80        let json_2 = r#"{ "m": true }"#;
81        let result_2 = serde_json::from_str::<TestEnum>(json_2).unwrap();
82        assert!(matches!(result_2, TestEnum::VariantTrue(_)));
83    }
84}