Skip to main content

mas_storage_pg/oauth2/
authorization_grant.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2021-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::collections::BTreeMap;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13    AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, Clock, Pkce, Session,
14    UlidExt as _,
15};
16use mas_iana::oauth::PkceCodeChallengeMethod;
17use mas_storage::oauth2::OAuth2AuthorizationGrantRepository;
18use oauth2_types::{requests::ResponseMode, scope::Scope};
19use rand::RngCore;
20use sqlx::{PgConnection, types::Json};
21use ulid::Ulid;
22use url::Url;
23use uuid::Uuid;
24
25use crate::{DatabaseError, DatabaseInconsistencyError, tracing::ExecuteExt};
26
27/// An implementation of [`OAuth2AuthorizationGrantRepository`] for a PostgreSQL
28/// connection
29pub struct PgOAuth2AuthorizationGrantRepository<'c> {
30    conn: &'c mut PgConnection,
31}
32
33impl<'c> PgOAuth2AuthorizationGrantRepository<'c> {
34    /// Create a new [`PgOAuth2AuthorizationGrantRepository`] from an active
35    /// PostgreSQL connection
36    pub fn new(conn: &'c mut PgConnection) -> Self {
37        Self { conn }
38    }
39}
40
41struct GrantLookup {
42    oauth2_authorization_grant_id: Uuid,
43    created_at: DateTime<Utc>,
44    cancelled_at: Option<DateTime<Utc>>,
45    fulfilled_at: Option<DateTime<Utc>>,
46    exchanged_at: Option<DateTime<Utc>>,
47    scope: String,
48    state: Option<String>,
49    nonce: Option<String>,
50    redirect_uri: String,
51    response_mode: String,
52    response_type_code: bool,
53    response_type_id_token: bool,
54    authorization_code: Option<String>,
55    code_challenge: Option<String>,
56    code_challenge_method: Option<String>,
57    login_hint: Option<String>,
58    locale: Option<String>,
59    raw_parameters: Option<Json<BTreeMap<String, String>>>,
60    oauth2_client_id: Uuid,
61    oauth2_session_id: Option<Uuid>,
62}
63
64impl TryFrom<GrantLookup> for AuthorizationGrant {
65    type Error = DatabaseInconsistencyError;
66
67    fn try_from(value: GrantLookup) -> Result<Self, Self::Error> {
68        let id = value.oauth2_authorization_grant_id.into();
69        let scope: Scope = value.scope.parse().map_err(|e| {
70            DatabaseInconsistencyError::on("oauth2_authorization_grants")
71                .column("scope")
72                .row(id)
73                .source(e)
74        })?;
75
76        let stage = match (
77            value.fulfilled_at,
78            value.exchanged_at,
79            value.cancelled_at,
80            value.oauth2_session_id,
81        ) {
82            (None, None, None, None) => AuthorizationGrantStage::Pending,
83            (Some(fulfilled_at), None, None, Some(session_id)) => {
84                AuthorizationGrantStage::Fulfilled {
85                    session_id: session_id.into(),
86                    fulfilled_at,
87                }
88            }
89            (Some(fulfilled_at), Some(exchanged_at), None, Some(session_id)) => {
90                AuthorizationGrantStage::Exchanged {
91                    session_id: session_id.into(),
92                    fulfilled_at,
93                    exchanged_at,
94                }
95            }
96            (None, None, Some(cancelled_at), None) => {
97                AuthorizationGrantStage::Cancelled { cancelled_at }
98            }
99            _ => {
100                return Err(
101                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
102                        .column("stage")
103                        .row(id),
104                );
105            }
106        };
107
108        let pkce = match (value.code_challenge, value.code_challenge_method) {
109            (Some(challenge), Some(challenge_method)) if challenge_method == "plain" => {
110                Some(Pkce {
111                    challenge_method: PkceCodeChallengeMethod::Plain,
112                    challenge,
113                })
114            }
115            (Some(challenge), Some(challenge_method)) if challenge_method == "S256" => Some(Pkce {
116                challenge_method: PkceCodeChallengeMethod::S256,
117                challenge,
118            }),
119            (None, None) => None,
120            _ => {
121                return Err(
122                    DatabaseInconsistencyError::on("oauth2_authorization_grants")
123                        .column("code_challenge_method")
124                        .row(id),
125                );
126            }
127        };
128
129        let code: Option<AuthorizationCode> =
130            match (value.response_type_code, value.authorization_code, pkce) {
131                (false, None, None) => None,
132                (true, Some(code), pkce) => Some(AuthorizationCode { code, pkce }),
133                _ => {
134                    return Err(
135                        DatabaseInconsistencyError::on("oauth2_authorization_grants")
136                            .column("authorization_code")
137                            .row(id),
138                    );
139                }
140            };
141
142        let redirect_uri = value.redirect_uri.parse().map_err(|e| {
143            DatabaseInconsistencyError::on("oauth2_authorization_grants")
144                .column("redirect_uri")
145                .row(id)
146                .source(e)
147        })?;
148
149        let response_mode = value.response_mode.parse().map_err(|e| {
150            DatabaseInconsistencyError::on("oauth2_authorization_grants")
151                .column("response_mode")
152                .row(id)
153                .source(e)
154        })?;
155
156        Ok(AuthorizationGrant {
157            id,
158            stage,
159            client_id: value.oauth2_client_id.into(),
160            code,
161            scope,
162            state: value.state,
163            nonce: value.nonce,
164            response_mode,
165            redirect_uri,
166            created_at: value.created_at,
167            response_type_id_token: value.response_type_id_token,
168            login_hint: value.login_hint,
169            locale: value.locale,
170            raw_parameters: value.raw_parameters.map(|Json(x)| x).unwrap_or_default(),
171        })
172    }
173}
174
175#[async_trait]
176impl OAuth2AuthorizationGrantRepository for PgOAuth2AuthorizationGrantRepository<'_> {
177    type Error = DatabaseError;
178
179    #[tracing::instrument(
180        name = "db.oauth2_authorization_grant.add",
181        skip_all,
182        fields(
183            db.query.text,
184            grant.id,
185            grant.scope = %scope,
186            %client.id,
187        ),
188        err,
189    )]
190    async fn add(
191        &mut self,
192        rng: &mut (dyn RngCore + Send),
193        clock: &dyn Clock,
194        client: &Client,
195        redirect_uri: Url,
196        scope: Scope,
197        code: Option<AuthorizationCode>,
198        state: Option<String>,
199        nonce: Option<String>,
200        response_mode: ResponseMode,
201        response_type_id_token: bool,
202        login_hint: Option<String>,
203        locale: Option<String>,
204        raw_parameters: BTreeMap<String, String>,
205    ) -> Result<AuthorizationGrant, Self::Error> {
206        let code_challenge = code
207            .as_ref()
208            .and_then(|c| c.pkce.as_ref())
209            .map(|p| &p.challenge);
210        let code_challenge_method = code
211            .as_ref()
212            .and_then(|c| c.pkce.as_ref())
213            .map(|p| p.challenge_method.to_string());
214        let code_str = code.as_ref().map(|c| &c.code);
215
216        let created_at = clock.now();
217        let id = Ulid::from_datetime_with_rng(created_at, rng);
218        tracing::Span::current().record("grant.id", tracing::field::display(id));
219
220        sqlx::query!(
221            r#"
222                INSERT INTO oauth2_authorization_grants (
223                     oauth2_authorization_grant_id,
224                     oauth2_client_id,
225                     redirect_uri,
226                     scope,
227                     state,
228                     nonce,
229                     response_mode,
230                     code_challenge,
231                     code_challenge_method,
232                     response_type_code,
233                     response_type_id_token,
234                     authorization_code,
235                     login_hint,
236                     locale,
237                     raw_parameters,
238                     created_at
239                )
240                VALUES
241                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
242            "#,
243            Uuid::from(id),
244            Uuid::from(client.id),
245            redirect_uri.to_string(),
246            scope.to_string(),
247            state,
248            nonce,
249            response_mode.to_string(),
250            code_challenge,
251            code_challenge_method,
252            code.is_some(),
253            response_type_id_token,
254            code_str,
255            login_hint,
256            locale,
257            Json(&raw_parameters) as _,
258            created_at,
259        )
260        .traced()
261        .execute(&mut *self.conn)
262        .await?;
263
264        Ok(AuthorizationGrant {
265            id,
266            stage: AuthorizationGrantStage::Pending,
267            code,
268            redirect_uri,
269            client_id: client.id,
270            scope,
271            state,
272            nonce,
273            response_mode,
274            created_at,
275            response_type_id_token,
276            login_hint,
277            locale,
278            raw_parameters,
279        })
280    }
281
282    #[tracing::instrument(
283        name = "db.oauth2_authorization_grant.lookup",
284        skip_all,
285        fields(
286            db.query.text,
287            grant.id = %id,
288        ),
289        err,
290    )]
291    async fn lookup(&mut self, id: Ulid) -> Result<Option<AuthorizationGrant>, Self::Error> {
292        let res = sqlx::query_as!(
293            GrantLookup,
294            r#"
295                SELECT oauth2_authorization_grant_id
296                     , created_at
297                     , cancelled_at
298                     , fulfilled_at
299                     , exchanged_at
300                     , scope
301                     , state
302                     , redirect_uri
303                     , response_mode
304                     , nonce
305                     , oauth2_client_id
306                     , authorization_code
307                     , response_type_code
308                     , response_type_id_token
309                     , code_challenge
310                     , code_challenge_method
311                     , login_hint
312                     , locale
313                     , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
314                     , oauth2_session_id
315                FROM
316                    oauth2_authorization_grants
317
318                WHERE oauth2_authorization_grant_id = $1
319            "#,
320            Uuid::from(id),
321        )
322        .traced()
323        .fetch_optional(&mut *self.conn)
324        .await?;
325
326        let Some(res) = res else { return Ok(None) };
327
328        Ok(Some(res.try_into()?))
329    }
330
331    #[tracing::instrument(
332        name = "db.oauth2_authorization_grant.find_by_code",
333        skip_all,
334        fields(
335            db.query.text,
336        ),
337        err,
338    )]
339    async fn find_by_code(
340        &mut self,
341        code: &str,
342    ) -> Result<Option<AuthorizationGrant>, Self::Error> {
343        let res = sqlx::query_as!(
344            GrantLookup,
345            r#"
346                SELECT oauth2_authorization_grant_id
347                     , created_at
348                     , cancelled_at
349                     , fulfilled_at
350                     , exchanged_at
351                     , scope
352                     , state
353                     , redirect_uri
354                     , response_mode
355                     , nonce
356                     , oauth2_client_id
357                     , authorization_code
358                     , response_type_code
359                     , response_type_id_token
360                     , code_challenge
361                     , code_challenge_method
362                     , login_hint
363                     , locale
364                     , raw_parameters AS "raw_parameters: Json<BTreeMap<String, String>>"
365                     , oauth2_session_id
366                FROM
367                    oauth2_authorization_grants
368
369                WHERE authorization_code = $1
370            "#,
371            code,
372        )
373        .traced()
374        .fetch_optional(&mut *self.conn)
375        .await?;
376
377        let Some(res) = res else { return Ok(None) };
378
379        Ok(Some(res.try_into()?))
380    }
381
382    #[tracing::instrument(
383        name = "db.oauth2_authorization_grant.fulfill",
384        skip_all,
385        fields(
386            db.query.text,
387            %grant.id,
388            client.id = %grant.client_id,
389            %session.id,
390        ),
391        err,
392    )]
393    async fn fulfill(
394        &mut self,
395        clock: &dyn Clock,
396        session: &Session,
397        grant: AuthorizationGrant,
398    ) -> Result<AuthorizationGrant, Self::Error> {
399        let fulfilled_at = clock.now();
400        let res = sqlx::query!(
401            r#"
402                UPDATE oauth2_authorization_grants
403                SET fulfilled_at = $2
404                  , oauth2_session_id = $3
405                WHERE oauth2_authorization_grant_id = $1
406            "#,
407            Uuid::from(grant.id),
408            fulfilled_at,
409            Uuid::from(session.id),
410        )
411        .traced()
412        .execute(&mut *self.conn)
413        .await?;
414
415        DatabaseError::ensure_affected_rows(&res, 1)?;
416
417        // XXX: check affected rows & new methods
418        let grant = grant
419            .fulfill(fulfilled_at, session)
420            .map_err(DatabaseError::to_invalid_operation)?;
421
422        Ok(grant)
423    }
424
425    #[tracing::instrument(
426        name = "db.oauth2_authorization_grant.exchange",
427        skip_all,
428        fields(
429            db.query.text,
430            %grant.id,
431            client.id = %grant.client_id,
432        ),
433        err,
434    )]
435    async fn exchange(
436        &mut self,
437        clock: &dyn Clock,
438        grant: AuthorizationGrant,
439    ) -> Result<AuthorizationGrant, Self::Error> {
440        let exchanged_at = clock.now();
441        let res = sqlx::query!(
442            r#"
443                UPDATE oauth2_authorization_grants
444                SET exchanged_at = $2
445                WHERE oauth2_authorization_grant_id = $1
446            "#,
447            Uuid::from(grant.id),
448            exchanged_at,
449        )
450        .traced()
451        .execute(&mut *self.conn)
452        .await?;
453
454        DatabaseError::ensure_affected_rows(&res, 1)?;
455
456        let grant = grant
457            .exchange(exchanged_at)
458            .map_err(DatabaseError::to_invalid_operation)?;
459
460        Ok(grant)
461    }
462
463    #[tracing::instrument(
464        name = "db.oauth2_authorization_grant.cleanup",
465        skip_all,
466        fields(
467            db.query.text,
468            since = since.map(tracing::field::display),
469            until = %until,
470            limit = limit,
471        ),
472        err,
473    )]
474    async fn cleanup(
475        &mut self,
476        since: Option<Ulid>,
477        until: Ulid,
478        limit: usize,
479    ) -> Result<(usize, Option<Ulid>), Self::Error> {
480        // `MAX(uuid)` isn't a thing in Postgres, so we can't just re-select the
481        // deleted rows and do a MAX on the `oauth2_authorization_grant_id`.
482        // Instead, we do the aggregation on the client side, which is a little
483        // less efficient, but good enough.
484        let res = sqlx::query_scalar!(
485            r#"
486                WITH to_delete AS (
487                    SELECT oauth2_authorization_grant_id
488                    FROM oauth2_authorization_grants
489                    WHERE ($1::uuid IS NULL OR oauth2_authorization_grant_id > $1)
490                    AND oauth2_authorization_grant_id <= $2
491                    ORDER BY oauth2_authorization_grant_id
492                    LIMIT $3
493                )
494                DELETE FROM oauth2_authorization_grants
495                USING to_delete
496                WHERE oauth2_authorization_grants.oauth2_authorization_grant_id = to_delete.oauth2_authorization_grant_id
497                RETURNING oauth2_authorization_grants.oauth2_authorization_grant_id
498            "#,
499            since.map(Uuid::from),
500            Uuid::from(until),
501            i64::try_from(limit).unwrap_or(i64::MAX)
502        )
503        .traced()
504        .fetch_all(&mut *self.conn)
505        .await?;
506
507        let count = res.len();
508        let max_id = res.into_iter().max();
509
510        Ok((count, max_id.map(Ulid::from)))
511    }
512}