Skip to main content

mas_storage_pg/compat/
refresh_token.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{
10    Clock, CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
11    UlidExt as _,
12};
13use mas_storage::compat::CompatRefreshTokenRepository;
14use rand::RngCore;
15use sqlx::PgConnection;
16use ulid::Ulid;
17use uuid::Uuid;
18
19use crate::{DatabaseError, tracing::ExecuteExt};
20
21/// An implementation of [`CompatRefreshTokenRepository`] for a PostgreSQL
22/// connection
23pub struct PgCompatRefreshTokenRepository<'c> {
24    conn: &'c mut PgConnection,
25}
26
27impl<'c> PgCompatRefreshTokenRepository<'c> {
28    /// Create a new [`PgCompatRefreshTokenRepository`] from an active
29    /// PostgreSQL connection
30    pub fn new(conn: &'c mut PgConnection) -> Self {
31        Self { conn }
32    }
33}
34
35struct CompatRefreshTokenLookup {
36    compat_refresh_token_id: Uuid,
37    refresh_token: String,
38    created_at: DateTime<Utc>,
39    consumed_at: Option<DateTime<Utc>>,
40    compat_access_token_id: Uuid,
41    compat_session_id: Uuid,
42}
43
44impl From<CompatRefreshTokenLookup> for CompatRefreshToken {
45    fn from(value: CompatRefreshTokenLookup) -> Self {
46        let state = match value.consumed_at {
47            Some(consumed_at) => CompatRefreshTokenState::Consumed { consumed_at },
48            None => CompatRefreshTokenState::Valid,
49        };
50
51        Self {
52            id: value.compat_refresh_token_id.into(),
53            state,
54            session_id: value.compat_session_id.into(),
55            token: value.refresh_token,
56            created_at: value.created_at,
57            access_token_id: value.compat_access_token_id.into(),
58        }
59    }
60}
61
62#[async_trait]
63impl CompatRefreshTokenRepository for PgCompatRefreshTokenRepository<'_> {
64    type Error = DatabaseError;
65
66    #[tracing::instrument(
67        name = "db.compat_refresh_token.lookup",
68        skip_all,
69        fields(
70            db.query.text,
71            compat_refresh_token.id = %id,
72        ),
73        err,
74    )]
75    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatRefreshToken>, Self::Error> {
76        let res = sqlx::query_as!(
77            CompatRefreshTokenLookup,
78            r#"
79                SELECT compat_refresh_token_id
80                     , refresh_token
81                     , created_at
82                     , consumed_at
83                     , compat_session_id
84                     , compat_access_token_id
85
86                FROM compat_refresh_tokens
87
88                WHERE compat_refresh_token_id = $1
89            "#,
90            Uuid::from(id),
91        )
92        .traced()
93        .fetch_optional(&mut *self.conn)
94        .await?;
95
96        let Some(res) = res else { return Ok(None) };
97
98        Ok(Some(res.into()))
99    }
100
101    #[tracing::instrument(
102        name = "db.compat_refresh_token.find_by_token",
103        skip_all,
104        fields(
105            db.query.text,
106        ),
107        err,
108    )]
109    async fn find_by_token(
110        &mut self,
111        refresh_token: &str,
112    ) -> Result<Option<CompatRefreshToken>, Self::Error> {
113        let res = sqlx::query_as!(
114            CompatRefreshTokenLookup,
115            r#"
116                SELECT compat_refresh_token_id
117                     , refresh_token
118                     , created_at
119                     , consumed_at
120                     , compat_session_id
121                     , compat_access_token_id
122
123                FROM compat_refresh_tokens
124
125                WHERE refresh_token = $1
126            "#,
127            refresh_token,
128        )
129        .traced()
130        .fetch_optional(&mut *self.conn)
131        .await?;
132
133        let Some(res) = res else { return Ok(None) };
134
135        Ok(Some(res.into()))
136    }
137
138    #[tracing::instrument(
139        name = "db.compat_refresh_token.add",
140        skip_all,
141        fields(
142            db.query.text,
143            compat_refresh_token.id,
144            %compat_session.id,
145            user.id = %compat_session.user_id,
146        ),
147        err,
148    )]
149    async fn add(
150        &mut self,
151        rng: &mut (dyn RngCore + Send),
152        clock: &dyn Clock,
153        compat_session: &CompatSession,
154        compat_access_token: &CompatAccessToken,
155        token: String,
156    ) -> Result<CompatRefreshToken, Self::Error> {
157        let created_at = clock.now();
158        let id = Ulid::from_datetime_with_rng(created_at, rng);
159        tracing::Span::current().record("compat_refresh_token.id", tracing::field::display(id));
160
161        sqlx::query!(
162            r#"
163                INSERT INTO compat_refresh_tokens
164                    (compat_refresh_token_id, compat_session_id,
165                     compat_access_token_id, refresh_token, created_at)
166                VALUES ($1, $2, $3, $4, $5)
167            "#,
168            Uuid::from(id),
169            Uuid::from(compat_session.id),
170            Uuid::from(compat_access_token.id),
171            token,
172            created_at,
173        )
174        .traced()
175        .execute(&mut *self.conn)
176        .await?;
177
178        Ok(CompatRefreshToken {
179            id,
180            state: CompatRefreshTokenState::default(),
181            session_id: compat_session.id,
182            access_token_id: compat_access_token.id,
183            token,
184            created_at,
185        })
186    }
187
188    #[tracing::instrument(
189        name = "db.compat_refresh_token.consume_and_replace",
190        skip_all,
191        fields(
192            db.query.text,
193            %compat_refresh_token.id,
194            %successor_compat_refresh_token.id,
195            compat_session.id = %compat_refresh_token.session_id,
196        ),
197        err,
198    )]
199    async fn consume_and_replace(
200        &mut self,
201        clock: &dyn Clock,
202        compat_refresh_token: CompatRefreshToken,
203        successor_compat_refresh_token: &CompatRefreshToken,
204    ) -> Result<CompatRefreshToken, Self::Error> {
205        if compat_refresh_token.session_id != successor_compat_refresh_token.session_id {
206            return Err(DatabaseError::invalid_operation());
207        }
208
209        let consumed_at = clock.now();
210        let res = sqlx::query!(
211            r#"
212                UPDATE compat_refresh_tokens
213                SET consumed_at = $2
214                WHERE compat_session_id = $1
215                  AND consumed_at IS NULL
216                  AND compat_refresh_token_id <> $3
217            "#,
218            Uuid::from(compat_refresh_token.session_id),
219            consumed_at,
220            Uuid::from(successor_compat_refresh_token.id),
221        )
222        .traced()
223        .execute(&mut *self.conn)
224        .await?;
225
226        // This can affect multiple rows in case we've imported refresh tokens
227        // from Synapse. What we care about is that it at least affected one,
228        // which is what we're checking here
229        if res.rows_affected() == 0 {
230            return Err(DatabaseError::RowsAffected {
231                expected: 1,
232                actual: 0,
233            });
234        }
235
236        let compat_refresh_token = compat_refresh_token
237            .consume(consumed_at)
238            .map_err(DatabaseError::to_invalid_operation)?;
239
240        Ok(compat_refresh_token)
241    }
242}