syn2mas/mas_writer/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6//! # MAS Writer
7//!
8//! This module is responsible for writing new records to MAS' database.
9
10use std::{
11    fmt::Display,
12    net::IpAddr,
13    sync::{
14        Arc,
15        atomic::{AtomicU32, Ordering},
16    },
17};
18
19use chrono::{DateTime, Utc};
20use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
21use sqlx::{Executor, PgConnection, query, query_as};
22use thiserror::Error;
23use thiserror_ext::{Construct, ContextInto};
24use tokio::sync::mpsc::{self, Receiver, Sender};
25use tracing::{Instrument, error, info, warn};
26use uuid::{NonNilUuid, Uuid};
27
28use self::{
29    constraint_pausing::{ConstraintDescription, IndexDescription},
30    locking::LockedMasDatabase,
31};
32use crate::Progress;
33
34pub mod checks;
35pub mod locking;
36
37mod constraint_pausing;
38
39#[derive(Debug, Error, Construct, ContextInto)]
40pub enum Error {
41    #[error("database error whilst {context}")]
42    Database {
43        #[source]
44        source: sqlx::Error,
45        context: String,
46    },
47
48    #[error("writer connection pool shut down due to error")]
49    #[expect(clippy::enum_variant_names)]
50    WriterConnectionPoolError,
51
52    #[error("inconsistent database: {0}")]
53    Inconsistent(String),
54
55    #[error("bug in syn2mas: write buffers not finished")]
56    WriteBuffersNotFinished,
57
58    #[error("{0}")]
59    Multiple(MultipleErrors),
60}
61
62#[derive(Debug)]
63pub struct MultipleErrors {
64    errors: Vec<Error>,
65}
66
67impl Display for MultipleErrors {
68    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69        write!(f, "multiple errors")?;
70        for error in &self.errors {
71            write!(f, "\n- {error}")?;
72        }
73        Ok(())
74    }
75}
76
77impl From<Vec<Error>> for MultipleErrors {
78    fn from(value: Vec<Error>) -> Self {
79        MultipleErrors { errors: value }
80    }
81}
82
83struct WriterConnectionPool {
84    /// How many connections are in circulation
85    num_connections: usize,
86
87    /// A receiver handle to get a writer connection
88    /// The writer connection will be mid-transaction!
89    connection_rx: Receiver<Result<PgConnection, Error>>,
90
91    /// A sender handle to return a writer connection to the pool
92    /// The connection should still be mid-transaction!
93    connection_tx: Sender<Result<PgConnection, Error>>,
94}
95
96impl WriterConnectionPool {
97    pub fn new(connections: Vec<PgConnection>) -> Self {
98        let num_connections = connections.len();
99        let (connection_tx, connection_rx) = mpsc::channel(num_connections);
100        for connection in connections {
101            connection_tx
102                .try_send(Ok(connection))
103                .expect("there should be room for this connection");
104        }
105
106        WriterConnectionPool {
107            num_connections,
108            connection_rx,
109            connection_tx,
110        }
111    }
112
113    pub async fn spawn_with_connection<F>(&mut self, task: F) -> Result<(), Error>
114    where
115        F: for<'conn> FnOnce(&'conn mut PgConnection) -> BoxFuture<'conn, Result<(), Error>>
116            + Send
117            + 'static,
118    {
119        match self.connection_rx.recv().await {
120            Some(Ok(mut connection)) => {
121                let connection_tx = self.connection_tx.clone();
122                tokio::task::spawn(
123                    async move {
124                        let to_return = match task(&mut connection).await {
125                            Ok(()) => Ok(connection),
126                            Err(error) => {
127                                error!("error in writer: {error}");
128                                Err(error)
129                            }
130                        };
131                        // This should always succeed in sending unless we're already shutting
132                        // down for some other reason.
133                        let _: Result<_, _> = connection_tx.send(to_return).await;
134                    }
135                    .instrument(tracing::debug_span!("spawn_with_connection")),
136                );
137
138                Ok(())
139            }
140            Some(Err(error)) => {
141                // This should always succeed in sending unless we're already shutting
142                // down for some other reason.
143                let _: Result<_, _> = self.connection_tx.send(Err(error)).await;
144
145                Err(Error::WriterConnectionPoolError)
146            }
147            None => {
148                unreachable!("we still hold a reference to the sender, so this shouldn't happen")
149            }
150        }
151    }
152
153    /// Finishes writing to the database, committing all changes.
154    ///
155    /// # Errors
156    ///
157    /// - If any errors were returned to the pool.
158    /// - If committing the changes failed.
159    ///
160    /// # Panics
161    ///
162    /// - If connections were not returned to the pool. (This indicates a
163    ///   serious bug.)
164    pub async fn finish(self) -> Result<(), Vec<Error>> {
165        let mut errors = Vec::new();
166
167        let Self {
168            num_connections,
169            mut connection_rx,
170            connection_tx,
171        } = self;
172        // Drop the sender handle so we gracefully allow the receiver to close
173        drop(connection_tx);
174
175        let mut finished_connections = 0;
176
177        while let Some(connection_or_error) = connection_rx.recv().await {
178            finished_connections += 1;
179
180            match connection_or_error {
181                Ok(mut connection) => {
182                    if let Err(err) = query("COMMIT;").execute(&mut connection).await {
183                        errors.push(err.into_database("commit writer transaction"));
184                    }
185                }
186                Err(error) => {
187                    errors.push(error);
188                }
189            }
190        }
191        assert_eq!(
192            finished_connections, num_connections,
193            "syn2mas had a bug: connections went missing {finished_connections} != {num_connections}"
194        );
195
196        if errors.is_empty() {
197            Ok(())
198        } else {
199            Err(errors)
200        }
201    }
202}
203
204/// Small utility to make sure `finish()` is called on all write buffers
205/// before committing to the database.
206#[derive(Default)]
207struct FinishChecker {
208    counter: Arc<AtomicU32>,
209}
210
211struct FinishCheckerHandle {
212    counter: Arc<AtomicU32>,
213}
214
215impl FinishChecker {
216    /// Acquire a new handle, for a task that should declare when it has
217    /// finished.
218    pub fn handle(&self) -> FinishCheckerHandle {
219        self.counter.fetch_add(1, Ordering::SeqCst);
220        FinishCheckerHandle {
221            counter: Arc::clone(&self.counter),
222        }
223    }
224
225    /// Check that all handles have been declared as finished.
226    pub fn check_all_finished(self) -> Result<(), Error> {
227        if self.counter.load(Ordering::SeqCst) == 0 {
228            Ok(())
229        } else {
230            Err(Error::WriteBuffersNotFinished)
231        }
232    }
233}
234
235impl FinishCheckerHandle {
236    /// Declare that the task this handle represents has been finished.
237    pub fn declare_finished(self) {
238        self.counter.fetch_sub(1, Ordering::SeqCst);
239    }
240}
241
242pub struct MasWriter {
243    conn: LockedMasDatabase,
244    writer_pool: WriterConnectionPool,
245    dry_run: bool,
246
247    indices_to_restore: Vec<IndexDescription>,
248    constraints_to_restore: Vec<ConstraintDescription>,
249
250    write_buffer_finish_checker: FinishChecker,
251}
252
253pub trait WriteBatch: Send + Sync + Sized + 'static {
254    fn write_batch(
255        conn: &mut PgConnection,
256        batch: Vec<Self>,
257    ) -> impl Future<Output = Result<(), Error>> + Send;
258}
259
260pub struct MasNewUser {
261    pub user_id: NonNilUuid,
262    pub username: String,
263    pub created_at: DateTime<Utc>,
264    pub locked_at: Option<DateTime<Utc>>,
265    pub deactivated_at: Option<DateTime<Utc>>,
266    pub can_request_admin: bool,
267    /// Whether the user was a Synapse guest.
268    /// Although MAS doesn't support guest access, it's still useful to track
269    /// for the future.
270    pub is_guest: bool,
271}
272
273impl WriteBatch for MasNewUser {
274    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
275        // `UNNEST` is a fast way to do bulk inserts, as it lets us send multiple rows
276        // in one statement without having to change the statement
277        // SQL thus altering the query plan. See <https://github.com/launchbadge/sqlx/blob/main/FAQ.md#how-can-i-bind-an-array-to-a-values-clause-how-can-i-do-bulk-inserts>.
278        // In the future we could consider using sqlx's support for `PgCopyIn` / the
279        // `COPY FROM STDIN` statement, which is allegedly the best
280        // for insert performance, but is less simple to encode.
281        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
282        let mut usernames: Vec<String> = Vec::with_capacity(batch.len());
283        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
284        let mut locked_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
285        let mut deactivated_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
286        let mut can_request_admins: Vec<bool> = Vec::with_capacity(batch.len());
287        let mut is_guests: Vec<bool> = Vec::with_capacity(batch.len());
288        for MasNewUser {
289            user_id,
290            username,
291            created_at,
292            locked_at,
293            deactivated_at,
294            can_request_admin,
295            is_guest,
296        } in batch
297        {
298            user_ids.push(user_id.get());
299            usernames.push(username);
300            created_ats.push(created_at);
301            locked_ats.push(locked_at);
302            deactivated_ats.push(deactivated_at);
303            can_request_admins.push(can_request_admin);
304            is_guests.push(is_guest);
305        }
306
307        sqlx::query!(
308            r#"
309            INSERT INTO syn2mas__users (
310              user_id, username,
311              created_at, locked_at,
312              deactivated_at,
313              can_request_admin, is_guest)
314            SELECT * FROM UNNEST(
315              $1::UUID[], $2::TEXT[],
316              $3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],
317              $5::TIMESTAMP WITH TIME ZONE[],
318              $6::BOOL[], $7::BOOL[])
319            "#,
320            &user_ids[..],
321            &usernames[..],
322            &created_ats[..],
323            // We need to override the typing for arrays of optionals (sqlx limitation)
324            &locked_ats[..] as &[Option<DateTime<Utc>>],
325            &deactivated_ats[..] as &[Option<DateTime<Utc>>],
326            &can_request_admins[..],
327            &is_guests[..],
328        )
329        .execute(&mut *conn)
330        .await
331        .into_database("writing users to MAS")?;
332
333        Ok(())
334    }
335}
336
337pub struct MasNewUserPassword {
338    pub user_password_id: Uuid,
339    pub user_id: NonNilUuid,
340    pub hashed_password: String,
341    pub created_at: DateTime<Utc>,
342}
343
344impl WriteBatch for MasNewUserPassword {
345    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
346        let mut user_password_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
347        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
348        let mut hashed_passwords: Vec<String> = Vec::with_capacity(batch.len());
349        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
350        let mut versions: Vec<i32> = Vec::with_capacity(batch.len());
351        for MasNewUserPassword {
352            user_password_id,
353            user_id,
354            hashed_password,
355            created_at,
356        } in batch
357        {
358            user_password_ids.push(user_password_id);
359            user_ids.push(user_id.get());
360            hashed_passwords.push(hashed_password);
361            created_ats.push(created_at);
362            versions.push(MIGRATED_PASSWORD_VERSION.into());
363        }
364
365        sqlx::query!(
366            r#"
367            INSERT INTO syn2mas__user_passwords
368            (user_password_id, user_id, hashed_password, created_at, version)
369            SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $5::INTEGER[])
370            "#,
371            &user_password_ids[..],
372            &user_ids[..],
373            &hashed_passwords[..],
374            &created_ats[..],
375            &versions[..],
376        ).execute(&mut *conn).await.into_database("writing users to MAS")?;
377
378        Ok(())
379    }
380}
381
382pub struct MasNewEmailThreepid {
383    pub user_email_id: Uuid,
384    pub user_id: NonNilUuid,
385    pub email: String,
386    pub created_at: DateTime<Utc>,
387}
388
389impl WriteBatch for MasNewEmailThreepid {
390    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
391        let mut user_email_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
392        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
393        let mut emails: Vec<String> = Vec::with_capacity(batch.len());
394        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
395
396        for MasNewEmailThreepid {
397            user_email_id,
398            user_id,
399            email,
400            created_at,
401        } in batch
402        {
403            user_email_ids.push(user_email_id);
404            user_ids.push(user_id.get());
405            emails.push(email);
406            created_ats.push(created_at);
407        }
408
409        sqlx::query!(
410            r#"
411                INSERT INTO syn2mas__user_emails
412                (user_email_id, user_id, email, created_at)
413                SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[])
414            "#,
415            &user_email_ids[..],
416            &user_ids[..],
417            &emails[..],
418            &created_ats[..],
419        )
420        .execute(&mut *conn)
421        .await
422        .into_database("writing emails to MAS")?;
423
424        Ok(())
425    }
426}
427
428pub struct MasNewUnsupportedThreepid {
429    pub user_id: NonNilUuid,
430    pub medium: String,
431    pub address: String,
432    pub created_at: DateTime<Utc>,
433}
434
435impl WriteBatch for MasNewUnsupportedThreepid {
436    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
437        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
438        let mut mediums: Vec<String> = Vec::with_capacity(batch.len());
439        let mut addresses: Vec<String> = Vec::with_capacity(batch.len());
440        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
441
442        for MasNewUnsupportedThreepid {
443            user_id,
444            medium,
445            address,
446            created_at,
447        } in batch
448        {
449            user_ids.push(user_id.get());
450            mediums.push(medium);
451            addresses.push(address);
452            created_ats.push(created_at);
453        }
454
455        sqlx::query!(
456            r#"
457            INSERT INTO syn2mas__user_unsupported_third_party_ids
458            (user_id, medium, address, created_at)
459            SELECT * FROM UNNEST($1::UUID[], $2::TEXT[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[])
460            "#,
461            &user_ids[..],
462            &mediums[..],
463            &addresses[..],
464            &created_ats[..],
465        )
466        .execute(&mut *conn)
467        .await
468        .into_database("writing unsupported threepids to MAS")?;
469
470        Ok(())
471    }
472}
473
474pub struct MasNewUpstreamOauthLink {
475    pub link_id: Uuid,
476    pub user_id: NonNilUuid,
477    pub upstream_provider_id: Uuid,
478    pub subject: String,
479    pub created_at: DateTime<Utc>,
480}
481
482impl WriteBatch for MasNewUpstreamOauthLink {
483    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
484        let mut link_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
485        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
486        let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
487        let mut subjects: Vec<String> = Vec::with_capacity(batch.len());
488        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
489
490        for MasNewUpstreamOauthLink {
491            link_id,
492            user_id,
493            upstream_provider_id,
494            subject,
495            created_at,
496        } in batch
497        {
498            link_ids.push(link_id);
499            user_ids.push(user_id.get());
500            upstream_provider_ids.push(upstream_provider_id);
501            subjects.push(subject);
502            created_ats.push(created_at);
503        }
504
505        sqlx::query!(
506            r#"
507            INSERT INTO syn2mas__upstream_oauth_links
508            (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
509            SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
510            "#,
511            &link_ids[..],
512            &user_ids[..],
513            &upstream_provider_ids[..],
514            &subjects[..],
515            &created_ats[..],
516        ).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
517
518        Ok(())
519    }
520}
521
522pub struct MasNewCompatSession {
523    pub session_id: Uuid,
524    pub user_id: NonNilUuid,
525    pub device_id: Option<String>,
526    pub human_name: Option<String>,
527    pub created_at: DateTime<Utc>,
528    pub is_synapse_admin: bool,
529    pub last_active_at: Option<DateTime<Utc>>,
530    pub last_active_ip: Option<IpAddr>,
531    pub user_agent: Option<String>,
532}
533
534impl WriteBatch for MasNewCompatSession {
535    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
536        let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
537        let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
538        let mut device_ids: Vec<Option<String>> = Vec::with_capacity(batch.len());
539        let mut human_names: Vec<Option<String>> = Vec::with_capacity(batch.len());
540        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
541        let mut is_synapse_admins: Vec<bool> = Vec::with_capacity(batch.len());
542        let mut last_active_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
543        let mut last_active_ips: Vec<Option<IpAddr>> = Vec::with_capacity(batch.len());
544        let mut user_agents: Vec<Option<String>> = Vec::with_capacity(batch.len());
545
546        for MasNewCompatSession {
547            session_id,
548            user_id,
549            device_id,
550            human_name,
551            created_at,
552            is_synapse_admin,
553            last_active_at,
554            last_active_ip,
555            user_agent,
556        } in batch
557        {
558            session_ids.push(session_id);
559            user_ids.push(user_id.get());
560            device_ids.push(device_id);
561            human_names.push(human_name);
562            created_ats.push(created_at);
563            is_synapse_admins.push(is_synapse_admin);
564            last_active_ats.push(last_active_at);
565            last_active_ips.push(last_active_ip);
566            user_agents.push(user_agent);
567        }
568
569        sqlx::query!(
570            r#"
571            INSERT INTO syn2mas__compat_sessions (
572              compat_session_id, user_id,
573              device_id, human_name,
574              created_at, is_synapse_admin,
575              last_active_at, last_active_ip,
576              user_agent)
577            SELECT * FROM UNNEST(
578              $1::UUID[], $2::UUID[],
579              $3::TEXT[], $4::TEXT[],
580              $5::TIMESTAMP WITH TIME ZONE[], $6::BOOLEAN[],
581              $7::TIMESTAMP WITH TIME ZONE[], $8::INET[],
582              $9::TEXT[])
583            "#,
584            &session_ids[..],
585            &user_ids[..],
586            &device_ids[..] as &[Option<String>],
587            &human_names[..] as &[Option<String>],
588            &created_ats[..],
589            &is_synapse_admins[..],
590            // We need to override the typing for arrays of optionals (sqlx limitation)
591            &last_active_ats[..] as &[Option<DateTime<Utc>>],
592            &last_active_ips[..] as &[Option<IpAddr>],
593            &user_agents[..] as &[Option<String>],
594        )
595        .execute(&mut *conn)
596        .await
597        .into_database("writing compat sessions to MAS")?;
598
599        Ok(())
600    }
601}
602
603pub struct MasNewCompatAccessToken {
604    pub token_id: Uuid,
605    pub session_id: Uuid,
606    pub access_token: String,
607    pub created_at: DateTime<Utc>,
608    pub expires_at: Option<DateTime<Utc>>,
609}
610
611impl WriteBatch for MasNewCompatAccessToken {
612    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
613        let mut token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
614        let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
615        let mut access_tokens: Vec<String> = Vec::with_capacity(batch.len());
616        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
617        let mut expires_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
618
619        for MasNewCompatAccessToken {
620            token_id,
621            session_id,
622            access_token,
623            created_at,
624            expires_at,
625        } in batch
626        {
627            token_ids.push(token_id);
628            session_ids.push(session_id);
629            access_tokens.push(access_token);
630            created_ats.push(created_at);
631            expires_ats.push(expires_at);
632        }
633
634        sqlx::query!(
635            r#"
636            INSERT INTO syn2mas__compat_access_tokens (
637              compat_access_token_id,
638              compat_session_id,
639              access_token,
640              created_at,
641              expires_at)
642            SELECT * FROM UNNEST(
643              $1::UUID[],
644              $2::UUID[],
645              $3::TEXT[],
646              $4::TIMESTAMP WITH TIME ZONE[],
647              $5::TIMESTAMP WITH TIME ZONE[])
648            "#,
649            &token_ids[..],
650            &session_ids[..],
651            &access_tokens[..],
652            &created_ats[..],
653            // We need to override the typing for arrays of optionals (sqlx limitation)
654            &expires_ats[..] as &[Option<DateTime<Utc>>],
655        )
656        .execute(&mut *conn)
657        .await
658        .into_database("writing compat access tokens to MAS")?;
659
660        Ok(())
661    }
662}
663
664pub struct MasNewCompatRefreshToken {
665    pub refresh_token_id: Uuid,
666    pub session_id: Uuid,
667    pub access_token_id: Uuid,
668    pub refresh_token: String,
669    pub created_at: DateTime<Utc>,
670}
671
672impl WriteBatch for MasNewCompatRefreshToken {
673    async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
674        let mut refresh_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
675        let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
676        let mut access_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
677        let mut refresh_tokens: Vec<String> = Vec::with_capacity(batch.len());
678        let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
679
680        for MasNewCompatRefreshToken {
681            refresh_token_id,
682            session_id,
683            access_token_id,
684            refresh_token,
685            created_at,
686        } in batch
687        {
688            refresh_token_ids.push(refresh_token_id);
689            session_ids.push(session_id);
690            access_token_ids.push(access_token_id);
691            refresh_tokens.push(refresh_token);
692            created_ats.push(created_at);
693        }
694
695        sqlx::query!(
696            r#"
697            INSERT INTO syn2mas__compat_refresh_tokens (
698              compat_refresh_token_id,
699              compat_session_id,
700              compat_access_token_id,
701              refresh_token,
702              created_at)
703            SELECT * FROM UNNEST(
704              $1::UUID[],
705              $2::UUID[],
706              $3::UUID[],
707              $4::TEXT[],
708              $5::TIMESTAMP WITH TIME ZONE[])
709            "#,
710            &refresh_token_ids[..],
711            &session_ids[..],
712            &access_token_ids[..],
713            &refresh_tokens[..],
714            &created_ats[..],
715        )
716        .execute(&mut *conn)
717        .await
718        .into_database("writing compat refresh tokens to MAS")?;
719
720        Ok(())
721    }
722}
723
724/// The 'version' of the password hashing scheme used for passwords when they
725/// are migrated from Synapse to MAS.
726/// This is version 1, as in the previous syn2mas script.
727// TODO hardcoding version to `1` may not be correct long-term?
728pub const MIGRATED_PASSWORD_VERSION: u16 = 1;
729
730/// List of all MAS tables that are written to by syn2mas.
731pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
732    "users",
733    "user_passwords",
734    "user_emails",
735    "user_unsupported_third_party_ids",
736    "upstream_oauth_links",
737    "compat_sessions",
738    "compat_access_tokens",
739    "compat_refresh_tokens",
740];
741
742/// Detect whether a syn2mas migration has started on the given database.
743///
744/// Concretly, this checks for the presence of syn2mas restoration tables.
745///
746/// Returns `true` if syn2mas has started, or `false` if it hasn't.
747///
748/// # Errors
749///
750/// Errors are returned under the following circumstances:
751///
752/// - If any database error occurs whilst querying the database.
753/// - If some, but not all, syn2mas restoration tables are present. (This
754///   shouldn't be possible without syn2mas having been sabotaged!)
755pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result<bool, Error> {
756    // Names of tables used for syn2mas resumption
757    // Must be `String`s, not just `&str`, for the query.
758    let restore_table_names = vec![
759        "syn2mas_restore_constraints".to_owned(),
760        "syn2mas_restore_indices".to_owned(),
761    ];
762
763    let num_resumption_tables = query!(
764        r#"
765        SELECT 1 AS _dummy FROM pg_tables WHERE schemaname = current_schema
766        AND tablename = ANY($1)
767        "#,
768        &restore_table_names,
769    )
770    .fetch_all(conn.as_mut())
771    .await
772    .into_database("failed to query count of resumption tables")?
773    .len();
774
775    if num_resumption_tables == 0 {
776        Ok(false)
777    } else if num_resumption_tables == restore_table_names.len() {
778        Ok(true)
779    } else {
780        Err(Error::inconsistent(
781            "some, but not all, syn2mas resumption tables were found",
782        ))
783    }
784}
785
786impl MasWriter {
787    /// Creates a new MAS writer.
788    ///
789    /// # Errors
790    ///
791    /// Errors are returned in the following conditions:
792    ///
793    /// - If the database connection experiences an error.
794    #[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)]
795    pub async fn new(
796        mut conn: LockedMasDatabase,
797        mut writer_connections: Vec<PgConnection>,
798        dry_run: bool,
799    ) -> Result<Self, Error> {
800        // Given that we don't have any concurrent transactions here,
801        // the READ COMMITTED isolation level is sufficient.
802        query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
803            .execute(conn.as_mut())
804            .await
805            .into_database("begin MAS transaction")?;
806
807        let syn2mas_started = is_syn2mas_in_progress(conn.as_mut()).await?;
808
809        let indices_to_restore;
810        let constraints_to_restore;
811
812        if syn2mas_started {
813            // We are resuming from a partially-done syn2mas migration
814            // We should reset the database so that we're starting from scratch.
815            warn!("Partial syn2mas migration has already been done; resetting.");
816            for table in MAS_TABLES_AFFECTED_BY_MIGRATION {
817                query(&format!("TRUNCATE syn2mas__{table};"))
818                    .execute(conn.as_mut())
819                    .await
820                    .into_database_with(|| format!("failed to truncate table syn2mas__{table}"))?;
821            }
822
823            indices_to_restore = query_as!(
824                IndexDescription,
825                "SELECT table_name, name, definition FROM syn2mas_restore_indices ORDER BY order_key"
826            )
827                .fetch_all(conn.as_mut())
828                .await
829                .into_database("failed to get syn2mas restore data (index descriptions)")?;
830            constraints_to_restore = query_as!(
831                ConstraintDescription,
832                "SELECT table_name, name, definition FROM syn2mas_restore_constraints ORDER BY order_key"
833            )
834                .fetch_all(conn.as_mut())
835                .await
836                .into_database("failed to get syn2mas restore data (constraint descriptions)")?;
837        } else {
838            info!("Starting new syn2mas migration");
839
840            conn.as_mut()
841                .execute_many(include_str!("syn2mas_temporary_tables.sql"))
842                // We don't care about any query results
843                .try_collect::<Vec<_>>()
844                .await
845                .into_database("could not create temporary tables")?;
846
847            // Pause (temporarily drop) indices and constraints in order to improve
848            // performance of bulk data loading.
849            (indices_to_restore, constraints_to_restore) =
850                Self::pause_indices(conn.as_mut()).await?;
851
852            // Persist these index and constraint definitions.
853            for IndexDescription {
854                name,
855                table_name,
856                definition,
857            } in &indices_to_restore
858            {
859                query!(
860                    r#"
861                    INSERT INTO syn2mas_restore_indices (name, table_name, definition)
862                    VALUES ($1, $2, $3)
863                    "#,
864                    name,
865                    table_name,
866                    definition
867                )
868                .execute(conn.as_mut())
869                .await
870                .into_database("failed to save restore data (index)")?;
871            }
872            for ConstraintDescription {
873                name,
874                table_name,
875                definition,
876            } in &constraints_to_restore
877            {
878                query!(
879                    r#"
880                    INSERT INTO syn2mas_restore_constraints (name, table_name, definition)
881                    VALUES ($1, $2, $3)
882                    "#,
883                    name,
884                    table_name,
885                    definition
886                )
887                .execute(conn.as_mut())
888                .await
889                .into_database("failed to save restore data (index)")?;
890            }
891        }
892
893        query("COMMIT;")
894            .execute(conn.as_mut())
895            .await
896            .into_database("begin MAS transaction")?;
897
898        // Now after all the schema changes have been done, begin writer transactions
899        for writer_connection in &mut writer_connections {
900            query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
901                .execute(&mut *writer_connection)
902                .await
903                .into_database("begin MAS writer transaction")?;
904        }
905
906        Ok(Self {
907            conn,
908            dry_run,
909            writer_pool: WriterConnectionPool::new(writer_connections),
910            indices_to_restore,
911            constraints_to_restore,
912            write_buffer_finish_checker: FinishChecker::default(),
913        })
914    }
915
916    #[tracing::instrument(skip_all)]
917    async fn pause_indices(
918        conn: &mut PgConnection,
919    ) -> Result<(Vec<IndexDescription>, Vec<ConstraintDescription>), Error> {
920        let mut indices_to_restore = Vec::new();
921        let mut constraints_to_restore = Vec::new();
922
923        for &unprefixed_table in MAS_TABLES_AFFECTED_BY_MIGRATION {
924            let table = format!("syn2mas__{unprefixed_table}");
925            // First drop incoming foreign key constraints
926            for constraint in
927                constraint_pausing::describe_foreign_key_constraints_to_table(&mut *conn, &table)
928                    .await?
929            {
930                constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
931                constraints_to_restore.push(constraint);
932            }
933            // After all incoming foreign key constraints have been removed,
934            // we can now drop internal constraints.
935            for constraint in
936                constraint_pausing::describe_constraints_on_table(&mut *conn, &table).await?
937            {
938                constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
939                constraints_to_restore.push(constraint);
940            }
941            // After all constraints have been removed, we can drop indices.
942            for index in constraint_pausing::describe_indices_on_table(&mut *conn, &table).await? {
943                constraint_pausing::drop_index(&mut *conn, &index).await?;
944                indices_to_restore.push(index);
945            }
946        }
947
948        Ok((indices_to_restore, constraints_to_restore))
949    }
950
951    async fn restore_indices(
952        conn: &mut LockedMasDatabase,
953        indices_to_restore: &[IndexDescription],
954        constraints_to_restore: &[ConstraintDescription],
955        progress: &Progress,
956    ) -> Result<(), Error> {
957        // First restore all indices. The order is not important as far as I know.
958        // However the indices are needed before constraints.
959        for index in indices_to_restore.iter().rev() {
960            progress.rebuild_index(index.name.clone());
961            constraint_pausing::restore_index(conn.as_mut(), index).await?;
962        }
963        // Then restore all constraints.
964        // The order here is the reverse of drop order, since some constraints may rely
965        // on other constraints to work.
966        for constraint in constraints_to_restore.iter().rev() {
967            progress.rebuild_constraint(constraint.name.clone());
968            constraint_pausing::restore_constraint(conn.as_mut(), constraint).await?;
969        }
970        Ok(())
971    }
972
973    /// Finish writing to the MAS database, flushing and committing all changes.
974    /// It returns the unlocked underlying connection.
975    ///
976    /// # Errors
977    ///
978    /// Errors are returned in the following conditions:
979    ///
980    /// - If the database connection experiences an error.
981    #[tracing::instrument(skip_all)]
982    pub async fn finish(mut self, progress: &Progress) -> Result<PgConnection, Error> {
983        self.write_buffer_finish_checker.check_all_finished()?;
984
985        // Commit all writer transactions to the database.
986        self.writer_pool
987            .finish()
988            .await
989            .map_err(|errors| Error::Multiple(MultipleErrors::from(errors)))?;
990
991        // Now all the data has been migrated, finish off by restoring indices and
992        // constraints!
993        query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
994            .execute(self.conn.as_mut())
995            .await
996            .into_database("begin MAS transaction")?;
997
998        Self::restore_indices(
999            &mut self.conn,
1000            &self.indices_to_restore,
1001            &self.constraints_to_restore,
1002            progress,
1003        )
1004        .await?;
1005
1006        self.conn
1007            .as_mut()
1008            .execute_many(include_str!("syn2mas_revert_temporary_tables.sql"))
1009            // We don't care about any query results
1010            .try_collect::<Vec<_>>()
1011            .await
1012            .into_database("could not revert temporary tables")?;
1013
1014        // If we're in dry-run mode, truncate all the tables we've written to
1015        if self.dry_run {
1016            warn!("Migration ran in dry-run mode, deleting all imported data");
1017            let tables = MAS_TABLES_AFFECTED_BY_MIGRATION
1018                .iter()
1019                .map(|table| format!("\"{table}\""))
1020                .collect::<Vec<_>>()
1021                .join(", ");
1022
1023            // Note that we do that with CASCADE, because we do that *after*
1024            // restoring the FK constraints.
1025            //
1026            // The alternative would be to list all the tables we have FK to
1027            // those tables, which would be a hassle, or to do that after
1028            // restoring the constraints, which would mean we wouldn't validate
1029            // that we've done valid FKs in dry-run mode.
1030            query(&format!("TRUNCATE TABLE {tables} CASCADE;"))
1031                .execute(self.conn.as_mut())
1032                .await
1033                .into_database_with(|| "failed to truncate all tables")?;
1034        }
1035
1036        query("COMMIT;")
1037            .execute(self.conn.as_mut())
1038            .await
1039            .into_database("ending MAS transaction")?;
1040
1041        let conn = self
1042            .conn
1043            .unlock()
1044            .await
1045            .into_database("could not unlock MAS database")?;
1046
1047        Ok(conn)
1048    }
1049}
1050
1051// How many entries to buffer at once, before writing a batch of rows to the
1052// database.
1053const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
1054
1055/// A buffer for writing rows to the MAS database.
1056/// Generic over the type of rows.
1057pub struct MasWriteBuffer<T> {
1058    rows: Vec<T>,
1059    finish_checker_handle: FinishCheckerHandle,
1060}
1061
1062impl<T> MasWriteBuffer<T>
1063where
1064    T: WriteBatch,
1065{
1066    pub fn new(writer: &MasWriter) -> Self {
1067        MasWriteBuffer {
1068            rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
1069            finish_checker_handle: writer.write_buffer_finish_checker.handle(),
1070        }
1071    }
1072
1073    pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1074        self.flush(writer).await?;
1075        self.finish_checker_handle.declare_finished();
1076        Ok(())
1077    }
1078
1079    pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> {
1080        if self.rows.is_empty() {
1081            return Ok(());
1082        }
1083        let rows = std::mem::take(&mut self.rows);
1084        self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
1085        writer
1086            .writer_pool
1087            .spawn_with_connection(move |conn| T::write_batch(conn, rows).boxed())
1088            .boxed()
1089            .await?;
1090        Ok(())
1091    }
1092
1093    pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> {
1094        self.rows.push(row);
1095        if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
1096            self.flush(writer).await?;
1097        }
1098        Ok(())
1099    }
1100}
1101
1102#[cfg(test)]
1103mod test {
1104    use std::collections::{BTreeMap, BTreeSet};
1105
1106    use chrono::DateTime;
1107    use futures_util::TryStreamExt;
1108    use serde::Serialize;
1109    use sqlx::{Column, PgConnection, PgPool, Row};
1110    use uuid::{NonNilUuid, Uuid};
1111
1112    use crate::{
1113        LockedMasDatabase, MasWriter, Progress,
1114        mas_writer::{
1115            MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
1116            MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
1117            MasNewUserPassword, MasWriteBuffer,
1118        },
1119    };
1120
1121    /// A snapshot of a whole database
1122    #[derive(Default, Serialize)]
1123    #[serde(transparent)]
1124    struct DatabaseSnapshot {
1125        tables: BTreeMap<String, TableSnapshot>,
1126    }
1127
1128    #[derive(Serialize)]
1129    #[serde(transparent)]
1130    struct TableSnapshot {
1131        rows: BTreeSet<RowSnapshot>,
1132    }
1133
1134    #[derive(PartialEq, Eq, PartialOrd, Ord, Serialize)]
1135    #[serde(transparent)]
1136    struct RowSnapshot {
1137        columns_to_values: BTreeMap<String, Option<String>>,
1138    }
1139
1140    const SKIPPED_TABLES: &[&str] = &["_sqlx_migrations"];
1141
1142    /// Produces a serialisable snapshot of a database, usable for snapshot
1143    /// testing
1144    ///
1145    /// For brevity, empty tables, as well as [`SKIPPED_TABLES`], will not be
1146    /// included in the snapshot.
1147    async fn snapshot_database(conn: &mut PgConnection) -> DatabaseSnapshot {
1148        let mut out = DatabaseSnapshot::default();
1149        let table_names: Vec<String> = sqlx::query_scalar(
1150            "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema();",
1151        )
1152        .fetch_all(&mut *conn)
1153        .await
1154        .unwrap();
1155
1156        for table_name in table_names {
1157            if SKIPPED_TABLES.contains(&table_name.as_str()) {
1158                continue;
1159            }
1160
1161            let column_names: Vec<String> = sqlx::query_scalar(
1162                "SELECT column_name FROM information_schema.columns WHERE table_name = $1 AND table_schema = current_schema();"
1163            ).bind(&table_name).fetch_all(&mut *conn).await.expect("failed to get column names for table for snapshotting");
1164
1165            let column_name_list = column_names
1166                .iter()
1167                // stringify all the values for simplicity
1168                .map(|column_name| format!("{column_name}::TEXT AS \"{column_name}\""))
1169                .collect::<Vec<_>>()
1170                .join(", ");
1171
1172            let table_rows = sqlx::query(&format!("SELECT {column_name_list} FROM {table_name};"))
1173                .fetch(&mut *conn)
1174                .map_ok(|row| {
1175                    let mut columns_to_values = BTreeMap::new();
1176                    for (idx, column) in row.columns().iter().enumerate() {
1177                        columns_to_values.insert(column.name().to_owned(), row.get(idx));
1178                    }
1179                    RowSnapshot { columns_to_values }
1180                })
1181                .try_collect::<BTreeSet<RowSnapshot>>()
1182                .await
1183                .expect("failed to fetch rows from table for snapshotting");
1184
1185            if !table_rows.is_empty() {
1186                out.tables
1187                    .insert(table_name, TableSnapshot { rows: table_rows });
1188            }
1189        }
1190
1191        out
1192    }
1193
1194    /// Make a snapshot assertion against the database.
1195    macro_rules! assert_db_snapshot {
1196        ($db: expr) => {
1197            let db_snapshot = snapshot_database($db).await;
1198            ::insta::assert_yaml_snapshot!(db_snapshot);
1199        };
1200    }
1201
1202    /// Runs some code with a `MasWriter`.
1203    ///
1204    /// The callback is responsible for `finish`ing the `MasWriter`.
1205    async fn make_mas_writer(pool: &PgPool) -> MasWriter {
1206        let main_conn = pool.acquire().await.unwrap().detach();
1207        let mut writer_conns = Vec::new();
1208        for _ in 0..2 {
1209            writer_conns.push(
1210                pool.acquire()
1211                    .await
1212                    .expect("failed to acquire MasWriter writer connection")
1213                    .detach(),
1214            );
1215        }
1216        let locked_main_conn = LockedMasDatabase::try_new(main_conn)
1217            .await
1218            .expect("failed to lock MAS database")
1219            .expect_left("MAS database is already locked");
1220        MasWriter::new(locked_main_conn, writer_conns, false)
1221            .await
1222            .expect("failed to construct MasWriter")
1223    }
1224
1225    /// Tests writing a single user, without a password.
1226    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1227    async fn test_write_user(pool: PgPool) {
1228        let mut writer = make_mas_writer(&pool).await;
1229        let mut buffer = MasWriteBuffer::new(&writer);
1230
1231        buffer
1232            .write(
1233                &mut writer,
1234                MasNewUser {
1235                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1236                    username: "alice".to_owned(),
1237                    created_at: DateTime::default(),
1238                    locked_at: None,
1239                    deactivated_at: None,
1240                    can_request_admin: false,
1241                    is_guest: false,
1242                },
1243            )
1244            .await
1245            .expect("failed to write user");
1246
1247        buffer
1248            .finish(&mut writer)
1249            .await
1250            .expect("failed to finish MasWriter");
1251
1252        let mut conn = writer
1253            .finish(&Progress::default())
1254            .await
1255            .expect("failed to finish MasWriter");
1256
1257        assert_db_snapshot!(&mut conn);
1258    }
1259
1260    /// Tests writing a single user, with a password.
1261    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1262    async fn test_write_user_with_password(pool: PgPool) {
1263        const USER_ID: NonNilUuid = NonNilUuid::new(Uuid::from_u128(1u128)).unwrap();
1264
1265        let mut writer = make_mas_writer(&pool).await;
1266
1267        let mut user_buffer = MasWriteBuffer::new(&writer);
1268        let mut password_buffer = MasWriteBuffer::new(&writer);
1269
1270        user_buffer
1271            .write(
1272                &mut writer,
1273                MasNewUser {
1274                    user_id: USER_ID,
1275                    username: "alice".to_owned(),
1276                    created_at: DateTime::default(),
1277                    locked_at: None,
1278                    deactivated_at: None,
1279                    can_request_admin: false,
1280                    is_guest: false,
1281                },
1282            )
1283            .await
1284            .expect("failed to write user");
1285
1286        password_buffer
1287            .write(
1288                &mut writer,
1289                MasNewUserPassword {
1290                    user_password_id: Uuid::from_u128(42u128),
1291                    user_id: USER_ID,
1292                    hashed_password: "$bcrypt$aaaaaaaaaaa".to_owned(),
1293                    created_at: DateTime::default(),
1294                },
1295            )
1296            .await
1297            .expect("failed to write password");
1298
1299        user_buffer
1300            .finish(&mut writer)
1301            .await
1302            .expect("failed to finish MasWriteBuffer");
1303        password_buffer
1304            .finish(&mut writer)
1305            .await
1306            .expect("failed to finish MasWriteBuffer");
1307
1308        let mut conn = writer
1309            .finish(&Progress::default())
1310            .await
1311            .expect("failed to finish MasWriter");
1312
1313        assert_db_snapshot!(&mut conn);
1314    }
1315
1316    /// Tests writing a single user, with an e-mail address associated.
1317    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1318    async fn test_write_user_with_email(pool: PgPool) {
1319        let mut writer = make_mas_writer(&pool).await;
1320
1321        let mut user_buffer = MasWriteBuffer::new(&writer);
1322        let mut email_buffer = MasWriteBuffer::new(&writer);
1323
1324        user_buffer
1325            .write(
1326                &mut writer,
1327                MasNewUser {
1328                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1329                    username: "alice".to_owned(),
1330                    created_at: DateTime::default(),
1331                    locked_at: None,
1332                    deactivated_at: None,
1333                    can_request_admin: false,
1334                    is_guest: false,
1335                },
1336            )
1337            .await
1338            .expect("failed to write user");
1339
1340        email_buffer
1341            .write(
1342                &mut writer,
1343                MasNewEmailThreepid {
1344                    user_email_id: Uuid::from_u128(2u128),
1345                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1346                    email: "alice@example.org".to_owned(),
1347                    created_at: DateTime::default(),
1348                },
1349            )
1350            .await
1351            .expect("failed to write e-mail");
1352
1353        user_buffer
1354            .finish(&mut writer)
1355            .await
1356            .expect("failed to finish user buffer");
1357        email_buffer
1358            .finish(&mut writer)
1359            .await
1360            .expect("failed to finish email buffer");
1361
1362        let mut conn = writer
1363            .finish(&Progress::default())
1364            .await
1365            .expect("failed to finish MasWriter");
1366
1367        assert_db_snapshot!(&mut conn);
1368    }
1369
1370    /// Tests writing a single user, with a unsupported third-party ID
1371    /// associated.
1372    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1373    async fn test_write_user_with_unsupported_threepid(pool: PgPool) {
1374        let mut writer = make_mas_writer(&pool).await;
1375
1376        let mut user_buffer = MasWriteBuffer::new(&writer);
1377        let mut threepid_buffer = MasWriteBuffer::new(&writer);
1378
1379        user_buffer
1380            .write(
1381                &mut writer,
1382                MasNewUser {
1383                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1384                    username: "alice".to_owned(),
1385                    created_at: DateTime::default(),
1386                    locked_at: None,
1387                    deactivated_at: None,
1388                    can_request_admin: false,
1389                    is_guest: false,
1390                },
1391            )
1392            .await
1393            .expect("failed to write user");
1394
1395        threepid_buffer
1396            .write(
1397                &mut writer,
1398                MasNewUnsupportedThreepid {
1399                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1400                    medium: "msisdn".to_owned(),
1401                    address: "441189998819991197253".to_owned(),
1402                    created_at: DateTime::default(),
1403                },
1404            )
1405            .await
1406            .expect("failed to write phone number (unsupported threepid)");
1407
1408        user_buffer
1409            .finish(&mut writer)
1410            .await
1411            .expect("failed to finish user buffer");
1412        threepid_buffer
1413            .finish(&mut writer)
1414            .await
1415            .expect("failed to finish threepid buffer");
1416
1417        let mut conn = writer
1418            .finish(&Progress::default())
1419            .await
1420            .expect("failed to finish MasWriter");
1421
1422        assert_db_snapshot!(&mut conn);
1423    }
1424
1425    /// Tests writing a single user, with a link to an upstream provider.
1426    /// There needs to be an upstream provider in the database already — in the
1427    /// real migration, this is done by running a provider sync first.
1428    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
1429    async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
1430        let mut writer = make_mas_writer(&pool).await;
1431
1432        let mut user_buffer = MasWriteBuffer::new(&writer);
1433        let mut link_buffer = MasWriteBuffer::new(&writer);
1434
1435        user_buffer
1436            .write(
1437                &mut writer,
1438                MasNewUser {
1439                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1440                    username: "alice".to_owned(),
1441                    created_at: DateTime::default(),
1442                    locked_at: None,
1443                    deactivated_at: None,
1444                    can_request_admin: false,
1445                    is_guest: false,
1446                },
1447            )
1448            .await
1449            .expect("failed to write user");
1450
1451        link_buffer
1452            .write(
1453                &mut writer,
1454                MasNewUpstreamOauthLink {
1455                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1456                    link_id: Uuid::from_u128(3u128),
1457                    upstream_provider_id: Uuid::from_u128(4u128),
1458                    subject: "12345.67890".to_owned(),
1459                    created_at: DateTime::default(),
1460                },
1461            )
1462            .await
1463            .expect("failed to write link");
1464
1465        user_buffer
1466            .finish(&mut writer)
1467            .await
1468            .expect("failed to finish user buffer");
1469        link_buffer
1470            .finish(&mut writer)
1471            .await
1472            .expect("failed to finish link buffer");
1473
1474        let mut conn = writer
1475            .finish(&Progress::default())
1476            .await
1477            .expect("failed to finish MasWriter");
1478
1479        assert_db_snapshot!(&mut conn);
1480    }
1481
1482    /// Tests writing a single user, with a device (compat session).
1483    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1484    async fn test_write_user_with_device(pool: PgPool) {
1485        let mut writer = make_mas_writer(&pool).await;
1486
1487        let mut user_buffer = MasWriteBuffer::new(&writer);
1488        let mut session_buffer = MasWriteBuffer::new(&writer);
1489
1490        user_buffer
1491            .write(
1492                &mut writer,
1493                MasNewUser {
1494                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1495                    username: "alice".to_owned(),
1496                    created_at: DateTime::default(),
1497                    locked_at: None,
1498                    deactivated_at: None,
1499                    can_request_admin: false,
1500                    is_guest: false,
1501                },
1502            )
1503            .await
1504            .expect("failed to write user");
1505
1506        session_buffer
1507            .write(
1508                &mut writer,
1509                MasNewCompatSession {
1510                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1511                    session_id: Uuid::from_u128(5u128),
1512                    created_at: DateTime::default(),
1513                    device_id: Some("ADEVICE".to_owned()),
1514                    human_name: Some("alice's pinephone".to_owned()),
1515                    is_synapse_admin: true,
1516                    last_active_at: Some(DateTime::default()),
1517                    last_active_ip: Some("203.0.113.1".parse().unwrap()),
1518                    user_agent: Some("Browser/5.0".to_owned()),
1519                },
1520            )
1521            .await
1522            .expect("failed to write compat session");
1523
1524        user_buffer
1525            .finish(&mut writer)
1526            .await
1527            .expect("failed to finish user buffer");
1528        session_buffer
1529            .finish(&mut writer)
1530            .await
1531            .expect("failed to finish session buffer");
1532
1533        let mut conn = writer
1534            .finish(&Progress::default())
1535            .await
1536            .expect("failed to finish MasWriter");
1537
1538        assert_db_snapshot!(&mut conn);
1539    }
1540
1541    /// Tests writing a single user, with a device and an access token.
1542    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1543    async fn test_write_user_with_access_token(pool: PgPool) {
1544        let mut writer = make_mas_writer(&pool).await;
1545
1546        let mut user_buffer = MasWriteBuffer::new(&writer);
1547        let mut session_buffer = MasWriteBuffer::new(&writer);
1548        let mut token_buffer = MasWriteBuffer::new(&writer);
1549
1550        user_buffer
1551            .write(
1552                &mut writer,
1553                MasNewUser {
1554                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1555                    username: "alice".to_owned(),
1556                    created_at: DateTime::default(),
1557                    locked_at: None,
1558                    deactivated_at: None,
1559                    can_request_admin: false,
1560                    is_guest: false,
1561                },
1562            )
1563            .await
1564            .expect("failed to write user");
1565
1566        session_buffer
1567            .write(
1568                &mut writer,
1569                MasNewCompatSession {
1570                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1571                    session_id: Uuid::from_u128(5u128),
1572                    created_at: DateTime::default(),
1573                    device_id: Some("ADEVICE".to_owned()),
1574                    human_name: None,
1575                    is_synapse_admin: false,
1576                    last_active_at: None,
1577                    last_active_ip: None,
1578                    user_agent: None,
1579                },
1580            )
1581            .await
1582            .expect("failed to write compat session");
1583
1584        token_buffer
1585            .write(
1586                &mut writer,
1587                MasNewCompatAccessToken {
1588                    token_id: Uuid::from_u128(6u128),
1589                    session_id: Uuid::from_u128(5u128),
1590                    access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1591                    created_at: DateTime::default(),
1592                    expires_at: None,
1593                },
1594            )
1595            .await
1596            .expect("failed to write access token");
1597
1598        user_buffer
1599            .finish(&mut writer)
1600            .await
1601            .expect("failed to finish user buffer");
1602        session_buffer
1603            .finish(&mut writer)
1604            .await
1605            .expect("failed to finish session buffer");
1606        token_buffer
1607            .finish(&mut writer)
1608            .await
1609            .expect("failed to finish token buffer");
1610
1611        let mut conn = writer
1612            .finish(&Progress::default())
1613            .await
1614            .expect("failed to finish MasWriter");
1615
1616        assert_db_snapshot!(&mut conn);
1617    }
1618
1619    /// Tests writing a single user, with a device, an access token and a
1620    /// refresh token.
1621    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1622    async fn test_write_user_with_refresh_token(pool: PgPool) {
1623        let mut writer = make_mas_writer(&pool).await;
1624
1625        let mut user_buffer = MasWriteBuffer::new(&writer);
1626        let mut session_buffer = MasWriteBuffer::new(&writer);
1627        let mut token_buffer = MasWriteBuffer::new(&writer);
1628        let mut refresh_token_buffer = MasWriteBuffer::new(&writer);
1629
1630        user_buffer
1631            .write(
1632                &mut writer,
1633                MasNewUser {
1634                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1635                    username: "alice".to_owned(),
1636                    created_at: DateTime::default(),
1637                    locked_at: None,
1638                    deactivated_at: None,
1639                    can_request_admin: false,
1640                    is_guest: false,
1641                },
1642            )
1643            .await
1644            .expect("failed to write user");
1645
1646        session_buffer
1647            .write(
1648                &mut writer,
1649                MasNewCompatSession {
1650                    user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1651                    session_id: Uuid::from_u128(5u128),
1652                    created_at: DateTime::default(),
1653                    device_id: Some("ADEVICE".to_owned()),
1654                    human_name: None,
1655                    is_synapse_admin: false,
1656                    last_active_at: None,
1657                    last_active_ip: None,
1658                    user_agent: None,
1659                },
1660            )
1661            .await
1662            .expect("failed to write compat session");
1663
1664        token_buffer
1665            .write(
1666                &mut writer,
1667                MasNewCompatAccessToken {
1668                    token_id: Uuid::from_u128(6u128),
1669                    session_id: Uuid::from_u128(5u128),
1670                    access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1671                    created_at: DateTime::default(),
1672                    expires_at: None,
1673                },
1674            )
1675            .await
1676            .expect("failed to write access token");
1677
1678        refresh_token_buffer
1679            .write(
1680                &mut writer,
1681                MasNewCompatRefreshToken {
1682                    refresh_token_id: Uuid::from_u128(7u128),
1683                    session_id: Uuid::from_u128(5u128),
1684                    access_token_id: Uuid::from_u128(6u128),
1685                    refresh_token: "syr_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1686                    created_at: DateTime::default(),
1687                },
1688            )
1689            .await
1690            .expect("failed to write refresh token");
1691
1692        user_buffer
1693            .finish(&mut writer)
1694            .await
1695            .expect("failed to finish user buffer");
1696        session_buffer
1697            .finish(&mut writer)
1698            .await
1699            .expect("failed to finish session buffer");
1700        token_buffer
1701            .finish(&mut writer)
1702            .await
1703            .expect("failed to finish token buffer");
1704        refresh_token_buffer
1705            .finish(&mut writer)
1706            .await
1707            .expect("failed to finish refresh token buffer");
1708
1709        let mut conn = writer
1710            .finish(&Progress::default())
1711            .await
1712            .expect("failed to finish MasWriter");
1713
1714        assert_db_snapshot!(&mut conn);
1715    }
1716}