use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use mas_storage::Clock;
use rand::{Rng, RngCore};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error;
use crate::cookies::{CookieDecodeError, CookieJar};
#[derive(Debug, Error)]
pub enum CsrfError {
#[error("CSRF token mismatch")]
Mismatch,
#[error("Missing CSRF cookie")]
Missing,
#[error("could not decode CSRF cookie")]
DecodeCookie(#[from] CookieDecodeError),
#[error("CSRF token expired")]
Expired,
#[error("could not decode CSRF token")]
Decode(#[from] DecodeError),
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct CsrfToken {
#[serde_as(as = "TimestampSeconds<i64>")]
expiration: DateTime<Utc>,
token: [u8; 32],
}
impl CsrfToken {
fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
let expiration = now + ttl;
Self { expiration, token }
}
fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
let token = rng.gen();
Self::new(token, now, ttl)
}
fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
Self::new(self.token, now, ttl)
}
#[must_use]
pub fn form_value(&self) -> String {
BASE64URL_NOPAD.encode(&self.token[..])
}
pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
if self.token[..] == form_value {
Ok(())
} else {
Err(CsrfError::Mismatch)
}
}
fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
if now < self.expiration {
Ok(self)
} else {
Err(CsrfError::Expired)
}
}
}
#[derive(Deserialize)]
pub struct ProtectedForm<T> {
csrf: String,
#[serde(flatten)]
inner: T,
}
pub trait CsrfExt {
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
where
R: RngCore,
C: Clock;
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
where
C: Clock;
}
impl CsrfExt for CookieJar {
fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
where
R: RngCore,
C: Clock,
{
let now = clock.now();
let maybe_token = match self.load::<CsrfToken>("csrf") {
Ok(Some(token)) => {
let token = token.verify_expiration(now);
token.ok()
}
Ok(None) => None,
Err(e) => {
tracing::warn!("Failed to decode CSRF cookie: {}", e);
None
}
};
let token = maybe_token.map_or_else(
|| CsrfToken::generate(now, rng, Duration::try_hours(1).unwrap()),
|token| token.refresh(now, Duration::try_hours(1).unwrap()),
);
let jar = self.save("csrf", &token, false);
(token, jar)
}
fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
where
C: Clock,
{
let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
let token = token.verify_expiration(clock.now())?;
token.verify_form_value(&form.csrf)?;
Ok(form.inner)
}
}