mas_storage_pg/personal/
session.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    Clock, User,
12    personal::{
13        PersonalAccessToken,
14        session::{PersonalSession, PersonalSessionOwner, SessionState},
15    },
16};
17use mas_storage::{
18    Page, Pagination,
19    pagination::Node,
20    personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
21};
22use oauth2_types::scope::Scope;
23use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26    Cond, Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
27    extension::postgres::PgExpr as _,
28};
29use sea_query_binder::SqlxBinder as _;
30use sqlx::PgConnection;
31use tracing::{Instrument as _, info_span};
32use ulid::Ulid;
33use uuid::Uuid;
34
35use crate::{
36    DatabaseError,
37    errors::DatabaseInconsistencyError,
38    filter::{Filter, StatementExt as _},
39    iden::{PersonalAccessTokens, PersonalSessions},
40    pagination::QueryBuilderExt as _,
41    tracing::ExecuteExt as _,
42};
43
44/// An implementation of [`PersonalSessionRepository`] for a PostgreSQL
45/// connection
46pub struct PgPersonalSessionRepository<'c> {
47    conn: &'c mut PgConnection,
48}
49
50impl<'c> PgPersonalSessionRepository<'c> {
51    /// Create a new [`PgPersonalSessionRepository`] from an active PostgreSQL
52    /// connection
53    pub fn new(conn: &'c mut PgConnection) -> Self {
54        Self { conn }
55    }
56}
57
58#[derive(sqlx::FromRow)]
59#[enum_def]
60struct PersonalSessionLookup {
61    personal_session_id: Uuid,
62    owner_user_id: Option<Uuid>,
63    owner_oauth2_client_id: Option<Uuid>,
64    actor_user_id: Uuid,
65    human_name: String,
66    scope_list: Vec<String>,
67    created_at: DateTime<Utc>,
68    revoked_at: Option<DateTime<Utc>>,
69    last_active_at: Option<DateTime<Utc>>,
70    last_active_ip: Option<IpAddr>,
71}
72
73impl Node<Ulid> for PersonalSessionLookup {
74    fn cursor(&self) -> Ulid {
75        self.personal_session_id.into()
76    }
77}
78
79impl TryFrom<PersonalSessionLookup> for PersonalSession {
80    type Error = DatabaseInconsistencyError;
81
82    fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
83        let id = Ulid::from(value.personal_session_id);
84        let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
85        let scope = scope.map_err(|e| {
86            DatabaseInconsistencyError::on("personal_sessions")
87                .column("scope")
88                .row(id)
89                .source(e)
90        })?;
91
92        let state = match value.revoked_at {
93            None => SessionState::Valid,
94            Some(revoked_at) => SessionState::Revoked { revoked_at },
95        };
96
97        let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
98            (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
99            (None, Some(owner_oauth2_client_id)) => {
100                PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
101            }
102            _ => {
103                // should be impossible (CHECK constraint in Postgres prevents it)
104                return Err(DatabaseInconsistencyError::on("personal_sessions")
105                    .column("owner_user_id, owner_oauth2_client_id")
106                    .row(id));
107            }
108        };
109
110        Ok(PersonalSession {
111            id,
112            state,
113            owner,
114            actor_user_id: Ulid::from(value.actor_user_id),
115            human_name: value.human_name,
116            scope,
117            created_at: value.created_at,
118            last_active_at: value.last_active_at,
119            last_active_ip: value.last_active_ip,
120        })
121    }
122}
123
124#[derive(sqlx::FromRow)]
125#[enum_def]
126struct PersonalSessionAndAccessTokenLookup {
127    personal_session_id: Uuid,
128    owner_user_id: Option<Uuid>,
129    owner_oauth2_client_id: Option<Uuid>,
130    actor_user_id: Uuid,
131    human_name: String,
132    scope_list: Vec<String>,
133    created_at: DateTime<Utc>,
134    revoked_at: Option<DateTime<Utc>>,
135    last_active_at: Option<DateTime<Utc>>,
136    last_active_ip: Option<IpAddr>,
137
138    // tokens
139    personal_access_token_id: Option<Uuid>,
140    token_created_at: Option<DateTime<Utc>>,
141    token_expires_at: Option<DateTime<Utc>>,
142}
143
144impl Node<Ulid> for PersonalSessionAndAccessTokenLookup {
145    fn cursor(&self) -> Ulid {
146        self.personal_session_id.into()
147    }
148}
149
150impl TryFrom<PersonalSessionAndAccessTokenLookup>
151    for (PersonalSession, Option<PersonalAccessToken>)
152{
153    type Error = DatabaseInconsistencyError;
154
155    fn try_from(value: PersonalSessionAndAccessTokenLookup) -> Result<Self, Self::Error> {
156        let session = PersonalSession::try_from(PersonalSessionLookup {
157            personal_session_id: value.personal_session_id,
158            owner_user_id: value.owner_user_id,
159            owner_oauth2_client_id: value.owner_oauth2_client_id,
160            actor_user_id: value.actor_user_id,
161            human_name: value.human_name,
162            scope_list: value.scope_list,
163            created_at: value.created_at,
164            revoked_at: value.revoked_at,
165            last_active_at: value.last_active_at,
166            last_active_ip: value.last_active_ip,
167        })?;
168
169        let token_opt = if let Some(id) = value.personal_access_token_id {
170            let id = Ulid::from(id);
171            Some(PersonalAccessToken {
172                id,
173                session_id: session.id,
174                // should not be possible
175                created_at: value.token_created_at.ok_or(
176                    DatabaseInconsistencyError::on("personal_sessions")
177                        .column("created_at")
178                        .row(id),
179                )?,
180                expires_at: value.token_expires_at,
181                revoked_at: None,
182            })
183        } else {
184            None
185        };
186
187        Ok((session, token_opt))
188    }
189}
190
191#[async_trait]
192impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
193    type Error = DatabaseError;
194
195    #[tracing::instrument(
196        name = "db.personal_session.lookup",
197        skip_all,
198        fields(
199            db.query.text,
200            session.id = %id,
201        ),
202        err,
203    )]
204    async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
205        let res = sqlx::query_as!(
206            PersonalSessionLookup,
207            r#"
208                SELECT personal_session_id
209                     , owner_user_id
210                     , owner_oauth2_client_id
211                     , actor_user_id
212                     , scope_list
213                     , created_at
214                     , revoked_at
215                     , human_name
216                     , last_active_at
217                     , last_active_ip as "last_active_ip: IpAddr"
218                FROM personal_sessions
219
220                WHERE personal_session_id = $1
221            "#,
222            Uuid::from(id),
223        )
224        .traced()
225        .fetch_optional(&mut *self.conn)
226        .await?;
227
228        let Some(session) = res else { return Ok(None) };
229
230        Ok(Some(session.try_into()?))
231    }
232
233    #[tracing::instrument(
234        name = "db.personal_session.add",
235        skip_all,
236        fields(
237            db.query.text,
238            session.id,
239            session.scope = %scope,
240        ),
241        err,
242    )]
243    async fn add(
244        &mut self,
245        rng: &mut (dyn RngCore + Send),
246        clock: &dyn Clock,
247        owner: PersonalSessionOwner,
248        actor_user: &User,
249        human_name: String,
250        scope: Scope,
251    ) -> Result<PersonalSession, Self::Error> {
252        let created_at = clock.now();
253        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
254        tracing::Span::current().record("session.id", tracing::field::display(id));
255
256        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
257
258        let (owner_user_id, owner_oauth2_client_id) = match owner {
259            PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
260            PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
261        };
262
263        sqlx::query!(
264            r#"
265                INSERT INTO personal_sessions
266                    ( personal_session_id
267                    , owner_user_id
268                    , owner_oauth2_client_id
269                    , actor_user_id
270                    , human_name
271                    , scope_list
272                    , created_at
273                    )
274                VALUES ($1, $2, $3, $4, $5, $6, $7)
275            "#,
276            Uuid::from(id),
277            owner_user_id,
278            owner_oauth2_client_id,
279            Uuid::from(actor_user.id),
280            &human_name,
281            &scope_list,
282            created_at,
283        )
284        .traced()
285        .execute(&mut *self.conn)
286        .await?;
287
288        Ok(PersonalSession {
289            id,
290            state: SessionState::Valid,
291            owner,
292            actor_user_id: actor_user.id,
293            human_name,
294            scope,
295            created_at,
296            last_active_at: None,
297            last_active_ip: None,
298        })
299    }
300
301    #[tracing::instrument(
302        name = "db.personal_session.revoke",
303        skip_all,
304        fields(
305            db.query.text,
306            %session.id,
307            %session.scope,
308        ),
309        err,
310    )]
311    async fn revoke(
312        &mut self,
313        clock: &dyn Clock,
314        session: PersonalSession,
315    ) -> Result<PersonalSession, Self::Error> {
316        let revoked_at = clock.now();
317
318        {
319            // Revoke dependent PATs
320            let span = info_span!(
321                "db.personal_session.revoke.tokens",
322                { DB_QUERY_TEXT } = tracing::field::Empty,
323            );
324
325            sqlx::query!(
326                r#"
327                    UPDATE personal_access_tokens
328                    SET revoked_at = $2
329                    WHERE personal_session_id = $1 AND revoked_at IS NULL
330                "#,
331                Uuid::from(session.id),
332                revoked_at,
333            )
334            .record(&span)
335            .execute(&mut *self.conn)
336            .instrument(span)
337            .await?;
338        }
339
340        let res = sqlx::query!(
341            r#"
342                UPDATE personal_sessions
343                SET revoked_at = $2
344                WHERE personal_session_id = $1
345            "#,
346            Uuid::from(session.id),
347            revoked_at,
348        )
349        .traced()
350        .execute(&mut *self.conn)
351        .await?;
352
353        DatabaseError::ensure_affected_rows(&res, 1)?;
354
355        session
356            .finish(revoked_at)
357            .map_err(DatabaseError::to_invalid_operation)
358    }
359
360    #[tracing::instrument(
361        name = "db.personal_session.revoke_bulk",
362        skip_all,
363        fields(
364            db.query.text,
365        ),
366        err,
367    )]
368    async fn revoke_bulk(
369        &mut self,
370        clock: &dyn Clock,
371        filter: PersonalSessionFilter<'_>,
372    ) -> Result<usize, Self::Error> {
373        let revoked_at = clock.now();
374
375        let (sql, arguments) = Query::update()
376            .table(PersonalSessions::Table)
377            .value(PersonalSessions::RevokedAt, revoked_at)
378            .and_where(
379                Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
380                    // Because filters apply to both the session and access token tables,
381                    // Use a subquery to make it possible to use a JOIN
382                    // onto the personal access token table.
383                    .in_subquery(
384                        Query::select()
385                            .expr(Expr::col((
386                                PersonalSessions::Table,
387                                PersonalSessions::PersonalSessionId,
388                            )))
389                            .from(PersonalSessions::Table)
390                            .left_join(
391                                PersonalAccessTokens::Table,
392                                Cond::all()
393                                    // Match session ID
394                                    .add(
395                                        Expr::col((
396                                            PersonalSessions::Table,
397                                            PersonalSessions::PersonalSessionId,
398                                        ))
399                                        .eq(Expr::col((
400                                            PersonalAccessTokens::Table,
401                                            PersonalAccessTokens::PersonalSessionId,
402                                        ))),
403                                    )
404                                    // Only choose the active access token for each session
405                                    .add(
406                                        Expr::col((
407                                            PersonalAccessTokens::Table,
408                                            PersonalAccessTokens::RevokedAt,
409                                        ))
410                                        .is_null(),
411                                    ),
412                            )
413                            .apply_filter(filter)
414                            .take(),
415                    ),
416            )
417            .build_sqlx(PostgresQueryBuilder);
418
419        let res = sqlx::query_with(&sql, arguments)
420            .traced()
421            .execute(&mut *self.conn)
422            .await?;
423
424        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
425    }
426
427    #[tracing::instrument(
428        name = "db.personal_session.list",
429        skip_all,
430        fields(
431            db.query.text,
432        ),
433        err,
434    )]
435    async fn list(
436        &mut self,
437        filter: PersonalSessionFilter<'_>,
438        pagination: Pagination,
439    ) -> Result<Page<(PersonalSession, Option<PersonalAccessToken>)>, Self::Error> {
440        let (sql, arguments) = Query::select()
441            .expr_as(
442                Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
443                PersonalSessionAndAccessTokenLookupIden::PersonalSessionId,
444            )
445            .expr_as(
446                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
447                PersonalSessionAndAccessTokenLookupIden::OwnerUserId,
448            )
449            .expr_as(
450                Expr::col((
451                    PersonalSessions::Table,
452                    PersonalSessions::OwnerOAuth2ClientId,
453                )),
454                PersonalSessionAndAccessTokenLookupIden::OwnerOauth2ClientId,
455            )
456            .expr_as(
457                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
458                PersonalSessionAndAccessTokenLookupIden::ActorUserId,
459            )
460            .expr_as(
461                Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
462                PersonalSessionAndAccessTokenLookupIden::HumanName,
463            )
464            .expr_as(
465                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
466                PersonalSessionAndAccessTokenLookupIden::ScopeList,
467            )
468            .expr_as(
469                Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
470                PersonalSessionAndAccessTokenLookupIden::CreatedAt,
471            )
472            .expr_as(
473                Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
474                PersonalSessionAndAccessTokenLookupIden::RevokedAt,
475            )
476            .expr_as(
477                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
478                PersonalSessionAndAccessTokenLookupIden::LastActiveAt,
479            )
480            .expr_as(
481                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
482                PersonalSessionAndAccessTokenLookupIden::LastActiveIp,
483            )
484            .expr_as(
485                Expr::col((
486                    PersonalAccessTokens::Table,
487                    PersonalAccessTokens::PersonalAccessTokenId,
488                )),
489                PersonalSessionAndAccessTokenLookupIden::PersonalAccessTokenId,
490            )
491            .expr_as(
492                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::CreatedAt)),
493                PersonalSessionAndAccessTokenLookupIden::TokenCreatedAt,
494            )
495            .expr_as(
496                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt)),
497                PersonalSessionAndAccessTokenLookupIden::TokenExpiresAt,
498            )
499            .from(PersonalSessions::Table)
500            .left_join(
501                PersonalAccessTokens::Table,
502                Cond::all()
503                    // Match session ID
504                    .add(
505                        Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
506                            .eq(Expr::col((
507                                PersonalAccessTokens::Table,
508                                PersonalAccessTokens::PersonalSessionId,
509                            ))),
510                    )
511                    // Only choose the active access token for each session
512                    .add(
513                        Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
514                            .is_null(),
515                    ),
516            )
517            .apply_filter(filter)
518            .generate_pagination(
519                (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
520                pagination,
521            )
522            .build_sqlx(PostgresQueryBuilder);
523
524        let edges: Vec<PersonalSessionAndAccessTokenLookup> = sqlx::query_as_with(&sql, arguments)
525            .traced()
526            .fetch_all(&mut *self.conn)
527            .await?;
528
529        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
530
531        Ok(page)
532    }
533
534    #[tracing::instrument(
535        name = "db.personal_session.count",
536        skip_all,
537        fields(
538            db.query.text,
539        ),
540        err,
541    )]
542    async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
543        let (sql, arguments) = Query::select()
544            .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
545            .from(PersonalSessions::Table)
546            .left_join(
547                PersonalAccessTokens::Table,
548                Cond::all()
549                    // Match session ID
550                    .add(
551                        Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
552                            .eq(Expr::col((
553                                PersonalAccessTokens::Table,
554                                PersonalAccessTokens::PersonalSessionId,
555                            ))),
556                    )
557                    // Only choose the active access token for each session
558                    .add(
559                        Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
560                            .is_null(),
561                    ),
562            )
563            .apply_filter(filter)
564            .build_sqlx(PostgresQueryBuilder);
565
566        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
567            .traced()
568            .fetch_one(&mut *self.conn)
569            .await?;
570
571        count
572            .try_into()
573            .map_err(DatabaseError::to_invalid_operation)
574    }
575
576    #[tracing::instrument(
577        name = "db.personal_session.record_batch_activity",
578        skip_all,
579        fields(
580            db.query.text,
581        ),
582        err,
583    )]
584    async fn record_batch_activity(
585        &mut self,
586        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
587    ) -> Result<(), Self::Error> {
588        // Sort the activity by ID, so that when batching the updates, Postgres
589        // locks the rows in a stable order, preventing deadlocks
590        activities.sort_unstable();
591        let mut ids = Vec::with_capacity(activities.len());
592        let mut last_activities = Vec::with_capacity(activities.len());
593        let mut ips = Vec::with_capacity(activities.len());
594
595        for (id, last_activity, ip) in activities {
596            ids.push(Uuid::from(id));
597            last_activities.push(last_activity);
598            ips.push(ip);
599        }
600
601        let res = sqlx::query!(
602            r#"
603                UPDATE personal_sessions
604                SET last_active_at = GREATEST(t.last_active_at, personal_sessions.last_active_at)
605                  , last_active_ip = COALESCE(t.last_active_ip, personal_sessions.last_active_ip)
606                FROM (
607                    SELECT *
608                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
609                        AS t(personal_session_id, last_active_at, last_active_ip)
610                ) AS t
611                WHERE personal_sessions.personal_session_id = t.personal_session_id
612            "#,
613            &ids,
614            &last_activities,
615            &ips as &[Option<IpAddr>],
616        )
617        .traced()
618        .execute(&mut *self.conn)
619        .await?;
620
621        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
622
623        Ok(())
624    }
625}
626
627impl Filter for PersonalSessionFilter<'_> {
628    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
629        sea_query::Condition::all()
630            .add_option(self.owner_user().map(|user| {
631                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
632                    .eq(Uuid::from(user.id))
633            }))
634            .add_option(self.owner_oauth2_client().map(|client| {
635                Expr::col((
636                    PersonalSessions::Table,
637                    PersonalSessions::OwnerOAuth2ClientId,
638                ))
639                .eq(Uuid::from(client.id))
640            }))
641            .add_option(self.actor_user().map(|user| {
642                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
643                    .eq(Uuid::from(user.id))
644            }))
645            .add_option(self.device().map(|device| -> SimpleExpr {
646                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
647                    Condition::any()
648                        .add(
649                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
650                                PersonalSessions::Table,
651                                PersonalSessions::ScopeList,
652                            )))),
653                        )
654                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
655                            Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
656                        )))
657                        .into()
658                } else {
659                    // If the device ID can't be encoded as a scope token, match no rows
660                    Expr::val(false).into()
661                }
662            }))
663            .add_option(self.state().map(|state| match state {
664                PersonalSessionState::Active => {
665                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
666                }
667                PersonalSessionState::Revoked => {
668                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
669                }
670            }))
671            .add_option(self.scope().map(|scope| {
672                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
673                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
674            }))
675            .add_option(self.last_active_before().map(|last_active_before| {
676                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
677                    .lt(last_active_before)
678            }))
679            .add_option(self.last_active_after().map(|last_active_after| {
680                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
681                    .gt(last_active_after)
682            }))
683            .add_option(self.expires_before().map(|expires_before| {
684                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
685                    .lt(expires_before)
686            }))
687            .add_option(self.expires_after().map(|expires_after| {
688                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
689                    .gt(expires_after)
690            }))
691            .add_option(self.expires().map(|expires| {
692                let column =
693                    Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt));
694
695                if expires {
696                    column.is_not_null()
697                } else {
698                    column.is_null()
699                }
700            }))
701    }
702}