1use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12 Authentication, AuthenticationMethod, BrowserSession, Password,
13 UpstreamOAuthAuthorizationSession, User,
14};
15use mas_storage::{
16 Clock, Page, Pagination,
17 user::{BrowserSessionFilter, BrowserSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27 DatabaseError, DatabaseInconsistencyError,
28 filter::StatementExt,
29 iden::{UpstreamOAuthAuthorizationSessions, UserSessionAuthentications, UserSessions, Users},
30 pagination::QueryBuilderExt,
31 tracing::ExecuteExt,
32};
33
34pub struct PgBrowserSessionRepository<'c> {
37 conn: &'c mut PgConnection,
38}
39
40impl<'c> PgBrowserSessionRepository<'c> {
41 pub fn new(conn: &'c mut PgConnection) -> Self {
44 Self { conn }
45 }
46}
47
48#[allow(clippy::struct_field_names)]
49#[derive(sqlx::FromRow)]
50#[sea_query::enum_def]
51struct SessionLookup {
52 user_session_id: Uuid,
53 user_session_created_at: DateTime<Utc>,
54 user_session_finished_at: Option<DateTime<Utc>>,
55 user_session_user_agent: Option<String>,
56 user_session_last_active_at: Option<DateTime<Utc>>,
57 user_session_last_active_ip: Option<IpAddr>,
58 user_id: Uuid,
59 user_username: String,
60 user_created_at: DateTime<Utc>,
61 user_locked_at: Option<DateTime<Utc>>,
62 user_deactivated_at: Option<DateTime<Utc>>,
63 user_can_request_admin: bool,
64}
65
66impl TryFrom<SessionLookup> for BrowserSession {
67 type Error = DatabaseInconsistencyError;
68
69 fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
70 let id = Ulid::from(value.user_id);
71 let user = User {
72 id,
73 username: value.user_username,
74 sub: id.to_string(),
75 created_at: value.user_created_at,
76 locked_at: value.user_locked_at,
77 deactivated_at: value.user_deactivated_at,
78 can_request_admin: value.user_can_request_admin,
79 };
80
81 Ok(BrowserSession {
82 id: value.user_session_id.into(),
83 user,
84 created_at: value.user_session_created_at,
85 finished_at: value.user_session_finished_at,
86 user_agent: value.user_session_user_agent,
87 last_active_at: value.user_session_last_active_at,
88 last_active_ip: value.user_session_last_active_ip,
89 })
90 }
91}
92
93struct AuthenticationLookup {
94 user_session_authentication_id: Uuid,
95 created_at: DateTime<Utc>,
96 user_password_id: Option<Uuid>,
97 upstream_oauth_authorization_session_id: Option<Uuid>,
98}
99
100impl TryFrom<AuthenticationLookup> for Authentication {
101 type Error = DatabaseInconsistencyError;
102
103 fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
104 let id = Ulid::from(value.user_session_authentication_id);
105 let authentication_method = match (
106 value.user_password_id.map(Into::into),
107 value
108 .upstream_oauth_authorization_session_id
109 .map(Into::into),
110 ) {
111 (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
112 (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
113 upstream_oauth2_session_id,
114 },
115 (None, None) => AuthenticationMethod::Unknown,
116 _ => {
117 return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
118 }
119 };
120
121 Ok(Authentication {
122 id,
123 created_at: value.created_at,
124 authentication_method,
125 })
126 }
127}
128
129impl crate::filter::Filter for BrowserSessionFilter<'_> {
130 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
131 sea_query::Condition::all()
132 .add_option(self.user().map(|user| {
133 Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
134 }))
135 .add_option(self.state().map(|state| {
136 if state.is_active() {
137 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
138 } else {
139 Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
140 }
141 }))
142 .add_option(self.last_active_after().map(|last_active_after| {
143 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
144 }))
145 .add_option(self.last_active_before().map(|last_active_before| {
146 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
147 }))
148 .add_option(self.authenticated_by_upstream_sessions().map(|filter| {
149 let join_expr = Expr::col((
152 UserSessionAuthentications::Table,
153 UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
154 ))
155 .eq(Expr::col((
156 UpstreamOAuthAuthorizationSessions::Table,
157 UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
158 )));
159
160 Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
161 Query::select()
162 .expr(Expr::col((
163 UserSessionAuthentications::Table,
164 UserSessionAuthentications::UserSessionId,
165 )))
166 .from(UserSessionAuthentications::Table)
167 .inner_join(UpstreamOAuthAuthorizationSessions::Table, join_expr)
168 .apply_filter(filter)
169 .take(),
170 )
171 }))
172 }
173}
174
175#[async_trait]
176impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
177 type Error = DatabaseError;
178
179 #[tracing::instrument(
180 name = "db.browser_session.lookup",
181 skip_all,
182 fields(
183 db.query.text,
184 user_session.id = %id,
185 ),
186 err,
187 )]
188 async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
189 let res = sqlx::query_as!(
190 SessionLookup,
191 r#"
192 SELECT s.user_session_id
193 , s.created_at AS "user_session_created_at"
194 , s.finished_at AS "user_session_finished_at"
195 , s.user_agent AS "user_session_user_agent"
196 , s.last_active_at AS "user_session_last_active_at"
197 , s.last_active_ip AS "user_session_last_active_ip: IpAddr"
198 , u.user_id
199 , u.username AS "user_username"
200 , u.created_at AS "user_created_at"
201 , u.locked_at AS "user_locked_at"
202 , u.deactivated_at AS "user_deactivated_at"
203 , u.can_request_admin AS "user_can_request_admin"
204 FROM user_sessions s
205 INNER JOIN users u
206 USING (user_id)
207 WHERE s.user_session_id = $1
208 "#,
209 Uuid::from(id),
210 )
211 .traced()
212 .fetch_optional(&mut *self.conn)
213 .await?;
214
215 let Some(res) = res else { return Ok(None) };
216
217 Ok(Some(res.try_into()?))
218 }
219
220 #[tracing::instrument(
221 name = "db.browser_session.add",
222 skip_all,
223 fields(
224 db.query.text,
225 %user.id,
226 user_session.id,
227 ),
228 err,
229 )]
230 async fn add(
231 &mut self,
232 rng: &mut (dyn RngCore + Send),
233 clock: &dyn Clock,
234 user: &User,
235 user_agent: Option<String>,
236 ) -> Result<BrowserSession, Self::Error> {
237 let created_at = clock.now();
238 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
239 tracing::Span::current().record("user_session.id", tracing::field::display(id));
240
241 sqlx::query!(
242 r#"
243 INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
244 VALUES ($1, $2, $3, $4)
245 "#,
246 Uuid::from(id),
247 Uuid::from(user.id),
248 created_at,
249 user_agent.as_deref(),
250 )
251 .traced()
252 .execute(&mut *self.conn)
253 .await?;
254
255 let session = BrowserSession {
256 id,
257 user: user.clone(),
259 created_at,
260 finished_at: None,
261 user_agent,
262 last_active_at: None,
263 last_active_ip: None,
264 };
265
266 Ok(session)
267 }
268
269 #[tracing::instrument(
270 name = "db.browser_session.finish",
271 skip_all,
272 fields(
273 db.query.text,
274 %user_session.id,
275 ),
276 err,
277 )]
278 async fn finish(
279 &mut self,
280 clock: &dyn Clock,
281 mut user_session: BrowserSession,
282 ) -> Result<BrowserSession, Self::Error> {
283 let finished_at = clock.now();
284 let res = sqlx::query!(
285 r#"
286 UPDATE user_sessions
287 SET finished_at = $1
288 WHERE user_session_id = $2
289 "#,
290 finished_at,
291 Uuid::from(user_session.id),
292 )
293 .traced()
294 .execute(&mut *self.conn)
295 .await?;
296
297 user_session.finished_at = Some(finished_at);
298
299 DatabaseError::ensure_affected_rows(&res, 1)?;
300
301 Ok(user_session)
302 }
303
304 #[tracing::instrument(
305 name = "db.browser_session.finish_bulk",
306 skip_all,
307 fields(
308 db.query.text,
309 ),
310 err,
311 )]
312 async fn finish_bulk(
313 &mut self,
314 clock: &dyn Clock,
315 filter: BrowserSessionFilter<'_>,
316 ) -> Result<usize, Self::Error> {
317 let finished_at = clock.now();
318 let (sql, arguments) = sea_query::Query::update()
319 .table(UserSessions::Table)
320 .value(UserSessions::FinishedAt, finished_at)
321 .apply_filter(filter)
322 .build_sqlx(PostgresQueryBuilder);
323
324 let res = sqlx::query_with(&sql, arguments)
325 .traced()
326 .execute(&mut *self.conn)
327 .await?;
328
329 Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
330 }
331
332 #[tracing::instrument(
333 name = "db.browser_session.list",
334 skip_all,
335 fields(
336 db.query.text,
337 ),
338 err,
339 )]
340 async fn list(
341 &mut self,
342 filter: BrowserSessionFilter<'_>,
343 pagination: Pagination,
344 ) -> Result<Page<BrowserSession>, Self::Error> {
345 let (sql, arguments) = sea_query::Query::select()
346 .expr_as(
347 Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
348 SessionLookupIden::UserSessionId,
349 )
350 .expr_as(
351 Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
352 SessionLookupIden::UserSessionCreatedAt,
353 )
354 .expr_as(
355 Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
356 SessionLookupIden::UserSessionFinishedAt,
357 )
358 .expr_as(
359 Expr::col((UserSessions::Table, UserSessions::UserAgent)),
360 SessionLookupIden::UserSessionUserAgent,
361 )
362 .expr_as(
363 Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
364 SessionLookupIden::UserSessionLastActiveAt,
365 )
366 .expr_as(
367 Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
368 SessionLookupIden::UserSessionLastActiveIp,
369 )
370 .expr_as(
371 Expr::col((Users::Table, Users::UserId)),
372 SessionLookupIden::UserId,
373 )
374 .expr_as(
375 Expr::col((Users::Table, Users::Username)),
376 SessionLookupIden::UserUsername,
377 )
378 .expr_as(
379 Expr::col((Users::Table, Users::CreatedAt)),
380 SessionLookupIden::UserCreatedAt,
381 )
382 .expr_as(
383 Expr::col((Users::Table, Users::LockedAt)),
384 SessionLookupIden::UserLockedAt,
385 )
386 .expr_as(
387 Expr::col((Users::Table, Users::DeactivatedAt)),
388 SessionLookupIden::UserDeactivatedAt,
389 )
390 .expr_as(
391 Expr::col((Users::Table, Users::CanRequestAdmin)),
392 SessionLookupIden::UserCanRequestAdmin,
393 )
394 .from(UserSessions::Table)
395 .inner_join(
396 Users::Table,
397 Expr::col((UserSessions::Table, UserSessions::UserId))
398 .equals((Users::Table, Users::UserId)),
399 )
400 .apply_filter(filter)
401 .generate_pagination(
402 (UserSessions::Table, UserSessions::UserSessionId),
403 pagination,
404 )
405 .build_sqlx(PostgresQueryBuilder);
406
407 let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
408 .traced()
409 .fetch_all(&mut *self.conn)
410 .await?;
411
412 let page = pagination
413 .process(edges)
414 .try_map(BrowserSession::try_from)?;
415
416 Ok(page)
417 }
418
419 #[tracing::instrument(
420 name = "db.browser_session.count",
421 skip_all,
422 fields(
423 db.query.text,
424 ),
425 err,
426 )]
427 async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
428 let (sql, arguments) = sea_query::Query::select()
429 .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
430 .from(UserSessions::Table)
431 .apply_filter(filter)
432 .build_sqlx(PostgresQueryBuilder);
433
434 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
435 .traced()
436 .fetch_one(&mut *self.conn)
437 .await?;
438
439 count
440 .try_into()
441 .map_err(DatabaseError::to_invalid_operation)
442 }
443
444 #[tracing::instrument(
445 name = "db.browser_session.authenticate_with_password",
446 skip_all,
447 fields(
448 db.query.text,
449 %user_session.id,
450 %user_password.id,
451 user_session_authentication.id,
452 ),
453 err,
454 )]
455 async fn authenticate_with_password(
456 &mut self,
457 rng: &mut (dyn RngCore + Send),
458 clock: &dyn Clock,
459 user_session: &BrowserSession,
460 user_password: &Password,
461 ) -> Result<Authentication, Self::Error> {
462 let created_at = clock.now();
463 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
464 tracing::Span::current().record(
465 "user_session_authentication.id",
466 tracing::field::display(id),
467 );
468
469 sqlx::query!(
470 r#"
471 INSERT INTO user_session_authentications
472 (user_session_authentication_id, user_session_id, created_at, user_password_id)
473 VALUES ($1, $2, $3, $4)
474 "#,
475 Uuid::from(id),
476 Uuid::from(user_session.id),
477 created_at,
478 Uuid::from(user_password.id),
479 )
480 .traced()
481 .execute(&mut *self.conn)
482 .await?;
483
484 Ok(Authentication {
485 id,
486 created_at,
487 authentication_method: AuthenticationMethod::Password {
488 user_password_id: user_password.id,
489 },
490 })
491 }
492
493 #[tracing::instrument(
494 name = "db.browser_session.authenticate_with_upstream",
495 skip_all,
496 fields(
497 db.query.text,
498 %user_session.id,
499 %upstream_oauth_session.id,
500 user_session_authentication.id,
501 ),
502 err,
503 )]
504 async fn authenticate_with_upstream(
505 &mut self,
506 rng: &mut (dyn RngCore + Send),
507 clock: &dyn Clock,
508 user_session: &BrowserSession,
509 upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
510 ) -> Result<Authentication, Self::Error> {
511 let created_at = clock.now();
512 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
513 tracing::Span::current().record(
514 "user_session_authentication.id",
515 tracing::field::display(id),
516 );
517
518 sqlx::query!(
519 r#"
520 INSERT INTO user_session_authentications
521 (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
522 VALUES ($1, $2, $3, $4)
523 "#,
524 Uuid::from(id),
525 Uuid::from(user_session.id),
526 created_at,
527 Uuid::from(upstream_oauth_session.id),
528 )
529 .traced()
530 .execute(&mut *self.conn)
531 .await?;
532
533 Ok(Authentication {
534 id,
535 created_at,
536 authentication_method: AuthenticationMethod::UpstreamOAuth2 {
537 upstream_oauth2_session_id: upstream_oauth_session.id,
538 },
539 })
540 }
541
542 #[tracing::instrument(
543 name = "db.browser_session.get_last_authentication",
544 skip_all,
545 fields(
546 db.query.text,
547 %user_session.id,
548 ),
549 err,
550 )]
551 async fn get_last_authentication(
552 &mut self,
553 user_session: &BrowserSession,
554 ) -> Result<Option<Authentication>, Self::Error> {
555 let authentication = sqlx::query_as!(
556 AuthenticationLookup,
557 r#"
558 SELECT user_session_authentication_id
559 , created_at
560 , user_password_id
561 , upstream_oauth_authorization_session_id
562 FROM user_session_authentications
563 WHERE user_session_id = $1
564 ORDER BY created_at DESC
565 LIMIT 1
566 "#,
567 Uuid::from(user_session.id),
568 )
569 .traced()
570 .fetch_optional(&mut *self.conn)
571 .await?;
572
573 let Some(authentication) = authentication else {
574 return Ok(None);
575 };
576
577 let authentication = Authentication::try_from(authentication)?;
578 Ok(Some(authentication))
579 }
580
581 #[tracing::instrument(
582 name = "db.browser_session.record_batch_activity",
583 skip_all,
584 fields(
585 db.query.text,
586 ),
587 err,
588 )]
589 async fn record_batch_activity(
590 &mut self,
591 mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
592 ) -> Result<(), Self::Error> {
593 activities.sort_unstable();
596 let mut ids = Vec::with_capacity(activities.len());
597 let mut last_activities = Vec::with_capacity(activities.len());
598 let mut ips = Vec::with_capacity(activities.len());
599
600 for (id, last_activity, ip) in activities {
601 ids.push(Uuid::from(id));
602 last_activities.push(last_activity);
603 ips.push(ip);
604 }
605
606 let res = sqlx::query!(
607 r#"
608 UPDATE user_sessions
609 SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
610 , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
611 FROM (
612 SELECT *
613 FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
614 AS t(user_session_id, last_active_at, last_active_ip)
615 ) AS t
616 WHERE user_sessions.user_session_id = t.user_session_id
617 "#,
618 &ids,
619 &last_activities,
620 &ips as &[Option<IpAddr>],
621 )
622 .traced()
623 .execute(&mut *self.conn)
624 .await?;
625
626 DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
627
628 Ok(())
629 }
630}