Skip to main content

mas_storage_pg/upstream_oauth2/
session.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    BrowserSession, Clock, UlidExt as _, UpstreamOAuthAuthorizationSession,
12    UpstreamOAuthAuthorizationSessionState, UpstreamOAuthLink, UpstreamOAuthProvider,
13};
14use mas_storage::{
15    Page, Pagination,
16    pagination::Node,
17    upstream_oauth2::{UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository},
18};
19use rand::RngCore;
20use sea_query::{
21    Expr, ExprTrait, PostgresQueryBuilder, Query, enum_def, extension::postgres::PgExpr,
22};
23use sea_query_sqlx::SqlxBinder;
24use sqlx::PgConnection;
25use ulid::Ulid;
26use uuid::Uuid;
27
28use crate::{
29    DatabaseError, DatabaseInconsistencyError,
30    filter::{Filter, StatementExt},
31    iden::UpstreamOAuthAuthorizationSessions,
32    pagination::QueryBuilderExt,
33    tracing::ExecuteExt,
34};
35
36impl Filter for UpstreamOAuthSessionFilter<'_> {
37    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
38        sea_query::Condition::all()
39            .add_option(self.provider().map(|provider| {
40                Expr::col((
41                    UpstreamOAuthAuthorizationSessions::Table,
42                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
43                ))
44                .eq(Uuid::from(provider.id))
45            }))
46            .add_option(self.sub_claim().map(|sub| {
47                Expr::col((
48                    UpstreamOAuthAuthorizationSessions::Table,
49                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
50                ))
51                .cast_json_field("sub")
52                .eq(sub)
53            }))
54            .add_option(self.sid_claim().map(|sid| {
55                Expr::col((
56                    UpstreamOAuthAuthorizationSessions::Table,
57                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
58                ))
59                .cast_json_field("sid")
60                .eq(sid)
61            }))
62    }
63}
64
65/// An implementation of [`UpstreamOAuthSessionRepository`] for a PostgreSQL
66/// connection
67pub struct PgUpstreamOAuthSessionRepository<'c> {
68    conn: &'c mut PgConnection,
69}
70
71impl<'c> PgUpstreamOAuthSessionRepository<'c> {
72    /// Create a new [`PgUpstreamOAuthSessionRepository`] from an active
73    /// PostgreSQL connection
74    pub fn new(conn: &'c mut PgConnection) -> Self {
75        Self { conn }
76    }
77}
78
79#[derive(sqlx::FromRow)]
80#[enum_def]
81struct SessionLookup {
82    upstream_oauth_authorization_session_id: Uuid,
83    upstream_oauth_provider_id: Uuid,
84    upstream_oauth_link_id: Option<Uuid>,
85    state: String,
86    code_challenge_verifier: Option<String>,
87    nonce: Option<String>,
88    id_token: Option<String>,
89    id_token_claims: Option<serde_json::Value>,
90    userinfo: Option<serde_json::Value>,
91    created_at: DateTime<Utc>,
92    completed_at: Option<DateTime<Utc>>,
93    consumed_at: Option<DateTime<Utc>>,
94    extra_callback_parameters: Option<serde_json::Value>,
95    unlinked_at: Option<DateTime<Utc>>,
96}
97
98impl Node<Ulid> for SessionLookup {
99    fn cursor(&self) -> Ulid {
100        self.upstream_oauth_authorization_session_id.into()
101    }
102}
103
104impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
105    type Error = DatabaseInconsistencyError;
106
107    fn try_from(value: SessionLookup) -> Result<Self, Self::Error> {
108        let id = value.upstream_oauth_authorization_session_id.into();
109        let state = match (
110            value.upstream_oauth_link_id,
111            value.id_token,
112            value.id_token_claims,
113            value.extra_callback_parameters,
114            value.userinfo,
115            value.completed_at,
116            value.consumed_at,
117            value.unlinked_at,
118        ) {
119            (None, None, None, None, None, None, None, None) => {
120                UpstreamOAuthAuthorizationSessionState::Pending
121            }
122            (
123                Some(link_id),
124                id_token,
125                id_token_claims,
126                extra_callback_parameters,
127                userinfo,
128                Some(completed_at),
129                None,
130                None,
131            ) => UpstreamOAuthAuthorizationSessionState::Completed {
132                completed_at,
133                link_id: link_id.into(),
134                id_token,
135                id_token_claims,
136                extra_callback_parameters,
137                userinfo,
138            },
139            (
140                Some(link_id),
141                id_token,
142                id_token_claims,
143                extra_callback_parameters,
144                userinfo,
145                Some(completed_at),
146                Some(consumed_at),
147                None,
148            ) => UpstreamOAuthAuthorizationSessionState::Consumed {
149                completed_at,
150                link_id: link_id.into(),
151                id_token,
152                id_token_claims,
153                extra_callback_parameters,
154                userinfo,
155                consumed_at,
156            },
157            (
158                _,
159                id_token,
160                id_token_claims,
161                _,
162                _,
163                Some(completed_at),
164                consumed_at,
165                Some(unlinked_at),
166            ) => UpstreamOAuthAuthorizationSessionState::Unlinked {
167                completed_at,
168                id_token,
169                id_token_claims,
170                consumed_at,
171                unlinked_at,
172            },
173            _ => {
174                return Err(DatabaseInconsistencyError::on(
175                    "upstream_oauth_authorization_sessions",
176                )
177                .row(id));
178            }
179        };
180
181        Ok(Self {
182            id,
183            provider_id: value.upstream_oauth_provider_id.into(),
184            state_str: value.state,
185            nonce: value.nonce,
186            code_challenge_verifier: value.code_challenge_verifier,
187            created_at: value.created_at,
188            state,
189        })
190    }
191}
192
193#[async_trait]
194impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
195    type Error = DatabaseError;
196
197    #[tracing::instrument(
198        name = "db.upstream_oauth_authorization_session.lookup",
199        skip_all,
200        fields(
201            db.query.text,
202            upstream_oauth_provider.id = %id,
203        ),
204        err,
205    )]
206    async fn lookup(
207        &mut self,
208        id: Ulid,
209    ) -> Result<Option<UpstreamOAuthAuthorizationSession>, Self::Error> {
210        let res = sqlx::query_as!(
211            SessionLookup,
212            r#"
213                SELECT
214                    upstream_oauth_authorization_session_id,
215                    upstream_oauth_provider_id,
216                    upstream_oauth_link_id,
217                    state,
218                    code_challenge_verifier,
219                    nonce,
220                    id_token,
221                    id_token_claims,
222                    extra_callback_parameters,
223                    userinfo,
224                    created_at,
225                    completed_at,
226                    consumed_at,
227                    unlinked_at
228                FROM upstream_oauth_authorization_sessions
229                WHERE upstream_oauth_authorization_session_id = $1
230            "#,
231            Uuid::from(id),
232        )
233        .traced()
234        .fetch_optional(&mut *self.conn)
235        .await?;
236
237        let Some(res) = res else { return Ok(None) };
238
239        Ok(Some(res.try_into()?))
240    }
241
242    #[tracing::instrument(
243        name = "db.upstream_oauth_authorization_session.add",
244        skip_all,
245        fields(
246            db.query.text,
247            %upstream_oauth_provider.id,
248            upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
249            %upstream_oauth_provider.client_id,
250            upstream_oauth_authorization_session.id,
251        ),
252        err,
253    )]
254    async fn add(
255        &mut self,
256        rng: &mut (dyn RngCore + Send),
257        clock: &dyn Clock,
258        upstream_oauth_provider: &UpstreamOAuthProvider,
259        state_str: String,
260        code_challenge_verifier: Option<String>,
261        nonce: Option<String>,
262    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
263        let created_at = clock.now();
264        let id = Ulid::from_datetime_with_rng(created_at, rng);
265        tracing::Span::current().record(
266            "upstream_oauth_authorization_session.id",
267            tracing::field::display(id),
268        );
269
270        sqlx::query!(
271            r#"
272                INSERT INTO upstream_oauth_authorization_sessions (
273                    upstream_oauth_authorization_session_id,
274                    upstream_oauth_provider_id,
275                    state,
276                    code_challenge_verifier,
277                    nonce,
278                    created_at,
279                    completed_at,
280                    consumed_at,
281                    id_token,
282                    userinfo
283                ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
284            "#,
285            Uuid::from(id),
286            Uuid::from(upstream_oauth_provider.id),
287            &state_str,
288            code_challenge_verifier.as_deref(),
289            nonce,
290            created_at,
291        )
292        .traced()
293        .execute(&mut *self.conn)
294        .await?;
295
296        Ok(UpstreamOAuthAuthorizationSession {
297            id,
298            state: UpstreamOAuthAuthorizationSessionState::default(),
299            provider_id: upstream_oauth_provider.id,
300            state_str,
301            code_challenge_verifier,
302            nonce,
303            created_at,
304        })
305    }
306
307    #[tracing::instrument(
308        name = "db.upstream_oauth_authorization_session.complete_with_link",
309        skip_all,
310        fields(
311            db.query.text,
312            %upstream_oauth_authorization_session.id,
313            %upstream_oauth_link.id,
314        ),
315        err,
316    )]
317    async fn complete_with_link(
318        &mut self,
319        clock: &dyn Clock,
320        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
321        upstream_oauth_link: &UpstreamOAuthLink,
322        id_token: Option<String>,
323        id_token_claims: Option<serde_json::Value>,
324        extra_callback_parameters: Option<serde_json::Value>,
325        userinfo: Option<serde_json::Value>,
326    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
327        let completed_at = clock.now();
328
329        sqlx::query!(
330            r#"
331                UPDATE upstream_oauth_authorization_sessions
332                SET upstream_oauth_link_id = $1
333                  , completed_at = $2
334                  , id_token = $3
335                  , id_token_claims = $4
336                  , extra_callback_parameters = $5
337                  , userinfo = $6
338                WHERE upstream_oauth_authorization_session_id = $7
339            "#,
340            Uuid::from(upstream_oauth_link.id),
341            completed_at,
342            id_token,
343            id_token_claims,
344            extra_callback_parameters,
345            userinfo,
346            Uuid::from(upstream_oauth_authorization_session.id),
347        )
348        .traced()
349        .execute(&mut *self.conn)
350        .await?;
351
352        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
353            .complete(
354                completed_at,
355                upstream_oauth_link,
356                id_token,
357                id_token_claims,
358                extra_callback_parameters,
359                userinfo,
360            )
361            .map_err(DatabaseError::to_invalid_operation)?;
362
363        Ok(upstream_oauth_authorization_session)
364    }
365
366    /// Mark a session as consumed
367    #[tracing::instrument(
368        name = "db.upstream_oauth_authorization_session.consume",
369        skip_all,
370        fields(
371            db.query.text,
372            %upstream_oauth_authorization_session.id,
373        ),
374        err,
375    )]
376    async fn consume(
377        &mut self,
378        clock: &dyn Clock,
379        upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
380        browser_session: &BrowserSession,
381    ) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
382        let consumed_at = clock.now();
383        sqlx::query!(
384            r#"
385                UPDATE upstream_oauth_authorization_sessions
386                SET consumed_at = $1,
387                    user_session_id = $2
388                WHERE upstream_oauth_authorization_session_id = $3
389            "#,
390            consumed_at,
391            Uuid::from(browser_session.id),
392            Uuid::from(upstream_oauth_authorization_session.id),
393        )
394        .traced()
395        .execute(&mut *self.conn)
396        .await?;
397
398        let upstream_oauth_authorization_session = upstream_oauth_authorization_session
399            .consume(consumed_at)
400            .map_err(DatabaseError::to_invalid_operation)?;
401
402        Ok(upstream_oauth_authorization_session)
403    }
404
405    #[tracing::instrument(
406        name = "db.upstream_oauth_authorization_session.list",
407        skip_all,
408        fields(
409            db.query.text,
410        ),
411        err,
412    )]
413    async fn list(
414        &mut self,
415        filter: UpstreamOAuthSessionFilter<'_>,
416        pagination: Pagination,
417    ) -> Result<Page<UpstreamOAuthAuthorizationSession>, Self::Error> {
418        let (sql, arguments) = Query::select()
419            .expr_as(
420                Expr::col((
421                    UpstreamOAuthAuthorizationSessions::Table,
422                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
423                )),
424                SessionLookupIden::UpstreamOauthAuthorizationSessionId,
425            )
426            .expr_as(
427                Expr::col((
428                    UpstreamOAuthAuthorizationSessions::Table,
429                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthProviderId,
430                )),
431                SessionLookupIden::UpstreamOauthProviderId,
432            )
433            .expr_as(
434                Expr::col((
435                    UpstreamOAuthAuthorizationSessions::Table,
436                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthLinkId,
437                )),
438                SessionLookupIden::UpstreamOauthLinkId,
439            )
440            .expr_as(
441                Expr::col((
442                    UpstreamOAuthAuthorizationSessions::Table,
443                    UpstreamOAuthAuthorizationSessions::State,
444                )),
445                SessionLookupIden::State,
446            )
447            .expr_as(
448                Expr::col((
449                    UpstreamOAuthAuthorizationSessions::Table,
450                    UpstreamOAuthAuthorizationSessions::CodeChallengeVerifier,
451                )),
452                SessionLookupIden::CodeChallengeVerifier,
453            )
454            .expr_as(
455                Expr::col((
456                    UpstreamOAuthAuthorizationSessions::Table,
457                    UpstreamOAuthAuthorizationSessions::Nonce,
458                )),
459                SessionLookupIden::Nonce,
460            )
461            .expr_as(
462                Expr::col((
463                    UpstreamOAuthAuthorizationSessions::Table,
464                    UpstreamOAuthAuthorizationSessions::IdToken,
465                )),
466                SessionLookupIden::IdToken,
467            )
468            .expr_as(
469                Expr::col((
470                    UpstreamOAuthAuthorizationSessions::Table,
471                    UpstreamOAuthAuthorizationSessions::IdTokenClaims,
472                )),
473                SessionLookupIden::IdTokenClaims,
474            )
475            .expr_as(
476                Expr::col((
477                    UpstreamOAuthAuthorizationSessions::Table,
478                    UpstreamOAuthAuthorizationSessions::ExtraCallbackParameters,
479                )),
480                SessionLookupIden::ExtraCallbackParameters,
481            )
482            .expr_as(
483                Expr::col((
484                    UpstreamOAuthAuthorizationSessions::Table,
485                    UpstreamOAuthAuthorizationSessions::Userinfo,
486                )),
487                SessionLookupIden::Userinfo,
488            )
489            .expr_as(
490                Expr::col((
491                    UpstreamOAuthAuthorizationSessions::Table,
492                    UpstreamOAuthAuthorizationSessions::CreatedAt,
493                )),
494                SessionLookupIden::CreatedAt,
495            )
496            .expr_as(
497                Expr::col((
498                    UpstreamOAuthAuthorizationSessions::Table,
499                    UpstreamOAuthAuthorizationSessions::CompletedAt,
500                )),
501                SessionLookupIden::CompletedAt,
502            )
503            .expr_as(
504                Expr::col((
505                    UpstreamOAuthAuthorizationSessions::Table,
506                    UpstreamOAuthAuthorizationSessions::ConsumedAt,
507                )),
508                SessionLookupIden::ConsumedAt,
509            )
510            .expr_as(
511                Expr::col((
512                    UpstreamOAuthAuthorizationSessions::Table,
513                    UpstreamOAuthAuthorizationSessions::UnlinkedAt,
514                )),
515                SessionLookupIden::UnlinkedAt,
516            )
517            .from(UpstreamOAuthAuthorizationSessions::Table)
518            .apply_filter(filter)
519            .generate_pagination(
520                (
521                    UpstreamOAuthAuthorizationSessions::Table,
522                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
523                ),
524                pagination,
525            )
526            .build_sqlx(PostgresQueryBuilder);
527
528        let edges: Vec<SessionLookup> = sqlx::query_as_with(&sql, arguments)
529            .traced()
530            .fetch_all(&mut *self.conn)
531            .await?;
532
533        let page = pagination
534            .process(edges)
535            .try_map(UpstreamOAuthAuthorizationSession::try_from)?;
536
537        Ok(page)
538    }
539
540    #[tracing::instrument(
541        name = "db.upstream_oauth_authorization_session.count",
542        skip_all,
543        fields(
544            db.query.text,
545        ),
546        err,
547    )]
548    async fn count(
549        &mut self,
550        filter: UpstreamOAuthSessionFilter<'_>,
551    ) -> Result<usize, Self::Error> {
552        let (sql, arguments) = Query::select()
553            .expr(
554                Expr::col((
555                    UpstreamOAuthAuthorizationSessions::Table,
556                    UpstreamOAuthAuthorizationSessions::UpstreamOAuthAuthorizationSessionId,
557                ))
558                .count(),
559            )
560            .from(UpstreamOAuthAuthorizationSessions::Table)
561            .apply_filter(filter)
562            .build_sqlx(PostgresQueryBuilder);
563
564        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
565            .traced()
566            .fetch_one(&mut *self.conn)
567            .await?;
568
569        count
570            .try_into()
571            .map_err(DatabaseError::to_invalid_operation)
572    }
573
574    #[tracing::instrument(
575        name = "db.upstream_oauth_authorization_session.cleanup",
576        skip_all,
577        fields(
578            db.query.text,
579            since = since.map(tracing::field::display),
580            until = %until,
581            limit = limit,
582        ),
583        err,
584    )]
585    async fn cleanup_orphaned(
586        &mut self,
587        since: Option<Ulid>,
588        until: Ulid,
589        limit: usize,
590    ) -> Result<(usize, Option<Ulid>), Self::Error> {
591        // Use ULID cursor-based pagination for pending sessions only.
592        // We only delete sessions that are not yet completed.
593        // `MAX(uuid)` isn't a thing in Postgres, so we aggregate on the client side.
594        let res = sqlx::query_scalar!(
595            r#"
596                WITH to_delete AS (
597                    SELECT upstream_oauth_authorization_session_id
598                    FROM upstream_oauth_authorization_sessions
599                    WHERE ($1::uuid IS NULL OR upstream_oauth_authorization_session_id > $1)
600                      AND upstream_oauth_authorization_session_id <= $2
601                      AND user_session_id IS NULL
602                    ORDER BY upstream_oauth_authorization_session_id
603                    LIMIT $3
604                )
605                DELETE FROM upstream_oauth_authorization_sessions
606                USING to_delete
607                WHERE upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id = to_delete.upstream_oauth_authorization_session_id
608                RETURNING upstream_oauth_authorization_sessions.upstream_oauth_authorization_session_id
609            "#,
610            since.map(Uuid::from),
611            Uuid::from(until),
612            i64::try_from(limit).unwrap_or(i64::MAX)
613        )
614        .traced()
615        .fetch_all(&mut *self.conn)
616        .await?;
617
618        let count = res.len();
619        let max_id = res.into_iter().max();
620
621        Ok((count, max_id.map(Ulid::from)))
622    }
623}