Skip to main content

mas_storage_pg/
lib.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8//! An implementation of the storage traits for a PostgreSQL database
9//!
10//! This backend uses [`sqlx`] to interact with the database. Most queries are
11//! type-checked, using introspection data recorded in the `sqlx-data.json`
12//! file. This file is generated by the `sqlx` CLI tool, and should be updated
13//! whenever the database schema changes, or new queries are added.
14//!
15//! # Implementing a new repository
16//!
17//! When a new repository is defined in [`mas_storage`], it should be
18//! implemented here, with the PostgreSQL backend.
19//!
20//! A typical implementation will look like this:
21//!
22//! ```rust
23//! # use async_trait::async_trait;
24//! # use ulid::Ulid;
25//! # use rand::RngCore;
26//! # use mas_data_model::Clock;
27//! # use mas_data_model::UlidExt;
28//! # use mas_storage_pg::{DatabaseError, ExecuteExt};
29//! # use sqlx::PgConnection;
30//! # use uuid::Uuid;
31//! #
32//! # // A fake data structure, usually defined in mas-data-model
33//! # #[derive(sqlx::FromRow)]
34//! # struct FakeData {
35//! #    id: Ulid,
36//! # }
37//! #
38//! # // A fake repository trait, usually defined in mas-storage
39//! # #[async_trait]
40//! # pub trait FakeDataRepository: Send + Sync {
41//! #     type Error;
42//! #     async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error>;
43//! #     async fn add(
44//! #         &mut self,
45//! #         rng: &mut (dyn RngCore + Send),
46//! #         clock: &dyn Clock,
47//! #     ) -> Result<FakeData, Self::Error>;
48//! # }
49//! #
50//! /// An implementation of [`FakeDataRepository`] for a PostgreSQL connection
51//! pub struct PgFakeDataRepository<'c> {
52//!     conn: &'c mut PgConnection,
53//! }
54//!
55//! impl<'c> PgFakeDataRepository<'c> {
56//!     /// Create a new [`FakeDataRepository`] from an active PostgreSQL connection
57//!     pub fn new(conn: &'c mut PgConnection) -> Self {
58//!         Self { conn }
59//!     }
60//! }
61//!
62//! #[derive(sqlx::FromRow)]
63//! struct FakeDataLookup {
64//!     fake_data_id: Uuid,
65//! }
66//!
67//! impl From<FakeDataLookup> for FakeData {
68//!     fn from(value: FakeDataLookup) -> Self {
69//!         Self {
70//!             id: value.fake_data_id.into(),
71//!         }
72//!     }
73//! }
74//!
75//! #[async_trait]
76//! impl<'c> FakeDataRepository for PgFakeDataRepository<'c> {
77//!     type Error = DatabaseError;
78//!
79//!     #[tracing::instrument(
80//!         name = "db.fake_data.lookup",
81//!         skip_all,
82//!         fields(
83//!             db.query.text,
84//!             fake_data.id = %id,
85//!         ),
86//!         err,
87//!     )]
88//!     async fn lookup(&mut self, id: Ulid) -> Result<Option<FakeData>, Self::Error> {
89//!         // Note: here we would use the macro version instead, but it's not possible here in
90//!         // this documentation example
91//!         let res: Option<FakeDataLookup> = sqlx::query_as(
92//!             r#"
93//!                 SELECT fake_data_id
94//!                 FROM fake_data
95//!                 WHERE fake_data_id = $1
96//!             "#,
97//!         )
98//!         .bind(Uuid::from(id))
99//!         .traced()
100//!         .fetch_optional(&mut *self.conn)
101//!         .await?;
102//!
103//!         let Some(res) = res else { return Ok(None) };
104//!
105//!         Ok(Some(res.into()))
106//!     }
107//!
108//!     #[tracing::instrument(
109//!         name = "db.fake_data.add",
110//!         skip_all,
111//!         fields(
112//!             db.query.text,
113//!             fake_data.id,
114//!         ),
115//!         err,
116//!     )]
117//!     async fn add(
118//!         &mut self,
119//!         rng: &mut (dyn RngCore + Send),
120//!         clock: &dyn Clock,
121//!     ) -> Result<FakeData, Self::Error> {
122//!         let created_at = clock.now();
123//!         let id = Ulid::from_datetime_with_rng(created_at, rng);
124//!         tracing::Span::current().record("fake_data.id", tracing::field::display(id));
125//!
126//!         // Note: here we would use the macro version instead, but it's not possible here in
127//!         // this documentation example
128//!         sqlx::query(
129//!             r#"
130//!                 INSERT INTO fake_data (id)
131//!                 VALUES ($1)
132//!             "#,
133//!         )
134//!         .bind(Uuid::from(id))
135//!         .traced()
136//!         .execute(&mut *self.conn)
137//!         .await?;
138//!
139//!         Ok(FakeData {
140//!             id,
141//!         })
142//!     }
143//! }
144//! ```
145//!
146//! A few things to note with the implementation:
147//!
148//!  - All methods are traced, with an explicit, somewhat consistent name.
149//!  - The SQL statement is included as attribute, by declaring a
150//!    `db.query.text` attribute on the tracing span, and then calling
151//!    [`ExecuteExt::traced`].
152//!  - The IDs are all [`Ulid`], and generated from the clock and the random
153//!    number generated passed as parameters. The generated IDs are recorded in
154//!    the span.
155//!  - The IDs are stored as [`Uuid`] in PostgreSQL, so conversions are required
156//!  - "Not found" errors are handled by returning `Ok(None)` instead of an
157//!    error.
158//!
159//! [`Ulid`]: ulid::Ulid
160//! [`Uuid`]: uuid::Uuid
161
162#![deny(clippy::future_not_send, missing_docs)]
163#![allow(clippy::module_name_repetitions, clippy::blocks_in_conditions)]
164
165use std::collections::{BTreeMap, BTreeSet, HashSet};
166
167use ::tracing::{Instrument, debug, info, info_span, warn};
168use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
169use sqlx::{
170    Either, PgConnection,
171    migrate::{AppliedMigration, Migrate, MigrateError, Migration, Migrator},
172    postgres::{PgAdvisoryLock, PgAdvisoryLockKey},
173};
174
175pub mod app_session;
176pub mod compat;
177pub mod oauth2;
178pub mod personal;
179pub mod queue;
180pub mod upstream_oauth2;
181pub mod user;
182
183mod errors;
184pub(crate) mod filter;
185pub(crate) mod iden;
186pub(crate) mod pagination;
187pub(crate) mod policy_data;
188pub(crate) mod repository;
189pub(crate) mod telemetry;
190pub(crate) mod tracing;
191pub(crate) mod ulid_at;
192
193pub(crate) use self::errors::DatabaseInconsistencyError;
194pub use self::{
195    errors::DatabaseError,
196    repository::{PgRepository, PgRepositoryFactory},
197    tracing::ExecuteExt,
198};
199
200/// Embedded migrations in the binary
201pub static MIGRATOR: Migrator = sqlx::migrate!();
202
203fn available_migrations() -> BTreeMap<i64, &'static Migration> {
204    MIGRATOR.iter().map(|m| (m.version, m)).collect()
205}
206
207/// This is the list of migrations we've removed from the migration history but
208/// might have been applied in the past
209#[expect(clippy::inconsistent_digit_grouping)]
210const ALLOWED_MISSING_MIGRATIONS: &[i64] = &[
211    // https://github.com/matrix-org/matrix-authentication-service/pull/1585
212    20220709_210445,
213    20230330_210841,
214    20230408_110421,
215];
216
217fn allowed_missing_migrations() -> BTreeSet<i64> {
218    ALLOWED_MISSING_MIGRATIONS.iter().copied().collect()
219}
220
221/// This is a list of possible additional checksums from previous versions of
222/// migrations. The checksum we store in the database is 48 bytes long. We're
223/// not really concerned with partial hash collisions, and to avoid this file to
224/// be completely unreadable, we only store the upper 16 bytes of that hash.
225#[expect(clippy::inconsistent_digit_grouping)]
226const ALLOWED_ALTERNATE_CHECKSUMS: &[(i64, u128)] = &[
227    // https://github.com/element-hq/matrix-authentication-service/pull/5300
228    (20250410_000000, 0x8811_c3ef_dbee_8c00_5b49_25da_5d55_9c3f),
229    (20250410_000001, 0x7990_37b3_2193_8a5d_c72f_bccd_95fd_82e5),
230    (20250410_000002, 0xf2b8_f120_deae_27e7_60d0_79a3_0b77_eea3),
231    (20250410_000003, 0x06be_fc2b_cedc_acf4_b981_02c7_b40c_c469),
232    (20250410_000004, 0x0a90_9c6a_dba7_545c_10d9_60eb_6d30_2f50),
233    (20250410_000006, 0xcc7f_5152_6497_5729_d94b_be0d_9c95_8316),
234    (20250410_000007, 0x12e7_cfab_a017_a5a5_4f2c_18fa_541c_ce62),
235    (20250410_000008, 0x171d_62e5_ee1a_f0d9_3639_6c5a_277c_54cd),
236    (20250410_000009, 0xb1a0_93c7_6645_92ad_df45_b395_57bb_a281),
237    (20250410_000010, 0x8089_86ac_7cff_8d86_2850_d287_cdb1_2b57),
238    (20250410_000011, 0x8d9d_3fae_02c9_3d3f_81e4_6242_2b39_b5b8),
239    (20250410_000012, 0x9805_1372_41aa_d5b0_ebe1_ba9d_28c7_faf6),
240    (20250410_000013, 0x7291_9a97_e4d1_0d45_1791_6e8c_3f2d_e34d),
241    (20250410_000014, 0x811d_f965_8127_e168_4aa2_f177_a4e6_f077),
242    (20250410_000015, 0xa639_0780_aab7_d60d_5fcb_771d_13ed_73ee),
243    (20250410_000016, 0x22b6_e909_6de4_39e3_b2b9_c684_7417_fe07),
244    (20250410_000017, 0x9dfe_b6d3_89e4_e509_651b_2793_8d8d_cd32),
245    (20250410_000018, 0x638f_bdbc_2276_5094_020b_cec1_ab95_c07f),
246    (20250410_000019, 0xa283_84bc_5fd5_7cbd_b5fb_b5fe_0255_6845),
247    (20250410_000020, 0x17d1_54b1_7c6e_fc48_61dd_da3d_f8a5_9546),
248    (20250410_000022, 0xbc36_af82_994a_6f93_8aca_a46b_fc3c_ffde),
249    (20250410_000023, 0x54ec_3b07_ac79_443b_9e18_a2b3_2d17_5ab9),
250    (20250410_000024, 0x8ab4_4f80_00b6_58b2_d757_c40f_bc72_3d87),
251    (20250410_000025, 0x5dc4_2ff3_3042_2f45_046d_10af_ab3a_b583),
252    (20250410_000026, 0x5263_c547_0b64_6425_5729_48b2_ce84_7cad),
253    (20250410_000027, 0x0aad_cb50_1d6a_7794_9017_d24d_55e7_1b9d),
254    (20250410_000028, 0x8fc1_92f8_68df_ca4e_3e2b_cddf_bc12_cffe),
255    (20250410_000029, 0x416c_9446_b6a3_1b49_2940_a8ac_c1c2_665a),
256    (20250410_000030, 0x83a5_e51e_25a6_77fb_2b79_6ea5_db1e_364f),
257    (20250410_000031, 0xfa18_a707_9438_dbc7_2cde_b5f1_ee21_5c7e),
258    (20250410_000032, 0xd669_662e_8930_838a_b142_c3fa_7b39_d2a0),
259    (20250410_000033, 0x4019_1053_cabc_191c_c02e_9aa9_407c_0de5),
260    (20250410_000034, 0xdd59_e595_24e6_4dad_c5f7_fef2_90b8_df57),
261    (20250410_000035, 0x09b4_ea53_2da4_9c39_eb10_db33_6a6d_608b),
262    (20250410_000036, 0x3ca5_9c78_8480_e342_d729_907c_d293_2049),
263    (20250410_000037, 0xc857_2a10_450b_0612_822c_2b86_535a_ea7d),
264    (20250410_000038, 0x1642_39da_9c3b_d9fd_b1e1_72b1_db78_b978),
265    (20250410_000039, 0xdd70_b211_6016_bb84_0d84_f04e_eb8a_59d9),
266    (20250410_000040, 0xe435_ead6_c363_a0b6_e048_dd85_0ecb_9499),
267    (20250410_000041, 0xe9f3_122f_70d4_9839_c818_4b18_0192_ae26),
268    (20250410_000043, 0xec5e_1400_483d_c4bf_6014_aba4_ffc3_6236),
269    (20250410_000044, 0x4750_5eba_4095_6664_78d0_27f9_64bf_64f4),
270    (20250410_000045, 0x9a53_bd70_4cad_2bf1_61d4_f143_0c82_681d),
271    (20250410_121612, 0x25f0_9d20_a897_df18_162d_1c47_b68e_81bd),
272    (20250602_212101, 0xd1a8_782c_b3f0_5045_3f46_49a0_bab0_822b),
273    (20250708_155857, 0xb78e_6957_a588_c16a_d292_a0c7_cae9_f290),
274    (20250915_092635, 0x6854_d58b_99d7_3ac5_82f8_25e5_b1c3_cc0b),
275    (20251127_145951, 0x3bcd_d92e_8391_2a2c_8a18_1d76_354f_96c6),
276];
277
278fn alternate_checksums_map() -> BTreeMap<i64, HashSet<u128>> {
279    let mut map = BTreeMap::new();
280    for (version, checksum) in ALLOWED_ALTERNATE_CHECKSUMS {
281        map.entry(*version)
282            .or_insert_with(HashSet::new)
283            .insert(*checksum);
284    }
285    map
286}
287
288/// Load the list of applied migrations into a map.
289///
290/// It's important to use a [`BTreeMap`] so that the migrations are naturally
291/// ordered by version.
292async fn applied_migrations_map(
293    conn: &mut PgConnection,
294) -> Result<BTreeMap<i64, AppliedMigration>, MigrateError> {
295    let applied_migrations = conn
296        .list_applied_migrations()
297        .await?
298        .into_iter()
299        .map(|m| (m.version, m))
300        .collect();
301
302    Ok(applied_migrations)
303}
304
305/// Checks if the migration table exists
306async fn migration_table_exists(conn: &mut PgConnection) -> Result<bool, sqlx::Error> {
307    sqlx::query_scalar!(
308        r#"
309            SELECT EXISTS (
310                SELECT 1
311                FROM information_schema.tables
312                WHERE table_name = '_sqlx_migrations'
313            ) AS "exists!"
314        "#,
315    )
316    .fetch_one(conn)
317    .await
318}
319
320/// Run the migrations on the given connection
321///
322/// This function acquires an advisory lock on the database to ensure that only
323/// one migrator is running at a time.
324///
325/// # Errors
326///
327/// This function returns an error if the migration fails.
328#[::tracing::instrument(name = "db.migrate", skip_all, err)]
329pub async fn migrate(conn: &mut PgConnection) -> Result<(), MigrateError> {
330    // Get the database name and use it to derive an advisory lock key. This
331    // is the same lock key used by SQLx default migrator, so that it works even
332    // with older versions of MAS, and when running through `cargo sqlx migrate run`
333    let database_name = sqlx::query_scalar!(r#"SELECT current_database() as "current_database!""#)
334        .fetch_one(&mut *conn)
335        .await
336        .map_err(MigrateError::from)?;
337
338    let lock =
339        PgAdvisoryLock::with_key(PgAdvisoryLockKey::BigInt(generate_lock_id(&database_name)));
340
341    // Try to acquire the migration lock in a loop.
342    //
343    // The reason we do that with a `try_acquire` is because in Postgres, `CREATE
344    // INDEX CONCURRENTLY` will *not* complete whilst an advisory lock is being
345    // acquired on another connection. This then means that if we run two
346    // migration process at the same time, one of them will go through and block
347    // on concurrent index creations, because the other will get stuck trying to
348    // acquire this lock.
349    //
350    // To avoid this, we use `try_acquire`/`pg_advisory_lock_try` in a loop, which
351    // will fail immediately if the lock is held by another connection, allowing
352    // potential 'CREATE INDEX CONCURRENTLY' statements to complete.
353    let mut backoff = std::time::Duration::from_millis(250);
354    let mut conn = conn;
355    let mut locked_connection = loop {
356        match lock.try_acquire(conn).await? {
357            Either::Left(guard) => break guard,
358            Either::Right(conn_) => {
359                warn!(
360                    "Another process is already running migrations on the database, waiting {duration}s and trying again…",
361                    duration = backoff.as_secs_f32()
362                );
363                tokio::time::sleep(backoff).await;
364                backoff = std::cmp::min(backoff * 2, std::time::Duration::from_secs(5));
365                conn = conn_;
366            }
367        }
368    };
369
370    // Creates the migration table if missing
371    // We check if the table exists before calling `ensure_migrations_table` to
372    // avoid the pesky 'relation "_sqlx_migrations" already exists, skipping' notice
373    if !migration_table_exists(locked_connection.as_mut()).await? {
374        locked_connection.as_mut().ensure_migrations_table().await?;
375    }
376
377    for migration in pending_migrations(locked_connection.as_mut()).await? {
378        info!(
379            "Applying migration {version}: {description}",
380            version = migration.version,
381            description = migration.description
382        );
383        locked_connection
384            .as_mut()
385            .apply(migration)
386            .instrument(info_span!(
387                "db.migrate.run_migration",
388                db.migration.version = migration.version,
389                db.migration.description = &*migration.description,
390                { DB_QUERY_TEXT } = &*migration.sql,
391            ))
392            .await?;
393    }
394
395    locked_connection.release_now().await?;
396
397    Ok(())
398}
399
400/// Get the list of pending migrations
401///
402/// # Errors
403///
404/// This function returns an error if there is a problem checking the applied
405/// migrations
406pub async fn pending_migrations(
407    conn: &mut PgConnection,
408) -> Result<Vec<&'static Migration>, MigrateError> {
409    // Load the maps of available migrations, applied migrations, migrations that
410    // are allowed to be missing, alternate checksums for migrations that changed
411    let available_migrations = available_migrations();
412    let allowed_missing = allowed_missing_migrations();
413    let alternate_checksums = alternate_checksums_map();
414    let applied_migrations = if migration_table_exists(&mut *conn).await? {
415        applied_migrations_map(&mut *conn).await?
416    } else {
417        BTreeMap::new()
418    };
419
420    // Check that all applied migrations are still valid
421    for applied_migration in applied_migrations.values() {
422        // Check that we know about the applied migration
423        if let Some(migration) = available_migrations.get(&applied_migration.version) {
424            // Check the migration checksum
425            if applied_migration.checksum != migration.checksum {
426                // The checksum we have in the database doesn't match the one we
427                // have embedded. This might be because a migration was
428                // intentionally changed, so we check the alternate checksums
429                if let Some(alternates) = alternate_checksums.get(&applied_migration.version) {
430                    // This converts the first 16 bytes of the checksum into a u128
431                    let Some(applied_checksum_prefix) = applied_migration
432                        .checksum
433                        .get(..16)
434                        .and_then(|bytes| bytes.try_into().ok())
435                        .map(u128::from_be_bytes)
436                    else {
437                        return Err(MigrateError::ExecuteMigration(
438                            sqlx::Error::InvalidArgument(
439                                "checksum stored in database is invalid".to_owned(),
440                            ),
441                            applied_migration.version,
442                        ));
443                    };
444
445                    if !alternates.contains(&applied_checksum_prefix) {
446                        warn!(
447                            "The database has a migration applied ({version}) which has known alternative checksums {alternates:x?}, but none of them matched {applied_checksum_prefix:x}",
448                            version = applied_migration.version,
449                        );
450                        return Err(MigrateError::VersionMismatch(applied_migration.version));
451                    }
452                } else {
453                    return Err(MigrateError::VersionMismatch(applied_migration.version));
454                }
455            }
456        } else if allowed_missing.contains(&applied_migration.version) {
457            // The migration is missing, but allowed to be missing
458            debug!(
459                "The database has a migration applied ({version}) that doesn't exist anymore, but it was intentionally removed",
460                version = applied_migration.version
461            );
462        } else {
463            // The migration is missing, warn about it
464            warn!(
465                "The database has a migration applied ({version}) that doesn't exist anymore! This should not happen, unless rolling back to an older version of MAS.",
466                version = applied_migration.version
467            );
468        }
469    }
470
471    Ok(available_migrations
472        .values()
473        .copied()
474        .filter(|migration| {
475            !migration.migration_type.is_down_migration()
476                && !applied_migrations.contains_key(&migration.version)
477        })
478        .collect())
479}
480
481// Copied from the sqlx source code, so that we generate the same lock ID
482fn generate_lock_id(database_name: &str) -> i64 {
483    const CRC_IEEE: crc::Crc<u32> = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
484    // 0x3d32ad9e chosen by fair dice roll
485    0x3d32_ad9e * i64::from(CRC_IEEE.checksum(database_name.as_bytes()))
486}