1use std::{num::NonZeroU32, time::Duration};
8
9use governor::Quota;
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize, de::Error as _};
12
13use crate::ConfigurationSection;
14
15#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
17pub struct RateLimitingConfig {
18    #[serde(default)]
20    pub account_recovery: AccountRecoveryRateLimitingConfig,
21
22    #[serde(default)]
24    pub login: LoginRateLimitingConfig,
25
26    #[serde(default = "default_registration")]
29    pub registration: RateLimiterConfiguration,
30
31    #[serde(default)]
33    pub email_authentication: EmailauthenticationRateLimitingConfig,
34}
35
36#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
37pub struct LoginRateLimitingConfig {
38    #[serde(default = "default_login_per_ip")]
45    pub per_ip: RateLimiterConfiguration,
46
47    #[serde(default = "default_login_per_account")]
56    pub per_account: RateLimiterConfiguration,
57}
58
59#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
60pub struct AccountRecoveryRateLimitingConfig {
61    #[serde(default = "default_account_recovery_per_ip")]
67    pub per_ip: RateLimiterConfiguration,
68
69    #[serde(default = "default_account_recovery_per_address")]
75    pub per_address: RateLimiterConfiguration,
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
79pub struct EmailauthenticationRateLimitingConfig {
80    #[serde(default = "default_email_authentication_per_ip")]
84    pub per_ip: RateLimiterConfiguration,
85
86    #[serde(default = "default_email_authentication_per_address")]
92    pub per_address: RateLimiterConfiguration,
93
94    #[serde(default = "default_email_authentication_emails_per_session")]
98    pub emails_per_session: RateLimiterConfiguration,
99
100    #[serde(default = "default_email_authentication_attempt_per_session")]
104    pub attempt_per_session: RateLimiterConfiguration,
105}
106
107#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
108pub struct RateLimiterConfiguration {
109    pub burst: NonZeroU32,
112    pub per_second: f64,
115}
116
117impl ConfigurationSection for RateLimitingConfig {
118    const PATH: Option<&'static str> = Some("rate_limiting");
119
120    fn validate(
121        &self,
122        figment: &figment::Figment,
123    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
124        let metadata = figment.find_metadata(Self::PATH.unwrap());
125
126        let error_on_field = |mut error: figment::error::Error, field: &'static str| {
127            error.metadata = metadata.cloned();
128            error.profile = Some(figment::Profile::Default);
129            error.path = vec![Self::PATH.unwrap().to_owned(), field.to_owned()];
130            error
131        };
132
133        let error_on_nested_field =
134            |mut error: figment::error::Error, container: &'static str, field: &'static str| {
135                error.metadata = metadata.cloned();
136                error.profile = Some(figment::Profile::Default);
137                error.path = vec![
138                    Self::PATH.unwrap().to_owned(),
139                    container.to_owned(),
140                    field.to_owned(),
141                ];
142                error
143            };
144
145        let error_on_limiter =
147            |limiter: &RateLimiterConfiguration| -> Option<figment::error::Error> {
148                let recip = limiter.per_second.recip();
149                if recip < 1.0e-9 || !recip.is_finite() {
151                    return Some(figment::error::Error::custom(
152                        "`per_second` must be a number that is more than zero and less than 1_000_000_000 (1e9)",
153                    ));
154                }
155
156                None
157            };
158
159        if let Some(error) = error_on_limiter(&self.account_recovery.per_ip) {
160            return Err(error_on_nested_field(error, "account_recovery", "per_ip").into());
161        }
162        if let Some(error) = error_on_limiter(&self.account_recovery.per_address) {
163            return Err(error_on_nested_field(error, "account_recovery", "per_address").into());
164        }
165
166        if let Some(error) = error_on_limiter(&self.registration) {
167            return Err(error_on_field(error, "registration").into());
168        }
169
170        if let Some(error) = error_on_limiter(&self.login.per_ip) {
171            return Err(error_on_nested_field(error, "login", "per_ip").into());
172        }
173        if let Some(error) = error_on_limiter(&self.login.per_account) {
174            return Err(error_on_nested_field(error, "login", "per_account").into());
175        }
176
177        Ok(())
178    }
179}
180
181impl RateLimitingConfig {
182    pub(crate) fn is_default(config: &RateLimitingConfig) -> bool {
183        config == &RateLimitingConfig::default()
184    }
185}
186
187impl RateLimiterConfiguration {
188    pub fn to_quota(self) -> Option<Quota> {
189        let reciprocal = self.per_second.recip();
190        if !reciprocal.is_finite() {
191            return None;
192        }
193        Some(Quota::with_period(Duration::from_secs_f64(reciprocal))?.allow_burst(self.burst))
194    }
195}
196
197fn default_login_per_ip() -> RateLimiterConfiguration {
198    RateLimiterConfiguration {
199        burst: NonZeroU32::new(3).unwrap(),
200        per_second: 3.0 / 60.0,
201    }
202}
203
204fn default_login_per_account() -> RateLimiterConfiguration {
205    RateLimiterConfiguration {
206        burst: NonZeroU32::new(1800).unwrap(),
207        per_second: 1800.0 / 3600.0,
208    }
209}
210
211fn default_registration() -> RateLimiterConfiguration {
212    RateLimiterConfiguration {
213        burst: NonZeroU32::new(3).unwrap(),
214        per_second: 3.0 / 3600.0,
215    }
216}
217
218fn default_account_recovery_per_ip() -> RateLimiterConfiguration {
219    RateLimiterConfiguration {
220        burst: NonZeroU32::new(3).unwrap(),
221        per_second: 3.0 / 3600.0,
222    }
223}
224
225fn default_account_recovery_per_address() -> RateLimiterConfiguration {
226    RateLimiterConfiguration {
227        burst: NonZeroU32::new(3).unwrap(),
228        per_second: 1.0 / 3600.0,
229    }
230}
231
232fn default_email_authentication_per_ip() -> RateLimiterConfiguration {
233    RateLimiterConfiguration {
234        burst: NonZeroU32::new(5).unwrap(),
235        per_second: 1.0 / 60.0,
236    }
237}
238
239fn default_email_authentication_per_address() -> RateLimiterConfiguration {
240    RateLimiterConfiguration {
241        burst: NonZeroU32::new(3).unwrap(),
242        per_second: 1.0 / 3600.0,
243    }
244}
245
246fn default_email_authentication_emails_per_session() -> RateLimiterConfiguration {
247    RateLimiterConfiguration {
248        burst: NonZeroU32::new(2).unwrap(),
249        per_second: 1.0 / 300.0,
250    }
251}
252
253fn default_email_authentication_attempt_per_session() -> RateLimiterConfiguration {
254    RateLimiterConfiguration {
255        burst: NonZeroU32::new(10).unwrap(),
256        per_second: 1.0 / 60.0,
257    }
258}
259
260impl Default for RateLimitingConfig {
261    fn default() -> Self {
262        RateLimitingConfig {
263            login: LoginRateLimitingConfig::default(),
264            registration: default_registration(),
265            account_recovery: AccountRecoveryRateLimitingConfig::default(),
266            email_authentication: EmailauthenticationRateLimitingConfig::default(),
267        }
268    }
269}
270
271impl Default for LoginRateLimitingConfig {
272    fn default() -> Self {
273        LoginRateLimitingConfig {
274            per_ip: default_login_per_ip(),
275            per_account: default_login_per_account(),
276        }
277    }
278}
279
280impl Default for AccountRecoveryRateLimitingConfig {
281    fn default() -> Self {
282        AccountRecoveryRateLimitingConfig {
283            per_ip: default_account_recovery_per_ip(),
284            per_address: default_account_recovery_per_address(),
285        }
286    }
287}
288
289impl Default for EmailauthenticationRateLimitingConfig {
290    fn default() -> Self {
291        EmailauthenticationRateLimitingConfig {
292            per_ip: default_email_authentication_per_ip(),
293            per_address: default_email_authentication_per_address(),
294            emails_per_session: default_email_authentication_emails_per_session(),
295            attempt_per_session: default_email_authentication_attempt_per_session(),
296        }
297    }
298}