1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.

use chrono::{DateTime, Duration, Utc};
use data_encoding::{DecodeError, BASE64URL_NOPAD};
use mas_storage::Clock;
use rand::{Rng, RngCore};
use serde::{Deserialize, Serialize};
use serde_with::{serde_as, TimestampSeconds};
use thiserror::Error;

use crate::cookies::{CookieDecodeError, CookieJar};

/// Failed to validate CSRF token
#[derive(Debug, Error)]
pub enum CsrfError {
    /// The token in the form did not match the token in the cookie
    #[error("CSRF token mismatch")]
    Mismatch,

    /// The token in the form did not match the token in the cookie
    #[error("Missing CSRF cookie")]
    Missing,

    /// Failed to decode the token
    #[error("could not decode CSRF cookie")]
    DecodeCookie(#[from] CookieDecodeError),

    /// The token expired
    #[error("CSRF token expired")]
    Expired,

    /// Failed to decode the token
    #[error("could not decode CSRF token")]
    Decode(#[from] DecodeError),
}

/// A CSRF token
#[serde_as]
#[derive(Serialize, Deserialize, Debug)]
pub struct CsrfToken {
    #[serde_as(as = "TimestampSeconds<i64>")]
    expiration: DateTime<Utc>,
    token: [u8; 32],
}

impl CsrfToken {
    /// Create a new token from a defined value valid for a specified duration
    fn new(token: [u8; 32], now: DateTime<Utc>, ttl: Duration) -> Self {
        let expiration = now + ttl;
        Self { expiration, token }
    }

    /// Generate a new random token valid for a specified duration
    fn generate(now: DateTime<Utc>, mut rng: impl Rng, ttl: Duration) -> Self {
        let token = rng.gen();
        Self::new(token, now, ttl)
    }

    /// Generate a new token with the same value but an up to date expiration
    fn refresh(self, now: DateTime<Utc>, ttl: Duration) -> Self {
        Self::new(self.token, now, ttl)
    }

    /// Get the value to include in HTML forms
    #[must_use]
    pub fn form_value(&self) -> String {
        BASE64URL_NOPAD.encode(&self.token[..])
    }

    /// Verifies that the value got from an HTML form matches this token
    ///
    /// # Errors
    ///
    /// Returns an error if the value in the form does not match this token
    pub fn verify_form_value(&self, form_value: &str) -> Result<(), CsrfError> {
        let form_value = BASE64URL_NOPAD.decode(form_value.as_bytes())?;
        if self.token[..] == form_value {
            Ok(())
        } else {
            Err(CsrfError::Mismatch)
        }
    }

    fn verify_expiration(self, now: DateTime<Utc>) -> Result<Self, CsrfError> {
        if now < self.expiration {
            Ok(self)
        } else {
            Err(CsrfError::Expired)
        }
    }
}

// A CSRF-protected form
#[derive(Deserialize)]
pub struct ProtectedForm<T> {
    csrf: String,

    #[serde(flatten)]
    inner: T,
}

pub trait CsrfExt {
    /// Get the current CSRF token out of the cookie jar, generating a new one
    /// if necessary
    fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
    where
        R: RngCore,
        C: Clock;

    /// Verify that the given CSRF-protected form is valid, returning the inner
    /// value
    ///
    /// # Errors
    ///
    /// Returns an error if the CSRF cookie is missing or if the value in the
    /// form is invalid
    fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
    where
        C: Clock;
}

impl CsrfExt for CookieJar {
    fn csrf_token<C, R>(self, clock: &C, rng: R) -> (CsrfToken, Self)
    where
        R: RngCore,
        C: Clock,
    {
        let now = clock.now();
        let maybe_token = match self.load::<CsrfToken>("csrf") {
            Ok(Some(token)) => {
                let token = token.verify_expiration(now);

                // If the token is expired, just ignore it
                token.ok()
            }
            Ok(None) => None,
            Err(e) => {
                tracing::warn!("Failed to decode CSRF cookie: {}", e);
                None
            }
        };

        let token = maybe_token.map_or_else(
            || CsrfToken::generate(now, rng, Duration::try_hours(1).unwrap()),
            |token| token.refresh(now, Duration::try_hours(1).unwrap()),
        );

        let jar = self.save("csrf", &token, false);
        (token, jar)
    }

    fn verify_form<C, T>(&self, clock: &C, form: ProtectedForm<T>) -> Result<T, CsrfError>
    where
        C: Clock,
    {
        let token: CsrfToken = self.load("csrf")?.ok_or(CsrfError::Missing)?;
        let token = token.verify_expiration(clock.now())?;
        token.verify_form_value(&form.csrf)?;
        Ok(form.inner)
    }
}