1use async_trait::async_trait;
12use mas_data_model::{Clock, UlidExt as _, User};
13use mas_storage::user::{UserFilter, UserRepository};
14use rand::RngCore;
15use sea_query::{
16 Expr, ExprTrait, PostgresQueryBuilder, Query, SimpleExpr, extension::postgres::PgExpr as _,
17};
18use sea_query_sqlx::SqlxBinder;
19use sqlx::PgConnection;
20use ulid::Ulid;
21use uuid::Uuid;
22
23use crate::{
24 DatabaseError,
25 filter::{Filter, StatementExt},
26 iden::{CompatSessions, OAuth2Sessions, Users},
27 pagination::QueryBuilderExt,
28 tracing::ExecuteExt,
29};
30
31mod email;
32mod password;
33mod recovery;
34mod registration;
35mod registration_token;
36mod session;
37mod terms;
38
39#[cfg(test)]
40mod tests;
41
42pub use self::{
43 email::PgUserEmailRepository, password::PgUserPasswordRepository,
44 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
45 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
46 terms::PgUserTermsRepository,
47};
48
49pub struct PgUserRepository<'c> {
51 conn: &'c mut PgConnection,
52}
53
54impl<'c> PgUserRepository<'c> {
55 pub fn new(conn: &'c mut PgConnection) -> Self {
57 Self { conn }
58 }
59}
60
61mod priv_ {
62 #![allow(missing_docs)]
65
66 use chrono::{DateTime, Utc};
67 use mas_storage::pagination::Node;
68 use sea_query::enum_def;
69 use ulid::Ulid;
70 use uuid::Uuid;
71
72 #[derive(Debug, Clone, sqlx::FromRow)]
73 #[enum_def]
74 pub(super) struct UserLookup {
75 pub(super) user_id: Uuid,
76 pub(super) username: String,
77 pub(super) created_at: DateTime<Utc>,
78 pub(super) locked_at: Option<DateTime<Utc>>,
79 pub(super) deactivated_at: Option<DateTime<Utc>>,
80 pub(super) can_request_admin: bool,
81 pub(super) is_guest: bool,
82 }
83
84 impl Node<Ulid> for UserLookup {
85 fn cursor(&self) -> Ulid {
86 self.user_id.into()
87 }
88 }
89}
90
91use priv_::{UserLookup, UserLookupIden};
92
93impl From<UserLookup> for User {
94 fn from(value: UserLookup) -> Self {
95 let id = value.user_id.into();
96 Self {
97 id,
98 username: value.username,
99 sub: id.to_string(),
100 created_at: value.created_at,
101 locked_at: value.locked_at,
102 deactivated_at: value.deactivated_at,
103 can_request_admin: value.can_request_admin,
104 is_guest: value.is_guest,
105 }
106 }
107}
108
109impl Filter for UserFilter<'_> {
110 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
111 sea_query::Condition::all()
112 .add_option(self.state().map(|state| {
113 match state {
114 mas_storage::user::UserState::Deactivated => {
115 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
116 }
117 mas_storage::user::UserState::Locked => {
118 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
119 }
120 mas_storage::user::UserState::Active => {
121 Expr::col((Users::Table, Users::LockedAt))
122 .is_null()
123 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
124 }
125 }
126 }))
127 .add_option(self.can_request_admin().map(|can_request_admin| {
128 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
129 }))
130 .add_option(
131 self.is_guest()
132 .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
133 )
134 .add_option(self.search().map(|search| {
135 Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
136 }))
137 .add_option(self.active_oauth2_session_for_any_of_clients().map(
138 |clients| -> SimpleExpr {
139 let client_ids: Vec<SimpleExpr> =
140 clients.iter().map(|c| Expr::val(Uuid::from(*c))).collect();
141 Expr::exists(
142 Query::select()
143 .expr(Expr::cust("1"))
144 .from(OAuth2Sessions::Table)
145 .and_where(
146 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId))
147 .equals((Users::Table, Users::UserId)),
148 )
149 .and_where(
150 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
151 .is_null(),
152 )
153 .and_where(
154 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
155 .is_in(client_ids),
156 )
157 .take(),
158 )
159 },
160 ))
161 .add_option(self.has_active_oauth2_session().map(|has| -> SimpleExpr {
162 let exists = Expr::exists(
163 Query::select()
164 .expr(Expr::cust("1"))
165 .from(OAuth2Sessions::Table)
166 .and_where(
167 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserId))
168 .equals((Users::Table, Users::UserId)),
169 )
170 .and_where(
171 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
172 .is_null(),
173 )
174 .take(),
175 );
176 if has { exists } else { exists.not() }
177 }))
178 .add_option(self.has_active_compat_session().map(|has| -> SimpleExpr {
179 let exists = Expr::exists(
180 Query::select()
181 .expr(Expr::cust("1"))
182 .from(CompatSessions::Table)
183 .and_where(
184 Expr::col((CompatSessions::Table, CompatSessions::UserId))
185 .equals((Users::Table, Users::UserId)),
186 )
187 .and_where(
188 Expr::col((CompatSessions::Table, CompatSessions::FinishedAt))
189 .is_null(),
190 )
191 .take(),
192 );
193 if has { exists } else { exists.not() }
194 }))
195 }
196}
197
198#[async_trait]
199impl UserRepository for PgUserRepository<'_> {
200 type Error = DatabaseError;
201
202 #[tracing::instrument(
203 name = "db.user.lookup",
204 skip_all,
205 fields(
206 db.query.text,
207 user.id = %id,
208 ),
209 err,
210 )]
211 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
212 let res = sqlx::query_as!(
213 UserLookup,
214 r#"
215 SELECT user_id
216 , username
217 , created_at
218 , locked_at
219 , deactivated_at
220 , can_request_admin
221 , is_guest
222 FROM users
223 WHERE user_id = $1
224 "#,
225 Uuid::from(id),
226 )
227 .traced()
228 .fetch_optional(&mut *self.conn)
229 .await?;
230
231 let Some(res) = res else { return Ok(None) };
232
233 Ok(Some(res.into()))
234 }
235
236 #[tracing::instrument(
237 name = "db.user.find_by_username",
238 skip_all,
239 fields(
240 db.query.text,
241 user.username = username,
242 ),
243 err,
244 )]
245 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
246 let res = sqlx::query_as!(
250 UserLookup,
251 r#"
252 SELECT user_id
253 , username
254 , created_at
255 , locked_at
256 , deactivated_at
257 , can_request_admin
258 , is_guest
259 FROM users
260 WHERE LOWER(username) = LOWER($1)
261 "#,
262 username,
263 )
264 .traced()
265 .fetch_all(&mut *self.conn)
266 .await?;
267
268 match &res[..] {
269 [user] => Ok(Some(user.clone().into())),
271 [] => Ok(None),
273 list => {
274 if let Some(user) = list.iter().find(|user| user.username == username) {
277 Ok(Some(user.clone().into()))
278 } else {
279 Ok(None)
281 }
282 }
283 }
284 }
285
286 #[tracing::instrument(
287 name = "db.user.add",
288 skip_all,
289 fields(
290 db.query.text,
291 user.username = username,
292 user.id,
293 ),
294 err,
295 )]
296 async fn add(
297 &mut self,
298 rng: &mut (dyn RngCore + Send),
299 clock: &dyn Clock,
300 username: String,
301 ) -> Result<User, Self::Error> {
302 let created_at = clock.now();
303 let id = Ulid::from_datetime_with_rng(created_at, rng);
304 tracing::Span::current().record("user.id", tracing::field::display(id));
305
306 let res = sqlx::query!(
307 r#"
308 INSERT INTO users (user_id, username, created_at)
309 VALUES ($1, $2, $3)
310 ON CONFLICT (username) DO NOTHING
311 "#,
312 Uuid::from(id),
313 username,
314 created_at,
315 )
316 .traced()
317 .execute(&mut *self.conn)
318 .await?;
319
320 DatabaseError::ensure_affected_rows(&res, 1)?;
323
324 Ok(User {
325 id,
326 username,
327 sub: id.to_string(),
328 created_at,
329 locked_at: None,
330 deactivated_at: None,
331 can_request_admin: false,
332 is_guest: false,
333 })
334 }
335
336 #[tracing::instrument(
337 name = "db.user.exists",
338 skip_all,
339 fields(
340 db.query.text,
341 user.username = username,
342 ),
343 err,
344 )]
345 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
346 let exists = sqlx::query_scalar!(
347 r#"
348 SELECT EXISTS(
349 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
350 ) AS "exists!"
351 "#,
352 username
353 )
354 .traced()
355 .fetch_one(&mut *self.conn)
356 .await?;
357
358 Ok(exists)
359 }
360
361 #[tracing::instrument(
362 name = "db.user.lock",
363 skip_all,
364 fields(
365 db.query.text,
366 %user.id,
367 ),
368 err,
369 )]
370 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
371 if user.locked_at.is_some() {
372 return Ok(user);
373 }
374
375 let locked_at = clock.now();
376 let res = sqlx::query!(
377 r#"
378 UPDATE users
379 SET locked_at = $1
380 WHERE user_id = $2
381 "#,
382 locked_at,
383 Uuid::from(user.id),
384 )
385 .traced()
386 .execute(&mut *self.conn)
387 .await?;
388
389 DatabaseError::ensure_affected_rows(&res, 1)?;
390
391 user.locked_at = Some(locked_at);
392
393 Ok(user)
394 }
395
396 #[tracing::instrument(
397 name = "db.user.unlock",
398 skip_all,
399 fields(
400 db.query.text,
401 %user.id,
402 ),
403 err,
404 )]
405 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
406 if user.locked_at.is_none() {
407 return Ok(user);
408 }
409
410 let res = sqlx::query!(
411 r#"
412 UPDATE users
413 SET locked_at = NULL
414 WHERE user_id = $1
415 "#,
416 Uuid::from(user.id),
417 )
418 .traced()
419 .execute(&mut *self.conn)
420 .await?;
421
422 DatabaseError::ensure_affected_rows(&res, 1)?;
423
424 user.locked_at = None;
425
426 Ok(user)
427 }
428
429 #[tracing::instrument(
430 name = "db.user.deactivate",
431 skip_all,
432 fields(
433 db.query.text,
434 %user.id,
435 ),
436 err,
437 )]
438 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
439 if user.deactivated_at.is_some() {
440 return Ok(user);
441 }
442
443 let deactivated_at = clock.now();
444 let res = sqlx::query!(
445 r#"
446 UPDATE users
447 SET deactivated_at = $2
448 WHERE user_id = $1
449 AND deactivated_at IS NULL
450 "#,
451 Uuid::from(user.id),
452 deactivated_at,
453 )
454 .traced()
455 .execute(&mut *self.conn)
456 .await?;
457
458 DatabaseError::ensure_affected_rows(&res, 1)?;
459
460 user.deactivated_at = Some(deactivated_at);
461
462 Ok(user)
463 }
464
465 #[tracing::instrument(
466 name = "db.user.reactivate",
467 skip_all,
468 fields(
469 db.query.text,
470 %user.id,
471 ),
472 err,
473 )]
474 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
475 if user.deactivated_at.is_none() {
476 return Ok(user);
477 }
478
479 let res = sqlx::query!(
480 r#"
481 UPDATE users
482 SET deactivated_at = NULL
483 WHERE user_id = $1
484 "#,
485 Uuid::from(user.id),
486 )
487 .traced()
488 .execute(&mut *self.conn)
489 .await?;
490
491 DatabaseError::ensure_affected_rows(&res, 1)?;
492
493 user.deactivated_at = None;
494
495 Ok(user)
496 }
497
498 #[tracing::instrument(
499 name = "db.user.delete_unsupported_threepids",
500 skip_all,
501 fields(
502 db.query.text,
503 %user.id,
504 ),
505 err,
506 )]
507 async fn delete_unsupported_threepids(&mut self, user: &User) -> Result<usize, Self::Error> {
508 let res = sqlx::query!(
509 r#"
510 DELETE FROM user_unsupported_third_party_ids
511 WHERE user_id = $1
512 "#,
513 Uuid::from(user.id),
514 )
515 .traced()
516 .execute(&mut *self.conn)
517 .await?;
518
519 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
520 }
521
522 #[tracing::instrument(
523 name = "db.user.set_can_request_admin",
524 skip_all,
525 fields(
526 db.query.text,
527 %user.id,
528 user.can_request_admin = can_request_admin,
529 ),
530 err,
531 )]
532 async fn set_can_request_admin(
533 &mut self,
534 mut user: User,
535 can_request_admin: bool,
536 ) -> Result<User, Self::Error> {
537 let res = sqlx::query!(
538 r#"
539 UPDATE users
540 SET can_request_admin = $2
541 WHERE user_id = $1
542 "#,
543 Uuid::from(user.id),
544 can_request_admin,
545 )
546 .traced()
547 .execute(&mut *self.conn)
548 .await?;
549
550 DatabaseError::ensure_affected_rows(&res, 1)?;
551
552 user.can_request_admin = can_request_admin;
553
554 Ok(user)
555 }
556
557 #[tracing::instrument(
558 name = "db.user.list",
559 skip_all,
560 fields(
561 db.query.text,
562 ),
563 err,
564 )]
565 async fn list(
566 &mut self,
567 filter: UserFilter<'_>,
568 pagination: mas_storage::Pagination,
569 ) -> Result<mas_storage::Page<User>, Self::Error> {
570 let (sql, arguments) = Query::select()
571 .expr_as(
572 Expr::col((Users::Table, Users::UserId)),
573 UserLookupIden::UserId,
574 )
575 .expr_as(
576 Expr::col((Users::Table, Users::Username)),
577 UserLookupIden::Username,
578 )
579 .expr_as(
580 Expr::col((Users::Table, Users::CreatedAt)),
581 UserLookupIden::CreatedAt,
582 )
583 .expr_as(
584 Expr::col((Users::Table, Users::LockedAt)),
585 UserLookupIden::LockedAt,
586 )
587 .expr_as(
588 Expr::col((Users::Table, Users::DeactivatedAt)),
589 UserLookupIden::DeactivatedAt,
590 )
591 .expr_as(
592 Expr::col((Users::Table, Users::CanRequestAdmin)),
593 UserLookupIden::CanRequestAdmin,
594 )
595 .expr_as(
596 Expr::col((Users::Table, Users::IsGuest)),
597 UserLookupIden::IsGuest,
598 )
599 .from(Users::Table)
600 .apply_filter(filter)
601 .generate_pagination((Users::Table, Users::UserId), pagination)
602 .build_sqlx(PostgresQueryBuilder);
603
604 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
605 .traced()
606 .fetch_all(&mut *self.conn)
607 .await?;
608
609 let page = pagination.process(edges).map(User::from);
610
611 Ok(page)
612 }
613
614 #[tracing::instrument(
615 name = "db.user.count",
616 skip_all,
617 fields(
618 db.query.text,
619 ),
620 err,
621 )]
622 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
623 let (sql, arguments) = Query::select()
624 .expr(Expr::col((Users::Table, Users::UserId)).count())
625 .from(Users::Table)
626 .apply_filter(filter)
627 .build_sqlx(PostgresQueryBuilder);
628
629 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
630 .traced()
631 .fetch_one(&mut *self.conn)
632 .await?;
633
634 count
635 .try_into()
636 .map_err(DatabaseError::to_invalid_operation)
637 }
638
639 #[tracing::instrument(
640 name = "db.user.acquire_lock_for_sync",
641 skip_all,
642 fields(
643 db.query.text,
644 user.id = %user.id,
645 ),
646 err,
647 )]
648 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
649 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
657
658 sqlx::query!(
661 r#"
662 SELECT pg_advisory_xact_lock($1)
663 "#,
664 lock_id,
665 )
666 .traced()
667 .execute(&mut *self.conn)
668 .await?;
669
670 Ok(())
671 }
672}