Skip to main content

mas_storage_pg/oauth2/
device_code_grant.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::net::IpAddr;
9
10use async_trait::async_trait;
11use chrono::{DateTime, Utc};
12use mas_data_model::{
13    BrowserSession, Clock, DeviceCodeGrant, DeviceCodeGrantState, Session, UlidExt as _,
14};
15use mas_storage::oauth2::{OAuth2DeviceCodeGrantParams, OAuth2DeviceCodeGrantRepository};
16use oauth2_types::scope::Scope;
17use rand::RngCore;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use uuid::Uuid;
21
22use crate::{DatabaseError, ExecuteExt, errors::DatabaseInconsistencyError};
23
24/// An implementation of [`OAuth2DeviceCodeGrantRepository`] for a PostgreSQL
25/// connection
26pub struct PgOAuth2DeviceCodeGrantRepository<'c> {
27    conn: &'c mut PgConnection,
28}
29
30impl<'c> PgOAuth2DeviceCodeGrantRepository<'c> {
31    /// Create a new [`PgOAuth2DeviceCodeGrantRepository`] from an active
32    /// PostgreSQL connection
33    pub fn new(conn: &'c mut PgConnection) -> Self {
34        Self { conn }
35    }
36}
37
38struct OAuth2DeviceGrantLookup {
39    oauth2_device_code_grant_id: Uuid,
40    oauth2_client_id: Uuid,
41    scope: String,
42    device_code: String,
43    user_code: String,
44    created_at: DateTime<Utc>,
45    expires_at: DateTime<Utc>,
46    fulfilled_at: Option<DateTime<Utc>>,
47    rejected_at: Option<DateTime<Utc>>,
48    exchanged_at: Option<DateTime<Utc>>,
49    user_session_id: Option<Uuid>,
50    oauth2_session_id: Option<Uuid>,
51    ip_address: Option<IpAddr>,
52    user_agent: Option<String>,
53    locale: Option<String>,
54}
55
56impl TryFrom<OAuth2DeviceGrantLookup> for DeviceCodeGrant {
57    type Error = DatabaseInconsistencyError;
58
59    fn try_from(
60        OAuth2DeviceGrantLookup {
61            oauth2_device_code_grant_id,
62            oauth2_client_id,
63            scope,
64            device_code,
65            user_code,
66            created_at,
67            expires_at,
68            fulfilled_at,
69            rejected_at,
70            exchanged_at,
71            user_session_id,
72            oauth2_session_id,
73            ip_address,
74            user_agent,
75            locale,
76        }: OAuth2DeviceGrantLookup,
77    ) -> Result<Self, Self::Error> {
78        let id = Ulid::from(oauth2_device_code_grant_id);
79        let client_id = Ulid::from(oauth2_client_id);
80
81        let scope: Scope = scope.parse().map_err(|e| {
82            DatabaseInconsistencyError::on("oauth2_authorization_grants")
83                .column("scope")
84                .row(id)
85                .source(e)
86        })?;
87
88        let state = match (
89            fulfilled_at,
90            rejected_at,
91            exchanged_at,
92            user_session_id,
93            oauth2_session_id,
94        ) {
95            (None, None, None, None, None) => DeviceCodeGrantState::Pending,
96
97            (Some(fulfilled_at), None, None, Some(user_session_id), None) => {
98                DeviceCodeGrantState::Fulfilled {
99                    browser_session_id: Ulid::from(user_session_id),
100                    fulfilled_at,
101                }
102            }
103
104            (None, Some(rejected_at), None, Some(user_session_id), None) => {
105                DeviceCodeGrantState::Rejected {
106                    browser_session_id: Ulid::from(user_session_id),
107                    rejected_at,
108                }
109            }
110
111            (
112                Some(fulfilled_at),
113                None,
114                Some(exchanged_at),
115                Some(user_session_id),
116                Some(oauth2_session_id),
117            ) => DeviceCodeGrantState::Exchanged {
118                browser_session_id: Ulid::from(user_session_id),
119                session_id: Ulid::from(oauth2_session_id),
120                fulfilled_at,
121                exchanged_at,
122            },
123
124            _ => return Err(DatabaseInconsistencyError::on("oauth2_device_code_grant").row(id)),
125        };
126
127        Ok(DeviceCodeGrant {
128            id,
129            state,
130            client_id,
131            scope,
132            user_code,
133            device_code,
134            created_at,
135            expires_at,
136            ip_address,
137            user_agent,
138            locale,
139        })
140    }
141}
142
143#[async_trait]
144impl OAuth2DeviceCodeGrantRepository for PgOAuth2DeviceCodeGrantRepository<'_> {
145    type Error = DatabaseError;
146
147    #[tracing::instrument(
148        name = "db.oauth2_device_code_grant.add",
149        skip_all,
150        fields(
151            db.query.text,
152            oauth2_device_code.id,
153            oauth2_device_code.scope = %params.scope,
154            oauth2_client.id = %params.client.id,
155        ),
156        err,
157    )]
158    async fn add(
159        &mut self,
160        rng: &mut (dyn RngCore + Send),
161        clock: &dyn Clock,
162        params: OAuth2DeviceCodeGrantParams<'_>,
163    ) -> Result<DeviceCodeGrant, Self::Error> {
164        let now = clock.now();
165        let id = Ulid::from_datetime_with_rng(now, rng);
166        tracing::Span::current().record("oauth2_device_code.id", tracing::field::display(id));
167
168        let created_at = now;
169        let expires_at = now + params.expires_in;
170        let client_id = params.client.id;
171
172        sqlx::query!(
173            r#"
174                INSERT INTO "oauth2_device_code_grant"
175                    ( oauth2_device_code_grant_id
176                    , oauth2_client_id
177                    , scope
178                    , device_code
179                    , user_code
180                    , created_at
181                    , expires_at
182                    , ip_address
183                    , user_agent
184                    )
185                VALUES
186                    ($1, $2, $3, $4, $5, $6, $7, $8, $9)
187            "#,
188            Uuid::from(id),
189            Uuid::from(client_id),
190            params.scope.to_string(),
191            &params.device_code,
192            &params.user_code,
193            created_at,
194            expires_at,
195            params.ip_address as Option<IpAddr>,
196            params.user_agent.as_deref(),
197        )
198        .traced()
199        .execute(&mut *self.conn)
200        .await?;
201
202        Ok(DeviceCodeGrant {
203            id,
204            state: DeviceCodeGrantState::Pending,
205            client_id,
206            scope: params.scope,
207            user_code: params.user_code,
208            device_code: params.device_code,
209            created_at,
210            expires_at,
211            ip_address: params.ip_address,
212            user_agent: params.user_agent,
213            locale: None,
214        })
215    }
216
217    #[tracing::instrument(
218        name = "db.oauth2_device_code_grant.lookup",
219        skip_all,
220        fields(
221            db.query.text,
222            oauth2_device_code.id = %id,
223        ),
224        err,
225    )]
226    async fn lookup(&mut self, id: Ulid) -> Result<Option<DeviceCodeGrant>, Self::Error> {
227        let res = sqlx::query_as!(
228            OAuth2DeviceGrantLookup,
229            r#"
230                SELECT oauth2_device_code_grant_id
231                     , oauth2_client_id
232                     , scope
233                     , device_code
234                     , user_code
235                     , created_at
236                     , expires_at
237                     , fulfilled_at
238                     , rejected_at
239                     , exchanged_at
240                     , user_session_id
241                     , oauth2_session_id
242                     , ip_address as "ip_address: IpAddr"
243                     , user_agent
244                     , locale
245                FROM
246                    oauth2_device_code_grant
247
248                WHERE oauth2_device_code_grant_id = $1
249            "#,
250            Uuid::from(id),
251        )
252        .traced()
253        .fetch_optional(&mut *self.conn)
254        .await?;
255
256        let Some(res) = res else { return Ok(None) };
257
258        Ok(Some(res.try_into()?))
259    }
260
261    #[tracing::instrument(
262        name = "db.oauth2_device_code_grant.find_by_user_code",
263        skip_all,
264        fields(
265            db.query.text,
266            oauth2_device_code.user_code = %user_code,
267        ),
268        err,
269    )]
270    async fn find_by_user_code(
271        &mut self,
272        user_code: &str,
273    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
274        let res = sqlx::query_as!(
275            OAuth2DeviceGrantLookup,
276            r#"
277                SELECT oauth2_device_code_grant_id
278                     , oauth2_client_id
279                     , scope
280                     , device_code
281                     , user_code
282                     , created_at
283                     , expires_at
284                     , fulfilled_at
285                     , rejected_at
286                     , exchanged_at
287                     , user_session_id
288                     , oauth2_session_id
289                     , ip_address as "ip_address: IpAddr"
290                     , user_agent
291                     , locale
292                FROM
293                    oauth2_device_code_grant
294
295                WHERE user_code = $1
296            "#,
297            user_code,
298        )
299        .traced()
300        .fetch_optional(&mut *self.conn)
301        .await?;
302
303        let Some(res) = res else { return Ok(None) };
304
305        Ok(Some(res.try_into()?))
306    }
307
308    #[tracing::instrument(
309        name = "db.oauth2_device_code_grant.find_by_device_code",
310        skip_all,
311        fields(
312            db.query.text,
313            oauth2_device_code.device_code = %device_code,
314        ),
315        err,
316    )]
317    async fn find_by_device_code(
318        &mut self,
319        device_code: &str,
320    ) -> Result<Option<DeviceCodeGrant>, Self::Error> {
321        let res = sqlx::query_as!(
322            OAuth2DeviceGrantLookup,
323            r#"
324                SELECT oauth2_device_code_grant_id
325                     , oauth2_client_id
326                     , scope
327                     , device_code
328                     , user_code
329                     , created_at
330                     , expires_at
331                     , fulfilled_at
332                     , rejected_at
333                     , exchanged_at
334                     , user_session_id
335                     , oauth2_session_id
336                     , ip_address as "ip_address: IpAddr"
337                     , user_agent
338                     , locale
339                FROM
340                    oauth2_device_code_grant
341
342                WHERE device_code = $1
343            "#,
344            device_code,
345        )
346        .traced()
347        .fetch_optional(&mut *self.conn)
348        .await?;
349
350        let Some(res) = res else { return Ok(None) };
351
352        Ok(Some(res.try_into()?))
353    }
354
355    #[tracing::instrument(
356        name = "db.oauth2_device_code_grant.fulfill",
357        skip_all,
358        fields(
359            db.query.text,
360            oauth2_device_code.id = %device_code_grant.id,
361            oauth2_client.id = %device_code_grant.client_id,
362            browser_session.id = %browser_session.id,
363            user.id = %browser_session.user.id,
364        ),
365        err,
366    )]
367    async fn fulfill(
368        &mut self,
369        clock: &dyn Clock,
370        device_code_grant: DeviceCodeGrant,
371        browser_session: &BrowserSession,
372        locale: Option<String>,
373    ) -> Result<DeviceCodeGrant, Self::Error> {
374        let fulfilled_at = clock.now();
375        let device_code_grant = device_code_grant
376            .fulfill(browser_session, locale, fulfilled_at)
377            .map_err(DatabaseError::to_invalid_operation)?;
378
379        let res = sqlx::query!(
380            r#"
381                UPDATE oauth2_device_code_grant
382                SET fulfilled_at = $1
383                  , user_session_id = $2
384                  , locale = $3
385                WHERE oauth2_device_code_grant_id = $4
386            "#,
387            fulfilled_at,
388            Uuid::from(browser_session.id),
389            device_code_grant.locale.as_deref(),
390            Uuid::from(device_code_grant.id),
391        )
392        .traced()
393        .execute(&mut *self.conn)
394        .await?;
395
396        DatabaseError::ensure_affected_rows(&res, 1)?;
397
398        Ok(device_code_grant)
399    }
400
401    #[tracing::instrument(
402        name = "db.oauth2_device_code_grant.reject",
403        skip_all,
404        fields(
405            db.query.text,
406            oauth2_device_code.id = %device_code_grant.id,
407            oauth2_client.id = %device_code_grant.client_id,
408            browser_session.id = %browser_session.id,
409            user.id = %browser_session.user.id,
410        ),
411        err,
412    )]
413    async fn reject(
414        &mut self,
415        clock: &dyn Clock,
416        device_code_grant: DeviceCodeGrant,
417        browser_session: &BrowserSession,
418    ) -> Result<DeviceCodeGrant, Self::Error> {
419        let fulfilled_at = clock.now();
420        let device_code_grant = device_code_grant
421            .reject(browser_session, fulfilled_at)
422            .map_err(DatabaseError::to_invalid_operation)?;
423
424        let res = sqlx::query!(
425            r#"
426                UPDATE oauth2_device_code_grant
427                SET rejected_at = $1
428                  , user_session_id = $2
429                WHERE oauth2_device_code_grant_id = $3
430            "#,
431            fulfilled_at,
432            Uuid::from(browser_session.id),
433            Uuid::from(device_code_grant.id),
434        )
435        .traced()
436        .execute(&mut *self.conn)
437        .await?;
438
439        DatabaseError::ensure_affected_rows(&res, 1)?;
440
441        Ok(device_code_grant)
442    }
443
444    #[tracing::instrument(
445        name = "db.oauth2_device_code_grant.exchange",
446        skip_all,
447        fields(
448            db.query.text,
449            oauth2_device_code.id = %device_code_grant.id,
450            oauth2_client.id = %device_code_grant.client_id,
451            oauth2_session.id = %session.id,
452        ),
453        err,
454    )]
455    async fn exchange(
456        &mut self,
457        clock: &dyn Clock,
458        device_code_grant: DeviceCodeGrant,
459        session: &Session,
460    ) -> Result<DeviceCodeGrant, Self::Error> {
461        let exchanged_at = clock.now();
462        let device_code_grant = device_code_grant
463            .exchange(session, exchanged_at)
464            .map_err(DatabaseError::to_invalid_operation)?;
465
466        let res = sqlx::query!(
467            r#"
468                UPDATE oauth2_device_code_grant
469                SET exchanged_at = $1
470                  , oauth2_session_id = $2
471                WHERE oauth2_device_code_grant_id = $3
472            "#,
473            exchanged_at,
474            Uuid::from(session.id),
475            Uuid::from(device_code_grant.id),
476        )
477        .traced()
478        .execute(&mut *self.conn)
479        .await?;
480
481        DatabaseError::ensure_affected_rows(&res, 1)?;
482
483        Ok(device_code_grant)
484    }
485
486    #[tracing::instrument(
487        name = "db.oauth2_device_code_grant.cleanup",
488        skip_all,
489        fields(
490            db.query.text,
491            since = since.map(tracing::field::display),
492            until = %until,
493            limit = limit,
494        ),
495        err,
496    )]
497    async fn cleanup(
498        &mut self,
499        since: Option<Ulid>,
500        until: Ulid,
501        limit: usize,
502    ) -> Result<(usize, Option<Ulid>), Self::Error> {
503        // `MAX(uuid)` isn't a thing in Postgres, so we can't just re-select the
504        // deleted rows and do a MAX on the `oauth2_device_code_grant_id`.
505        // Instead, we do the aggregation on the client side, which is a little
506        // less efficient, but good enough.
507        let res = sqlx::query_scalar!(
508            r#"
509                WITH to_delete AS (
510                    SELECT oauth2_device_code_grant_id
511                    FROM oauth2_device_code_grant
512                    WHERE ($1::uuid IS NULL OR oauth2_device_code_grant_id > $1)
513                    AND oauth2_device_code_grant_id <= $2
514                    ORDER BY oauth2_device_code_grant_id
515                    LIMIT $3
516                )
517                DELETE FROM oauth2_device_code_grant
518                USING to_delete
519                WHERE oauth2_device_code_grant.oauth2_device_code_grant_id = to_delete.oauth2_device_code_grant_id
520                RETURNING oauth2_device_code_grant.oauth2_device_code_grant_id
521            "#,
522            since.map(Uuid::from),
523            Uuid::from(until),
524            i64::try_from(limit).unwrap_or(i64::MAX)
525        )
526        .traced()
527        .fetch_all(&mut *self.conn)
528        .await?;
529
530        let count = res.len();
531        let max_id = res.into_iter().max();
532
533        Ok((count, max_id.map(Ulid::from)))
534    }
535}