use std::{net::IpAddr, sync::Arc, time::Duration};
use governor::{clock::QuantaClock, state::keyed::DashMapStateStore, RateLimiter};
use mas_config::RateLimitingConfig;
use mas_data_model::User;
use ulid::Ulid;
#[derive(Debug, Clone, thiserror::Error)]
pub enum AccountRecoveryLimitedError {
#[error("Too many account recovery requests for requester {0}")]
Requester(RequesterFingerprint),
#[error("Too many account recovery requests for e-mail {0}")]
Email(String),
}
#[derive(Debug, Clone, Copy, thiserror::Error)]
pub enum PasswordCheckLimitedError {
#[error("Too many password checks for requester {0}")]
Requester(RequesterFingerprint),
#[error("Too many password checks for user {0}")]
User(Ulid),
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum RegistrationLimitedError {
#[error("Too many account registration requests for requester {0}")]
Requester(RequesterFingerprint),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequesterFingerprint {
ip: Option<IpAddr>,
}
impl std::fmt::Display for RequesterFingerprint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(ip) = self.ip {
write!(f, "{ip}")
} else {
write!(f, "(NO CLIENT IP)")
}
}
}
impl RequesterFingerprint {
pub const EMPTY: Self = Self { ip: None };
#[must_use]
pub const fn new(ip: IpAddr) -> Self {
Self { ip: Some(ip) }
}
}
#[derive(Debug, Clone)]
pub struct Limiter {
inner: Arc<LimiterInner>,
}
type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
#[derive(Debug)]
struct LimiterInner {
account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
account_recovery_per_email: KeyedRateLimiter<String>,
password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
password_check_for_user: KeyedRateLimiter<Ulid>,
registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
}
impl LimiterInner {
fn new(config: &RateLimitingConfig) -> Option<Self> {
Some(Self {
account_recovery_per_requester: RateLimiter::keyed(
config.account_recovery.per_ip.to_quota()?,
),
account_recovery_per_email: RateLimiter::keyed(
config.account_recovery.per_address.to_quota()?,
),
password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
})
}
}
impl Limiter {
#[must_use]
pub fn new(config: &RateLimitingConfig) -> Option<Self> {
Some(Self {
inner: Arc::new(LimiterInner::new(config)?),
})
}
pub fn start(&self) {
let this = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
this.inner.account_recovery_per_email.retain_recent();
this.inner.account_recovery_per_requester.retain_recent();
this.inner.password_check_for_requester.retain_recent();
this.inner.password_check_for_user.retain_recent();
this.inner.registration_per_requester.retain_recent();
interval.tick().await;
}
});
}
pub fn check_account_recovery(
&self,
requester: RequesterFingerprint,
email_address: &str,
) -> Result<(), AccountRecoveryLimitedError> {
self.inner
.account_recovery_per_requester
.check_key(&requester)
.map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
let canonical_email = email_address.to_lowercase();
self.inner
.account_recovery_per_email
.check_key(&canonical_email)
.map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
Ok(())
}
pub fn check_password(
&self,
key: RequesterFingerprint,
user: &User,
) -> Result<(), PasswordCheckLimitedError> {
self.inner
.password_check_for_requester
.check_key(&key)
.map_err(|_| PasswordCheckLimitedError::Requester(key))?;
self.inner
.password_check_for_user
.check_key(&user.id)
.map_err(|_| PasswordCheckLimitedError::User(user.id))?;
Ok(())
}
pub fn check_registration(
&self,
requester: RequesterFingerprint,
) -> Result<(), RegistrationLimitedError> {
self.inner
.registration_per_requester
.check_key(&requester)
.map_err(|_| RegistrationLimitedError::Requester(requester))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use mas_data_model::User;
use mas_storage::{clock::MockClock, Clock};
use rand::SeedableRng;
use super::*;
#[test]
fn test_password_check_limiter() {
let now = MockClock::default().now();
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
let requesters: [_; 768] = (0..=255)
.flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
.collect::<Vec<_>>()
.try_into()
.unwrap();
let alice = User {
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
username: "alice".to_owned(),
sub: "123-456".to_owned(),
primary_user_email_id: None,
created_at: now,
locked_at: None,
can_request_admin: false,
};
let bob = User {
id: Ulid::from_datetime_with_source(now.into(), &mut rng),
username: "bob".to_owned(),
sub: "123-456".to_owned(),
primary_user_email_id: None,
created_at: now,
locked_at: None,
can_request_admin: false,
};
assert!(limiter.check_password(requesters[0], &alice).is_ok());
assert!(limiter.check_password(requesters[0], &alice).is_ok());
assert!(limiter.check_password(requesters[0], &alice).is_ok());
assert!(limiter.check_password(requesters[0], &alice).is_err());
assert!(limiter.check_password(requesters[0], &bob).is_err());
assert!(limiter.check_password(requesters[1], &alice).is_ok());
for requester in requesters.iter().skip(2).take(598) {
assert!(limiter.check_password(*requester, &alice).is_ok());
assert!(limiter.check_password(*requester, &alice).is_ok());
assert!(limiter.check_password(*requester, &alice).is_ok());
assert!(limiter.check_password(*requester, &alice).is_err());
}
assert!(limiter.check_password(requesters[600], &alice).is_ok());
assert!(limiter.check_password(requesters[601], &alice).is_ok());
assert!(limiter.check_password(requesters[602], &alice).is_err());
assert!(limiter.check_password(requesters[603], &bob).is_ok());
}
}