mas_storage_pg/compat/
sso_login.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use async_trait::async_trait;
8use chrono::{DateTime, Utc};
9use mas_data_model::{BrowserSession, Clock, CompatSession, CompatSsoLogin, CompatSsoLoginState};
10use mas_storage::{
11    Page, Pagination,
12    compat::{CompatSsoLoginFilter, CompatSsoLoginRepository},
13    pagination::Node,
14};
15use rand::RngCore;
16use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
17use sea_query_binder::SqlxBinder;
18use sqlx::PgConnection;
19use ulid::Ulid;
20use url::Url;
21use uuid::Uuid;
22
23use crate::{
24    DatabaseError, DatabaseInconsistencyError,
25    filter::{Filter, StatementExt},
26    iden::{CompatSsoLogins, UserSessions},
27    pagination::QueryBuilderExt,
28    tracing::ExecuteExt,
29};
30
31/// An implementation of [`CompatSsoLoginRepository`] for a PostgreSQL
32/// connection
33pub struct PgCompatSsoLoginRepository<'c> {
34    conn: &'c mut PgConnection,
35}
36
37impl<'c> PgCompatSsoLoginRepository<'c> {
38    /// Create a new [`PgCompatSsoLoginRepository`] from an active PostgreSQL
39    /// connection
40    pub fn new(conn: &'c mut PgConnection) -> Self {
41        Self { conn }
42    }
43}
44
45#[derive(sqlx::FromRow, Debug)]
46#[enum_def]
47struct CompatSsoLoginLookup {
48    compat_sso_login_id: Uuid,
49    login_token: String,
50    redirect_uri: String,
51    created_at: DateTime<Utc>,
52    fulfilled_at: Option<DateTime<Utc>>,
53    exchanged_at: Option<DateTime<Utc>>,
54    user_session_id: Option<Uuid>,
55    compat_session_id: Option<Uuid>,
56}
57
58impl Node<Ulid> for CompatSsoLoginLookup {
59    fn cursor(&self) -> Ulid {
60        self.compat_sso_login_id.into()
61    }
62}
63
64impl TryFrom<CompatSsoLoginLookup> for CompatSsoLogin {
65    type Error = DatabaseInconsistencyError;
66
67    fn try_from(res: CompatSsoLoginLookup) -> Result<Self, Self::Error> {
68        let id = res.compat_sso_login_id.into();
69        let redirect_uri = Url::parse(&res.redirect_uri).map_err(|e| {
70            DatabaseInconsistencyError::on("compat_sso_logins")
71                .column("redirect_uri")
72                .row(id)
73                .source(e)
74        })?;
75
76        let state = match (
77            res.fulfilled_at,
78            res.exchanged_at,
79            res.user_session_id,
80            res.compat_session_id,
81        ) {
82            (None, None, None, None) => CompatSsoLoginState::Pending,
83            (Some(fulfilled_at), None, Some(browser_session_id), None) => {
84                CompatSsoLoginState::Fulfilled {
85                    fulfilled_at,
86                    browser_session_id: browser_session_id.into(),
87                }
88            }
89            (Some(fulfilled_at), Some(exchanged_at), _, Some(compat_session_id)) => {
90                CompatSsoLoginState::Exchanged {
91                    fulfilled_at,
92                    exchanged_at,
93                    compat_session_id: compat_session_id.into(),
94                }
95            }
96            _ => return Err(DatabaseInconsistencyError::on("compat_sso_logins").row(id)),
97        };
98
99        Ok(CompatSsoLogin {
100            id,
101            login_token: res.login_token,
102            redirect_uri,
103            created_at: res.created_at,
104            state,
105        })
106    }
107}
108
109impl Filter for CompatSsoLoginFilter<'_> {
110    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
111        sea_query::Condition::all()
112            .add_option(self.user().map(|user| {
113                Expr::exists(
114                    Query::select()
115                        .expr(Expr::cust("1"))
116                        .from(UserSessions::Table)
117                        .and_where(
118                            Expr::col((UserSessions::Table, UserSessions::UserId))
119                                .eq(Uuid::from(user.id)),
120                        )
121                        .and_where(
122                            Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId))
123                                .equals((UserSessions::Table, UserSessions::UserSessionId)),
124                        )
125                        .take(),
126                )
127            }))
128            .add_option(self.state().map(|state| {
129                if state.is_exchanged() {
130                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)).is_not_null()
131                } else if state.is_fulfilled() {
132                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt))
133                        .is_not_null()
134                        .and(
135                            Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt))
136                                .is_null(),
137                        )
138                } else {
139                    Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)).is_null()
140                }
141            }))
142    }
143}
144
145#[async_trait]
146impl CompatSsoLoginRepository for PgCompatSsoLoginRepository<'_> {
147    type Error = DatabaseError;
148
149    #[tracing::instrument(
150        name = "db.compat_sso_login.lookup",
151        skip_all,
152        fields(
153            db.query.text,
154            compat_sso_login.id = %id,
155        ),
156        err,
157    )]
158    async fn lookup(&mut self, id: Ulid) -> Result<Option<CompatSsoLogin>, Self::Error> {
159        let res = sqlx::query_as!(
160            CompatSsoLoginLookup,
161            r#"
162                SELECT compat_sso_login_id
163                     , login_token
164                     , redirect_uri
165                     , created_at
166                     , fulfilled_at
167                     , exchanged_at
168                     , compat_session_id
169                     , user_session_id
170
171                FROM compat_sso_logins
172                WHERE compat_sso_login_id = $1
173            "#,
174            Uuid::from(id),
175        )
176        .traced()
177        .fetch_optional(&mut *self.conn)
178        .await?;
179
180        let Some(res) = res else { return Ok(None) };
181
182        Ok(Some(res.try_into()?))
183    }
184
185    #[tracing::instrument(
186        name = "db.compat_sso_login.find_for_session",
187        skip_all,
188        fields(
189            db.query.text,
190            %compat_session.id,
191        ),
192        err,
193    )]
194    async fn find_for_session(
195        &mut self,
196        compat_session: &CompatSession,
197    ) -> Result<Option<CompatSsoLogin>, Self::Error> {
198        let res = sqlx::query_as!(
199            CompatSsoLoginLookup,
200            r#"
201                SELECT compat_sso_login_id
202                     , login_token
203                     , redirect_uri
204                     , created_at
205                     , fulfilled_at
206                     , exchanged_at
207                     , compat_session_id
208                     , user_session_id
209
210                FROM compat_sso_logins
211                WHERE compat_session_id = $1
212            "#,
213            Uuid::from(compat_session.id),
214        )
215        .traced()
216        .fetch_optional(&mut *self.conn)
217        .await?;
218
219        let Some(res) = res else { return Ok(None) };
220
221        Ok(Some(res.try_into()?))
222    }
223
224    #[tracing::instrument(
225        name = "db.compat_sso_login.find_by_token",
226        skip_all,
227        fields(
228            db.query.text,
229        ),
230        err,
231    )]
232    async fn find_by_token(
233        &mut self,
234        login_token: &str,
235    ) -> Result<Option<CompatSsoLogin>, Self::Error> {
236        let res = sqlx::query_as!(
237            CompatSsoLoginLookup,
238            r#"
239                SELECT compat_sso_login_id
240                     , login_token
241                     , redirect_uri
242                     , created_at
243                     , fulfilled_at
244                     , exchanged_at
245                     , compat_session_id
246                     , user_session_id
247
248                FROM compat_sso_logins
249                WHERE login_token = $1
250            "#,
251            login_token,
252        )
253        .traced()
254        .fetch_optional(&mut *self.conn)
255        .await?;
256
257        let Some(res) = res else { return Ok(None) };
258
259        Ok(Some(res.try_into()?))
260    }
261
262    #[tracing::instrument(
263        name = "db.compat_sso_login.add",
264        skip_all,
265        fields(
266            db.query.text,
267            compat_sso_login.id,
268            compat_sso_login.redirect_uri = %redirect_uri,
269        ),
270        err,
271    )]
272    async fn add(
273        &mut self,
274        rng: &mut (dyn RngCore + Send),
275        clock: &dyn Clock,
276        login_token: String,
277        redirect_uri: Url,
278    ) -> Result<CompatSsoLogin, Self::Error> {
279        let created_at = clock.now();
280        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
281        tracing::Span::current().record("compat_sso_login.id", tracing::field::display(id));
282
283        sqlx::query!(
284            r#"
285                INSERT INTO compat_sso_logins
286                    (compat_sso_login_id, login_token, redirect_uri, created_at)
287                VALUES ($1, $2, $3, $4)
288            "#,
289            Uuid::from(id),
290            &login_token,
291            redirect_uri.as_str(),
292            created_at,
293        )
294        .traced()
295        .execute(&mut *self.conn)
296        .await?;
297
298        Ok(CompatSsoLogin {
299            id,
300            login_token,
301            redirect_uri,
302            created_at,
303            state: CompatSsoLoginState::default(),
304        })
305    }
306
307    #[tracing::instrument(
308        name = "db.compat_sso_login.fulfill",
309        skip_all,
310        fields(
311            db.query.text,
312            %compat_sso_login.id,
313            %browser_session.id,
314            user.id = %browser_session.user.id,
315        ),
316        err,
317    )]
318    async fn fulfill(
319        &mut self,
320        clock: &dyn Clock,
321        compat_sso_login: CompatSsoLogin,
322        browser_session: &BrowserSession,
323    ) -> Result<CompatSsoLogin, Self::Error> {
324        let fulfilled_at = clock.now();
325        let compat_sso_login = compat_sso_login
326            .fulfill(fulfilled_at, browser_session)
327            .map_err(DatabaseError::to_invalid_operation)?;
328
329        let res = sqlx::query!(
330            r#"
331                UPDATE compat_sso_logins
332                SET
333                    user_session_id = $2,
334                    fulfilled_at = $3
335                WHERE
336                    compat_sso_login_id = $1
337            "#,
338            Uuid::from(compat_sso_login.id),
339            Uuid::from(browser_session.id),
340            fulfilled_at,
341        )
342        .traced()
343        .execute(&mut *self.conn)
344        .await?;
345
346        DatabaseError::ensure_affected_rows(&res, 1)?;
347
348        Ok(compat_sso_login)
349    }
350
351    #[tracing::instrument(
352        name = "db.compat_sso_login.exchange",
353        skip_all,
354        fields(
355            db.query.text,
356            %compat_sso_login.id,
357            %compat_session.id,
358            compat_session.device.id = compat_session.device.as_ref().map(mas_data_model::Device::as_str),
359        ),
360        err,
361    )]
362    async fn exchange(
363        &mut self,
364        clock: &dyn Clock,
365        compat_sso_login: CompatSsoLogin,
366        compat_session: &CompatSession,
367    ) -> Result<CompatSsoLogin, Self::Error> {
368        let exchanged_at = clock.now();
369        let compat_sso_login = compat_sso_login
370            .exchange(exchanged_at, compat_session)
371            .map_err(DatabaseError::to_invalid_operation)?;
372
373        let res = sqlx::query!(
374            r#"
375                UPDATE compat_sso_logins
376                SET
377                    exchanged_at = $2,
378                    compat_session_id = $3
379                WHERE
380                    compat_sso_login_id = $1
381            "#,
382            Uuid::from(compat_sso_login.id),
383            exchanged_at,
384            Uuid::from(compat_session.id),
385        )
386        .traced()
387        .execute(&mut *self.conn)
388        .await?;
389
390        DatabaseError::ensure_affected_rows(&res, 1)?;
391
392        Ok(compat_sso_login)
393    }
394
395    #[tracing::instrument(
396        name = "db.compat_sso_login.list",
397        skip_all,
398        fields(
399            db.query.text,
400        ),
401        err
402    )]
403    async fn list(
404        &mut self,
405        filter: CompatSsoLoginFilter<'_>,
406        pagination: Pagination,
407    ) -> Result<Page<CompatSsoLogin>, Self::Error> {
408        let (sql, arguments) = Query::select()
409            .expr_as(
410                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)),
411                CompatSsoLoginLookupIden::CompatSsoLoginId,
412            )
413            .expr_as(
414                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSessionId)),
415                CompatSsoLoginLookupIden::CompatSessionId,
416            )
417            .expr_as(
418                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::UserSessionId)),
419                CompatSsoLoginLookupIden::UserSessionId,
420            )
421            .expr_as(
422                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::LoginToken)),
423                CompatSsoLoginLookupIden::LoginToken,
424            )
425            .expr_as(
426                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::RedirectUri)),
427                CompatSsoLoginLookupIden::RedirectUri,
428            )
429            .expr_as(
430                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CreatedAt)),
431                CompatSsoLoginLookupIden::CreatedAt,
432            )
433            .expr_as(
434                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::FulfilledAt)),
435                CompatSsoLoginLookupIden::FulfilledAt,
436            )
437            .expr_as(
438                Expr::col((CompatSsoLogins::Table, CompatSsoLogins::ExchangedAt)),
439                CompatSsoLoginLookupIden::ExchangedAt,
440            )
441            .from(CompatSsoLogins::Table)
442            .apply_filter(filter)
443            .generate_pagination(
444                (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId),
445                pagination,
446            )
447            .build_sqlx(PostgresQueryBuilder);
448
449        let edges: Vec<CompatSsoLoginLookup> = sqlx::query_as_with(&sql, arguments)
450            .traced()
451            .fetch_all(&mut *self.conn)
452            .await?;
453
454        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
455
456        Ok(page)
457    }
458
459    #[tracing::instrument(
460        name = "db.compat_sso_login.count",
461        skip_all,
462        fields(
463            db.query.text,
464        ),
465        err
466    )]
467    async fn count(&mut self, filter: CompatSsoLoginFilter<'_>) -> Result<usize, Self::Error> {
468        let (sql, arguments) = Query::select()
469            .expr(Expr::col((CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId)).count())
470            .from(CompatSsoLogins::Table)
471            .apply_filter(filter)
472            .build_sqlx(PostgresQueryBuilder);
473
474        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
475            .traced()
476            .fetch_one(&mut *self.conn)
477            .await?;
478
479        count
480            .try_into()
481            .map_err(DatabaseError::to_invalid_operation)
482    }
483}