Skip to main content

mas_storage_pg/oauth2/
client.rs

1// Copyright 2025, 2026 Element Creations Ltd.
2// Copyright 2024, 2025 New Vector Ltd.
3// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
4//
5// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
6// Please see LICENSE files in the repository root for full details.
7
8use std::{
9    collections::{BTreeMap, BTreeSet},
10    string::ToString,
11};
12
13use async_trait::async_trait;
14use mas_data_model::{Client, Clock, JwksOrJwksUri, UlidExt as _};
15use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
16use mas_jose::jwk::PublicJsonWebKeySet;
17use mas_storage::{
18    Page, Pagination,
19    oauth2::{OAuth2ClientFilter, OAuth2ClientKind, OAuth2ClientRepository},
20    pagination::Node,
21};
22use oauth2_types::{oidc::ApplicationType, requests::GrantType};
23use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26    Expr, ExprTrait, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
27    extension::postgres::PgExpr as _,
28};
29use sea_query_sqlx::SqlxBinder;
30use sqlx::PgConnection;
31use tracing::{Instrument, info_span};
32use ulid::Ulid;
33use url::Url;
34use uuid::Uuid;
35
36use crate::{
37    DatabaseError, DatabaseInconsistencyError,
38    filter::{Filter, StatementExt},
39    iden::{OAuth2Clients, OAuth2Sessions},
40    pagination::QueryBuilderExt,
41    tracing::ExecuteExt,
42};
43
44/// An implementation of [`OAuth2ClientRepository`] for a PostgreSQL connection
45pub struct PgOAuth2ClientRepository<'c> {
46    conn: &'c mut PgConnection,
47}
48
49impl<'c> PgOAuth2ClientRepository<'c> {
50    /// Create a new [`PgOAuth2ClientRepository`] from an active PostgreSQL
51    /// connection
52    pub fn new(conn: &'c mut PgConnection) -> Self {
53        Self { conn }
54    }
55}
56
57#[expect(clippy::struct_excessive_bools)]
58#[derive(Debug, sqlx::FromRow)]
59#[enum_def]
60struct OAuth2ClientLookup {
61    oauth2_client_id: Uuid,
62    metadata_digest: Option<String>,
63    encrypted_client_secret: Option<String>,
64    application_type: Option<String>,
65    redirect_uris: Vec<String>,
66    grant_type_authorization_code: bool,
67    grant_type_refresh_token: bool,
68    grant_type_client_credentials: bool,
69    grant_type_device_code: bool,
70    client_name: Option<String>,
71    logo_uri: Option<String>,
72    client_uri: Option<String>,
73    policy_uri: Option<String>,
74    tos_uri: Option<String>,
75    jwks_uri: Option<String>,
76    jwks: Option<serde_json::Value>,
77    id_token_signed_response_alg: Option<String>,
78    userinfo_signed_response_alg: Option<String>,
79    token_endpoint_auth_method: Option<String>,
80    token_endpoint_auth_signing_alg: Option<String>,
81    initiate_login_uri: Option<String>,
82    is_static: bool,
83}
84
85impl Node<Ulid> for OAuth2ClientLookup {
86    fn cursor(&self) -> Ulid {
87        self.oauth2_client_id.into()
88    }
89}
90
91impl Filter for OAuth2ClientFilter<'_> {
92    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
93        sea_query::Condition::all()
94            .add_option(self.kind().map(|kind| {
95                let is_static = matches!(kind, OAuth2ClientKind::Static);
96                Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).eq(is_static)
97            }))
98            .add_option(self.client_name().map(|client_name| {
99                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName))
100                    .ilike(format!("%{client_name}%"))
101            }))
102            .add_option(self.client_uri().map(|client_uri| {
103                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri))
104                    .ilike(format!("%{client_uri}%"))
105            }))
106            .add_option(self.grant_type().map(|grant_type| -> SimpleExpr {
107                let column = match grant_type {
108                    GrantType::AuthorizationCode => OAuth2Clients::GrantTypeAuthorizationCode,
109                    GrantType::RefreshToken => OAuth2Clients::GrantTypeRefreshToken,
110                    GrantType::ClientCredentials => OAuth2Clients::GrantTypeClientCredentials,
111                    GrantType::DeviceCode => OAuth2Clients::GrantTypeDeviceCode,
112                    // The other grant types don't have a dedicated column, so no
113                    // client can declare them: the filter matches nothing.
114                    _ => return Expr::val(false),
115                };
116                Expr::col((OAuth2Clients::Table, column)).eq(true)
117            }))
118            .add_option(self.has_active_sessions().map(|has| -> SimpleExpr {
119                let exists = Expr::exists(
120                    Query::select()
121                        .expr(Expr::cust("1"))
122                        .from(OAuth2Sessions::Table)
123                        .and_where(
124                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
125                                .equals((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
126                        )
127                        .and_where(
128                            Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
129                                .is_null(),
130                        )
131                        .take(),
132                );
133                if has { exists } else { exists.not() }
134            }))
135    }
136}
137
138impl TryFrom<OAuth2ClientLookup> for Client {
139    type Error = DatabaseInconsistencyError;
140
141    fn try_from(value: OAuth2ClientLookup) -> Result<Self, Self::Error> {
142        let id = Ulid::from(value.oauth2_client_id);
143
144        let redirect_uris: Result<Vec<Url>, _> =
145            value.redirect_uris.iter().map(|s| s.parse()).collect();
146        let redirect_uris = redirect_uris.map_err(|e| {
147            DatabaseInconsistencyError::on("oauth2_clients")
148                .column("redirect_uris")
149                .row(id)
150                .source(e)
151        })?;
152
153        let application_type = value
154            .application_type
155            .map(|s| s.parse())
156            .transpose()
157            .map_err(|e| {
158                DatabaseInconsistencyError::on("oauth2_clients")
159                    .column("application_type")
160                    .row(id)
161                    .source(e)
162            })?;
163
164        let mut grant_types = Vec::new();
165        if value.grant_type_authorization_code {
166            grant_types.push(GrantType::AuthorizationCode);
167        }
168        if value.grant_type_refresh_token {
169            grant_types.push(GrantType::RefreshToken);
170        }
171        if value.grant_type_client_credentials {
172            grant_types.push(GrantType::ClientCredentials);
173        }
174        if value.grant_type_device_code {
175            grant_types.push(GrantType::DeviceCode);
176        }
177
178        let logo_uri = value.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
179            DatabaseInconsistencyError::on("oauth2_clients")
180                .column("logo_uri")
181                .row(id)
182                .source(e)
183        })?;
184
185        let client_uri = value
186            .client_uri
187            .map(|s| s.parse())
188            .transpose()
189            .map_err(|e| {
190                DatabaseInconsistencyError::on("oauth2_clients")
191                    .column("client_uri")
192                    .row(id)
193                    .source(e)
194            })?;
195
196        let policy_uri = value
197            .policy_uri
198            .map(|s| s.parse())
199            .transpose()
200            .map_err(|e| {
201                DatabaseInconsistencyError::on("oauth2_clients")
202                    .column("policy_uri")
203                    .row(id)
204                    .source(e)
205            })?;
206
207        let tos_uri = value.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
208            DatabaseInconsistencyError::on("oauth2_clients")
209                .column("tos_uri")
210                .row(id)
211                .source(e)
212        })?;
213
214        let id_token_signed_response_alg = value
215            .id_token_signed_response_alg
216            .map(|s| s.parse())
217            .transpose()
218            .map_err(|e| {
219                DatabaseInconsistencyError::on("oauth2_clients")
220                    .column("id_token_signed_response_alg")
221                    .row(id)
222                    .source(e)
223            })?;
224
225        let userinfo_signed_response_alg = value
226            .userinfo_signed_response_alg
227            .map(|s| s.parse())
228            .transpose()
229            .map_err(|e| {
230                DatabaseInconsistencyError::on("oauth2_clients")
231                    .column("userinfo_signed_response_alg")
232                    .row(id)
233                    .source(e)
234            })?;
235
236        let token_endpoint_auth_method = value
237            .token_endpoint_auth_method
238            .map(|s| s.parse())
239            .transpose()
240            .map_err(|e| {
241                DatabaseInconsistencyError::on("oauth2_clients")
242                    .column("token_endpoint_auth_method")
243                    .row(id)
244                    .source(e)
245            })?;
246
247        let token_endpoint_auth_signing_alg = value
248            .token_endpoint_auth_signing_alg
249            .map(|s| s.parse())
250            .transpose()
251            .map_err(|e| {
252                DatabaseInconsistencyError::on("oauth2_clients")
253                    .column("token_endpoint_auth_signing_alg")
254                    .row(id)
255                    .source(e)
256            })?;
257
258        let initiate_login_uri = value
259            .initiate_login_uri
260            .map(|s| s.parse())
261            .transpose()
262            .map_err(|e| {
263                DatabaseInconsistencyError::on("oauth2_clients")
264                    .column("initiate_login_uri")
265                    .row(id)
266                    .source(e)
267            })?;
268
269        let jwks = match (value.jwks, value.jwks_uri) {
270            (None, None) => None,
271            (Some(jwks), None) => {
272                let jwks = serde_json::from_value(jwks).map_err(|e| {
273                    DatabaseInconsistencyError::on("oauth2_clients")
274                        .column("jwks")
275                        .row(id)
276                        .source(e)
277                })?;
278                Some(JwksOrJwksUri::Jwks(jwks))
279            }
280            (None, Some(jwks_uri)) => {
281                let jwks_uri = jwks_uri.parse().map_err(|e| {
282                    DatabaseInconsistencyError::on("oauth2_clients")
283                        .column("jwks_uri")
284                        .row(id)
285                        .source(e)
286                })?;
287
288                Some(JwksOrJwksUri::JwksUri(jwks_uri))
289            }
290            _ => {
291                return Err(DatabaseInconsistencyError::on("oauth2_clients")
292                    .column("jwks(_uri)")
293                    .row(id));
294            }
295        };
296
297        Ok(Client {
298            id,
299            client_id: id.to_string(),
300            metadata_digest: value.metadata_digest,
301            encrypted_client_secret: value.encrypted_client_secret,
302            application_type,
303            redirect_uris,
304            grant_types,
305            client_name: value.client_name,
306            logo_uri,
307            client_uri,
308            policy_uri,
309            tos_uri,
310            jwks,
311            id_token_signed_response_alg,
312            userinfo_signed_response_alg,
313            token_endpoint_auth_method,
314            token_endpoint_auth_signing_alg,
315            initiate_login_uri,
316            is_static: value.is_static,
317        })
318    }
319}
320
321#[async_trait]
322impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
323    type Error = DatabaseError;
324
325    #[tracing::instrument(
326        name = "db.oauth2_client.lookup",
327        skip_all,
328        fields(
329            db.query.text,
330            oauth2_client.id = %id,
331        ),
332        err,
333    )]
334    async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
335        let res = sqlx::query_as!(
336            OAuth2ClientLookup,
337            r#"
338                SELECT oauth2_client_id
339                     , metadata_digest
340                     , encrypted_client_secret
341                     , application_type
342                     , redirect_uris
343                     , grant_type_authorization_code
344                     , grant_type_refresh_token
345                     , grant_type_client_credentials
346                     , grant_type_device_code
347                     , client_name
348                     , logo_uri
349                     , client_uri
350                     , policy_uri
351                     , tos_uri
352                     , jwks_uri
353                     , jwks
354                     , id_token_signed_response_alg
355                     , userinfo_signed_response_alg
356                     , token_endpoint_auth_method
357                     , token_endpoint_auth_signing_alg
358                     , initiate_login_uri
359                     , is_static
360                FROM oauth2_clients c
361
362                WHERE oauth2_client_id = $1
363            "#,
364            Uuid::from(id),
365        )
366        .traced()
367        .fetch_optional(&mut *self.conn)
368        .await?;
369
370        let Some(res) = res else { return Ok(None) };
371
372        Ok(Some(res.try_into()?))
373    }
374
375    #[tracing::instrument(
376        name = "db.oauth2_client.find_by_metadata_digest",
377        skip_all,
378        fields(
379            db.query.text,
380        ),
381        err,
382    )]
383    async fn find_by_metadata_digest(
384        &mut self,
385        digest: &str,
386    ) -> Result<Option<Client>, Self::Error> {
387        let res = sqlx::query_as!(
388            OAuth2ClientLookup,
389            r#"
390                SELECT oauth2_client_id
391                    , metadata_digest
392                    , encrypted_client_secret
393                    , application_type
394                    , redirect_uris
395                    , grant_type_authorization_code
396                    , grant_type_refresh_token
397                    , grant_type_client_credentials
398                    , grant_type_device_code
399                    , client_name
400                    , logo_uri
401                    , client_uri
402                    , policy_uri
403                    , tos_uri
404                    , jwks_uri
405                    , jwks
406                    , id_token_signed_response_alg
407                    , userinfo_signed_response_alg
408                    , token_endpoint_auth_method
409                    , token_endpoint_auth_signing_alg
410                    , initiate_login_uri
411                    , is_static
412                FROM oauth2_clients
413                WHERE metadata_digest = $1
414            "#,
415            digest,
416        )
417        .traced()
418        .fetch_optional(&mut *self.conn)
419        .await?;
420
421        let Some(res) = res else { return Ok(None) };
422
423        Ok(Some(res.try_into()?))
424    }
425
426    #[tracing::instrument(
427        name = "db.oauth2_client.load_batch",
428        skip_all,
429        fields(
430            db.query.text,
431        ),
432        err,
433    )]
434    async fn load_batch(
435        &mut self,
436        ids: BTreeSet<Ulid>,
437    ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
438        let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
439        let res = sqlx::query_as!(
440            OAuth2ClientLookup,
441            r#"
442                SELECT oauth2_client_id
443                     , metadata_digest
444                     , encrypted_client_secret
445                     , application_type
446                     , redirect_uris
447                     , grant_type_authorization_code
448                     , grant_type_refresh_token
449                     , grant_type_client_credentials
450                     , grant_type_device_code
451                     , client_name
452                     , logo_uri
453                     , client_uri
454                     , policy_uri
455                     , tos_uri
456                     , jwks_uri
457                     , jwks
458                     , id_token_signed_response_alg
459                     , userinfo_signed_response_alg
460                     , token_endpoint_auth_method
461                     , token_endpoint_auth_signing_alg
462                     , initiate_login_uri
463                     , is_static
464                FROM oauth2_clients c
465
466                WHERE oauth2_client_id = ANY($1::uuid[])
467            "#,
468            &ids,
469        )
470        .traced()
471        .fetch_all(&mut *self.conn)
472        .await?;
473
474        res.into_iter()
475            .map(|r| {
476                r.try_into()
477                    .map(|c: Client| (c.id, c))
478                    .map_err(DatabaseError::from)
479            })
480            .collect()
481    }
482
483    #[tracing::instrument(
484        name = "db.oauth2_client.add",
485        skip_all,
486        fields(
487            db.query.text,
488            client.id,
489            client.name = client_name
490        ),
491        err,
492    )]
493    async fn add(
494        &mut self,
495        rng: &mut (dyn RngCore + Send),
496        clock: &dyn Clock,
497        redirect_uris: Vec<Url>,
498        metadata_digest: Option<String>,
499        encrypted_client_secret: Option<String>,
500        application_type: Option<ApplicationType>,
501        grant_types: Vec<GrantType>,
502        client_name: Option<String>,
503        logo_uri: Option<Url>,
504        client_uri: Option<Url>,
505        policy_uri: Option<Url>,
506        tos_uri: Option<Url>,
507        jwks_uri: Option<Url>,
508        jwks: Option<PublicJsonWebKeySet>,
509        id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
510        userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
511        token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
512        token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
513        initiate_login_uri: Option<Url>,
514    ) -> Result<Client, Self::Error> {
515        let now = clock.now();
516        let id = Ulid::from_datetime_with_rng(now, rng);
517        tracing::Span::current().record("client.id", tracing::field::display(id));
518
519        let jwks_json = jwks
520            .as_ref()
521            .map(serde_json::to_value)
522            .transpose()
523            .map_err(DatabaseError::to_invalid_operation)?;
524
525        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
526
527        sqlx::query!(
528            r#"
529                INSERT INTO oauth2_clients
530                    ( oauth2_client_id
531                    , metadata_digest
532                    , encrypted_client_secret
533                    , application_type
534                    , redirect_uris
535                    , grant_type_authorization_code
536                    , grant_type_refresh_token
537                    , grant_type_client_credentials
538                    , grant_type_device_code
539                    , client_name
540                    , logo_uri
541                    , client_uri
542                    , policy_uri
543                    , tos_uri
544                    , jwks_uri
545                    , jwks
546                    , id_token_signed_response_alg
547                    , userinfo_signed_response_alg
548                    , token_endpoint_auth_method
549                    , token_endpoint_auth_signing_alg
550                    , initiate_login_uri
551                    , is_static
552                    )
553                VALUES
554                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
555                    $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
556            "#,
557            Uuid::from(id),
558            metadata_digest,
559            encrypted_client_secret,
560            application_type.as_ref().map(ToString::to_string),
561            &redirect_uris_array,
562            grant_types.contains(&GrantType::AuthorizationCode),
563            grant_types.contains(&GrantType::RefreshToken),
564            grant_types.contains(&GrantType::ClientCredentials),
565            grant_types.contains(&GrantType::DeviceCode),
566            client_name,
567            logo_uri.as_ref().map(Url::as_str),
568            client_uri.as_ref().map(Url::as_str),
569            policy_uri.as_ref().map(Url::as_str),
570            tos_uri.as_ref().map(Url::as_str),
571            jwks_uri.as_ref().map(Url::as_str),
572            jwks_json,
573            id_token_signed_response_alg
574                .as_ref()
575                .map(ToString::to_string),
576            userinfo_signed_response_alg
577                .as_ref()
578                .map(ToString::to_string),
579            token_endpoint_auth_method.as_ref().map(ToString::to_string),
580            token_endpoint_auth_signing_alg
581                .as_ref()
582                .map(ToString::to_string),
583            initiate_login_uri.as_ref().map(Url::as_str),
584        )
585        .traced()
586        .execute(&mut *self.conn)
587        .await?;
588
589        let jwks = match (jwks, jwks_uri) {
590            (None, None) => None,
591            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
592            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
593            _ => return Err(DatabaseError::invalid_operation()),
594        };
595
596        Ok(Client {
597            id,
598            client_id: id.to_string(),
599            metadata_digest: None,
600            encrypted_client_secret,
601            application_type,
602            redirect_uris,
603            grant_types,
604            client_name,
605            logo_uri,
606            client_uri,
607            policy_uri,
608            tos_uri,
609            jwks,
610            id_token_signed_response_alg,
611            userinfo_signed_response_alg,
612            token_endpoint_auth_method,
613            token_endpoint_auth_signing_alg,
614            initiate_login_uri,
615            is_static: false,
616        })
617    }
618
619    #[tracing::instrument(
620        name = "db.oauth2_client.upsert_static",
621        skip_all,
622        fields(
623            db.query.text,
624            client.id = %client_id,
625        ),
626        err,
627    )]
628    async fn upsert_static(
629        &mut self,
630        client_id: Ulid,
631        client_name: Option<String>,
632        client_auth_method: OAuthClientAuthenticationMethod,
633        encrypted_client_secret: Option<String>,
634        jwks: Option<PublicJsonWebKeySet>,
635        jwks_uri: Option<Url>,
636        redirect_uris: Vec<Url>,
637    ) -> Result<Client, Self::Error> {
638        let jwks_json = jwks
639            .as_ref()
640            .map(serde_json::to_value)
641            .transpose()
642            .map_err(DatabaseError::to_invalid_operation)?;
643
644        let client_auth_method = client_auth_method.to_string();
645        let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
646
647        sqlx::query!(
648            r#"
649                INSERT INTO oauth2_clients
650                    ( oauth2_client_id
651                    , encrypted_client_secret
652                    , redirect_uris
653                    , grant_type_authorization_code
654                    , grant_type_refresh_token
655                    , grant_type_client_credentials
656                    , grant_type_device_code
657                    , token_endpoint_auth_method
658                    , jwks
659                    , client_name
660                    , jwks_uri
661                    , is_static
662                    )
663                VALUES
664                    ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, TRUE)
665                ON CONFLICT (oauth2_client_id)
666                DO
667                    UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
668                             , redirect_uris = EXCLUDED.redirect_uris
669                             , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
670                             , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
671                             , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
672                             , grant_type_device_code = EXCLUDED.grant_type_device_code
673                             , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
674                             , jwks = EXCLUDED.jwks
675                             , client_name = EXCLUDED.client_name
676                             , jwks_uri = EXCLUDED.jwks_uri
677                             , is_static = TRUE
678            "#,
679            Uuid::from(client_id),
680            encrypted_client_secret,
681            &redirect_uris_array,
682            true,
683            true,
684            true,
685            true,
686            client_auth_method,
687            jwks_json,
688            client_name,
689            jwks_uri.as_ref().map(Url::as_str),
690        )
691        .traced()
692        .execute(&mut *self.conn)
693        .await?;
694
695        let jwks = match (jwks, jwks_uri) {
696            (None, None) => None,
697            (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
698            (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
699            _ => return Err(DatabaseError::invalid_operation()),
700        };
701
702        Ok(Client {
703            id: client_id,
704            client_id: client_id.to_string(),
705            metadata_digest: None,
706            encrypted_client_secret,
707            application_type: None,
708            redirect_uris,
709            grant_types: vec![
710                GrantType::AuthorizationCode,
711                GrantType::RefreshToken,
712                GrantType::ClientCredentials,
713            ],
714            client_name,
715            logo_uri: None,
716            client_uri: None,
717            policy_uri: None,
718            tos_uri: None,
719            jwks,
720            id_token_signed_response_alg: None,
721            userinfo_signed_response_alg: None,
722            token_endpoint_auth_method: None,
723            token_endpoint_auth_signing_alg: None,
724            initiate_login_uri: None,
725            is_static: true,
726        })
727    }
728
729    #[tracing::instrument(
730        name = "db.oauth2_client.all_static",
731        skip_all,
732        fields(
733            db.query.text,
734        ),
735        err,
736    )]
737    async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
738        let res = sqlx::query_as!(
739            OAuth2ClientLookup,
740            r#"
741                SELECT oauth2_client_id
742                     , metadata_digest
743                     , encrypted_client_secret
744                     , application_type
745                     , redirect_uris
746                     , grant_type_authorization_code
747                     , grant_type_refresh_token
748                     , grant_type_client_credentials
749                     , grant_type_device_code
750                     , client_name
751                     , logo_uri
752                     , client_uri
753                     , policy_uri
754                     , tos_uri
755                     , jwks_uri
756                     , jwks
757                     , id_token_signed_response_alg
758                     , userinfo_signed_response_alg
759                     , token_endpoint_auth_method
760                     , token_endpoint_auth_signing_alg
761                     , initiate_login_uri
762                     , is_static
763                FROM oauth2_clients c
764                WHERE is_static = TRUE
765            "#,
766        )
767        .traced()
768        .fetch_all(&mut *self.conn)
769        .await?;
770
771        res.into_iter()
772            .map(|r| r.try_into().map_err(DatabaseError::from))
773            .collect()
774    }
775
776    #[tracing::instrument(
777        name = "db.oauth2_client.delete_by_id",
778        skip_all,
779        fields(
780            db.query.text,
781            client.id = %id,
782        ),
783        err,
784    )]
785    async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
786        // Delete the authorization grants
787        {
788            let span = info_span!(
789                "db.oauth2_client.delete_by_id.authorization_grants",
790                { DB_QUERY_TEXT } = tracing::field::Empty,
791            );
792
793            sqlx::query!(
794                r#"
795                    DELETE FROM oauth2_authorization_grants
796                    WHERE oauth2_client_id = $1
797                "#,
798                Uuid::from(id),
799            )
800            .record(&span)
801            .execute(&mut *self.conn)
802            .instrument(span)
803            .await?;
804        }
805
806        // Delete the OAuth 2 sessions related data
807        {
808            let span = info_span!(
809                "db.oauth2_client.delete_by_id.access_tokens",
810                { DB_QUERY_TEXT } = tracing::field::Empty,
811            );
812
813            sqlx::query!(
814                r#"
815                    DELETE FROM oauth2_access_tokens
816                    WHERE oauth2_session_id IN (
817                        SELECT oauth2_session_id
818                        FROM oauth2_sessions
819                        WHERE oauth2_client_id = $1
820                    )
821                "#,
822                Uuid::from(id),
823            )
824            .record(&span)
825            .execute(&mut *self.conn)
826            .instrument(span)
827            .await?;
828        }
829
830        {
831            let span = info_span!(
832                "db.oauth2_client.delete_by_id.refresh_tokens",
833                { DB_QUERY_TEXT } = tracing::field::Empty,
834            );
835
836            sqlx::query!(
837                r#"
838                    DELETE FROM oauth2_refresh_tokens
839                    WHERE oauth2_session_id IN (
840                        SELECT oauth2_session_id
841                        FROM oauth2_sessions
842                        WHERE oauth2_client_id = $1
843                    )
844                "#,
845                Uuid::from(id),
846            )
847            .record(&span)
848            .execute(&mut *self.conn)
849            .instrument(span)
850            .await?;
851        }
852
853        {
854            let span = info_span!(
855                "db.oauth2_client.delete_by_id.sessions",
856                { DB_QUERY_TEXT } = tracing::field::Empty,
857            );
858
859            sqlx::query!(
860                r#"
861                    DELETE FROM oauth2_sessions
862                    WHERE oauth2_client_id = $1
863                "#,
864                Uuid::from(id),
865            )
866            .record(&span)
867            .execute(&mut *self.conn)
868            .instrument(span)
869            .await?;
870        }
871
872        // Delete any personal access tokens & sessions owned
873        // by the client
874        {
875            let span = info_span!(
876                "db.oauth2_client.delete_by_id.personal_access_tokens",
877                { DB_QUERY_TEXT } = tracing::field::Empty,
878            );
879
880            sqlx::query!(
881                r#"
882                    DELETE FROM personal_access_tokens
883                    WHERE personal_session_id IN (
884                        SELECT personal_session_id
885                        FROM personal_sessions
886                        WHERE owner_oauth2_client_id = $1
887                    )
888                "#,
889                Uuid::from(id),
890            )
891            .record(&span)
892            .execute(&mut *self.conn)
893            .instrument(span)
894            .await?;
895        }
896        {
897            let span = info_span!(
898                "db.oauth2_client.delete_by_id.personal_sessions",
899                { DB_QUERY_TEXT } = tracing::field::Empty,
900            );
901
902            sqlx::query!(
903                r#"
904                    DELETE FROM personal_sessions
905                    WHERE owner_oauth2_client_id = $1
906                "#,
907                Uuid::from(id),
908            )
909            .record(&span)
910            .execute(&mut *self.conn)
911            .instrument(span)
912            .await?;
913        }
914
915        // Now delete the client itself
916        let res = sqlx::query!(
917            r#"
918                DELETE FROM oauth2_clients
919                WHERE oauth2_client_id = $1
920            "#,
921            Uuid::from(id),
922        )
923        .traced()
924        .execute(&mut *self.conn)
925        .await?;
926
927        DatabaseError::ensure_affected_rows(&res, 1)
928    }
929
930    #[tracing::instrument(
931        name = "db.oauth2_client.list",
932        skip_all,
933        fields(
934            db.query.text,
935        ),
936        err,
937    )]
938    async fn list(
939        &mut self,
940        filter: OAuth2ClientFilter<'_>,
941        pagination: Pagination,
942    ) -> Result<Page<Client>, Self::Error> {
943        let (sql, arguments) = Query::select()
944            .expr_as(
945                Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
946                OAuth2ClientLookupIden::Oauth2ClientId,
947            )
948            .expr_as(
949                Expr::cust("metadata_digest"),
950                OAuth2ClientLookupIden::MetadataDigest,
951            )
952            .expr_as(
953                Expr::cust("encrypted_client_secret"),
954                OAuth2ClientLookupIden::EncryptedClientSecret,
955            )
956            .expr_as(
957                Expr::cust("application_type"),
958                OAuth2ClientLookupIden::ApplicationType,
959            )
960            .expr_as(
961                Expr::col((OAuth2Clients::Table, OAuth2Clients::RedirectUris)),
962                OAuth2ClientLookupIden::RedirectUris,
963            )
964            .expr_as(
965                Expr::cust("grant_type_authorization_code"),
966                OAuth2ClientLookupIden::GrantTypeAuthorizationCode,
967            )
968            .expr_as(
969                Expr::cust("grant_type_refresh_token"),
970                OAuth2ClientLookupIden::GrantTypeRefreshToken,
971            )
972            .expr_as(
973                Expr::cust("grant_type_client_credentials"),
974                OAuth2ClientLookupIden::GrantTypeClientCredentials,
975            )
976            .expr_as(
977                Expr::cust("grant_type_device_code"),
978                OAuth2ClientLookupIden::GrantTypeDeviceCode,
979            )
980            .expr_as(
981                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName)),
982                OAuth2ClientLookupIden::ClientName,
983            )
984            .expr_as(
985                Expr::col((OAuth2Clients::Table, OAuth2Clients::LogoUri)),
986                OAuth2ClientLookupIden::LogoUri,
987            )
988            .expr_as(
989                Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri)),
990                OAuth2ClientLookupIden::ClientUri,
991            )
992            .expr_as(Expr::cust("policy_uri"), OAuth2ClientLookupIden::PolicyUri)
993            .expr_as(Expr::cust("tos_uri"), OAuth2ClientLookupIden::TosUri)
994            .expr_as(Expr::cust("jwks_uri"), OAuth2ClientLookupIden::JwksUri)
995            .expr_as(Expr::cust("jwks"), OAuth2ClientLookupIden::Jwks)
996            .expr_as(
997                Expr::cust("id_token_signed_response_alg"),
998                OAuth2ClientLookupIden::IdTokenSignedResponseAlg,
999            )
1000            .expr_as(
1001                Expr::cust("userinfo_signed_response_alg"),
1002                OAuth2ClientLookupIden::UserinfoSignedResponseAlg,
1003            )
1004            .expr_as(
1005                Expr::cust("token_endpoint_auth_method"),
1006                OAuth2ClientLookupIden::TokenEndpointAuthMethod,
1007            )
1008            .expr_as(
1009                Expr::cust("token_endpoint_auth_signing_alg"),
1010                OAuth2ClientLookupIden::TokenEndpointAuthSigningAlg,
1011            )
1012            .expr_as(
1013                Expr::cust("initiate_login_uri"),
1014                OAuth2ClientLookupIden::InitiateLoginUri,
1015            )
1016            .expr_as(
1017                Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)),
1018                OAuth2ClientLookupIden::IsStatic,
1019            )
1020            .from(OAuth2Clients::Table)
1021            .apply_filter(filter)
1022            .generate_pagination(
1023                (OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId),
1024                pagination,
1025            )
1026            .build_sqlx(PostgresQueryBuilder);
1027
1028        let edges: Vec<OAuth2ClientLookup> = sqlx::query_as_with(&sql, arguments)
1029            .traced()
1030            .fetch_all(&mut *self.conn)
1031            .await?;
1032
1033        let page = pagination.process(edges).try_map(Client::try_from)?;
1034
1035        Ok(page)
1036    }
1037
1038    #[tracing::instrument(
1039        name = "db.oauth2_client.count",
1040        skip_all,
1041        fields(
1042            db.query.text,
1043        ),
1044        err,
1045    )]
1046    async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error> {
1047        let (sql, arguments) = Query::select()
1048            .expr(Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)).count())
1049            .from(OAuth2Clients::Table)
1050            .apply_filter(filter)
1051            .build_sqlx(PostgresQueryBuilder);
1052
1053        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
1054            .traced()
1055            .fetch_one(&mut *self.conn)
1056            .await?;
1057
1058        count
1059            .try_into()
1060            .map_err(DatabaseError::to_invalid_operation)
1061    }
1062}