mas_storage_pg/user/
session.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    Authentication, AuthenticationMethod, BrowserSession, Password,
13    UpstreamOAuthAuthorizationSession, User,
14};
15use mas_storage::{
16    Clock, Page, Pagination,
17    user::{BrowserSessionFilter, BrowserSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{Expr, PostgresQueryBuilder, Query};
21use sea_query_binder::SqlxBinder;
22use sqlx::PgConnection;
23use ulid::Ulid;
24use uuid::Uuid;
25
26use crate::{
27    DatabaseError, DatabaseInconsistencyError,
28    filter::StatementExt,
29    iden::{UpstreamOAuthAuthorizationSessions, UserSessionAuthentications, UserSessions, Users},
30    pagination::QueryBuilderExt,
31    tracing::ExecuteExt,
32};
33
34/// An implementation of [`BrowserSessionRepository`] for a PostgreSQL
35/// connection
36pub struct PgBrowserSessionRepository<'c> {
37    conn: &'c mut PgConnection,
38}
39
40impl<'c> PgBrowserSessionRepository<'c> {
41    /// Create a new [`PgBrowserSessionRepository`] from an active PostgreSQL
42    /// connection
43    pub fn new(conn: &'c mut PgConnection) -> Self {
44        Self { conn }
45    }
46}
47
48#[allow(clippy::struct_field_names)]
49#[derive(sqlx::FromRow)]
50#[sea_query::enum_def]
51struct SessionLookup {
52    user_session_id: Uuid,
53    user_session_created_at: DateTime<Utc>,
54    user_session_finished_at: Option<DateTime<Utc>>,
55    user_session_user_agent: Option<String>,
56    user_session_last_active_at: Option<DateTime<Utc>>,
57    user_session_last_active_ip: Option<IpAddr>,
58    user_id: Uuid,
59    user_username: String,
60    user_created_at: DateTime<Utc>,
61    user_locked_at: Option<DateTime<Utc>>,
62    user_deactivated_at: Option<DateTime<Utc>>,
63    user_can_request_admin: bool,
64}
65
66impl TryFrom<SessionLookup> for BrowserSession {
67    type Error = DatabaseInconsistencyError;
68
69    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
70        let id = Ulid::from(value.user_id);
71        let user = User {
72            id,
73            username: value.user_username,
74            sub: id.to_string(),
75            created_at: value.user_created_at,
76            locked_at: value.user_locked_at,
77            deactivated_at: value.user_deactivated_at,
78            can_request_admin: value.user_can_request_admin,
79        };
80
81        Ok(BrowserSession {
82            id: value.user_session_id.into(),
83            user,
84            created_at: value.user_session_created_at,
85            finished_at: value.user_session_finished_at,
86            user_agent: value.user_session_user_agent,
87            last_active_at: value.user_session_last_active_at,
88            last_active_ip: value.user_session_last_active_ip,
89        })
90    }
91}
92
93struct AuthenticationLookup {
94    user_session_authentication_id: Uuid,
95    created_at: DateTime<Utc>,
96    user_password_id: Option<Uuid>,
97    upstream_oauth_authorization_session_id: Option<Uuid>,
98}
99
100impl TryFrom<AuthenticationLookup> for Authentication {
101    type Error = DatabaseInconsistencyError;
102
103    fn try_from(value: AuthenticationLookup) -> Result<Self, Self::Error> {
104        let id = Ulid::from(value.user_session_authentication_id);
105        let authentication_method = match (
106            value.user_password_id.map(Into::into),
107            value
108                .upstream_oauth_authorization_session_id
109                .map(Into::into),
110        ) {
111            (Some(user_password_id), None) => AuthenticationMethod::Password { user_password_id },
112            (None, Some(upstream_oauth2_session_id)) => AuthenticationMethod::UpstreamOAuth2 {
113                upstream_oauth2_session_id,
114            },
115            (None, None) => AuthenticationMethod::Unknown,
116            _ => {
117                return Err(DatabaseInconsistencyError::on("user_session_authentications").row(id));
118            }
119        };
120
121        Ok(Authentication {
122            id,
123            created_at: value.created_at,
124            authentication_method,
125        })
126    }
127}
128
129impl crate::filter::Filter for BrowserSessionFilter<'_> {
130    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
131        sea_query::Condition::all()
132            .add_option(self.user().map(|user| {
133                Expr::col((UserSessions::Table, UserSessions::UserId)).eq(Uuid::from(user.id))
134            }))
135            .add_option(self.state().map(|state| {
136                if state.is_active() {
137                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_null()
138                } else {
139                    Expr::col((UserSessions::Table, UserSessions::FinishedAt)).is_not_null()
140                }
141            }))
142            .add_option(self.last_active_after().map(|last_active_after| {
143                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).gt(last_active_after)
144            }))
145            .add_option(self.last_active_before().map(|last_active_before| {
146                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)).lt(last_active_before)
147            }))
148            .add_option(self.authenticated_by_upstream_sessions().map(|filter| {
149                // For filtering by upstream sessions, we need to hop over the
150                // `user_session_authentications` table
151                let join_expr = Expr::col((
152                    UserSessionAuthentications::Table,
153                    UserSessionAuthentications::UpstreamOAuthAuthorizationSessionId,
154                ))
155                .eq(Expr::col((
156                    UpstreamOAuthAuthorizationSessions::Table,
157                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
158                )));
159
160                Expr::col((UserSessions::Table, UserSessions::UserSessionId)).in_subquery(
161                    Query::select()
162                        .expr(Expr::col((
163                            UserSessionAuthentications::Table,
164                            UserSessionAuthentications::UserSessionId,
165                        )))
166                        .from(UserSessionAuthentications::Table)
167                        .inner_join(UpstreamOAuthAuthorizationSessions::Table, join_expr)
168                        .apply_filter(filter)
169                        .take(),
170                )
171            }))
172    }
173}
174
175#[async_trait]
176impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
177    type Error = DatabaseError;
178
179    #[tracing::instrument(
180        name = "db.browser_session.lookup",
181        skip_all,
182        fields(
183            db.query.text,
184            user_session.id = %id,
185        ),
186        err,
187    )]
188    async fn lookup(&mut self, id: Ulid) -> Result<Option<BrowserSession>, Self::Error> {
189        let res = sqlx::query_as!(
190            SessionLookup,
191            r#"
192                SELECT s.user_session_id
193                     , s.created_at            AS "user_session_created_at"
194                     , s.finished_at           AS "user_session_finished_at"
195                     , s.user_agent            AS "user_session_user_agent"
196                     , s.last_active_at        AS "user_session_last_active_at"
197                     , s.last_active_ip        AS "user_session_last_active_ip: IpAddr"
198                     , u.user_id
199                     , u.username              AS "user_username"
200                     , u.created_at            AS "user_created_at"
201                     , u.locked_at             AS "user_locked_at"
202                     , u.deactivated_at        AS "user_deactivated_at"
203                     , u.can_request_admin     AS "user_can_request_admin"
204                FROM user_sessions s
205                INNER JOIN users u
206                    USING (user_id)
207                WHERE s.user_session_id = $1
208            "#,
209            Uuid::from(id),
210        )
211        .traced()
212        .fetch_optional(&mut *self.conn)
213        .await?;
214
215        let Some(res) = res else { return Ok(None) };
216
217        Ok(Some(res.try_into()?))
218    }
219
220    #[tracing::instrument(
221        name = "db.browser_session.add",
222        skip_all,
223        fields(
224            db.query.text,
225            %user.id,
226            user_session.id,
227        ),
228        err,
229    )]
230    async fn add(
231        &mut self,
232        rng: &mut (dyn RngCore + Send),
233        clock: &dyn Clock,
234        user: &User,
235        user_agent: Option<String>,
236    ) -> Result<BrowserSession, Self::Error> {
237        let created_at = clock.now();
238        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
239        tracing::Span::current().record("user_session.id", tracing::field::display(id));
240
241        sqlx::query!(
242            r#"
243                INSERT INTO user_sessions (user_session_id, user_id, created_at, user_agent)
244                VALUES ($1, $2, $3, $4)
245            "#,
246            Uuid::from(id),
247            Uuid::from(user.id),
248            created_at,
249            user_agent.as_deref(),
250        )
251        .traced()
252        .execute(&mut *self.conn)
253        .await?;
254
255        let session = BrowserSession {
256            id,
257            // XXX
258            user: user.clone(),
259            created_at,
260            finished_at: None,
261            user_agent,
262            last_active_at: None,
263            last_active_ip: None,
264        };
265
266        Ok(session)
267    }
268
269    #[tracing::instrument(
270        name = "db.browser_session.finish",
271        skip_all,
272        fields(
273            db.query.text,
274            %user_session.id,
275        ),
276        err,
277    )]
278    async fn finish(
279        &mut self,
280        clock: &dyn Clock,
281        mut user_session: BrowserSession,
282    ) -> Result<BrowserSession, Self::Error> {
283        let finished_at = clock.now();
284        let res = sqlx::query!(
285            r#"
286                UPDATE user_sessions
287                SET finished_at = $1
288                WHERE user_session_id = $2
289            "#,
290            finished_at,
291            Uuid::from(user_session.id),
292        )
293        .traced()
294        .execute(&mut *self.conn)
295        .await?;
296
297        user_session.finished_at = Some(finished_at);
298
299        DatabaseError::ensure_affected_rows(&res, 1)?;
300
301        Ok(user_session)
302    }
303
304    #[tracing::instrument(
305        name = "db.browser_session.finish_bulk",
306        skip_all,
307        fields(
308            db.query.text,
309        ),
310        err,
311    )]
312    async fn finish_bulk(
313        &mut self,
314        clock: &dyn Clock,
315        filter: BrowserSessionFilter<'_>,
316    ) -> Result<usize, Self::Error> {
317        let finished_at = clock.now();
318        let (sql, arguments) = sea_query::Query::update()
319            .table(UserSessions::Table)
320            .value(UserSessions::FinishedAt, finished_at)
321            .apply_filter(filter)
322            .build_sqlx(PostgresQueryBuilder);
323
324        let res = sqlx::query_with(&sql, arguments)
325            .traced()
326            .execute(&mut *self.conn)
327            .await?;
328
329        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
330    }
331
332    #[tracing::instrument(
333        name = "db.browser_session.list",
334        skip_all,
335        fields(
336            db.query.text,
337        ),
338        err,
339    )]
340    async fn list(
341        &mut self,
342        filter: BrowserSessionFilter<'_>,
343        pagination: Pagination,
344    ) -> Result<Page<BrowserSession>, Self::Error> {
345        let (sql, arguments) = sea_query::Query::select()
346            .expr_as(
347                Expr::col((UserSessions::Table, UserSessions::UserSessionId)),
348                SessionLookupIden::UserSessionId,
349            )
350            .expr_as(
351                Expr::col((UserSessions::Table, UserSessions::CreatedAt)),
352                SessionLookupIden::UserSessionCreatedAt,
353            )
354            .expr_as(
355                Expr::col((UserSessions::Table, UserSessions::FinishedAt)),
356                SessionLookupIden::UserSessionFinishedAt,
357            )
358            .expr_as(
359                Expr::col((UserSessions::Table, UserSessions::UserAgent)),
360                SessionLookupIden::UserSessionUserAgent,
361            )
362            .expr_as(
363                Expr::col((UserSessions::Table, UserSessions::LastActiveAt)),
364                SessionLookupIden::UserSessionLastActiveAt,
365            )
366            .expr_as(
367                Expr::col((UserSessions::Table, UserSessions::LastActiveIp)),
368                SessionLookupIden::UserSessionLastActiveIp,
369            )
370            .expr_as(
371                Expr::col((Users::Table, Users::UserId)),
372                SessionLookupIden::UserId,
373            )
374            .expr_as(
375                Expr::col((Users::Table, Users::Username)),
376                SessionLookupIden::UserUsername,
377            )
378            .expr_as(
379                Expr::col((Users::Table, Users::CreatedAt)),
380                SessionLookupIden::UserCreatedAt,
381            )
382            .expr_as(
383                Expr::col((Users::Table, Users::LockedAt)),
384                SessionLookupIden::UserLockedAt,
385            )
386            .expr_as(
387                Expr::col((Users::Table, Users::DeactivatedAt)),
388                SessionLookupIden::UserDeactivatedAt,
389            )
390            .expr_as(
391                Expr::col((Users::Table, Users::CanRequestAdmin)),
392                SessionLookupIden::UserCanRequestAdmin,
393            )
394            .from(UserSessions::Table)
395            .inner_join(
396                Users::Table,
397                Expr::col((UserSessions::Table, UserSessions::UserId))
398                    .equals((Users::Table, Users::UserId)),
399            )
400            .apply_filter(filter)
401            .generate_pagination(
402                (UserSessions::Table, UserSessions::UserSessionId),
403                pagination,
404            )
405            .build_sqlx(PostgresQueryBuilder);
406
407        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
408            .traced()
409            .fetch_all(&mut *self.conn)
410            .await?;
411
412        let page = pagination
413            .process(edges)
414            .try_map(BrowserSession::try_from)?;
415
416        Ok(page)
417    }
418
419    #[tracing::instrument(
420        name = "db.browser_session.count",
421        skip_all,
422        fields(
423            db.query.text,
424        ),
425        err,
426    )]
427    async fn count(&mut self, filter: BrowserSessionFilter<'_>) -> Result<usize, Self::Error> {
428        let (sql, arguments) = sea_query::Query::select()
429            .expr(Expr::col((UserSessions::Table, UserSessions::UserSessionId)).count())
430            .from(UserSessions::Table)
431            .apply_filter(filter)
432            .build_sqlx(PostgresQueryBuilder);
433
434        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
435            .traced()
436            .fetch_one(&mut *self.conn)
437            .await?;
438
439        count
440            .try_into()
441            .map_err(DatabaseError::to_invalid_operation)
442    }
443
444    #[tracing::instrument(
445        name = "db.browser_session.authenticate_with_password",
446        skip_all,
447        fields(
448            db.query.text,
449            %user_session.id,
450            %user_password.id,
451            user_session_authentication.id,
452        ),
453        err,
454    )]
455    async fn authenticate_with_password(
456        &mut self,
457        rng: &mut (dyn RngCore + Send),
458        clock: &dyn Clock,
459        user_session: &BrowserSession,
460        user_password: &Password,
461    ) -> Result<Authentication, Self::Error> {
462        let created_at = clock.now();
463        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
464        tracing::Span::current().record(
465            "user_session_authentication.id",
466            tracing::field::display(id),
467        );
468
469        sqlx::query!(
470            r#"
471                INSERT INTO user_session_authentications
472                    (user_session_authentication_id, user_session_id, created_at, user_password_id)
473                VALUES ($1, $2, $3, $4)
474            "#,
475            Uuid::from(id),
476            Uuid::from(user_session.id),
477            created_at,
478            Uuid::from(user_password.id),
479        )
480        .traced()
481        .execute(&mut *self.conn)
482        .await?;
483
484        Ok(Authentication {
485            id,
486            created_at,
487            authentication_method: AuthenticationMethod::Password {
488                user_password_id: user_password.id,
489            },
490        })
491    }
492
493    #[tracing::instrument(
494        name = "db.browser_session.authenticate_with_upstream",
495        skip_all,
496        fields(
497            db.query.text,
498            %user_session.id,
499            %upstream_oauth_session.id,
500            user_session_authentication.id,
501        ),
502        err,
503    )]
504    async fn authenticate_with_upstream(
505        &mut self,
506        rng: &mut (dyn RngCore + Send),
507        clock: &dyn Clock,
508        user_session: &BrowserSession,
509        upstream_oauth_session: &UpstreamOAuthAuthorizationSession,
510    ) -> Result<Authentication, Self::Error> {
511        let created_at = clock.now();
512        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
513        tracing::Span::current().record(
514            "user_session_authentication.id",
515            tracing::field::display(id),
516        );
517
518        sqlx::query!(
519            r#"
520                INSERT INTO user_session_authentications
521                    (user_session_authentication_id, user_session_id, created_at, upstream_oauth_authorization_session_id)
522                VALUES ($1, $2, $3, $4)
523            "#,
524            Uuid::from(id),
525            Uuid::from(user_session.id),
526            created_at,
527            Uuid::from(upstream_oauth_session.id),
528        )
529        .traced()
530        .execute(&mut *self.conn)
531        .await?;
532
533        Ok(Authentication {
534            id,
535            created_at,
536            authentication_method: AuthenticationMethod::UpstreamOAuth2 {
537                upstream_oauth2_session_id: upstream_oauth_session.id,
538            },
539        })
540    }
541
542    #[tracing::instrument(
543        name = "db.browser_session.get_last_authentication",
544        skip_all,
545        fields(
546            db.query.text,
547            %user_session.id,
548        ),
549        err,
550    )]
551    async fn get_last_authentication(
552        &mut self,
553        user_session: &BrowserSession,
554    ) -> Result<Option<Authentication>, Self::Error> {
555        let authentication = sqlx::query_as!(
556            AuthenticationLookup,
557            r#"
558                SELECT user_session_authentication_id
559                     , created_at
560                     , user_password_id
561                     , upstream_oauth_authorization_session_id
562                FROM user_session_authentications
563                WHERE user_session_id = $1
564                ORDER BY created_at DESC
565                LIMIT 1
566            "#,
567            Uuid::from(user_session.id),
568        )
569        .traced()
570        .fetch_optional(&mut *self.conn)
571        .await?;
572
573        let Some(authentication) = authentication else {
574            return Ok(None);
575        };
576
577        let authentication = Authentication::try_from(authentication)?;
578        Ok(Some(authentication))
579    }
580
581    #[tracing::instrument(
582        name = "db.browser_session.record_batch_activity",
583        skip_all,
584        fields(
585            db.query.text,
586        ),
587        err,
588    )]
589    async fn record_batch_activity(
590        &mut self,
591        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
592    ) -> Result<(), Self::Error> {
593        // Sort the activity by ID, so that when batching the updates, Postgres
594        // locks the rows in a stable order, preventing deadlocks
595        activities.sort_unstable();
596        let mut ids = Vec::with_capacity(activities.len());
597        let mut last_activities = Vec::with_capacity(activities.len());
598        let mut ips = Vec::with_capacity(activities.len());
599
600        for (id, last_activity, ip) in activities {
601            ids.push(Uuid::from(id));
602            last_activities.push(last_activity);
603            ips.push(ip);
604        }
605
606        let res = sqlx::query!(
607            r#"
608                UPDATE user_sessions
609                SET last_active_at = GREATEST(t.last_active_at, user_sessions.last_active_at)
610                  , last_active_ip = COALESCE(t.last_active_ip, user_sessions.last_active_ip)
611                FROM (
612                    SELECT *
613                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
614                        AS t(user_session_id, last_active_at, last_active_ip)
615                ) AS t
616                WHERE user_sessions.user_session_id = t.user_session_id
617            "#,
618            &ids,
619            &last_activities,
620            &ips as &[Option<IpAddr>],
621        )
622        .traced()
623        .execute(&mut *self.conn)
624        .await?;
625
626        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
627
628        Ok(())
629    }
630}