mas_storage_pg/user/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the user-related
8//! repositories
9
10use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query, extension::postgres::PgExpr as _};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21    DatabaseError,
22    filter::{Filter, StatementExt},
23    iden::Users,
24    pagination::QueryBuilderExt,
25    tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40    email::PgUserEmailRepository, password::PgUserPasswordRepository,
41    recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42    registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43    terms::PgUserTermsRepository,
44};
45
46/// An implementation of [`UserRepository`] for a PostgreSQL connection
47pub struct PgUserRepository<'c> {
48    conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52    /// Create a new [`PgUserRepository`] from an active PostgreSQL connection
53    pub fn new(conn: &'c mut PgConnection) -> Self {
54        Self { conn }
55    }
56}
57
58mod priv_ {
59    // The enum_def macro generates a public enum, which we don't want, because it
60    // triggers the missing docs warning
61    #![allow(missing_docs)]
62
63    use chrono::{DateTime, Utc};
64    use sea_query::enum_def;
65    use uuid::Uuid;
66
67    #[derive(Debug, Clone, sqlx::FromRow)]
68    #[enum_def]
69    pub(super) struct UserLookup {
70        pub(super) user_id: Uuid,
71        pub(super) username: String,
72        pub(super) created_at: DateTime<Utc>,
73        pub(super) locked_at: Option<DateTime<Utc>>,
74        pub(super) deactivated_at: Option<DateTime<Utc>>,
75        pub(super) can_request_admin: bool,
76        pub(super) is_guest: bool,
77    }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83    fn from(value: UserLookup) -> Self {
84        let id = value.user_id.into();
85        Self {
86            id,
87            username: value.username,
88            sub: id.to_string(),
89            created_at: value.created_at,
90            locked_at: value.locked_at,
91            deactivated_at: value.deactivated_at,
92            can_request_admin: value.can_request_admin,
93            is_guest: value.is_guest,
94        }
95    }
96}
97
98impl Filter for UserFilter<'_> {
99    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
100        sea_query::Condition::all()
101            .add_option(self.state().map(|state| {
102                match state {
103                    mas_storage::user::UserState::Deactivated => {
104                        Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
105                    }
106                    mas_storage::user::UserState::Locked => {
107                        Expr::col((Users::Table, Users::LockedAt)).is_not_null()
108                    }
109                    mas_storage::user::UserState::Active => {
110                        Expr::col((Users::Table, Users::LockedAt))
111                            .is_null()
112                            .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
113                    }
114                }
115            }))
116            .add_option(self.can_request_admin().map(|can_request_admin| {
117                Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
118            }))
119            .add_option(
120                self.is_guest()
121                    .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
122            )
123            .add_option(self.search().map(|search| {
124                Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
125            }))
126    }
127}
128
129#[async_trait]
130impl UserRepository for PgUserRepository<'_> {
131    type Error = DatabaseError;
132
133    #[tracing::instrument(
134        name = "db.user.lookup",
135        skip_all,
136        fields(
137            db.query.text,
138            user.id = %id,
139        ),
140        err,
141    )]
142    async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
143        let res = sqlx::query_as!(
144            UserLookup,
145            r#"
146                SELECT user_id
147                     , username
148                     , created_at
149                     , locked_at
150                     , deactivated_at
151                     , can_request_admin
152                     , is_guest
153                FROM users
154                WHERE user_id = $1
155            "#,
156            Uuid::from(id),
157        )
158        .traced()
159        .fetch_optional(&mut *self.conn)
160        .await?;
161
162        let Some(res) = res else { return Ok(None) };
163
164        Ok(Some(res.into()))
165    }
166
167    #[tracing::instrument(
168        name = "db.user.find_by_username",
169        skip_all,
170        fields(
171            db.query.text,
172            user.username = username,
173        ),
174        err,
175    )]
176    async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
177        // We may have multiple users with the same username, but with a different
178        // casing. In this case, we want to return the one which matches the exact
179        // casing
180        let res = sqlx::query_as!(
181            UserLookup,
182            r#"
183                SELECT user_id
184                     , username
185                     , created_at
186                     , locked_at
187                     , deactivated_at
188                     , can_request_admin
189                     , is_guest
190                FROM users
191                WHERE LOWER(username) = LOWER($1)
192            "#,
193            username,
194        )
195        .traced()
196        .fetch_all(&mut *self.conn)
197        .await?;
198
199        match &res[..] {
200            // Happy path: there is only one user matching the username…
201            [user] => Ok(Some(user.clone().into())),
202            // …or none.
203            [] => Ok(None),
204            list => {
205                // If there are multiple users with the same username, we want to
206                // return the one which matches the exact casing
207                if let Some(user) = list.iter().find(|user| user.username == username) {
208                    Ok(Some(user.clone().into()))
209                } else {
210                    // If none match exactly, we prefer to return nothing
211                    Ok(None)
212                }
213            }
214        }
215    }
216
217    #[tracing::instrument(
218        name = "db.user.add",
219        skip_all,
220        fields(
221            db.query.text,
222            user.username = username,
223            user.id,
224        ),
225        err,
226    )]
227    async fn add(
228        &mut self,
229        rng: &mut (dyn RngCore + Send),
230        clock: &dyn Clock,
231        username: String,
232    ) -> Result<User, Self::Error> {
233        let created_at = clock.now();
234        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
235        tracing::Span::current().record("user.id", tracing::field::display(id));
236
237        let res = sqlx::query!(
238            r#"
239                INSERT INTO users (user_id, username, created_at)
240                VALUES ($1, $2, $3)
241                ON CONFLICT (username) DO NOTHING
242            "#,
243            Uuid::from(id),
244            username,
245            created_at,
246        )
247        .traced()
248        .execute(&mut *self.conn)
249        .await?;
250
251        // If the user already exists, want to return an error but not poison the
252        // transaction
253        DatabaseError::ensure_affected_rows(&res, 1)?;
254
255        Ok(User {
256            id,
257            username,
258            sub: id.to_string(),
259            created_at,
260            locked_at: None,
261            deactivated_at: None,
262            can_request_admin: false,
263            is_guest: false,
264        })
265    }
266
267    #[tracing::instrument(
268        name = "db.user.exists",
269        skip_all,
270        fields(
271            db.query.text,
272            user.username = username,
273        ),
274        err,
275    )]
276    async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
277        let exists = sqlx::query_scalar!(
278            r#"
279                SELECT EXISTS(
280                    SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
281                ) AS "exists!"
282            "#,
283            username
284        )
285        .traced()
286        .fetch_one(&mut *self.conn)
287        .await?;
288
289        Ok(exists)
290    }
291
292    #[tracing::instrument(
293        name = "db.user.lock",
294        skip_all,
295        fields(
296            db.query.text,
297            %user.id,
298        ),
299        err,
300    )]
301    async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
302        if user.locked_at.is_some() {
303            return Ok(user);
304        }
305
306        let locked_at = clock.now();
307        let res = sqlx::query!(
308            r#"
309                UPDATE users
310                SET locked_at = $1
311                WHERE user_id = $2
312            "#,
313            locked_at,
314            Uuid::from(user.id),
315        )
316        .traced()
317        .execute(&mut *self.conn)
318        .await?;
319
320        DatabaseError::ensure_affected_rows(&res, 1)?;
321
322        user.locked_at = Some(locked_at);
323
324        Ok(user)
325    }
326
327    #[tracing::instrument(
328        name = "db.user.unlock",
329        skip_all,
330        fields(
331            db.query.text,
332            %user.id,
333        ),
334        err,
335    )]
336    async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
337        if user.locked_at.is_none() {
338            return Ok(user);
339        }
340
341        let res = sqlx::query!(
342            r#"
343                UPDATE users
344                SET locked_at = NULL
345                WHERE user_id = $1
346            "#,
347            Uuid::from(user.id),
348        )
349        .traced()
350        .execute(&mut *self.conn)
351        .await?;
352
353        DatabaseError::ensure_affected_rows(&res, 1)?;
354
355        user.locked_at = None;
356
357        Ok(user)
358    }
359
360    #[tracing::instrument(
361        name = "db.user.deactivate",
362        skip_all,
363        fields(
364            db.query.text,
365            %user.id,
366        ),
367        err,
368    )]
369    async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
370        if user.deactivated_at.is_some() {
371            return Ok(user);
372        }
373
374        let deactivated_at = clock.now();
375        let res = sqlx::query!(
376            r#"
377                UPDATE users
378                SET deactivated_at = $2
379                WHERE user_id = $1
380                  AND deactivated_at IS NULL
381            "#,
382            Uuid::from(user.id),
383            deactivated_at,
384        )
385        .traced()
386        .execute(&mut *self.conn)
387        .await?;
388
389        DatabaseError::ensure_affected_rows(&res, 1)?;
390
391        user.deactivated_at = Some(deactivated_at);
392
393        Ok(user)
394    }
395
396    #[tracing::instrument(
397        name = "db.user.reactivate",
398        skip_all,
399        fields(
400            db.query.text,
401            %user.id,
402        ),
403        err,
404    )]
405    async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
406        if user.deactivated_at.is_none() {
407            return Ok(user);
408        }
409
410        let res = sqlx::query!(
411            r#"
412                UPDATE users
413                SET deactivated_at = NULL
414                WHERE user_id = $1
415            "#,
416            Uuid::from(user.id),
417        )
418        .traced()
419        .execute(&mut *self.conn)
420        .await?;
421
422        DatabaseError::ensure_affected_rows(&res, 1)?;
423
424        user.deactivated_at = None;
425
426        Ok(user)
427    }
428
429    #[tracing::instrument(
430        name = "db.user.set_can_request_admin",
431        skip_all,
432        fields(
433            db.query.text,
434            %user.id,
435            user.can_request_admin = can_request_admin,
436        ),
437        err,
438    )]
439    async fn set_can_request_admin(
440        &mut self,
441        mut user: User,
442        can_request_admin: bool,
443    ) -> Result<User, Self::Error> {
444        let res = sqlx::query!(
445            r#"
446                UPDATE users
447                SET can_request_admin = $2
448                WHERE user_id = $1
449            "#,
450            Uuid::from(user.id),
451            can_request_admin,
452        )
453        .traced()
454        .execute(&mut *self.conn)
455        .await?;
456
457        DatabaseError::ensure_affected_rows(&res, 1)?;
458
459        user.can_request_admin = can_request_admin;
460
461        Ok(user)
462    }
463
464    #[tracing::instrument(
465        name = "db.user.list",
466        skip_all,
467        fields(
468            db.query.text,
469        ),
470        err,
471    )]
472    async fn list(
473        &mut self,
474        filter: UserFilter<'_>,
475        pagination: mas_storage::Pagination,
476    ) -> Result<mas_storage::Page<User>, Self::Error> {
477        let (sql, arguments) = Query::select()
478            .expr_as(
479                Expr::col((Users::Table, Users::UserId)),
480                UserLookupIden::UserId,
481            )
482            .expr_as(
483                Expr::col((Users::Table, Users::Username)),
484                UserLookupIden::Username,
485            )
486            .expr_as(
487                Expr::col((Users::Table, Users::CreatedAt)),
488                UserLookupIden::CreatedAt,
489            )
490            .expr_as(
491                Expr::col((Users::Table, Users::LockedAt)),
492                UserLookupIden::LockedAt,
493            )
494            .expr_as(
495                Expr::col((Users::Table, Users::DeactivatedAt)),
496                UserLookupIden::DeactivatedAt,
497            )
498            .expr_as(
499                Expr::col((Users::Table, Users::CanRequestAdmin)),
500                UserLookupIden::CanRequestAdmin,
501            )
502            .expr_as(
503                Expr::col((Users::Table, Users::IsGuest)),
504                UserLookupIden::IsGuest,
505            )
506            .from(Users::Table)
507            .apply_filter(filter)
508            .generate_pagination((Users::Table, Users::UserId), pagination)
509            .build_sqlx(PostgresQueryBuilder);
510
511        let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
512            .traced()
513            .fetch_all(&mut *self.conn)
514            .await?;
515
516        let page = pagination.process(edges).map(User::from);
517
518        Ok(page)
519    }
520
521    #[tracing::instrument(
522        name = "db.user.count",
523        skip_all,
524        fields(
525            db.query.text,
526        ),
527        err,
528    )]
529    async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
530        let (sql, arguments) = Query::select()
531            .expr(Expr::col((Users::Table, Users::UserId)).count())
532            .from(Users::Table)
533            .apply_filter(filter)
534            .build_sqlx(PostgresQueryBuilder);
535
536        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
537            .traced()
538            .fetch_one(&mut *self.conn)
539            .await?;
540
541        count
542            .try_into()
543            .map_err(DatabaseError::to_invalid_operation)
544    }
545
546    #[tracing::instrument(
547        name = "db.user.acquire_lock_for_sync",
548        skip_all,
549        fields(
550            db.query.text,
551            user.id = %user.id,
552        ),
553        err,
554    )]
555    async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
556        // XXX: this lock isn't stictly scoped to users, but as we don't use many
557        // postgres advisory locks, it's fine for now. Later on, we could use row-level
558        // locks to make sure we don't get into trouble
559
560        // Convert the user ID to a u128 and grab the lower 64 bits
561        // As this includes 64bit of the random part of the ULID, it should be random
562        // enough to not collide
563        let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
564
565        // Use a PG advisory lock, which will be released when the transaction is
566        // committed or rolled back
567        sqlx::query!(
568            r#"
569                SELECT pg_advisory_xact_lock($1)
570            "#,
571            lock_id,
572        )
573        .traced()
574        .execute(&mut *self.conn)
575        .await?;
576
577        Ok(())
578    }
579}