1use std::{
9 collections::{BTreeMap, BTreeSet},
10 string::ToString,
11};
12
13use async_trait::async_trait;
14use mas_data_model::{Client, Clock, JwksOrJwksUri, UlidExt as _};
15use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
16use mas_jose::jwk::PublicJsonWebKeySet;
17use mas_storage::{
18 Page, Pagination,
19 oauth2::{OAuth2ClientFilter, OAuth2ClientKind, OAuth2ClientRepository},
20 pagination::Node,
21};
22use oauth2_types::{oidc::ApplicationType, requests::GrantType};
23use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26 Expr, ExprTrait, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
27 extension::postgres::PgExpr as _,
28};
29use sea_query_sqlx::SqlxBinder;
30use sqlx::PgConnection;
31use tracing::{Instrument, info_span};
32use ulid::Ulid;
33use url::Url;
34use uuid::Uuid;
35
36use crate::{
37 DatabaseError, DatabaseInconsistencyError,
38 filter::{Filter, StatementExt},
39 iden::{OAuth2Clients, OAuth2Sessions},
40 pagination::QueryBuilderExt,
41 tracing::ExecuteExt,
42};
43
44pub struct PgOAuth2ClientRepository<'c> {
46 conn: &'c mut PgConnection,
47}
48
49impl<'c> PgOAuth2ClientRepository<'c> {
50 pub fn new(conn: &'c mut PgConnection) -> Self {
53 Self { conn }
54 }
55}
56
57#[expect(clippy::struct_excessive_bools)]
58#[derive(Debug, sqlx::FromRow)]
59#[enum_def]
60struct OAuth2ClientLookup {
61 oauth2_client_id: Uuid,
62 metadata_digest: Option<String>,
63 encrypted_client_secret: Option<String>,
64 application_type: Option<String>,
65 redirect_uris: Vec<String>,
66 grant_type_authorization_code: bool,
67 grant_type_refresh_token: bool,
68 grant_type_client_credentials: bool,
69 grant_type_device_code: bool,
70 client_name: Option<String>,
71 logo_uri: Option<String>,
72 client_uri: Option<String>,
73 policy_uri: Option<String>,
74 tos_uri: Option<String>,
75 jwks_uri: Option<String>,
76 jwks: Option<serde_json::Value>,
77 id_token_signed_response_alg: Option<String>,
78 userinfo_signed_response_alg: Option<String>,
79 token_endpoint_auth_method: Option<String>,
80 token_endpoint_auth_signing_alg: Option<String>,
81 initiate_login_uri: Option<String>,
82 is_static: bool,
83}
84
85impl Node<Ulid> for OAuth2ClientLookup {
86 fn cursor(&self) -> Ulid {
87 self.oauth2_client_id.into()
88 }
89}
90
91impl Filter for OAuth2ClientFilter<'_> {
92 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
93 sea_query::Condition::all()
94 .add_option(self.kind().map(|kind| {
95 let is_static = matches!(kind, OAuth2ClientKind::Static);
96 Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)).eq(is_static)
97 }))
98 .add_option(self.client_name().map(|client_name| {
99 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName))
100 .ilike(format!("%{client_name}%"))
101 }))
102 .add_option(self.client_uri().map(|client_uri| {
103 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri))
104 .ilike(format!("%{client_uri}%"))
105 }))
106 .add_option(self.grant_type().map(|grant_type| -> SimpleExpr {
107 let column = match grant_type {
108 GrantType::AuthorizationCode => OAuth2Clients::GrantTypeAuthorizationCode,
109 GrantType::RefreshToken => OAuth2Clients::GrantTypeRefreshToken,
110 GrantType::ClientCredentials => OAuth2Clients::GrantTypeClientCredentials,
111 GrantType::DeviceCode => OAuth2Clients::GrantTypeDeviceCode,
112 _ => return Expr::val(false),
115 };
116 Expr::col((OAuth2Clients::Table, column)).eq(true)
117 }))
118 .add_option(self.has_active_sessions().map(|has| -> SimpleExpr {
119 let exists = Expr::exists(
120 Query::select()
121 .expr(Expr::cust("1"))
122 .from(OAuth2Sessions::Table)
123 .and_where(
124 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::OAuth2ClientId))
125 .equals((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
126 )
127 .and_where(
128 Expr::col((OAuth2Sessions::Table, OAuth2Sessions::FinishedAt))
129 .is_null(),
130 )
131 .take(),
132 );
133 if has { exists } else { exists.not() }
134 }))
135 }
136}
137
138impl TryFrom<OAuth2ClientLookup> for Client {
139 type Error = DatabaseInconsistencyError;
140
141 fn try_from(value: OAuth2ClientLookup) -> Result<Self, Self::Error> {
142 let id = Ulid::from(value.oauth2_client_id);
143
144 let redirect_uris: Result<Vec<Url>, _> =
145 value.redirect_uris.iter().map(|s| s.parse()).collect();
146 let redirect_uris = redirect_uris.map_err(|e| {
147 DatabaseInconsistencyError::on("oauth2_clients")
148 .column("redirect_uris")
149 .row(id)
150 .source(e)
151 })?;
152
153 let application_type = value
154 .application_type
155 .map(|s| s.parse())
156 .transpose()
157 .map_err(|e| {
158 DatabaseInconsistencyError::on("oauth2_clients")
159 .column("application_type")
160 .row(id)
161 .source(e)
162 })?;
163
164 let mut grant_types = Vec::new();
165 if value.grant_type_authorization_code {
166 grant_types.push(GrantType::AuthorizationCode);
167 }
168 if value.grant_type_refresh_token {
169 grant_types.push(GrantType::RefreshToken);
170 }
171 if value.grant_type_client_credentials {
172 grant_types.push(GrantType::ClientCredentials);
173 }
174 if value.grant_type_device_code {
175 grant_types.push(GrantType::DeviceCode);
176 }
177
178 let logo_uri = value.logo_uri.map(|s| s.parse()).transpose().map_err(|e| {
179 DatabaseInconsistencyError::on("oauth2_clients")
180 .column("logo_uri")
181 .row(id)
182 .source(e)
183 })?;
184
185 let client_uri = value
186 .client_uri
187 .map(|s| s.parse())
188 .transpose()
189 .map_err(|e| {
190 DatabaseInconsistencyError::on("oauth2_clients")
191 .column("client_uri")
192 .row(id)
193 .source(e)
194 })?;
195
196 let policy_uri = value
197 .policy_uri
198 .map(|s| s.parse())
199 .transpose()
200 .map_err(|e| {
201 DatabaseInconsistencyError::on("oauth2_clients")
202 .column("policy_uri")
203 .row(id)
204 .source(e)
205 })?;
206
207 let tos_uri = value.tos_uri.map(|s| s.parse()).transpose().map_err(|e| {
208 DatabaseInconsistencyError::on("oauth2_clients")
209 .column("tos_uri")
210 .row(id)
211 .source(e)
212 })?;
213
214 let id_token_signed_response_alg = value
215 .id_token_signed_response_alg
216 .map(|s| s.parse())
217 .transpose()
218 .map_err(|e| {
219 DatabaseInconsistencyError::on("oauth2_clients")
220 .column("id_token_signed_response_alg")
221 .row(id)
222 .source(e)
223 })?;
224
225 let userinfo_signed_response_alg = value
226 .userinfo_signed_response_alg
227 .map(|s| s.parse())
228 .transpose()
229 .map_err(|e| {
230 DatabaseInconsistencyError::on("oauth2_clients")
231 .column("userinfo_signed_response_alg")
232 .row(id)
233 .source(e)
234 })?;
235
236 let token_endpoint_auth_method = value
237 .token_endpoint_auth_method
238 .map(|s| s.parse())
239 .transpose()
240 .map_err(|e| {
241 DatabaseInconsistencyError::on("oauth2_clients")
242 .column("token_endpoint_auth_method")
243 .row(id)
244 .source(e)
245 })?;
246
247 let token_endpoint_auth_signing_alg = value
248 .token_endpoint_auth_signing_alg
249 .map(|s| s.parse())
250 .transpose()
251 .map_err(|e| {
252 DatabaseInconsistencyError::on("oauth2_clients")
253 .column("token_endpoint_auth_signing_alg")
254 .row(id)
255 .source(e)
256 })?;
257
258 let initiate_login_uri = value
259 .initiate_login_uri
260 .map(|s| s.parse())
261 .transpose()
262 .map_err(|e| {
263 DatabaseInconsistencyError::on("oauth2_clients")
264 .column("initiate_login_uri")
265 .row(id)
266 .source(e)
267 })?;
268
269 let jwks = match (value.jwks, value.jwks_uri) {
270 (None, None) => None,
271 (Some(jwks), None) => {
272 let jwks = serde_json::from_value(jwks).map_err(|e| {
273 DatabaseInconsistencyError::on("oauth2_clients")
274 .column("jwks")
275 .row(id)
276 .source(e)
277 })?;
278 Some(JwksOrJwksUri::Jwks(jwks))
279 }
280 (None, Some(jwks_uri)) => {
281 let jwks_uri = jwks_uri.parse().map_err(|e| {
282 DatabaseInconsistencyError::on("oauth2_clients")
283 .column("jwks_uri")
284 .row(id)
285 .source(e)
286 })?;
287
288 Some(JwksOrJwksUri::JwksUri(jwks_uri))
289 }
290 _ => {
291 return Err(DatabaseInconsistencyError::on("oauth2_clients")
292 .column("jwks(_uri)")
293 .row(id));
294 }
295 };
296
297 Ok(Client {
298 id,
299 client_id: id.to_string(),
300 metadata_digest: value.metadata_digest,
301 encrypted_client_secret: value.encrypted_client_secret,
302 application_type,
303 redirect_uris,
304 grant_types,
305 client_name: value.client_name,
306 logo_uri,
307 client_uri,
308 policy_uri,
309 tos_uri,
310 jwks,
311 id_token_signed_response_alg,
312 userinfo_signed_response_alg,
313 token_endpoint_auth_method,
314 token_endpoint_auth_signing_alg,
315 initiate_login_uri,
316 is_static: value.is_static,
317 })
318 }
319}
320
321#[async_trait]
322impl OAuth2ClientRepository for PgOAuth2ClientRepository<'_> {
323 type Error = DatabaseError;
324
325 #[tracing::instrument(
326 name = "db.oauth2_client.lookup",
327 skip_all,
328 fields(
329 db.query.text,
330 oauth2_client.id = %id,
331 ),
332 err,
333 )]
334 async fn lookup(&mut self, id: Ulid) -> Result<Option<Client>, Self::Error> {
335 let res = sqlx::query_as!(
336 OAuth2ClientLookup,
337 r#"
338 SELECT oauth2_client_id
339 , metadata_digest
340 , encrypted_client_secret
341 , application_type
342 , redirect_uris
343 , grant_type_authorization_code
344 , grant_type_refresh_token
345 , grant_type_client_credentials
346 , grant_type_device_code
347 , client_name
348 , logo_uri
349 , client_uri
350 , policy_uri
351 , tos_uri
352 , jwks_uri
353 , jwks
354 , id_token_signed_response_alg
355 , userinfo_signed_response_alg
356 , token_endpoint_auth_method
357 , token_endpoint_auth_signing_alg
358 , initiate_login_uri
359 , is_static
360 FROM oauth2_clients c
361
362 WHERE oauth2_client_id = $1
363 "#,
364 Uuid::from(id),
365 )
366 .traced()
367 .fetch_optional(&mut *self.conn)
368 .await?;
369
370 let Some(res) = res else { return Ok(None) };
371
372 Ok(Some(res.try_into()?))
373 }
374
375 #[tracing::instrument(
376 name = "db.oauth2_client.find_by_metadata_digest",
377 skip_all,
378 fields(
379 db.query.text,
380 ),
381 err,
382 )]
383 async fn find_by_metadata_digest(
384 &mut self,
385 digest: &str,
386 ) -> Result<Option<Client>, Self::Error> {
387 let res = sqlx::query_as!(
388 OAuth2ClientLookup,
389 r#"
390 SELECT oauth2_client_id
391 , metadata_digest
392 , encrypted_client_secret
393 , application_type
394 , redirect_uris
395 , grant_type_authorization_code
396 , grant_type_refresh_token
397 , grant_type_client_credentials
398 , grant_type_device_code
399 , client_name
400 , logo_uri
401 , client_uri
402 , policy_uri
403 , tos_uri
404 , jwks_uri
405 , jwks
406 , id_token_signed_response_alg
407 , userinfo_signed_response_alg
408 , token_endpoint_auth_method
409 , token_endpoint_auth_signing_alg
410 , initiate_login_uri
411 , is_static
412 FROM oauth2_clients
413 WHERE metadata_digest = $1
414 "#,
415 digest,
416 )
417 .traced()
418 .fetch_optional(&mut *self.conn)
419 .await?;
420
421 let Some(res) = res else { return Ok(None) };
422
423 Ok(Some(res.try_into()?))
424 }
425
426 #[tracing::instrument(
427 name = "db.oauth2_client.load_batch",
428 skip_all,
429 fields(
430 db.query.text,
431 ),
432 err,
433 )]
434 async fn load_batch(
435 &mut self,
436 ids: BTreeSet<Ulid>,
437 ) -> Result<BTreeMap<Ulid, Client>, Self::Error> {
438 let ids: Vec<Uuid> = ids.into_iter().map(Uuid::from).collect();
439 let res = sqlx::query_as!(
440 OAuth2ClientLookup,
441 r#"
442 SELECT oauth2_client_id
443 , metadata_digest
444 , encrypted_client_secret
445 , application_type
446 , redirect_uris
447 , grant_type_authorization_code
448 , grant_type_refresh_token
449 , grant_type_client_credentials
450 , grant_type_device_code
451 , client_name
452 , logo_uri
453 , client_uri
454 , policy_uri
455 , tos_uri
456 , jwks_uri
457 , jwks
458 , id_token_signed_response_alg
459 , userinfo_signed_response_alg
460 , token_endpoint_auth_method
461 , token_endpoint_auth_signing_alg
462 , initiate_login_uri
463 , is_static
464 FROM oauth2_clients c
465
466 WHERE oauth2_client_id = ANY($1::uuid[])
467 "#,
468 &ids,
469 )
470 .traced()
471 .fetch_all(&mut *self.conn)
472 .await?;
473
474 res.into_iter()
475 .map(|r| {
476 r.try_into()
477 .map(|c: Client| (c.id, c))
478 .map_err(DatabaseError::from)
479 })
480 .collect()
481 }
482
483 #[tracing::instrument(
484 name = "db.oauth2_client.add",
485 skip_all,
486 fields(
487 db.query.text,
488 client.id,
489 client.name = client_name
490 ),
491 err,
492 )]
493 async fn add(
494 &mut self,
495 rng: &mut (dyn RngCore + Send),
496 clock: &dyn Clock,
497 redirect_uris: Vec<Url>,
498 metadata_digest: Option<String>,
499 encrypted_client_secret: Option<String>,
500 application_type: Option<ApplicationType>,
501 grant_types: Vec<GrantType>,
502 client_name: Option<String>,
503 logo_uri: Option<Url>,
504 client_uri: Option<Url>,
505 policy_uri: Option<Url>,
506 tos_uri: Option<Url>,
507 jwks_uri: Option<Url>,
508 jwks: Option<PublicJsonWebKeySet>,
509 id_token_signed_response_alg: Option<JsonWebSignatureAlg>,
510 userinfo_signed_response_alg: Option<JsonWebSignatureAlg>,
511 token_endpoint_auth_method: Option<OAuthClientAuthenticationMethod>,
512 token_endpoint_auth_signing_alg: Option<JsonWebSignatureAlg>,
513 initiate_login_uri: Option<Url>,
514 ) -> Result<Client, Self::Error> {
515 let now = clock.now();
516 let id = Ulid::from_datetime_with_rng(now, rng);
517 tracing::Span::current().record("client.id", tracing::field::display(id));
518
519 let jwks_json = jwks
520 .as_ref()
521 .map(serde_json::to_value)
522 .transpose()
523 .map_err(DatabaseError::to_invalid_operation)?;
524
525 let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
526
527 sqlx::query!(
528 r#"
529 INSERT INTO oauth2_clients
530 ( oauth2_client_id
531 , metadata_digest
532 , encrypted_client_secret
533 , application_type
534 , redirect_uris
535 , grant_type_authorization_code
536 , grant_type_refresh_token
537 , grant_type_client_credentials
538 , grant_type_device_code
539 , client_name
540 , logo_uri
541 , client_uri
542 , policy_uri
543 , tos_uri
544 , jwks_uri
545 , jwks
546 , id_token_signed_response_alg
547 , userinfo_signed_response_alg
548 , token_endpoint_auth_method
549 , token_endpoint_auth_signing_alg
550 , initiate_login_uri
551 , is_static
552 )
553 VALUES
554 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13,
555 $14, $15, $16, $17, $18, $19, $20, $21, FALSE)
556 "#,
557 Uuid::from(id),
558 metadata_digest,
559 encrypted_client_secret,
560 application_type.as_ref().map(ToString::to_string),
561 &redirect_uris_array,
562 grant_types.contains(&GrantType::AuthorizationCode),
563 grant_types.contains(&GrantType::RefreshToken),
564 grant_types.contains(&GrantType::ClientCredentials),
565 grant_types.contains(&GrantType::DeviceCode),
566 client_name,
567 logo_uri.as_ref().map(Url::as_str),
568 client_uri.as_ref().map(Url::as_str),
569 policy_uri.as_ref().map(Url::as_str),
570 tos_uri.as_ref().map(Url::as_str),
571 jwks_uri.as_ref().map(Url::as_str),
572 jwks_json,
573 id_token_signed_response_alg
574 .as_ref()
575 .map(ToString::to_string),
576 userinfo_signed_response_alg
577 .as_ref()
578 .map(ToString::to_string),
579 token_endpoint_auth_method.as_ref().map(ToString::to_string),
580 token_endpoint_auth_signing_alg
581 .as_ref()
582 .map(ToString::to_string),
583 initiate_login_uri.as_ref().map(Url::as_str),
584 )
585 .traced()
586 .execute(&mut *self.conn)
587 .await?;
588
589 let jwks = match (jwks, jwks_uri) {
590 (None, None) => None,
591 (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
592 (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
593 _ => return Err(DatabaseError::invalid_operation()),
594 };
595
596 Ok(Client {
597 id,
598 client_id: id.to_string(),
599 metadata_digest: None,
600 encrypted_client_secret,
601 application_type,
602 redirect_uris,
603 grant_types,
604 client_name,
605 logo_uri,
606 client_uri,
607 policy_uri,
608 tos_uri,
609 jwks,
610 id_token_signed_response_alg,
611 userinfo_signed_response_alg,
612 token_endpoint_auth_method,
613 token_endpoint_auth_signing_alg,
614 initiate_login_uri,
615 is_static: false,
616 })
617 }
618
619 #[tracing::instrument(
620 name = "db.oauth2_client.upsert_static",
621 skip_all,
622 fields(
623 db.query.text,
624 client.id = %client_id,
625 ),
626 err,
627 )]
628 async fn upsert_static(
629 &mut self,
630 client_id: Ulid,
631 client_name: Option<String>,
632 client_auth_method: OAuthClientAuthenticationMethod,
633 encrypted_client_secret: Option<String>,
634 jwks: Option<PublicJsonWebKeySet>,
635 jwks_uri: Option<Url>,
636 redirect_uris: Vec<Url>,
637 ) -> Result<Client, Self::Error> {
638 let jwks_json = jwks
639 .as_ref()
640 .map(serde_json::to_value)
641 .transpose()
642 .map_err(DatabaseError::to_invalid_operation)?;
643
644 let client_auth_method = client_auth_method.to_string();
645 let redirect_uris_array = redirect_uris.iter().map(Url::to_string).collect::<Vec<_>>();
646
647 sqlx::query!(
648 r#"
649 INSERT INTO oauth2_clients
650 ( oauth2_client_id
651 , encrypted_client_secret
652 , redirect_uris
653 , grant_type_authorization_code
654 , grant_type_refresh_token
655 , grant_type_client_credentials
656 , grant_type_device_code
657 , token_endpoint_auth_method
658 , jwks
659 , client_name
660 , jwks_uri
661 , is_static
662 )
663 VALUES
664 ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, TRUE)
665 ON CONFLICT (oauth2_client_id)
666 DO
667 UPDATE SET encrypted_client_secret = EXCLUDED.encrypted_client_secret
668 , redirect_uris = EXCLUDED.redirect_uris
669 , grant_type_authorization_code = EXCLUDED.grant_type_authorization_code
670 , grant_type_refresh_token = EXCLUDED.grant_type_refresh_token
671 , grant_type_client_credentials = EXCLUDED.grant_type_client_credentials
672 , grant_type_device_code = EXCLUDED.grant_type_device_code
673 , token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method
674 , jwks = EXCLUDED.jwks
675 , client_name = EXCLUDED.client_name
676 , jwks_uri = EXCLUDED.jwks_uri
677 , is_static = TRUE
678 "#,
679 Uuid::from(client_id),
680 encrypted_client_secret,
681 &redirect_uris_array,
682 true,
683 true,
684 true,
685 true,
686 client_auth_method,
687 jwks_json,
688 client_name,
689 jwks_uri.as_ref().map(Url::as_str),
690 )
691 .traced()
692 .execute(&mut *self.conn)
693 .await?;
694
695 let jwks = match (jwks, jwks_uri) {
696 (None, None) => None,
697 (Some(jwks), None) => Some(JwksOrJwksUri::Jwks(jwks)),
698 (None, Some(jwks_uri)) => Some(JwksOrJwksUri::JwksUri(jwks_uri)),
699 _ => return Err(DatabaseError::invalid_operation()),
700 };
701
702 Ok(Client {
703 id: client_id,
704 client_id: client_id.to_string(),
705 metadata_digest: None,
706 encrypted_client_secret,
707 application_type: None,
708 redirect_uris,
709 grant_types: vec![
710 GrantType::AuthorizationCode,
711 GrantType::RefreshToken,
712 GrantType::ClientCredentials,
713 ],
714 client_name,
715 logo_uri: None,
716 client_uri: None,
717 policy_uri: None,
718 tos_uri: None,
719 jwks,
720 id_token_signed_response_alg: None,
721 userinfo_signed_response_alg: None,
722 token_endpoint_auth_method: None,
723 token_endpoint_auth_signing_alg: None,
724 initiate_login_uri: None,
725 is_static: true,
726 })
727 }
728
729 #[tracing::instrument(
730 name = "db.oauth2_client.all_static",
731 skip_all,
732 fields(
733 db.query.text,
734 ),
735 err,
736 )]
737 async fn all_static(&mut self) -> Result<Vec<Client>, Self::Error> {
738 let res = sqlx::query_as!(
739 OAuth2ClientLookup,
740 r#"
741 SELECT oauth2_client_id
742 , metadata_digest
743 , encrypted_client_secret
744 , application_type
745 , redirect_uris
746 , grant_type_authorization_code
747 , grant_type_refresh_token
748 , grant_type_client_credentials
749 , grant_type_device_code
750 , client_name
751 , logo_uri
752 , client_uri
753 , policy_uri
754 , tos_uri
755 , jwks_uri
756 , jwks
757 , id_token_signed_response_alg
758 , userinfo_signed_response_alg
759 , token_endpoint_auth_method
760 , token_endpoint_auth_signing_alg
761 , initiate_login_uri
762 , is_static
763 FROM oauth2_clients c
764 WHERE is_static = TRUE
765 "#,
766 )
767 .traced()
768 .fetch_all(&mut *self.conn)
769 .await?;
770
771 res.into_iter()
772 .map(|r| r.try_into().map_err(DatabaseError::from))
773 .collect()
774 }
775
776 #[tracing::instrument(
777 name = "db.oauth2_client.delete_by_id",
778 skip_all,
779 fields(
780 db.query.text,
781 client.id = %id,
782 ),
783 err,
784 )]
785 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
786 {
788 let span = info_span!(
789 "db.oauth2_client.delete_by_id.authorization_grants",
790 { DB_QUERY_TEXT } = tracing::field::Empty,
791 );
792
793 sqlx::query!(
794 r#"
795 DELETE FROM oauth2_authorization_grants
796 WHERE oauth2_client_id = $1
797 "#,
798 Uuid::from(id),
799 )
800 .record(&span)
801 .execute(&mut *self.conn)
802 .instrument(span)
803 .await?;
804 }
805
806 {
808 let span = info_span!(
809 "db.oauth2_client.delete_by_id.access_tokens",
810 { DB_QUERY_TEXT } = tracing::field::Empty,
811 );
812
813 sqlx::query!(
814 r#"
815 DELETE FROM oauth2_access_tokens
816 WHERE oauth2_session_id IN (
817 SELECT oauth2_session_id
818 FROM oauth2_sessions
819 WHERE oauth2_client_id = $1
820 )
821 "#,
822 Uuid::from(id),
823 )
824 .record(&span)
825 .execute(&mut *self.conn)
826 .instrument(span)
827 .await?;
828 }
829
830 {
831 let span = info_span!(
832 "db.oauth2_client.delete_by_id.refresh_tokens",
833 { DB_QUERY_TEXT } = tracing::field::Empty,
834 );
835
836 sqlx::query!(
837 r#"
838 DELETE FROM oauth2_refresh_tokens
839 WHERE oauth2_session_id IN (
840 SELECT oauth2_session_id
841 FROM oauth2_sessions
842 WHERE oauth2_client_id = $1
843 )
844 "#,
845 Uuid::from(id),
846 )
847 .record(&span)
848 .execute(&mut *self.conn)
849 .instrument(span)
850 .await?;
851 }
852
853 {
854 let span = info_span!(
855 "db.oauth2_client.delete_by_id.sessions",
856 { DB_QUERY_TEXT } = tracing::field::Empty,
857 );
858
859 sqlx::query!(
860 r#"
861 DELETE FROM oauth2_sessions
862 WHERE oauth2_client_id = $1
863 "#,
864 Uuid::from(id),
865 )
866 .record(&span)
867 .execute(&mut *self.conn)
868 .instrument(span)
869 .await?;
870 }
871
872 {
875 let span = info_span!(
876 "db.oauth2_client.delete_by_id.personal_access_tokens",
877 { DB_QUERY_TEXT } = tracing::field::Empty,
878 );
879
880 sqlx::query!(
881 r#"
882 DELETE FROM personal_access_tokens
883 WHERE personal_session_id IN (
884 SELECT personal_session_id
885 FROM personal_sessions
886 WHERE owner_oauth2_client_id = $1
887 )
888 "#,
889 Uuid::from(id),
890 )
891 .record(&span)
892 .execute(&mut *self.conn)
893 .instrument(span)
894 .await?;
895 }
896 {
897 let span = info_span!(
898 "db.oauth2_client.delete_by_id.personal_sessions",
899 { DB_QUERY_TEXT } = tracing::field::Empty,
900 );
901
902 sqlx::query!(
903 r#"
904 DELETE FROM personal_sessions
905 WHERE owner_oauth2_client_id = $1
906 "#,
907 Uuid::from(id),
908 )
909 .record(&span)
910 .execute(&mut *self.conn)
911 .instrument(span)
912 .await?;
913 }
914
915 let res = sqlx::query!(
917 r#"
918 DELETE FROM oauth2_clients
919 WHERE oauth2_client_id = $1
920 "#,
921 Uuid::from(id),
922 )
923 .traced()
924 .execute(&mut *self.conn)
925 .await?;
926
927 DatabaseError::ensure_affected_rows(&res, 1)
928 }
929
930 #[tracing::instrument(
931 name = "db.oauth2_client.list",
932 skip_all,
933 fields(
934 db.query.text,
935 ),
936 err,
937 )]
938 async fn list(
939 &mut self,
940 filter: OAuth2ClientFilter<'_>,
941 pagination: Pagination,
942 ) -> Result<Page<Client>, Self::Error> {
943 let (sql, arguments) = Query::select()
944 .expr_as(
945 Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)),
946 OAuth2ClientLookupIden::Oauth2ClientId,
947 )
948 .expr_as(
949 Expr::cust("metadata_digest"),
950 OAuth2ClientLookupIden::MetadataDigest,
951 )
952 .expr_as(
953 Expr::cust("encrypted_client_secret"),
954 OAuth2ClientLookupIden::EncryptedClientSecret,
955 )
956 .expr_as(
957 Expr::cust("application_type"),
958 OAuth2ClientLookupIden::ApplicationType,
959 )
960 .expr_as(
961 Expr::col((OAuth2Clients::Table, OAuth2Clients::RedirectUris)),
962 OAuth2ClientLookupIden::RedirectUris,
963 )
964 .expr_as(
965 Expr::cust("grant_type_authorization_code"),
966 OAuth2ClientLookupIden::GrantTypeAuthorizationCode,
967 )
968 .expr_as(
969 Expr::cust("grant_type_refresh_token"),
970 OAuth2ClientLookupIden::GrantTypeRefreshToken,
971 )
972 .expr_as(
973 Expr::cust("grant_type_client_credentials"),
974 OAuth2ClientLookupIden::GrantTypeClientCredentials,
975 )
976 .expr_as(
977 Expr::cust("grant_type_device_code"),
978 OAuth2ClientLookupIden::GrantTypeDeviceCode,
979 )
980 .expr_as(
981 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientName)),
982 OAuth2ClientLookupIden::ClientName,
983 )
984 .expr_as(
985 Expr::col((OAuth2Clients::Table, OAuth2Clients::LogoUri)),
986 OAuth2ClientLookupIden::LogoUri,
987 )
988 .expr_as(
989 Expr::col((OAuth2Clients::Table, OAuth2Clients::ClientUri)),
990 OAuth2ClientLookupIden::ClientUri,
991 )
992 .expr_as(Expr::cust("policy_uri"), OAuth2ClientLookupIden::PolicyUri)
993 .expr_as(Expr::cust("tos_uri"), OAuth2ClientLookupIden::TosUri)
994 .expr_as(Expr::cust("jwks_uri"), OAuth2ClientLookupIden::JwksUri)
995 .expr_as(Expr::cust("jwks"), OAuth2ClientLookupIden::Jwks)
996 .expr_as(
997 Expr::cust("id_token_signed_response_alg"),
998 OAuth2ClientLookupIden::IdTokenSignedResponseAlg,
999 )
1000 .expr_as(
1001 Expr::cust("userinfo_signed_response_alg"),
1002 OAuth2ClientLookupIden::UserinfoSignedResponseAlg,
1003 )
1004 .expr_as(
1005 Expr::cust("token_endpoint_auth_method"),
1006 OAuth2ClientLookupIden::TokenEndpointAuthMethod,
1007 )
1008 .expr_as(
1009 Expr::cust("token_endpoint_auth_signing_alg"),
1010 OAuth2ClientLookupIden::TokenEndpointAuthSigningAlg,
1011 )
1012 .expr_as(
1013 Expr::cust("initiate_login_uri"),
1014 OAuth2ClientLookupIden::InitiateLoginUri,
1015 )
1016 .expr_as(
1017 Expr::col((OAuth2Clients::Table, OAuth2Clients::IsStatic)),
1018 OAuth2ClientLookupIden::IsStatic,
1019 )
1020 .from(OAuth2Clients::Table)
1021 .apply_filter(filter)
1022 .generate_pagination(
1023 (OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId),
1024 pagination,
1025 )
1026 .build_sqlx(PostgresQueryBuilder);
1027
1028 let edges: Vec<OAuth2ClientLookup> = sqlx::query_as_with(&sql, arguments)
1029 .traced()
1030 .fetch_all(&mut *self.conn)
1031 .await?;
1032
1033 let page = pagination.process(edges).try_map(Client::try_from)?;
1034
1035 Ok(page)
1036 }
1037
1038 #[tracing::instrument(
1039 name = "db.oauth2_client.count",
1040 skip_all,
1041 fields(
1042 db.query.text,
1043 ),
1044 err,
1045 )]
1046 async fn count(&mut self, filter: OAuth2ClientFilter<'_>) -> Result<usize, Self::Error> {
1047 let (sql, arguments) = Query::select()
1048 .expr(Expr::col((OAuth2Clients::Table, OAuth2Clients::OAuth2ClientId)).count())
1049 .from(OAuth2Clients::Table)
1050 .apply_filter(filter)
1051 .build_sqlx(PostgresQueryBuilder);
1052
1053 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
1054 .traced()
1055 .fetch_one(&mut *self.conn)
1056 .await?;
1057
1058 count
1059 .try_into()
1060 .map_err(DatabaseError::to_invalid_operation)
1061 }
1062}