1use async_trait::async_trait;
11use mas_data_model::{Clock, User};
12use mas_storage::user::{UserFilter, UserRepository};
13use rand::RngCore;
14use sea_query::{Expr, PostgresQueryBuilder, Query, extension::postgres::PgExpr as _};
15use sea_query_binder::SqlxBinder;
16use sqlx::PgConnection;
17use ulid::Ulid;
18use uuid::Uuid;
19
20use crate::{
21 DatabaseError,
22 filter::{Filter, StatementExt},
23 iden::Users,
24 pagination::QueryBuilderExt,
25 tracing::ExecuteExt,
26};
27
28mod email;
29mod password;
30mod recovery;
31mod registration;
32mod registration_token;
33mod session;
34mod terms;
35
36#[cfg(test)]
37mod tests;
38
39pub use self::{
40 email::PgUserEmailRepository, password::PgUserPasswordRepository,
41 recovery::PgUserRecoveryRepository, registration::PgUserRegistrationRepository,
42 registration_token::PgUserRegistrationTokenRepository, session::PgBrowserSessionRepository,
43 terms::PgUserTermsRepository,
44};
45
46pub struct PgUserRepository<'c> {
48 conn: &'c mut PgConnection,
49}
50
51impl<'c> PgUserRepository<'c> {
52 pub fn new(conn: &'c mut PgConnection) -> Self {
54 Self { conn }
55 }
56}
57
58mod priv_ {
59 #![allow(missing_docs)]
62
63 use chrono::{DateTime, Utc};
64 use sea_query::enum_def;
65 use uuid::Uuid;
66
67 #[derive(Debug, Clone, sqlx::FromRow)]
68 #[enum_def]
69 pub(super) struct UserLookup {
70 pub(super) user_id: Uuid,
71 pub(super) username: String,
72 pub(super) created_at: DateTime<Utc>,
73 pub(super) locked_at: Option<DateTime<Utc>>,
74 pub(super) deactivated_at: Option<DateTime<Utc>>,
75 pub(super) can_request_admin: bool,
76 pub(super) is_guest: bool,
77 }
78}
79
80use priv_::{UserLookup, UserLookupIden};
81
82impl From<UserLookup> for User {
83 fn from(value: UserLookup) -> Self {
84 let id = value.user_id.into();
85 Self {
86 id,
87 username: value.username,
88 sub: id.to_string(),
89 created_at: value.created_at,
90 locked_at: value.locked_at,
91 deactivated_at: value.deactivated_at,
92 can_request_admin: value.can_request_admin,
93 is_guest: value.is_guest,
94 }
95 }
96}
97
98impl Filter for UserFilter<'_> {
99 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
100 sea_query::Condition::all()
101 .add_option(self.state().map(|state| {
102 match state {
103 mas_storage::user::UserState::Deactivated => {
104 Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
105 }
106 mas_storage::user::UserState::Locked => {
107 Expr::col((Users::Table, Users::LockedAt)).is_not_null()
108 }
109 mas_storage::user::UserState::Active => {
110 Expr::col((Users::Table, Users::LockedAt))
111 .is_null()
112 .and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
113 }
114 }
115 }))
116 .add_option(self.can_request_admin().map(|can_request_admin| {
117 Expr::col((Users::Table, Users::CanRequestAdmin)).eq(can_request_admin)
118 }))
119 .add_option(
120 self.is_guest()
121 .map(|is_guest| Expr::col((Users::Table, Users::IsGuest)).eq(is_guest)),
122 )
123 .add_option(self.search().map(|search| {
124 Expr::col((Users::Table, Users::Username)).ilike(format!("%{search}%"))
125 }))
126 }
127}
128
129#[async_trait]
130impl UserRepository for PgUserRepository<'_> {
131 type Error = DatabaseError;
132
133 #[tracing::instrument(
134 name = "db.user.lookup",
135 skip_all,
136 fields(
137 db.query.text,
138 user.id = %id,
139 ),
140 err,
141 )]
142 async fn lookup(&mut self, id: Ulid) -> Result<Option<User>, Self::Error> {
143 let res = sqlx::query_as!(
144 UserLookup,
145 r#"
146 SELECT user_id
147 , username
148 , created_at
149 , locked_at
150 , deactivated_at
151 , can_request_admin
152 , is_guest
153 FROM users
154 WHERE user_id = $1
155 "#,
156 Uuid::from(id),
157 )
158 .traced()
159 .fetch_optional(&mut *self.conn)
160 .await?;
161
162 let Some(res) = res else { return Ok(None) };
163
164 Ok(Some(res.into()))
165 }
166
167 #[tracing::instrument(
168 name = "db.user.find_by_username",
169 skip_all,
170 fields(
171 db.query.text,
172 user.username = username,
173 ),
174 err,
175 )]
176 async fn find_by_username(&mut self, username: &str) -> Result<Option<User>, Self::Error> {
177 let res = sqlx::query_as!(
181 UserLookup,
182 r#"
183 SELECT user_id
184 , username
185 , created_at
186 , locked_at
187 , deactivated_at
188 , can_request_admin
189 , is_guest
190 FROM users
191 WHERE LOWER(username) = LOWER($1)
192 "#,
193 username,
194 )
195 .traced()
196 .fetch_all(&mut *self.conn)
197 .await?;
198
199 match &res[..] {
200 [user] => Ok(Some(user.clone().into())),
202 [] => Ok(None),
204 list => {
205 if let Some(user) = list.iter().find(|user| user.username == username) {
208 Ok(Some(user.clone().into()))
209 } else {
210 Ok(None)
212 }
213 }
214 }
215 }
216
217 #[tracing::instrument(
218 name = "db.user.add",
219 skip_all,
220 fields(
221 db.query.text,
222 user.username = username,
223 user.id,
224 ),
225 err,
226 )]
227 async fn add(
228 &mut self,
229 rng: &mut (dyn RngCore + Send),
230 clock: &dyn Clock,
231 username: String,
232 ) -> Result<User, Self::Error> {
233 let created_at = clock.now();
234 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
235 tracing::Span::current().record("user.id", tracing::field::display(id));
236
237 let res = sqlx::query!(
238 r#"
239 INSERT INTO users (user_id, username, created_at)
240 VALUES ($1, $2, $3)
241 ON CONFLICT (username) DO NOTHING
242 "#,
243 Uuid::from(id),
244 username,
245 created_at,
246 )
247 .traced()
248 .execute(&mut *self.conn)
249 .await?;
250
251 DatabaseError::ensure_affected_rows(&res, 1)?;
254
255 Ok(User {
256 id,
257 username,
258 sub: id.to_string(),
259 created_at,
260 locked_at: None,
261 deactivated_at: None,
262 can_request_admin: false,
263 is_guest: false,
264 })
265 }
266
267 #[tracing::instrument(
268 name = "db.user.exists",
269 skip_all,
270 fields(
271 db.query.text,
272 user.username = username,
273 ),
274 err,
275 )]
276 async fn exists(&mut self, username: &str) -> Result<bool, Self::Error> {
277 let exists = sqlx::query_scalar!(
278 r#"
279 SELECT EXISTS(
280 SELECT 1 FROM users WHERE LOWER(username) = LOWER($1)
281 ) AS "exists!"
282 "#,
283 username
284 )
285 .traced()
286 .fetch_one(&mut *self.conn)
287 .await?;
288
289 Ok(exists)
290 }
291
292 #[tracing::instrument(
293 name = "db.user.lock",
294 skip_all,
295 fields(
296 db.query.text,
297 %user.id,
298 ),
299 err,
300 )]
301 async fn lock(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
302 if user.locked_at.is_some() {
303 return Ok(user);
304 }
305
306 let locked_at = clock.now();
307 let res = sqlx::query!(
308 r#"
309 UPDATE users
310 SET locked_at = $1
311 WHERE user_id = $2
312 "#,
313 locked_at,
314 Uuid::from(user.id),
315 )
316 .traced()
317 .execute(&mut *self.conn)
318 .await?;
319
320 DatabaseError::ensure_affected_rows(&res, 1)?;
321
322 user.locked_at = Some(locked_at);
323
324 Ok(user)
325 }
326
327 #[tracing::instrument(
328 name = "db.user.unlock",
329 skip_all,
330 fields(
331 db.query.text,
332 %user.id,
333 ),
334 err,
335 )]
336 async fn unlock(&mut self, mut user: User) -> Result<User, Self::Error> {
337 if user.locked_at.is_none() {
338 return Ok(user);
339 }
340
341 let res = sqlx::query!(
342 r#"
343 UPDATE users
344 SET locked_at = NULL
345 WHERE user_id = $1
346 "#,
347 Uuid::from(user.id),
348 )
349 .traced()
350 .execute(&mut *self.conn)
351 .await?;
352
353 DatabaseError::ensure_affected_rows(&res, 1)?;
354
355 user.locked_at = None;
356
357 Ok(user)
358 }
359
360 #[tracing::instrument(
361 name = "db.user.deactivate",
362 skip_all,
363 fields(
364 db.query.text,
365 %user.id,
366 ),
367 err,
368 )]
369 async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
370 if user.deactivated_at.is_some() {
371 return Ok(user);
372 }
373
374 let deactivated_at = clock.now();
375 let res = sqlx::query!(
376 r#"
377 UPDATE users
378 SET deactivated_at = $2
379 WHERE user_id = $1
380 AND deactivated_at IS NULL
381 "#,
382 Uuid::from(user.id),
383 deactivated_at,
384 )
385 .traced()
386 .execute(&mut *self.conn)
387 .await?;
388
389 DatabaseError::ensure_affected_rows(&res, 1)?;
390
391 user.deactivated_at = Some(deactivated_at);
392
393 Ok(user)
394 }
395
396 #[tracing::instrument(
397 name = "db.user.reactivate",
398 skip_all,
399 fields(
400 db.query.text,
401 %user.id,
402 ),
403 err,
404 )]
405 async fn reactivate(&mut self, mut user: User) -> Result<User, Self::Error> {
406 if user.deactivated_at.is_none() {
407 return Ok(user);
408 }
409
410 let res = sqlx::query!(
411 r#"
412 UPDATE users
413 SET deactivated_at = NULL
414 WHERE user_id = $1
415 "#,
416 Uuid::from(user.id),
417 )
418 .traced()
419 .execute(&mut *self.conn)
420 .await?;
421
422 DatabaseError::ensure_affected_rows(&res, 1)?;
423
424 user.deactivated_at = None;
425
426 Ok(user)
427 }
428
429 #[tracing::instrument(
430 name = "db.user.set_can_request_admin",
431 skip_all,
432 fields(
433 db.query.text,
434 %user.id,
435 user.can_request_admin = can_request_admin,
436 ),
437 err,
438 )]
439 async fn set_can_request_admin(
440 &mut self,
441 mut user: User,
442 can_request_admin: bool,
443 ) -> Result<User, Self::Error> {
444 let res = sqlx::query!(
445 r#"
446 UPDATE users
447 SET can_request_admin = $2
448 WHERE user_id = $1
449 "#,
450 Uuid::from(user.id),
451 can_request_admin,
452 )
453 .traced()
454 .execute(&mut *self.conn)
455 .await?;
456
457 DatabaseError::ensure_affected_rows(&res, 1)?;
458
459 user.can_request_admin = can_request_admin;
460
461 Ok(user)
462 }
463
464 #[tracing::instrument(
465 name = "db.user.list",
466 skip_all,
467 fields(
468 db.query.text,
469 ),
470 err,
471 )]
472 async fn list(
473 &mut self,
474 filter: UserFilter<'_>,
475 pagination: mas_storage::Pagination,
476 ) -> Result<mas_storage::Page<User>, Self::Error> {
477 let (sql, arguments) = Query::select()
478 .expr_as(
479 Expr::col((Users::Table, Users::UserId)),
480 UserLookupIden::UserId,
481 )
482 .expr_as(
483 Expr::col((Users::Table, Users::Username)),
484 UserLookupIden::Username,
485 )
486 .expr_as(
487 Expr::col((Users::Table, Users::CreatedAt)),
488 UserLookupIden::CreatedAt,
489 )
490 .expr_as(
491 Expr::col((Users::Table, Users::LockedAt)),
492 UserLookupIden::LockedAt,
493 )
494 .expr_as(
495 Expr::col((Users::Table, Users::DeactivatedAt)),
496 UserLookupIden::DeactivatedAt,
497 )
498 .expr_as(
499 Expr::col((Users::Table, Users::CanRequestAdmin)),
500 UserLookupIden::CanRequestAdmin,
501 )
502 .expr_as(
503 Expr::col((Users::Table, Users::IsGuest)),
504 UserLookupIden::IsGuest,
505 )
506 .from(Users::Table)
507 .apply_filter(filter)
508 .generate_pagination((Users::Table, Users::UserId), pagination)
509 .build_sqlx(PostgresQueryBuilder);
510
511 let edges: Vec<UserLookup> = sqlx::query_as_with(&sql, arguments)
512 .traced()
513 .fetch_all(&mut *self.conn)
514 .await?;
515
516 let page = pagination.process(edges).map(User::from);
517
518 Ok(page)
519 }
520
521 #[tracing::instrument(
522 name = "db.user.count",
523 skip_all,
524 fields(
525 db.query.text,
526 ),
527 err,
528 )]
529 async fn count(&mut self, filter: UserFilter<'_>) -> Result<usize, Self::Error> {
530 let (sql, arguments) = Query::select()
531 .expr(Expr::col((Users::Table, Users::UserId)).count())
532 .from(Users::Table)
533 .apply_filter(filter)
534 .build_sqlx(PostgresQueryBuilder);
535
536 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
537 .traced()
538 .fetch_one(&mut *self.conn)
539 .await?;
540
541 count
542 .try_into()
543 .map_err(DatabaseError::to_invalid_operation)
544 }
545
546 #[tracing::instrument(
547 name = "db.user.acquire_lock_for_sync",
548 skip_all,
549 fields(
550 db.query.text,
551 user.id = %user.id,
552 ),
553 err,
554 )]
555 async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> {
556 let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64;
564
565 sqlx::query!(
568 r#"
569 SELECT pg_advisory_xact_lock($1)
570 "#,
571 lock_id,
572 )
573 .traced()
574 .execute(&mut *self.conn)
575 .await?;
576
577 Ok(())
578 }
579}