mas_storage_pg/compat/
sso_login.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{BrowserSession, CompatSession, CompatSsoLogin, CompatSsoLoginState};
10use mas_storage::{
11    Clock, Page, Pagination,
12    compat::{CompatSsoLoginFilter, CompatSsoLoginRepository},
13};
14use rand::RngCore;
15use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
16use sea_query_binder::SqlxBinder;
17use sqlx::PgConnection;
18use ulid::Ulid;
19use url::Url;
20use uuid::Uuid;
21
22use crate::{
23    DatabaseError, DatabaseInconsistencyError,
24    filter::{Filter, StatementExt},
25    iden::{CompatSsoLogins, UserSessions},
26    pagination::QueryBuilderExt,
27    tracing::ExecuteExt,
28};
29
30/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL
31/// connection
32pub struct PgCompatSsoLoginRepository<'c> {
33    conn: &'c mut PgConnection,
34}
35
36impl<'c> PgCompatSsoLoginRepository<'c> {
37    /// Create a new [`PgCompatSsoLoginRepository`] from an active PostgreSQL
38    /// connection
39    pub fn new(conn: &'c mut PgConnection) -> Self {
40        Self { conn }
41    }
42}
43
44#[derive(sqlx::FromRow, Debug)]
45#[enum_def]
46struct CompatSsoLoginLookup {
47    compat_sso_login_id: Uuid,
48    login_token: String,
49    redirect_uri: String,
50    created_at: DateTime<Utc>,
51    fulfilled_at: Option<DateTime<Utc>>,
52    exchanged_at: Option<DateTime<Utc>>,
53    user_session_id: Option<Uuid>,
54    compat_session_id: Option<Uuid>,
55}
56
57impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
58    type Error = DatabaseInconsistencyError;
59
60    fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
61        let id = res.compat_sso_login_id.into();
62        let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
63            DatabaseInconsistencyError::on("compat_sso_logins")
64                .column("redirect_uri")
65                .row(id)
66                .source(e)
67        })?;
68
69        let state = match (
70            res.fulfilled_at,
71            res.exchanged_at,
72            res.user_session_id,
73            res.compat_session_id,
74        ) {
75            (None, None, None, None) => CompatSsoLoginState::Pending,
76            (Some(fulfilled_at), None, Some(browser_session_id), None) => {
77                CompatSsoLoginState::Fulfilled {
78                    fulfilled_at,
79                    browser_session_id: browser_session_id.into(),
80                }
81            }
82            (Some(fulfilled_at), Some(exchanged_at), _, Some(compat_session_id)) => {
83                CompatSsoLoginState::Exchanged {
84                    fulfilled_at,
85                    exchanged_at,
86                    compat_session_id: compat_session_id.into(),
87                }
88            }
89            _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
90        };
91
92        Ok(CompatSsoLogin {
93            id,
94            login_token: res.login_token,
95            redirect_uri,
96            created_at: res.created_at,
97            state,
98        })
99    }
100}
101
102impl Filter for CompatSsoLoginFilter<'_> {
103    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
104        sea_query::Condition::all()
105            .add_option(self.user().map(|user| {
106                Expr::exists(
107                    Query::select()
108                        .expr(Expr::cust("1"))
109                        .from(UserSessions::Table)
110                        .and_where(
111                            Expr::col((UserSessions::Table, UserSessions::UserId))
112                                .eq(Uuid::from(user.id)),
113                        )
114                        .and_where(
115                            Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId))
116                                .equals((UserSessions::Table, UserSessions::UserSessionId)),
117                        )
118                        .take(),
119                )
120            }))
121            .add_option(self.state().map(|state| {
122                if state.is_exchanged() {
123                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
124                } else if state.is_fulfilled() {
125                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
126                        .is_not_null()
127                        .and(
128                            Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
129                                .is_null(),
130                        )
131                } else {
132                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
133                }
134            }))
135    }
136}
137
138#[async_trait]
139impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> {
140    type Error = DatabaseError;
141
142    #[tracing::instrument(
143        name = "db.compat_sso_login.lookup",
144        skip_all,
145        fields(
146            db.query.text,
147            compat_sso_login.id = %id,
148        ),
149        err,
150    )]
151    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
152        let res = sqlx::query_as!(
153            CompatSsoLoginLookup,
154            r#"
155                SELECT compat_sso_login_id
156                     , login_token
157                     , redirect_uri
158                     , created_at
159                     , fulfilled_at
160                     , exchanged_at
161                     , compat_session_id
162                     , user_session_id
163
164                FROM compat_sso_logins
165                WHERE compat_sso_login_id = $1
166            "#,
167            Uuid::from(id),
168        )
169        .traced()
170        .fetch_optional(&mut *self.conn)
171        .await?;
172
173        let Some(res) = res else { return Ok(None) };
174
175        Ok(Some(res.try_into()?))
176    }
177
178    #[tracing::instrument(
179        name = "db.compat_sso_login.find_for_session",
180        skip_all,
181        fields(
182            db.query.text,
183            %compat_session.id,
184        ),
185        err,
186    )]
187    async fn find_for_session(
188        &mut self,
189        compat_session: &CompatSession,
190    ) -> Result<Option<CompatSsoLogin>, Self::Error> {
191        let res = sqlx::query_as!(
192            CompatSsoLoginLookup,
193            r#"
194                SELECT compat_sso_login_id
195                     , login_token
196                     , redirect_uri
197                     , created_at
198                     , fulfilled_at
199                     , exchanged_at
200                     , compat_session_id
201                     , user_session_id
202
203                FROM compat_sso_logins
204                WHERE compat_session_id = $1
205            "#,
206            Uuid::from(compat_session.id),
207        )
208        .traced()
209        .fetch_optional(&mut *self.conn)
210        .await?;
211
212        let Some(res) = res else { return Ok(None) };
213
214        Ok(Some(res.try_into()?))
215    }
216
217    #[tracing::instrument(
218        name = "db.compat_sso_login.find_by_token",
219        skip_all,
220        fields(
221            db.query.text,
222        ),
223        err,
224    )]
225    async fn find_by_token(
226        &mut self,
227        login_token: &str,
228    ) -> Result<Option<CompatSsoLogin>, Self::Error> {
229        let res = sqlx::query_as!(
230            CompatSsoLoginLookup,
231            r#"
232                SELECT compat_sso_login_id
233                     , login_token
234                     , redirect_uri
235                     , created_at
236                     , fulfilled_at
237                     , exchanged_at
238                     , compat_session_id
239                     , user_session_id
240
241                FROM compat_sso_logins
242                WHERE login_token = $1
243            "#,
244            login_token,
245        )
246        .traced()
247        .fetch_optional(&mut *self.conn)
248        .await?;
249
250        let Some(res) = res else { return Ok(None) };
251
252        Ok(Some(res.try_into()?))
253    }
254
255    #[tracing::instrument(
256        name = "db.compat_sso_login.add",
257        skip_all,
258        fields(
259            db.query.text,
260            compat_sso_login.id,
261            compat_sso_login.redirect_uri = %redirect_uri,
262        ),
263        err,
264    )]
265    async fn add(
266        &mut self,
267        rng: &mut (dyn RngCore + Send),
268        clock: &dyn Clock,
269        login_token: String,
270        redirect_uri: Url,
271    ) -> Result<CompatSsoLogin, Self::Error> {
272        let created_at = clock.now();
273        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
274        tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
275
276        sqlx::query!(
277            r#"
278                INSERT INTO compat_sso_logins
279                    (compat_sso_login_id, login_token, redirect_uri, created_at)
280                VALUES ($1, $2, $3, $4)
281            "#,
282            Uuid::from(id),
283            &login_token,
284            redirect_uri.as_str(),
285            created_at,
286        )
287        .traced()
288        .execute(&mut *self.conn)
289        .await?;
290
291        Ok(CompatSsoLogin {
292            id,
293            login_token,
294            redirect_uri,
295            created_at,
296            state: CompatSsoLoginState::default(),
297        })
298    }
299
300    #[tracing::instrument(
301        name = "db.compat_sso_login.fulfill",
302        skip_all,
303        fields(
304            db.query.text,
305            %compat_sso_login.id,
306            %browser_session.id,
307            user.id = %browser_session.user.id,
308        ),
309        err,
310    )]
311    async fn fulfill(
312        &mut self,
313        clock: &dyn Clock,
314        compat_sso_login: CompatSsoLogin,
315        browser_session: &BrowserSession,
316    ) -> Result<CompatSsoLogin, Self::Error> {
317        let fulfilled_at = clock.now();
318        let compat_sso_login = compat_sso_login
319            .fulfill(fulfilled_at, browser_session)
320            .map_err(DatabaseError::to_invalid_operation)?;
321
322        let res = sqlx::query!(
323            r#"
324                UPDATE compat_sso_logins
325                SET
326                    user_session_id = $2,
327                    fulfilled_at = $3
328                WHERE
329                    compat_sso_login_id = $1
330            "#,
331            Uuid::from(compat_sso_login.id),
332            Uuid::from(browser_session.id),
333            fulfilled_at,
334        )
335        .traced()
336        .execute(&mut *self.conn)
337        .await?;
338
339        DatabaseError::ensure_affected_rows(&res, 1)?;
340
341        Ok(compat_sso_login)
342    }
343
344    #[tracing::instrument(
345        name = "db.compat_sso_login.exchange",
346        skip_all,
347        fields(
348            db.query.text,
349            %compat_sso_login.id,
350            %compat_session.id,
351            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
352        ),
353        err,
354    )]
355    async fn exchange(
356        &mut self,
357        clock: &dyn Clock,
358        compat_sso_login: CompatSsoLogin,
359        compat_session: &CompatSession,
360    ) -> Result<CompatSsoLogin, Self::Error> {
361        let exchanged_at = clock.now();
362        let compat_sso_login = compat_sso_login
363            .exchange(exchanged_at, compat_session)
364            .map_err(DatabaseError::to_invalid_operation)?;
365
366        let res = sqlx::query!(
367            r#"
368                UPDATE compat_sso_logins
369                SET
370                    exchanged_at = $2,
371                    compat_session_id = $3
372                WHERE
373                    compat_sso_login_id = $1
374            "#,
375            Uuid::from(compat_sso_login.id),
376            exchanged_at,
377            Uuid::from(compat_session.id),
378        )
379        .traced()
380        .execute(&mut *self.conn)
381        .await?;
382
383        DatabaseError::ensure_affected_rows(&res, 1)?;
384
385        Ok(compat_sso_login)
386    }
387
388    #[tracing::instrument(
389        name = "db.compat_sso_login.list",
390        skip_all,
391        fields(
392            db.query.text,
393        ),
394        err
395    )]
396    async fn list(
397        &mut self,
398        filter: CompatSsoLoginFilter<'_>,
399        pagination: Pagination,
400    ) -> Result<Page<CompatSsoLogin>, Self::Error> {
401        let (sql, arguments) = Query::select()
402            .expr_as(
403                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
404                CompatSsoLoginLookupIden::CompatSsoLoginId,
405            )
406            .expr_as(
407                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
408                CompatSsoLoginLookupIden::CompatSessionId,
409            )
410            .expr_as(
411                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId)),
412                CompatSsoLoginLookupIden::UserSessionId,
413            )
414            .expr_as(
415                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
416                CompatSsoLoginLookupIden::LoginToken,
417            )
418            .expr_as(
419                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
420                CompatSsoLoginLookupIden::RedirectUri,
421            )
422            .expr_as(
423                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
424                CompatSsoLoginLookupIden::CreatedAt,
425            )
426            .expr_as(
427                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
428                CompatSsoLoginLookupIden::FulfilledAt,
429            )
430            .expr_as(
431                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
432                CompatSsoLoginLookupIden::ExchangedAt,
433            )
434            .from(CompatSsoLogins::Table)
435            .apply_filter(filter)
436            .generate_pagination(
437                (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId),
438                pagination,
439            )
440            .build_sqlx(PostgresQueryBuilder);
441
442        let edges: Vec<CompatSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
443            .traced()
444            .fetch_all(&mut *self.conn)
445            .await?;
446
447        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
448
449        Ok(page)
450    }
451
452    #[tracing::instrument(
453        name = "db.compat_sso_login.count",
454        skip_all,
455        fields(
456            db.query.text,
457        ),
458        err
459    )]
460    async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error> {
461        let (sql, arguments) = Query::select()
462            .expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count())
463            .from(CompatSsoLogins::Table)
464            .apply_filter(filter)
465            .build_sqlx(PostgresQueryBuilder);
466
467        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
468            .traced()
469            .fetch_one(&mut *self.conn)
470            .await?;
471
472        count
473            .try_into()
474            .map_err(DatabaseError::to_invalid_operation)
475    }
476}