1use std::{
11 fmt::Display,
12 net::IpAddr,
13 sync::{
14 Arc,
15 atomic::{AtomicU32, Ordering},
16 },
17};
18
19use chrono::{DateTime, Utc};
20use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
21use sqlx::{Executor, PgConnection, query, query_as};
22use thiserror::Error;
23use thiserror_ext::{Construct, ContextInto};
24use tokio::sync::mpsc::{self, Receiver, Sender};
25use tracing::{Instrument, error, info, warn};
26use uuid::{NonNilUuid, Uuid};
27
28use self::{
29 constraint_pausing::{ConstraintDescription, IndexDescription},
30 locking::LockedMasDatabase,
31};
32use crate::Progress;
33
34pub mod checks;
35pub mod locking;
36
37mod constraint_pausing;
38
39#[derive(Debug, Error, Construct, ContextInto)]
40pub enum Error {
41 #[error("database error whilst {context}")]
42 Database {
43 #[source]
44 source: sqlx::Error,
45 context: String,
46 },
47
48 #[error("writer connection pool shut down due to error")]
49 #[expect(clippy::enum_variant_names)]
50 WriterConnectionPoolError,
51
52 #[error("inconsistent database: {0}")]
53 Inconsistent(String),
54
55 #[error("bug in syn2mas: write buffers not finished")]
56 WriteBuffersNotFinished,
57
58 #[error("{0}")]
59 Multiple(MultipleErrors),
60}
61
62#[derive(Debug)]
63pub struct MultipleErrors {
64 errors: Vec<Error>,
65}
66
67impl Display for MultipleErrors {
68 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
69 write!(f, "multiple errors")?;
70 for error in &self.errors {
71 write!(f, "\n- {error}")?;
72 }
73 Ok(())
74 }
75}
76
77impl From<Vec<Error>> for MultipleErrors {
78 fn from(value: Vec<Error>) -> Self {
79 MultipleErrors { errors: value }
80 }
81}
82
83struct WriterConnectionPool {
84 num_connections: usize,
86
87 connection_rx: Receiver<Result<PgConnection, Error>>,
90
91 connection_tx: Sender<Result<PgConnection, Error>>,
94}
95
96impl WriterConnectionPool {
97 pub fn new(connections: Vec<PgConnection>) -> Self {
98 let num_connections = connections.len();
99 let (connection_tx, connection_rx) = mpsc::channel(num_connections);
100 for connection in connections {
101 connection_tx
102 .try_send(Ok(connection))
103 .expect("there should be room for this connection");
104 }
105
106 WriterConnectionPool {
107 num_connections,
108 connection_rx,
109 connection_tx,
110 }
111 }
112
113 pub async fn spawn_with_connection<F>(&mut self, task: F) -> Result<(), Error>
114 where
115 F: for<'conn> FnOnce(&'conn mut PgConnection) -> BoxFuture<'conn, Result<(), Error>>
116 + Send
117 + 'static,
118 {
119 match self.connection_rx.recv().await {
120 Some(Ok(mut connection)) => {
121 let connection_tx = self.connection_tx.clone();
122 tokio::task::spawn(
123 async move {
124 let to_return = match task(&mut connection).await {
125 Ok(()) => Ok(connection),
126 Err(error) => {
127 error!("error in writer: {error}");
128 Err(error)
129 }
130 };
131 let _: Result<_, _> = connection_tx.send(to_return).await;
134 }
135 .instrument(tracing::debug_span!("spawn_with_connection")),
136 );
137
138 Ok(())
139 }
140 Some(Err(error)) => {
141 let _: Result<_, _> = self.connection_tx.send(Err(error)).await;
144
145 Err(Error::WriterConnectionPoolError)
146 }
147 None => {
148 unreachable!("we still hold a reference to the sender, so this shouldn't happen")
149 }
150 }
151 }
152
153 pub async fn finish(self) -> Result<(), Vec<Error>> {
165 let mut errors = Vec::new();
166
167 let Self {
168 num_connections,
169 mut connection_rx,
170 connection_tx,
171 } = self;
172 drop(connection_tx);
174
175 let mut finished_connections = 0;
176
177 while let Some(connection_or_error) = connection_rx.recv().await {
178 finished_connections += 1;
179
180 match connection_or_error {
181 Ok(mut connection) => {
182 if let Err(err) = query("COMMIT;").execute(&mut connection).await {
183 errors.push(err.into_database("commit writer transaction"));
184 }
185 }
186 Err(error) => {
187 errors.push(error);
188 }
189 }
190 }
191 assert_eq!(
192 finished_connections, num_connections,
193 "syn2mas had a bug: connections went missing {finished_connections} != {num_connections}"
194 );
195
196 if errors.is_empty() {
197 Ok(())
198 } else {
199 Err(errors)
200 }
201 }
202}
203
204#[derive(Default)]
207struct FinishChecker {
208 counter: Arc<AtomicU32>,
209}
210
211struct FinishCheckerHandle {
212 counter: Arc<AtomicU32>,
213}
214
215impl FinishChecker {
216 pub fn handle(&self) -> FinishCheckerHandle {
219 self.counter.fetch_add(1, Ordering::SeqCst);
220 FinishCheckerHandle {
221 counter: Arc::clone(&self.counter),
222 }
223 }
224
225 pub fn check_all_finished(self) -> Result<(), Error> {
227 if self.counter.load(Ordering::SeqCst) == 0 {
228 Ok(())
229 } else {
230 Err(Error::WriteBuffersNotFinished)
231 }
232 }
233}
234
235impl FinishCheckerHandle {
236 pub fn declare_finished(self) {
238 self.counter.fetch_sub(1, Ordering::SeqCst);
239 }
240}
241
242pub struct MasWriter {
243 conn: LockedMasDatabase,
244 writer_pool: WriterConnectionPool,
245 dry_run: bool,
246
247 indices_to_restore: Vec<IndexDescription>,
248 constraints_to_restore: Vec<ConstraintDescription>,
249
250 write_buffer_finish_checker: FinishChecker,
251}
252
253pub trait WriteBatch: Send + Sync + Sized + 'static {
254 fn write_batch(
255 conn: &mut PgConnection,
256 batch: Vec<Self>,
257 ) -> impl Future<Output = Result<(), Error>> + Send;
258}
259
260pub struct MasNewUser {
261 pub user_id: NonNilUuid,
262 pub username: String,
263 pub created_at: DateTime<Utc>,
264 pub locked_at: Option<DateTime<Utc>>,
265 pub deactivated_at: Option<DateTime<Utc>>,
266 pub can_request_admin: bool,
267 pub is_guest: bool,
271}
272
273impl WriteBatch for MasNewUser {
274 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
275 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
282 let mut usernames: Vec<String> = Vec::with_capacity(batch.len());
283 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
284 let mut locked_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
285 let mut deactivated_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
286 let mut can_request_admins: Vec<bool> = Vec::with_capacity(batch.len());
287 let mut is_guests: Vec<bool> = Vec::with_capacity(batch.len());
288 for MasNewUser {
289 user_id,
290 username,
291 created_at,
292 locked_at,
293 deactivated_at,
294 can_request_admin,
295 is_guest,
296 } in batch
297 {
298 user_ids.push(user_id.get());
299 usernames.push(username);
300 created_ats.push(created_at);
301 locked_ats.push(locked_at);
302 deactivated_ats.push(deactivated_at);
303 can_request_admins.push(can_request_admin);
304 is_guests.push(is_guest);
305 }
306
307 sqlx::query!(
308 r#"
309 INSERT INTO syn2mas__users (
310 user_id, username,
311 created_at, locked_at,
312 deactivated_at,
313 can_request_admin, is_guest)
314 SELECT * FROM UNNEST(
315 $1::UUID[], $2::TEXT[],
316 $3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],
317 $5::TIMESTAMP WITH TIME ZONE[],
318 $6::BOOL[], $7::BOOL[])
319 "#,
320 &user_ids[..],
321 &usernames[..],
322 &created_ats[..],
323 &locked_ats[..] as &[Option<DateTime<Utc>>],
325 &deactivated_ats[..] as &[Option<DateTime<Utc>>],
326 &can_request_admins[..],
327 &is_guests[..],
328 )
329 .execute(&mut *conn)
330 .await
331 .into_database("writing users to MAS")?;
332
333 Ok(())
334 }
335}
336
337pub struct MasNewUserPassword {
338 pub user_password_id: Uuid,
339 pub user_id: NonNilUuid,
340 pub hashed_password: String,
341 pub created_at: DateTime<Utc>,
342}
343
344impl WriteBatch for MasNewUserPassword {
345 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
346 let mut user_password_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
347 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
348 let mut hashed_passwords: Vec<String> = Vec::with_capacity(batch.len());
349 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
350 let mut versions: Vec<i32> = Vec::with_capacity(batch.len());
351 for MasNewUserPassword {
352 user_password_id,
353 user_id,
354 hashed_password,
355 created_at,
356 } in batch
357 {
358 user_password_ids.push(user_password_id);
359 user_ids.push(user_id.get());
360 hashed_passwords.push(hashed_password);
361 created_ats.push(created_at);
362 versions.push(MIGRATED_PASSWORD_VERSION.into());
363 }
364
365 sqlx::query!(
366 r#"
367 INSERT INTO syn2mas__user_passwords
368 (user_password_id, user_id, hashed_password, created_at, version)
369 SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[], $5::INTEGER[])
370 "#,
371 &user_password_ids[..],
372 &user_ids[..],
373 &hashed_passwords[..],
374 &created_ats[..],
375 &versions[..],
376 ).execute(&mut *conn).await.into_database("writing users to MAS")?;
377
378 Ok(())
379 }
380}
381
382pub struct MasNewEmailThreepid {
383 pub user_email_id: Uuid,
384 pub user_id: NonNilUuid,
385 pub email: String,
386 pub created_at: DateTime<Utc>,
387}
388
389impl WriteBatch for MasNewEmailThreepid {
390 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
391 let mut user_email_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
392 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
393 let mut emails: Vec<String> = Vec::with_capacity(batch.len());
394 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
395
396 for MasNewEmailThreepid {
397 user_email_id,
398 user_id,
399 email,
400 created_at,
401 } in batch
402 {
403 user_email_ids.push(user_email_id);
404 user_ids.push(user_id.get());
405 emails.push(email);
406 created_ats.push(created_at);
407 }
408
409 sqlx::query!(
410 r#"
411 INSERT INTO syn2mas__user_emails
412 (user_email_id, user_id, email, created_at)
413 SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[])
414 "#,
415 &user_email_ids[..],
416 &user_ids[..],
417 &emails[..],
418 &created_ats[..],
419 )
420 .execute(&mut *conn)
421 .await
422 .into_database("writing emails to MAS")?;
423
424 Ok(())
425 }
426}
427
428pub struct MasNewUnsupportedThreepid {
429 pub user_id: NonNilUuid,
430 pub medium: String,
431 pub address: String,
432 pub created_at: DateTime<Utc>,
433}
434
435impl WriteBatch for MasNewUnsupportedThreepid {
436 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
437 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
438 let mut mediums: Vec<String> = Vec::with_capacity(batch.len());
439 let mut addresses: Vec<String> = Vec::with_capacity(batch.len());
440 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
441
442 for MasNewUnsupportedThreepid {
443 user_id,
444 medium,
445 address,
446 created_at,
447 } in batch
448 {
449 user_ids.push(user_id.get());
450 mediums.push(medium);
451 addresses.push(address);
452 created_ats.push(created_at);
453 }
454
455 sqlx::query!(
456 r#"
457 INSERT INTO syn2mas__user_unsupported_third_party_ids
458 (user_id, medium, address, created_at)
459 SELECT * FROM UNNEST($1::UUID[], $2::TEXT[], $3::TEXT[], $4::TIMESTAMP WITH TIME ZONE[])
460 "#,
461 &user_ids[..],
462 &mediums[..],
463 &addresses[..],
464 &created_ats[..],
465 )
466 .execute(&mut *conn)
467 .await
468 .into_database("writing unsupported threepids to MAS")?;
469
470 Ok(())
471 }
472}
473
474pub struct MasNewUpstreamOauthLink {
475 pub link_id: Uuid,
476 pub user_id: NonNilUuid,
477 pub upstream_provider_id: Uuid,
478 pub subject: String,
479 pub created_at: DateTime<Utc>,
480}
481
482impl WriteBatch for MasNewUpstreamOauthLink {
483 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
484 let mut link_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
485 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
486 let mut upstream_provider_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
487 let mut subjects: Vec<String> = Vec::with_capacity(batch.len());
488 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
489
490 for MasNewUpstreamOauthLink {
491 link_id,
492 user_id,
493 upstream_provider_id,
494 subject,
495 created_at,
496 } in batch
497 {
498 link_ids.push(link_id);
499 user_ids.push(user_id.get());
500 upstream_provider_ids.push(upstream_provider_id);
501 subjects.push(subject);
502 created_ats.push(created_at);
503 }
504
505 sqlx::query!(
506 r#"
507 INSERT INTO syn2mas__upstream_oauth_links
508 (upstream_oauth_link_id, user_id, upstream_oauth_provider_id, subject, created_at)
509 SELECT * FROM UNNEST($1::UUID[], $2::UUID[], $3::UUID[], $4::TEXT[], $5::TIMESTAMP WITH TIME ZONE[])
510 "#,
511 &link_ids[..],
512 &user_ids[..],
513 &upstream_provider_ids[..],
514 &subjects[..],
515 &created_ats[..],
516 ).execute(&mut *conn).await.into_database("writing unsupported threepids to MAS")?;
517
518 Ok(())
519 }
520}
521
522pub struct MasNewCompatSession {
523 pub session_id: Uuid,
524 pub user_id: NonNilUuid,
525 pub device_id: Option<String>,
526 pub human_name: Option<String>,
527 pub created_at: DateTime<Utc>,
528 pub is_synapse_admin: bool,
529 pub last_active_at: Option<DateTime<Utc>>,
530 pub last_active_ip: Option<IpAddr>,
531 pub user_agent: Option<String>,
532}
533
534impl WriteBatch for MasNewCompatSession {
535 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
536 let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
537 let mut user_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
538 let mut device_ids: Vec<Option<String>> = Vec::with_capacity(batch.len());
539 let mut human_names: Vec<Option<String>> = Vec::with_capacity(batch.len());
540 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
541 let mut is_synapse_admins: Vec<bool> = Vec::with_capacity(batch.len());
542 let mut last_active_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
543 let mut last_active_ips: Vec<Option<IpAddr>> = Vec::with_capacity(batch.len());
544 let mut user_agents: Vec<Option<String>> = Vec::with_capacity(batch.len());
545
546 for MasNewCompatSession {
547 session_id,
548 user_id,
549 device_id,
550 human_name,
551 created_at,
552 is_synapse_admin,
553 last_active_at,
554 last_active_ip,
555 user_agent,
556 } in batch
557 {
558 session_ids.push(session_id);
559 user_ids.push(user_id.get());
560 device_ids.push(device_id);
561 human_names.push(human_name);
562 created_ats.push(created_at);
563 is_synapse_admins.push(is_synapse_admin);
564 last_active_ats.push(last_active_at);
565 last_active_ips.push(last_active_ip);
566 user_agents.push(user_agent);
567 }
568
569 sqlx::query!(
570 r#"
571 INSERT INTO syn2mas__compat_sessions (
572 compat_session_id, user_id,
573 device_id, human_name,
574 created_at, is_synapse_admin,
575 last_active_at, last_active_ip,
576 user_agent)
577 SELECT * FROM UNNEST(
578 $1::UUID[], $2::UUID[],
579 $3::TEXT[], $4::TEXT[],
580 $5::TIMESTAMP WITH TIME ZONE[], $6::BOOLEAN[],
581 $7::TIMESTAMP WITH TIME ZONE[], $8::INET[],
582 $9::TEXT[])
583 "#,
584 &session_ids[..],
585 &user_ids[..],
586 &device_ids[..] as &[Option<String>],
587 &human_names[..] as &[Option<String>],
588 &created_ats[..],
589 &is_synapse_admins[..],
590 &last_active_ats[..] as &[Option<DateTime<Utc>>],
592 &last_active_ips[..] as &[Option<IpAddr>],
593 &user_agents[..] as &[Option<String>],
594 )
595 .execute(&mut *conn)
596 .await
597 .into_database("writing compat sessions to MAS")?;
598
599 Ok(())
600 }
601}
602
603pub struct MasNewCompatAccessToken {
604 pub token_id: Uuid,
605 pub session_id: Uuid,
606 pub access_token: String,
607 pub created_at: DateTime<Utc>,
608 pub expires_at: Option<DateTime<Utc>>,
609}
610
611impl WriteBatch for MasNewCompatAccessToken {
612 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
613 let mut token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
614 let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
615 let mut access_tokens: Vec<String> = Vec::with_capacity(batch.len());
616 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
617 let mut expires_ats: Vec<Option<DateTime<Utc>>> = Vec::with_capacity(batch.len());
618
619 for MasNewCompatAccessToken {
620 token_id,
621 session_id,
622 access_token,
623 created_at,
624 expires_at,
625 } in batch
626 {
627 token_ids.push(token_id);
628 session_ids.push(session_id);
629 access_tokens.push(access_token);
630 created_ats.push(created_at);
631 expires_ats.push(expires_at);
632 }
633
634 sqlx::query!(
635 r#"
636 INSERT INTO syn2mas__compat_access_tokens (
637 compat_access_token_id,
638 compat_session_id,
639 access_token,
640 created_at,
641 expires_at)
642 SELECT * FROM UNNEST(
643 $1::UUID[],
644 $2::UUID[],
645 $3::TEXT[],
646 $4::TIMESTAMP WITH TIME ZONE[],
647 $5::TIMESTAMP WITH TIME ZONE[])
648 "#,
649 &token_ids[..],
650 &session_ids[..],
651 &access_tokens[..],
652 &created_ats[..],
653 &expires_ats[..] as &[Option<DateTime<Utc>>],
655 )
656 .execute(&mut *conn)
657 .await
658 .into_database("writing compat access tokens to MAS")?;
659
660 Ok(())
661 }
662}
663
664pub struct MasNewCompatRefreshToken {
665 pub refresh_token_id: Uuid,
666 pub session_id: Uuid,
667 pub access_token_id: Uuid,
668 pub refresh_token: String,
669 pub created_at: DateTime<Utc>,
670}
671
672impl WriteBatch for MasNewCompatRefreshToken {
673 async fn write_batch(conn: &mut PgConnection, batch: Vec<Self>) -> Result<(), Error> {
674 let mut refresh_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
675 let mut session_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
676 let mut access_token_ids: Vec<Uuid> = Vec::with_capacity(batch.len());
677 let mut refresh_tokens: Vec<String> = Vec::with_capacity(batch.len());
678 let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(batch.len());
679
680 for MasNewCompatRefreshToken {
681 refresh_token_id,
682 session_id,
683 access_token_id,
684 refresh_token,
685 created_at,
686 } in batch
687 {
688 refresh_token_ids.push(refresh_token_id);
689 session_ids.push(session_id);
690 access_token_ids.push(access_token_id);
691 refresh_tokens.push(refresh_token);
692 created_ats.push(created_at);
693 }
694
695 sqlx::query!(
696 r#"
697 INSERT INTO syn2mas__compat_refresh_tokens (
698 compat_refresh_token_id,
699 compat_session_id,
700 compat_access_token_id,
701 refresh_token,
702 created_at)
703 SELECT * FROM UNNEST(
704 $1::UUID[],
705 $2::UUID[],
706 $3::UUID[],
707 $4::TEXT[],
708 $5::TIMESTAMP WITH TIME ZONE[])
709 "#,
710 &refresh_token_ids[..],
711 &session_ids[..],
712 &access_token_ids[..],
713 &refresh_tokens[..],
714 &created_ats[..],
715 )
716 .execute(&mut *conn)
717 .await
718 .into_database("writing compat refresh tokens to MAS")?;
719
720 Ok(())
721 }
722}
723
724pub const MIGRATED_PASSWORD_VERSION: u16 = 1;
729
730pub const MAS_TABLES_AFFECTED_BY_MIGRATION: &[&str] = &[
732 "users",
733 "user_passwords",
734 "user_emails",
735 "user_unsupported_third_party_ids",
736 "upstream_oauth_links",
737 "compat_sessions",
738 "compat_access_tokens",
739 "compat_refresh_tokens",
740];
741
742pub async fn is_syn2mas_in_progress(conn: &mut PgConnection) -> Result<bool, Error> {
756 let restore_table_names = vec![
759 "syn2mas_restore_constraints".to_owned(),
760 "syn2mas_restore_indices".to_owned(),
761 ];
762
763 let num_resumption_tables = query!(
764 r#"
765 SELECT 1 AS _dummy FROM pg_tables WHERE schemaname = current_schema
766 AND tablename = ANY($1)
767 "#,
768 &restore_table_names,
769 )
770 .fetch_all(conn.as_mut())
771 .await
772 .into_database("failed to query count of resumption tables")?
773 .len();
774
775 if num_resumption_tables == 0 {
776 Ok(false)
777 } else if num_resumption_tables == restore_table_names.len() {
778 Ok(true)
779 } else {
780 Err(Error::inconsistent(
781 "some, but not all, syn2mas resumption tables were found",
782 ))
783 }
784}
785
786impl MasWriter {
787 #[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)]
795 pub async fn new(
796 mut conn: LockedMasDatabase,
797 mut writer_connections: Vec<PgConnection>,
798 dry_run: bool,
799 ) -> Result<Self, Error> {
800 query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
803 .execute(conn.as_mut())
804 .await
805 .into_database("begin MAS transaction")?;
806
807 let syn2mas_started = is_syn2mas_in_progress(conn.as_mut()).await?;
808
809 let indices_to_restore;
810 let constraints_to_restore;
811
812 if syn2mas_started {
813 warn!("Partial syn2mas migration has already been done; resetting.");
816 for table in MAS_TABLES_AFFECTED_BY_MIGRATION {
817 query(&format!("TRUNCATE syn2mas__{table};"))
818 .execute(conn.as_mut())
819 .await
820 .into_database_with(|| format!("failed to truncate table syn2mas__{table}"))?;
821 }
822
823 indices_to_restore = query_as!(
824 IndexDescription,
825 "SELECT table_name, name, definition FROM syn2mas_restore_indices ORDER BY order_key"
826 )
827 .fetch_all(conn.as_mut())
828 .await
829 .into_database("failed to get syn2mas restore data (index descriptions)")?;
830 constraints_to_restore = query_as!(
831 ConstraintDescription,
832 "SELECT table_name, name, definition FROM syn2mas_restore_constraints ORDER BY order_key"
833 )
834 .fetch_all(conn.as_mut())
835 .await
836 .into_database("failed to get syn2mas restore data (constraint descriptions)")?;
837 } else {
838 info!("Starting new syn2mas migration");
839
840 conn.as_mut()
841 .execute_many(include_str!("syn2mas_temporary_tables.sql"))
842 .try_collect::<Vec<_>>()
844 .await
845 .into_database("could not create temporary tables")?;
846
847 (indices_to_restore, constraints_to_restore) =
850 Self::pause_indices(conn.as_mut()).await?;
851
852 for IndexDescription {
854 name,
855 table_name,
856 definition,
857 } in &indices_to_restore
858 {
859 query!(
860 r#"
861 INSERT INTO syn2mas_restore_indices (name, table_name, definition)
862 VALUES ($1, $2, $3)
863 "#,
864 name,
865 table_name,
866 definition
867 )
868 .execute(conn.as_mut())
869 .await
870 .into_database("failed to save restore data (index)")?;
871 }
872 for ConstraintDescription {
873 name,
874 table_name,
875 definition,
876 } in &constraints_to_restore
877 {
878 query!(
879 r#"
880 INSERT INTO syn2mas_restore_constraints (name, table_name, definition)
881 VALUES ($1, $2, $3)
882 "#,
883 name,
884 table_name,
885 definition
886 )
887 .execute(conn.as_mut())
888 .await
889 .into_database("failed to save restore data (index)")?;
890 }
891 }
892
893 query("COMMIT;")
894 .execute(conn.as_mut())
895 .await
896 .into_database("begin MAS transaction")?;
897
898 for writer_connection in &mut writer_connections {
900 query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
901 .execute(&mut *writer_connection)
902 .await
903 .into_database("begin MAS writer transaction")?;
904 }
905
906 Ok(Self {
907 conn,
908 dry_run,
909 writer_pool: WriterConnectionPool::new(writer_connections),
910 indices_to_restore,
911 constraints_to_restore,
912 write_buffer_finish_checker: FinishChecker::default(),
913 })
914 }
915
916 #[tracing::instrument(skip_all)]
917 async fn pause_indices(
918 conn: &mut PgConnection,
919 ) -> Result<(Vec<IndexDescription>, Vec<ConstraintDescription>), Error> {
920 let mut indices_to_restore = Vec::new();
921 let mut constraints_to_restore = Vec::new();
922
923 for &unprefixed_table in MAS_TABLES_AFFECTED_BY_MIGRATION {
924 let table = format!("syn2mas__{unprefixed_table}");
925 for constraint in
927 constraint_pausing::describe_foreign_key_constraints_to_table(&mut *conn, &table)
928 .await?
929 {
930 constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
931 constraints_to_restore.push(constraint);
932 }
933 for constraint in
936 constraint_pausing::describe_constraints_on_table(&mut *conn, &table).await?
937 {
938 constraint_pausing::drop_constraint(&mut *conn, &constraint).await?;
939 constraints_to_restore.push(constraint);
940 }
941 for index in constraint_pausing::describe_indices_on_table(&mut *conn, &table).await? {
943 constraint_pausing::drop_index(&mut *conn, &index).await?;
944 indices_to_restore.push(index);
945 }
946 }
947
948 Ok((indices_to_restore, constraints_to_restore))
949 }
950
951 async fn restore_indices(
952 conn: &mut LockedMasDatabase,
953 indices_to_restore: &[IndexDescription],
954 constraints_to_restore: &[ConstraintDescription],
955 progress: &Progress,
956 ) -> Result<(), Error> {
957 for index in indices_to_restore.iter().rev() {
960 progress.rebuild_index(index.name.clone());
961 constraint_pausing::restore_index(conn.as_mut(), index).await?;
962 }
963 for constraint in constraints_to_restore.iter().rev() {
967 progress.rebuild_constraint(constraint.name.clone());
968 constraint_pausing::restore_constraint(conn.as_mut(), constraint).await?;
969 }
970 Ok(())
971 }
972
973 #[tracing::instrument(skip_all)]
982 pub async fn finish(mut self, progress: &Progress) -> Result<PgConnection, Error> {
983 self.write_buffer_finish_checker.check_all_finished()?;
984
985 self.writer_pool
987 .finish()
988 .await
989 .map_err(|errors| Error::Multiple(MultipleErrors::from(errors)))?;
990
991 query("BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED;")
994 .execute(self.conn.as_mut())
995 .await
996 .into_database("begin MAS transaction")?;
997
998 Self::restore_indices(
999 &mut self.conn,
1000 &self.indices_to_restore,
1001 &self.constraints_to_restore,
1002 progress,
1003 )
1004 .await?;
1005
1006 self.conn
1007 .as_mut()
1008 .execute_many(include_str!("syn2mas_revert_temporary_tables.sql"))
1009 .try_collect::<Vec<_>>()
1011 .await
1012 .into_database("could not revert temporary tables")?;
1013
1014 if self.dry_run {
1016 warn!("Migration ran in dry-run mode, deleting all imported data");
1017 let tables = MAS_TABLES_AFFECTED_BY_MIGRATION
1018 .iter()
1019 .map(|table| format!("\"{table}\""))
1020 .collect::<Vec<_>>()
1021 .join(", ");
1022
1023 query(&format!("TRUNCATE TABLE {tables} CASCADE;"))
1031 .execute(self.conn.as_mut())
1032 .await
1033 .into_database_with(|| "failed to truncate all tables")?;
1034 }
1035
1036 query("COMMIT;")
1037 .execute(self.conn.as_mut())
1038 .await
1039 .into_database("ending MAS transaction")?;
1040
1041 let conn = self
1042 .conn
1043 .unlock()
1044 .await
1045 .into_database("could not unlock MAS database")?;
1046
1047 Ok(conn)
1048 }
1049}
1050
1051const WRITE_BUFFER_BATCH_SIZE: usize = 4096;
1054
1055pub struct MasWriteBuffer<T> {
1058 rows: Vec<T>,
1059 finish_checker_handle: FinishCheckerHandle,
1060}
1061
1062impl<T> MasWriteBuffer<T>
1063where
1064 T: WriteBatch,
1065{
1066 pub fn new(writer: &MasWriter) -> Self {
1067 MasWriteBuffer {
1068 rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
1069 finish_checker_handle: writer.write_buffer_finish_checker.handle(),
1070 }
1071 }
1072
1073 pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
1074 self.flush(writer).await?;
1075 self.finish_checker_handle.declare_finished();
1076 Ok(())
1077 }
1078
1079 pub async fn flush(&mut self, writer: &mut MasWriter) -> Result<(), Error> {
1080 if self.rows.is_empty() {
1081 return Ok(());
1082 }
1083 let rows = std::mem::take(&mut self.rows);
1084 self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE);
1085 writer
1086 .writer_pool
1087 .spawn_with_connection(move |conn| T::write_batch(conn, rows).boxed())
1088 .boxed()
1089 .await?;
1090 Ok(())
1091 }
1092
1093 pub async fn write(&mut self, writer: &mut MasWriter, row: T) -> Result<(), Error> {
1094 self.rows.push(row);
1095 if self.rows.len() >= WRITE_BUFFER_BATCH_SIZE {
1096 self.flush(writer).await?;
1097 }
1098 Ok(())
1099 }
1100}
1101
1102#[cfg(test)]
1103mod test {
1104 use std::collections::{BTreeMap, BTreeSet};
1105
1106 use chrono::DateTime;
1107 use futures_util::TryStreamExt;
1108 use serde::Serialize;
1109 use sqlx::{Column, PgConnection, PgPool, Row};
1110 use uuid::{NonNilUuid, Uuid};
1111
1112 use crate::{
1113 LockedMasDatabase, MasWriter, Progress,
1114 mas_writer::{
1115 MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
1116 MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
1117 MasNewUserPassword, MasWriteBuffer,
1118 },
1119 };
1120
1121 #[derive(Default, Serialize)]
1123 #[serde(transparent)]
1124 struct DatabaseSnapshot {
1125 tables: BTreeMap<String, TableSnapshot>,
1126 }
1127
1128 #[derive(Serialize)]
1129 #[serde(transparent)]
1130 struct TableSnapshot {
1131 rows: BTreeSet<RowSnapshot>,
1132 }
1133
1134 #[derive(PartialEq, Eq, PartialOrd, Ord, Serialize)]
1135 #[serde(transparent)]
1136 struct RowSnapshot {
1137 columns_to_values: BTreeMap<String, Option<String>>,
1138 }
1139
1140 const SKIPPED_TABLES: &[&str] = &["_sqlx_migrations"];
1141
1142 async fn snapshot_database(conn: &mut PgConnection) -> DatabaseSnapshot {
1148 let mut out = DatabaseSnapshot::default();
1149 let table_names: Vec<String> = sqlx::query_scalar(
1150 "SELECT table_name FROM information_schema.tables WHERE table_schema = current_schema();",
1151 )
1152 .fetch_all(&mut *conn)
1153 .await
1154 .unwrap();
1155
1156 for table_name in table_names {
1157 if SKIPPED_TABLES.contains(&table_name.as_str()) {
1158 continue;
1159 }
1160
1161 let column_names: Vec<String> = sqlx::query_scalar(
1162 "SELECT column_name FROM information_schema.columns WHERE table_name = $1 AND table_schema = current_schema();"
1163 ).bind(&table_name).fetch_all(&mut *conn).await.expect("failed to get column names for table for snapshotting");
1164
1165 let column_name_list = column_names
1166 .iter()
1167 .map(|column_name| format!("{column_name}::TEXT AS \"{column_name}\""))
1169 .collect::<Vec<_>>()
1170 .join(", ");
1171
1172 let table_rows = sqlx::query(&format!("SELECT {column_name_list} FROM {table_name};"))
1173 .fetch(&mut *conn)
1174 .map_ok(|row| {
1175 let mut columns_to_values = BTreeMap::new();
1176 for (idx, column) in row.columns().iter().enumerate() {
1177 columns_to_values.insert(column.name().to_owned(), row.get(idx));
1178 }
1179 RowSnapshot { columns_to_values }
1180 })
1181 .try_collect::<BTreeSet<RowSnapshot>>()
1182 .await
1183 .expect("failed to fetch rows from table for snapshotting");
1184
1185 if !table_rows.is_empty() {
1186 out.tables
1187 .insert(table_name, TableSnapshot { rows: table_rows });
1188 }
1189 }
1190
1191 out
1192 }
1193
1194 macro_rules! assert_db_snapshot {
1196 ($db: expr) => {
1197 let db_snapshot = snapshot_database($db).await;
1198 ::insta::assert_yaml_snapshot!(db_snapshot);
1199 };
1200 }
1201
1202 async fn make_mas_writer(pool: &PgPool) -> MasWriter {
1206 let main_conn = pool.acquire().await.unwrap().detach();
1207 let mut writer_conns = Vec::new();
1208 for _ in 0..2 {
1209 writer_conns.push(
1210 pool.acquire()
1211 .await
1212 .expect("failed to acquire MasWriter writer connection")
1213 .detach(),
1214 );
1215 }
1216 let locked_main_conn = LockedMasDatabase::try_new(main_conn)
1217 .await
1218 .expect("failed to lock MAS database")
1219 .expect_left("MAS database is already locked");
1220 MasWriter::new(locked_main_conn, writer_conns, false)
1221 .await
1222 .expect("failed to construct MasWriter")
1223 }
1224
1225 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1227 async fn test_write_user(pool: PgPool) {
1228 let mut writer = make_mas_writer(&pool).await;
1229 let mut buffer = MasWriteBuffer::new(&writer);
1230
1231 buffer
1232 .write(
1233 &mut writer,
1234 MasNewUser {
1235 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1236 username: "alice".to_owned(),
1237 created_at: DateTime::default(),
1238 locked_at: None,
1239 deactivated_at: None,
1240 can_request_admin: false,
1241 is_guest: false,
1242 },
1243 )
1244 .await
1245 .expect("failed to write user");
1246
1247 buffer
1248 .finish(&mut writer)
1249 .await
1250 .expect("failed to finish MasWriter");
1251
1252 let mut conn = writer
1253 .finish(&Progress::default())
1254 .await
1255 .expect("failed to finish MasWriter");
1256
1257 assert_db_snapshot!(&mut conn);
1258 }
1259
1260 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1262 async fn test_write_user_with_password(pool: PgPool) {
1263 const USER_ID: NonNilUuid = NonNilUuid::new(Uuid::from_u128(1u128)).unwrap();
1264
1265 let mut writer = make_mas_writer(&pool).await;
1266
1267 let mut user_buffer = MasWriteBuffer::new(&writer);
1268 let mut password_buffer = MasWriteBuffer::new(&writer);
1269
1270 user_buffer
1271 .write(
1272 &mut writer,
1273 MasNewUser {
1274 user_id: USER_ID,
1275 username: "alice".to_owned(),
1276 created_at: DateTime::default(),
1277 locked_at: None,
1278 deactivated_at: None,
1279 can_request_admin: false,
1280 is_guest: false,
1281 },
1282 )
1283 .await
1284 .expect("failed to write user");
1285
1286 password_buffer
1287 .write(
1288 &mut writer,
1289 MasNewUserPassword {
1290 user_password_id: Uuid::from_u128(42u128),
1291 user_id: USER_ID,
1292 hashed_password: "$bcrypt$aaaaaaaaaaa".to_owned(),
1293 created_at: DateTime::default(),
1294 },
1295 )
1296 .await
1297 .expect("failed to write password");
1298
1299 user_buffer
1300 .finish(&mut writer)
1301 .await
1302 .expect("failed to finish MasWriteBuffer");
1303 password_buffer
1304 .finish(&mut writer)
1305 .await
1306 .expect("failed to finish MasWriteBuffer");
1307
1308 let mut conn = writer
1309 .finish(&Progress::default())
1310 .await
1311 .expect("failed to finish MasWriter");
1312
1313 assert_db_snapshot!(&mut conn);
1314 }
1315
1316 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1318 async fn test_write_user_with_email(pool: PgPool) {
1319 let mut writer = make_mas_writer(&pool).await;
1320
1321 let mut user_buffer = MasWriteBuffer::new(&writer);
1322 let mut email_buffer = MasWriteBuffer::new(&writer);
1323
1324 user_buffer
1325 .write(
1326 &mut writer,
1327 MasNewUser {
1328 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1329 username: "alice".to_owned(),
1330 created_at: DateTime::default(),
1331 locked_at: None,
1332 deactivated_at: None,
1333 can_request_admin: false,
1334 is_guest: false,
1335 },
1336 )
1337 .await
1338 .expect("failed to write user");
1339
1340 email_buffer
1341 .write(
1342 &mut writer,
1343 MasNewEmailThreepid {
1344 user_email_id: Uuid::from_u128(2u128),
1345 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1346 email: "alice@example.org".to_owned(),
1347 created_at: DateTime::default(),
1348 },
1349 )
1350 .await
1351 .expect("failed to write e-mail");
1352
1353 user_buffer
1354 .finish(&mut writer)
1355 .await
1356 .expect("failed to finish user buffer");
1357 email_buffer
1358 .finish(&mut writer)
1359 .await
1360 .expect("failed to finish email buffer");
1361
1362 let mut conn = writer
1363 .finish(&Progress::default())
1364 .await
1365 .expect("failed to finish MasWriter");
1366
1367 assert_db_snapshot!(&mut conn);
1368 }
1369
1370 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1373 async fn test_write_user_with_unsupported_threepid(pool: PgPool) {
1374 let mut writer = make_mas_writer(&pool).await;
1375
1376 let mut user_buffer = MasWriteBuffer::new(&writer);
1377 let mut threepid_buffer = MasWriteBuffer::new(&writer);
1378
1379 user_buffer
1380 .write(
1381 &mut writer,
1382 MasNewUser {
1383 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1384 username: "alice".to_owned(),
1385 created_at: DateTime::default(),
1386 locked_at: None,
1387 deactivated_at: None,
1388 can_request_admin: false,
1389 is_guest: false,
1390 },
1391 )
1392 .await
1393 .expect("failed to write user");
1394
1395 threepid_buffer
1396 .write(
1397 &mut writer,
1398 MasNewUnsupportedThreepid {
1399 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1400 medium: "msisdn".to_owned(),
1401 address: "441189998819991197253".to_owned(),
1402 created_at: DateTime::default(),
1403 },
1404 )
1405 .await
1406 .expect("failed to write phone number (unsupported threepid)");
1407
1408 user_buffer
1409 .finish(&mut writer)
1410 .await
1411 .expect("failed to finish user buffer");
1412 threepid_buffer
1413 .finish(&mut writer)
1414 .await
1415 .expect("failed to finish threepid buffer");
1416
1417 let mut conn = writer
1418 .finish(&Progress::default())
1419 .await
1420 .expect("failed to finish MasWriter");
1421
1422 assert_db_snapshot!(&mut conn);
1423 }
1424
1425 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR", fixtures("upstream_provider"))]
1429 async fn test_write_user_with_upstream_provider_link(pool: PgPool) {
1430 let mut writer = make_mas_writer(&pool).await;
1431
1432 let mut user_buffer = MasWriteBuffer::new(&writer);
1433 let mut link_buffer = MasWriteBuffer::new(&writer);
1434
1435 user_buffer
1436 .write(
1437 &mut writer,
1438 MasNewUser {
1439 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1440 username: "alice".to_owned(),
1441 created_at: DateTime::default(),
1442 locked_at: None,
1443 deactivated_at: None,
1444 can_request_admin: false,
1445 is_guest: false,
1446 },
1447 )
1448 .await
1449 .expect("failed to write user");
1450
1451 link_buffer
1452 .write(
1453 &mut writer,
1454 MasNewUpstreamOauthLink {
1455 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1456 link_id: Uuid::from_u128(3u128),
1457 upstream_provider_id: Uuid::from_u128(4u128),
1458 subject: "12345.67890".to_owned(),
1459 created_at: DateTime::default(),
1460 },
1461 )
1462 .await
1463 .expect("failed to write link");
1464
1465 user_buffer
1466 .finish(&mut writer)
1467 .await
1468 .expect("failed to finish user buffer");
1469 link_buffer
1470 .finish(&mut writer)
1471 .await
1472 .expect("failed to finish link buffer");
1473
1474 let mut conn = writer
1475 .finish(&Progress::default())
1476 .await
1477 .expect("failed to finish MasWriter");
1478
1479 assert_db_snapshot!(&mut conn);
1480 }
1481
1482 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1484 async fn test_write_user_with_device(pool: PgPool) {
1485 let mut writer = make_mas_writer(&pool).await;
1486
1487 let mut user_buffer = MasWriteBuffer::new(&writer);
1488 let mut session_buffer = MasWriteBuffer::new(&writer);
1489
1490 user_buffer
1491 .write(
1492 &mut writer,
1493 MasNewUser {
1494 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1495 username: "alice".to_owned(),
1496 created_at: DateTime::default(),
1497 locked_at: None,
1498 deactivated_at: None,
1499 can_request_admin: false,
1500 is_guest: false,
1501 },
1502 )
1503 .await
1504 .expect("failed to write user");
1505
1506 session_buffer
1507 .write(
1508 &mut writer,
1509 MasNewCompatSession {
1510 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1511 session_id: Uuid::from_u128(5u128),
1512 created_at: DateTime::default(),
1513 device_id: Some("ADEVICE".to_owned()),
1514 human_name: Some("alice's pinephone".to_owned()),
1515 is_synapse_admin: true,
1516 last_active_at: Some(DateTime::default()),
1517 last_active_ip: Some("203.0.113.1".parse().unwrap()),
1518 user_agent: Some("Browser/5.0".to_owned()),
1519 },
1520 )
1521 .await
1522 .expect("failed to write compat session");
1523
1524 user_buffer
1525 .finish(&mut writer)
1526 .await
1527 .expect("failed to finish user buffer");
1528 session_buffer
1529 .finish(&mut writer)
1530 .await
1531 .expect("failed to finish session buffer");
1532
1533 let mut conn = writer
1534 .finish(&Progress::default())
1535 .await
1536 .expect("failed to finish MasWriter");
1537
1538 assert_db_snapshot!(&mut conn);
1539 }
1540
1541 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1543 async fn test_write_user_with_access_token(pool: PgPool) {
1544 let mut writer = make_mas_writer(&pool).await;
1545
1546 let mut user_buffer = MasWriteBuffer::new(&writer);
1547 let mut session_buffer = MasWriteBuffer::new(&writer);
1548 let mut token_buffer = MasWriteBuffer::new(&writer);
1549
1550 user_buffer
1551 .write(
1552 &mut writer,
1553 MasNewUser {
1554 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1555 username: "alice".to_owned(),
1556 created_at: DateTime::default(),
1557 locked_at: None,
1558 deactivated_at: None,
1559 can_request_admin: false,
1560 is_guest: false,
1561 },
1562 )
1563 .await
1564 .expect("failed to write user");
1565
1566 session_buffer
1567 .write(
1568 &mut writer,
1569 MasNewCompatSession {
1570 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1571 session_id: Uuid::from_u128(5u128),
1572 created_at: DateTime::default(),
1573 device_id: Some("ADEVICE".to_owned()),
1574 human_name: None,
1575 is_synapse_admin: false,
1576 last_active_at: None,
1577 last_active_ip: None,
1578 user_agent: None,
1579 },
1580 )
1581 .await
1582 .expect("failed to write compat session");
1583
1584 token_buffer
1585 .write(
1586 &mut writer,
1587 MasNewCompatAccessToken {
1588 token_id: Uuid::from_u128(6u128),
1589 session_id: Uuid::from_u128(5u128),
1590 access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1591 created_at: DateTime::default(),
1592 expires_at: None,
1593 },
1594 )
1595 .await
1596 .expect("failed to write access token");
1597
1598 user_buffer
1599 .finish(&mut writer)
1600 .await
1601 .expect("failed to finish user buffer");
1602 session_buffer
1603 .finish(&mut writer)
1604 .await
1605 .expect("failed to finish session buffer");
1606 token_buffer
1607 .finish(&mut writer)
1608 .await
1609 .expect("failed to finish token buffer");
1610
1611 let mut conn = writer
1612 .finish(&Progress::default())
1613 .await
1614 .expect("failed to finish MasWriter");
1615
1616 assert_db_snapshot!(&mut conn);
1617 }
1618
1619 #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
1622 async fn test_write_user_with_refresh_token(pool: PgPool) {
1623 let mut writer = make_mas_writer(&pool).await;
1624
1625 let mut user_buffer = MasWriteBuffer::new(&writer);
1626 let mut session_buffer = MasWriteBuffer::new(&writer);
1627 let mut token_buffer = MasWriteBuffer::new(&writer);
1628 let mut refresh_token_buffer = MasWriteBuffer::new(&writer);
1629
1630 user_buffer
1631 .write(
1632 &mut writer,
1633 MasNewUser {
1634 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1635 username: "alice".to_owned(),
1636 created_at: DateTime::default(),
1637 locked_at: None,
1638 deactivated_at: None,
1639 can_request_admin: false,
1640 is_guest: false,
1641 },
1642 )
1643 .await
1644 .expect("failed to write user");
1645
1646 session_buffer
1647 .write(
1648 &mut writer,
1649 MasNewCompatSession {
1650 user_id: NonNilUuid::new(Uuid::from_u128(1u128)).unwrap(),
1651 session_id: Uuid::from_u128(5u128),
1652 created_at: DateTime::default(),
1653 device_id: Some("ADEVICE".to_owned()),
1654 human_name: None,
1655 is_synapse_admin: false,
1656 last_active_at: None,
1657 last_active_ip: None,
1658 user_agent: None,
1659 },
1660 )
1661 .await
1662 .expect("failed to write compat session");
1663
1664 token_buffer
1665 .write(
1666 &mut writer,
1667 MasNewCompatAccessToken {
1668 token_id: Uuid::from_u128(6u128),
1669 session_id: Uuid::from_u128(5u128),
1670 access_token: "syt_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1671 created_at: DateTime::default(),
1672 expires_at: None,
1673 },
1674 )
1675 .await
1676 .expect("failed to write access token");
1677
1678 refresh_token_buffer
1679 .write(
1680 &mut writer,
1681 MasNewCompatRefreshToken {
1682 refresh_token_id: Uuid::from_u128(7u128),
1683 session_id: Uuid::from_u128(5u128),
1684 access_token_id: Uuid::from_u128(6u128),
1685 refresh_token: "syr_zxcvzxcvzxcvzxcv_zxcv".to_owned(),
1686 created_at: DateTime::default(),
1687 },
1688 )
1689 .await
1690 .expect("failed to write refresh token");
1691
1692 user_buffer
1693 .finish(&mut writer)
1694 .await
1695 .expect("failed to finish user buffer");
1696 session_buffer
1697 .finish(&mut writer)
1698 .await
1699 .expect("failed to finish session buffer");
1700 token_buffer
1701 .finish(&mut writer)
1702 .await
1703 .expect("failed to finish token buffer");
1704 refresh_token_buffer
1705 .finish(&mut writer)
1706 .await
1707 .expect("failed to finish refresh token buffer");
1708
1709 let mut conn = writer
1710 .finish(&Progress::default())
1711 .await
1712 .expect("failed to finish MasWriter");
1713
1714 assert_db_snapshot!(&mut conn);
1715 }
1716}