Skip to main content

mas_storage_pg/personal/
session.rs

1// Copyright 2026 Element Creations Ltd.
2// Copyright 2025 New Vector Ltd.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use std::net::IpAddr;
8
9use async_trait::async_trait;
10use chrono::{DateTime, Utc};
11use mas_data_model::{
12    Clock, UlidExt as _, User,
13    personal::{
14        PersonalAccessToken,
15        session::{PersonalSession, PersonalSessionOwner, SessionState},
16    },
17};
18use mas_storage::{
19    Page, Pagination,
20    pagination::Node,
21    personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
22};
23use oauth2_types::scope::Scope;
24use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
25use rand::RngCore;
26use sea_query::{
27    Cond, Condition, Expr, ExprTrait, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
28    extension::postgres::PgExpr as _,
29};
30use sea_query_sqlx::SqlxBinder as _;
31use sqlx::PgConnection;
32use tracing::{Instrument as _, info_span};
33use ulid::Ulid;
34use uuid::Uuid;
35
36use crate::{
37    DatabaseError,
38    errors::DatabaseInconsistencyError,
39    filter::{Filter, StatementExt as _},
40    iden::{PersonalAccessTokens, PersonalSessions},
41    pagination::QueryBuilderExt as _,
42    tracing::ExecuteExt as _,
43};
44
45/// An implementation of [`PersonalSessionRepository`] for a PostgreSQL
46/// connection
47pub struct PgPersonalSessionRepository<'c> {
48    conn: &'c mut PgConnection,
49}
50
51impl<'c> PgPersonalSessionRepository<'c> {
52    /// Create a new [`PgPersonalSessionRepository`] from an active PostgreSQL
53    /// connection
54    pub fn new(conn: &'c mut PgConnection) -> Self {
55        Self { conn }
56    }
57}
58
59#[derive(sqlx::FromRow)]
60#[enum_def]
61struct PersonalSessionLookup {
62    personal_session_id: Uuid,
63    owner_user_id: Option<Uuid>,
64    owner_oauth2_client_id: Option<Uuid>,
65    actor_user_id: Uuid,
66    human_name: String,
67    scope_list: Vec<String>,
68    created_at: DateTime<Utc>,
69    revoked_at: Option<DateTime<Utc>>,
70    last_active_at: Option<DateTime<Utc>>,
71    last_active_ip: Option<IpAddr>,
72}
73
74impl Node<Ulid> for PersonalSessionLookup {
75    fn cursor(&self) -> Ulid {
76        self.personal_session_id.into()
77    }
78}
79
80impl TryFrom<PersonalSessionLookup> for PersonalSession {
81    type Error = DatabaseInconsistencyError;
82
83    fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
84        let id = Ulid::from(value.personal_session_id);
85        let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
86        let scope = scope.map_err(|e| {
87            DatabaseInconsistencyError::on("personal_sessions")
88                .column("scope")
89                .row(id)
90                .source(e)
91        })?;
92
93        let state = match value.revoked_at {
94            None => SessionState::Valid,
95            Some(revoked_at) => SessionState::Revoked { revoked_at },
96        };
97
98        let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
99            (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
100            (None, Some(owner_oauth2_client_id)) => {
101                PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
102            }
103            _ => {
104                // should be impossible (CHECK constraint in Postgres prevents it)
105                return Err(DatabaseInconsistencyError::on("personal_sessions")
106                    .column("owner_user_id, owner_oauth2_client_id")
107                    .row(id));
108            }
109        };
110
111        Ok(PersonalSession {
112            id,
113            state,
114            owner,
115            actor_user_id: Ulid::from(value.actor_user_id),
116            human_name: value.human_name,
117            scope,
118            created_at: value.created_at,
119            last_active_at: value.last_active_at,
120            last_active_ip: value.last_active_ip,
121        })
122    }
123}
124
125#[derive(sqlx::FromRow)]
126#[enum_def]
127struct PersonalSessionAndAccessTokenLookup {
128    personal_session_id: Uuid,
129    owner_user_id: Option<Uuid>,
130    owner_oauth2_client_id: Option<Uuid>,
131    actor_user_id: Uuid,
132    human_name: String,
133    scope_list: Vec<String>,
134    created_at: DateTime<Utc>,
135    revoked_at: Option<DateTime<Utc>>,
136    last_active_at: Option<DateTime<Utc>>,
137    last_active_ip: Option<IpAddr>,
138
139    // tokens
140    personal_access_token_id: Option<Uuid>,
141    token_created_at: Option<DateTime<Utc>>,
142    token_expires_at: Option<DateTime<Utc>>,
143}
144
145impl Node<Ulid> for PersonalSessionAndAccessTokenLookup {
146    fn cursor(&self) -> Ulid {
147        self.personal_session_id.into()
148    }
149}
150
151impl TryFrom<PersonalSessionAndAccessTokenLookup>
152    for (PersonalSession, Option<PersonalAccessToken>)
153{
154    type Error = DatabaseInconsistencyError;
155
156    fn try_from(value: PersonalSessionAndAccessTokenLookup) -> Result<Self, Self::Error> {
157        let session = PersonalSession::try_from(PersonalSessionLookup {
158            personal_session_id: value.personal_session_id,
159            owner_user_id: value.owner_user_id,
160            owner_oauth2_client_id: value.owner_oauth2_client_id,
161            actor_user_id: value.actor_user_id,
162            human_name: value.human_name,
163            scope_list: value.scope_list,
164            created_at: value.created_at,
165            revoked_at: value.revoked_at,
166            last_active_at: value.last_active_at,
167            last_active_ip: value.last_active_ip,
168        })?;
169
170        let token_opt = if let Some(id) = value.personal_access_token_id {
171            let id = Ulid::from(id);
172            Some(PersonalAccessToken {
173                id,
174                session_id: session.id,
175                // should not be possible
176                created_at: value.token_created_at.ok_or(
177                    DatabaseInconsistencyError::on("personal_sessions")
178                        .column("created_at")
179                        .row(id),
180                )?,
181                expires_at: value.token_expires_at,
182                revoked_at: None,
183            })
184        } else {
185            None
186        };
187
188        Ok((session, token_opt))
189    }
190}
191
192#[async_trait]
193impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
194    type Error = DatabaseError;
195
196    #[tracing::instrument(
197        name = "db.personal_session.lookup",
198        skip_all,
199        fields(
200            db.query.text,
201            session.id = %id,
202        ),
203        err,
204    )]
205    async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
206        let res = sqlx::query_as!(
207            PersonalSessionLookup,
208            r#"
209                SELECT personal_session_id
210                     , owner_user_id
211                     , owner_oauth2_client_id
212                     , actor_user_id
213                     , scope_list
214                     , created_at
215                     , revoked_at
216                     , human_name
217                     , last_active_at
218                     , last_active_ip as "last_active_ip: IpAddr"
219                FROM personal_sessions
220
221                WHERE personal_session_id = $1
222            "#,
223            Uuid::from(id),
224        )
225        .traced()
226        .fetch_optional(&mut *self.conn)
227        .await?;
228
229        let Some(session) = res else { return Ok(None) };
230
231        Ok(Some(session.try_into()?))
232    }
233
234    #[tracing::instrument(
235        name = "db.personal_session.add",
236        skip_all,
237        fields(
238            db.query.text,
239            session.id,
240            session.scope = %scope,
241        ),
242        err,
243    )]
244    async fn add(
245        &mut self,
246        rng: &mut (dyn RngCore + Send),
247        clock: &dyn Clock,
248        owner: PersonalSessionOwner,
249        actor_user: &User,
250        human_name: String,
251        scope: Scope,
252    ) -> Result<PersonalSession, Self::Error> {
253        let created_at = clock.now();
254        let id = Ulid::from_datetime_with_rng(created_at, rng);
255        tracing::Span::current().record("session.id", tracing::field::display(id));
256
257        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
258
259        let (owner_user_id, owner_oauth2_client_id) = match owner {
260            PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
261            PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
262        };
263
264        sqlx::query!(
265            r#"
266                INSERT INTO personal_sessions
267                    ( personal_session_id
268                    , owner_user_id
269                    , owner_oauth2_client_id
270                    , actor_user_id
271                    , human_name
272                    , scope_list
273                    , created_at
274                    )
275                VALUES ($1, $2, $3, $4, $5, $6, $7)
276            "#,
277            Uuid::from(id),
278            owner_user_id,
279            owner_oauth2_client_id,
280            Uuid::from(actor_user.id),
281            &human_name,
282            &scope_list,
283            created_at,
284        )
285        .traced()
286        .execute(&mut *self.conn)
287        .await?;
288
289        Ok(PersonalSession {
290            id,
291            state: SessionState::Valid,
292            owner,
293            actor_user_id: actor_user.id,
294            human_name,
295            scope,
296            created_at,
297            last_active_at: None,
298            last_active_ip: None,
299        })
300    }
301
302    #[tracing::instrument(
303        name = "db.personal_session.revoke",
304        skip_all,
305        fields(
306            db.query.text,
307            %session.id,
308            %session.scope,
309        ),
310        err,
311    )]
312    async fn revoke(
313        &mut self,
314        clock: &dyn Clock,
315        session: PersonalSession,
316    ) -> Result<PersonalSession, Self::Error> {
317        let revoked_at = clock.now();
318
319        {
320            // Revoke dependent PATs
321            let span = info_span!(
322                "db.personal_session.revoke.tokens",
323                { DB_QUERY_TEXT } = tracing::field::Empty,
324            );
325
326            sqlx::query!(
327                r#"
328                    UPDATE personal_access_tokens
329                    SET revoked_at = $2
330                    WHERE personal_session_id = $1 AND revoked_at IS NULL
331                "#,
332                Uuid::from(session.id),
333                revoked_at,
334            )
335            .record(&span)
336            .execute(&mut *self.conn)
337            .instrument(span)
338            .await?;
339        }
340
341        let res = sqlx::query!(
342            r#"
343                UPDATE personal_sessions
344                SET revoked_at = $2
345                WHERE personal_session_id = $1
346            "#,
347            Uuid::from(session.id),
348            revoked_at,
349        )
350        .traced()
351        .execute(&mut *self.conn)
352        .await?;
353
354        DatabaseError::ensure_affected_rows(&res, 1)?;
355
356        session
357            .finish(revoked_at)
358            .map_err(DatabaseError::to_invalid_operation)
359    }
360
361    #[tracing::instrument(
362        name = "db.personal_session.revoke_bulk",
363        skip_all,
364        fields(
365            db.query.text,
366        ),
367        err,
368    )]
369    async fn revoke_bulk(
370        &mut self,
371        clock: &dyn Clock,
372        filter: PersonalSessionFilter<'_>,
373    ) -> Result<usize, Self::Error> {
374        let revoked_at = clock.now();
375
376        let (sql, arguments) = Query::update()
377            .table(PersonalSessions::Table)
378            .value(PersonalSessions::RevokedAt, revoked_at)
379            .and_where(
380                Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
381                    // Because filters apply to both the session and access token tables,
382                    // Use a subquery to make it possible to use a JOIN
383                    // onto the personal access token table.
384                    .in_subquery(
385                        Query::select()
386                            .expr(Expr::col((
387                                PersonalSessions::Table,
388                                PersonalSessions::PersonalSessionId,
389                            )))
390                            .from(PersonalSessions::Table)
391                            .left_join(
392                                PersonalAccessTokens::Table,
393                                Cond::all()
394                                    // Match session ID
395                                    .add(
396                                        Expr::col((
397                                            PersonalSessions::Table,
398                                            PersonalSessions::PersonalSessionId,
399                                        ))
400                                        .eq(Expr::col((
401                                            PersonalAccessTokens::Table,
402                                            PersonalAccessTokens::PersonalSessionId,
403                                        ))),
404                                    )
405                                    // Only choose the active access token for each session
406                                    .add(
407                                        Expr::col((
408                                            PersonalAccessTokens::Table,
409                                            PersonalAccessTokens::RevokedAt,
410                                        ))
411                                        .is_null(),
412                                    ),
413                            )
414                            .apply_filter(filter)
415                            .take(),
416                    ),
417            )
418            .build_sqlx(PostgresQueryBuilder);
419
420        let res = sqlx::query_with(&sql, arguments)
421            .traced()
422            .execute(&mut *self.conn)
423            .await?;
424
425        Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
426    }
427
428    #[tracing::instrument(
429        name = "db.personal_session.list",
430        skip_all,
431        fields(
432            db.query.text,
433        ),
434        err,
435    )]
436    async fn list(
437        &mut self,
438        filter: PersonalSessionFilter<'_>,
439        pagination: Pagination,
440    ) -> Result<Page<(PersonalSession, Option<PersonalAccessToken>)>, Self::Error> {
441        let (sql, arguments) = Query::select()
442            .expr_as(
443                Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
444                PersonalSessionAndAccessTokenLookupIden::PersonalSessionId,
445            )
446            .expr_as(
447                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
448                PersonalSessionAndAccessTokenLookupIden::OwnerUserId,
449            )
450            .expr_as(
451                Expr::col((
452                    PersonalSessions::Table,
453                    PersonalSessions::OwnerOAuth2ClientId,
454                )),
455                PersonalSessionAndAccessTokenLookupIden::OwnerOauth2ClientId,
456            )
457            .expr_as(
458                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
459                PersonalSessionAndAccessTokenLookupIden::ActorUserId,
460            )
461            .expr_as(
462                Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
463                PersonalSessionAndAccessTokenLookupIden::HumanName,
464            )
465            .expr_as(
466                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
467                PersonalSessionAndAccessTokenLookupIden::ScopeList,
468            )
469            .expr_as(
470                Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
471                PersonalSessionAndAccessTokenLookupIden::CreatedAt,
472            )
473            .expr_as(
474                Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
475                PersonalSessionAndAccessTokenLookupIden::RevokedAt,
476            )
477            .expr_as(
478                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
479                PersonalSessionAndAccessTokenLookupIden::LastActiveAt,
480            )
481            .expr_as(
482                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
483                PersonalSessionAndAccessTokenLookupIden::LastActiveIp,
484            )
485            .expr_as(
486                Expr::col((
487                    PersonalAccessTokens::Table,
488                    PersonalAccessTokens::PersonalAccessTokenId,
489                )),
490                PersonalSessionAndAccessTokenLookupIden::PersonalAccessTokenId,
491            )
492            .expr_as(
493                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::CreatedAt)),
494                PersonalSessionAndAccessTokenLookupIden::TokenCreatedAt,
495            )
496            .expr_as(
497                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt)),
498                PersonalSessionAndAccessTokenLookupIden::TokenExpiresAt,
499            )
500            .from(PersonalSessions::Table)
501            .left_join(
502                PersonalAccessTokens::Table,
503                Cond::all()
504                    // Match session ID
505                    .add(
506                        Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
507                            .eq(Expr::col((
508                                PersonalAccessTokens::Table,
509                                PersonalAccessTokens::PersonalSessionId,
510                            ))),
511                    )
512                    // Only choose the active access token for each session
513                    .add(
514                        Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
515                            .is_null(),
516                    ),
517            )
518            .apply_filter(filter)
519            .generate_pagination(
520                (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
521                pagination,
522            )
523            .build_sqlx(PostgresQueryBuilder);
524
525        let edges: Vec<PersonalSessionAndAccessTokenLookup> = sqlx::query_as_with(&sql, arguments)
526            .traced()
527            .fetch_all(&mut *self.conn)
528            .await?;
529
530        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
531
532        Ok(page)
533    }
534
535    #[tracing::instrument(
536        name = "db.personal_session.count",
537        skip_all,
538        fields(
539            db.query.text,
540        ),
541        err,
542    )]
543    async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
544        let (sql, arguments) = Query::select()
545            .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
546            .from(PersonalSessions::Table)
547            .left_join(
548                PersonalAccessTokens::Table,
549                Cond::all()
550                    // Match session ID
551                    .add(
552                        Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
553                            .eq(Expr::col((
554                                PersonalAccessTokens::Table,
555                                PersonalAccessTokens::PersonalSessionId,
556                            ))),
557                    )
558                    // Only choose the active access token for each session
559                    .add(
560                        Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
561                            .is_null(),
562                    ),
563            )
564            .apply_filter(filter)
565            .build_sqlx(PostgresQueryBuilder);
566
567        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
568            .traced()
569            .fetch_one(&mut *self.conn)
570            .await?;
571
572        count
573            .try_into()
574            .map_err(DatabaseError::to_invalid_operation)
575    }
576
577    #[tracing::instrument(
578        name = "db.personal_session.record_batch_activity",
579        skip_all,
580        fields(
581            db.query.text,
582        ),
583        err,
584    )]
585    async fn record_batch_activity(
586        &mut self,
587        mut activities: Vec<(Ulid, DateTime<Utc>, Option<IpAddr>)>,
588    ) -> Result<(), Self::Error> {
589        // Sort the activity by ID, so that when batching the updates, Postgres
590        // locks the rows in a stable order, preventing deadlocks
591        activities.sort_unstable();
592        let mut ids = Vec::with_capacity(activities.len());
593        let mut last_activities = Vec::with_capacity(activities.len());
594        let mut ips = Vec::with_capacity(activities.len());
595
596        for (id, last_activity, ip) in activities {
597            ids.push(Uuid::from(id));
598            last_activities.push(last_activity);
599            ips.push(ip);
600        }
601
602        let res = sqlx::query!(
603            r#"
604                UPDATE personal_sessions
605                SET last_active_at = GREATEST(t.last_active_at, personal_sessions.last_active_at)
606                  , last_active_ip = COALESCE(t.last_active_ip, personal_sessions.last_active_ip)
607                FROM (
608                    SELECT *
609                    FROM UNNEST($1::uuid[], $2::timestamptz[], $3::inet[])
610                        AS t(personal_session_id, last_active_at, last_active_ip)
611                ) AS t
612                WHERE personal_sessions.personal_session_id = t.personal_session_id
613            "#,
614            &ids,
615            &last_activities,
616            &ips as &[Option<IpAddr>],
617        )
618        .traced()
619        .execute(&mut *self.conn)
620        .await?;
621
622        DatabaseError::ensure_affected_rows(&res, ids.len().try_into().unwrap_or(u64::MAX))?;
623
624        Ok(())
625    }
626}
627
628impl Filter for PersonalSessionFilter<'_> {
629    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
630        sea_query::Condition::all()
631            .add_option(self.owner_user().map(|user| {
632                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
633                    .eq(Uuid::from(user.id))
634            }))
635            .add_option(self.owner_oauth2_client().map(|client| {
636                Expr::col((
637                    PersonalSessions::Table,
638                    PersonalSessions::OwnerOAuth2ClientId,
639                ))
640                .eq(Uuid::from(client.id))
641            }))
642            .add_option(self.actor_user().map(|user| {
643                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
644                    .eq(Uuid::from(user.id))
645            }))
646            .add_option(self.device().map(|device| -> SimpleExpr {
647                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
648                    Condition::any()
649                        .add(
650                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
651                                PersonalSessions::Table,
652                                PersonalSessions::ScopeList,
653                            )))),
654                        )
655                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
656                            Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
657                        )))
658                        .into()
659                } else {
660                    // If the device ID can't be encoded as a scope token, match no rows
661                    Expr::val(false)
662                }
663            }))
664            .add_option(self.state().map(|state| match state {
665                PersonalSessionState::Active => {
666                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
667                }
668                PersonalSessionState::Revoked => {
669                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
670                }
671            }))
672            .add_option(self.scope().map(|scope| {
673                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
674                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
675            }))
676            .add_option(self.last_active_before().map(|last_active_before| {
677                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
678                    .lt(last_active_before)
679            }))
680            .add_option(self.last_active_after().map(|last_active_after| {
681                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
682                    .gt(last_active_after)
683            }))
684            .add_option(self.expires_before().map(|expires_before| {
685                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
686                    .lt(expires_before)
687            }))
688            .add_option(self.expires_after().map(|expires_after| {
689                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
690                    .gt(expires_after)
691            }))
692            .add_option(self.expires().map(|expires| {
693                let column =
694                    Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt));
695
696                if expires {
697                    column.is_not_null()
698                } else {
699                    column.is_null()
700                }
701            }))
702    }
703}