1use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 Clock, UlidExt as _, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
12};
13use mas_storage::{
14 Page, Pagination,
15 pagination::Node,
16 upstream_oauth2::{
17 UpstreamOAuthProviderFilter, UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
18 },
19};
20use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
21use rand::RngCore;
22use sea_query::{Expr, ExprTrait, PostgresQueryBuilder, Query, enum_def};
23use sea_query_sqlx::SqlxBinder;
24use sqlx::{PgConnection, types::Json};
25use tracing::{Instrument, info_span};
26use ulid::Ulid;
27use uuid::Uuid;
28
29use crate::{
30 DatabaseError, DatabaseInconsistencyError,
31 filter::{Filter, StatementExt},
32 iden::UpstreamOAuthProviders,
33 pagination::QueryBuilderExt,
34 tracing::ExecuteExt,
35};
36
37pub struct PgUpstreamOAuthProviderRepository<'c> {
40 conn: &'c mut PgConnection,
41}
42
43impl<'c> PgUpstreamOAuthProviderRepository<'c> {
44 pub fn new(conn: &'c mut PgConnection) -> Self {
47 Self { conn }
48 }
49}
50
51#[derive(sqlx::FromRow)]
52#[enum_def]
53struct ProviderLookup {
54 upstream_oauth_provider_id: Uuid,
55 issuer: Option<String>,
56 human_name: Option<String>,
57 brand_name: Option<String>,
58 scope: String,
59 client_id: String,
60 encrypted_client_secret: Option<String>,
61 token_endpoint_signing_alg: Option<String>,
62 token_endpoint_auth_method: String,
63 id_token_signed_response_alg: String,
64 fetch_userinfo: bool,
65 userinfo_signed_response_alg: Option<String>,
66 created_at: DateTime<Utc>,
67 disabled_at: Option<DateTime<Utc>>,
68 claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
69 jwks_uri_override: Option<String>,
70 authorization_endpoint_override: Option<String>,
71 token_endpoint_override: Option<String>,
72 userinfo_endpoint_override: Option<String>,
73 discovery_mode: String,
74 pkce_mode: String,
75 response_mode: Option<String>,
76 additional_parameters: Option<Json<Vec<(String, String)>>>,
77 forward_login_hint: bool,
78 on_backchannel_logout: String,
79 registration_token_required: bool,
80}
81
82impl Node<Ulid> for ProviderLookup {
83 fn cursor(&self) -> Ulid {
84 self.upstream_oauth_provider_id.into()
85 }
86}
87
88impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
89 type Error = DatabaseInconsistencyError;
90
91 fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
92 let id = value.upstream_oauth_provider_id.into();
93 let scope = value.scope.parse().map_err(|e| {
94 DatabaseInconsistencyError::on("upstream_oauth_providers")
95 .column("scope")
96 .row(id)
97 .source(e)
98 })?;
99 let token_endpoint_auth_method = value.token_endpoint_auth_method.parse().map_err(|e| {
100 DatabaseInconsistencyError::on("upstream_oauth_providers")
101 .column("token_endpoint_auth_method")
102 .row(id)
103 .source(e)
104 })?;
105 let token_endpoint_signing_alg = value
106 .token_endpoint_signing_alg
107 .map(|x| x.parse())
108 .transpose()
109 .map_err(|e| {
110 DatabaseInconsistencyError::on("upstream_oauth_providers")
111 .column("token_endpoint_signing_alg")
112 .row(id)
113 .source(e)
114 })?;
115 let id_token_signed_response_alg =
116 value.id_token_signed_response_alg.parse().map_err(|e| {
117 DatabaseInconsistencyError::on("upstream_oauth_providers")
118 .column("id_token_signed_response_alg")
119 .row(id)
120 .source(e)
121 })?;
122
123 let userinfo_signed_response_alg = value
124 .userinfo_signed_response_alg
125 .map(|x| x.parse())
126 .transpose()
127 .map_err(|e| {
128 DatabaseInconsistencyError::on("upstream_oauth_providers")
129 .column("userinfo_signed_response_alg")
130 .row(id)
131 .source(e)
132 })?;
133
134 let authorization_endpoint_override = value
135 .authorization_endpoint_override
136 .map(|x| x.parse())
137 .transpose()
138 .map_err(|e| {
139 DatabaseInconsistencyError::on("upstream_oauth_providers")
140 .column("authorization_endpoint_override")
141 .row(id)
142 .source(e)
143 })?;
144
145 let token_endpoint_override = value
146 .token_endpoint_override
147 .map(|x| x.parse())
148 .transpose()
149 .map_err(|e| {
150 DatabaseInconsistencyError::on("upstream_oauth_providers")
151 .column("token_endpoint_override")
152 .row(id)
153 .source(e)
154 })?;
155
156 let userinfo_endpoint_override = value
157 .userinfo_endpoint_override
158 .map(|x| x.parse())
159 .transpose()
160 .map_err(|e| {
161 DatabaseInconsistencyError::on("upstream_oauth_providers")
162 .column("userinfo_endpoint_override")
163 .row(id)
164 .source(e)
165 })?;
166
167 let jwks_uri_override = value
168 .jwks_uri_override
169 .map(|x| x.parse())
170 .transpose()
171 .map_err(|e| {
172 DatabaseInconsistencyError::on("upstream_oauth_providers")
173 .column("jwks_uri_override")
174 .row(id)
175 .source(e)
176 })?;
177
178 let discovery_mode = value.discovery_mode.parse().map_err(|e| {
179 DatabaseInconsistencyError::on("upstream_oauth_providers")
180 .column("discovery_mode")
181 .row(id)
182 .source(e)
183 })?;
184
185 let pkce_mode = value.pkce_mode.parse().map_err(|e| {
186 DatabaseInconsistencyError::on("upstream_oauth_providers")
187 .column("pkce_mode")
188 .row(id)
189 .source(e)
190 })?;
191
192 let response_mode = value
193 .response_mode
194 .map(|x| x.parse())
195 .transpose()
196 .map_err(|e| {
197 DatabaseInconsistencyError::on("upstream_oauth_providers")
198 .column("response_mode")
199 .row(id)
200 .source(e)
201 })?;
202
203 let additional_authorization_parameters = value
204 .additional_parameters
205 .map(|Json(x)| x)
206 .unwrap_or_default();
207
208 let on_backchannel_logout = value.on_backchannel_logout.parse().map_err(|e| {
209 DatabaseInconsistencyError::on("upstream_oauth_providers")
210 .column("on_backchannel_logout")
211 .row(id)
212 .source(e)
213 })?;
214
215 Ok(UpstreamOAuthProvider {
216 id,
217 issuer: value.issuer,
218 human_name: value.human_name,
219 brand_name: value.brand_name,
220 scope,
221 client_id: value.client_id,
222 encrypted_client_secret: value.encrypted_client_secret,
223 token_endpoint_auth_method,
224 token_endpoint_signing_alg,
225 id_token_signed_response_alg,
226 fetch_userinfo: value.fetch_userinfo,
227 userinfo_signed_response_alg,
228 created_at: value.created_at,
229 disabled_at: value.disabled_at,
230 claims_imports: value.claims_imports.0,
231 authorization_endpoint_override,
232 token_endpoint_override,
233 userinfo_endpoint_override,
234 jwks_uri_override,
235 discovery_mode,
236 pkce_mode,
237 response_mode,
238 additional_authorization_parameters,
239 forward_login_hint: value.forward_login_hint,
240 on_backchannel_logout,
241 registration_token_required: value.registration_token_required,
242 })
243 }
244}
245
246impl Filter for UpstreamOAuthProviderFilter<'_> {
247 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
248 sea_query::Condition::all().add_option(self.enabled().map(|enabled| {
249 Expr::col((
250 UpstreamOAuthProviders::Table,
251 UpstreamOAuthProviders::DisabledAt,
252 ))
253 .is_null()
254 .eq(enabled)
255 }))
256 }
257}
258
259#[async_trait]
260impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
261 type Error = DatabaseError;
262
263 #[tracing::instrument(
264 name = "db.upstream_oauth_provider.lookup",
265 skip_all,
266 fields(
267 db.query.text,
268 upstream_oauth_provider.id = %id,
269 ),
270 err,
271 )]
272 async fn lookup(&mut self, id: Ulid) -> Result<Option<UpstreamOAuthProvider>, Self::Error> {
273 let res = sqlx::query_as!(
274 ProviderLookup,
275 r#"
276 SELECT
277 upstream_oauth_provider_id,
278 issuer,
279 human_name,
280 brand_name,
281 scope,
282 client_id,
283 encrypted_client_secret,
284 token_endpoint_signing_alg,
285 token_endpoint_auth_method,
286 id_token_signed_response_alg,
287 fetch_userinfo,
288 userinfo_signed_response_alg,
289 created_at,
290 disabled_at,
291 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
292 jwks_uri_override,
293 authorization_endpoint_override,
294 token_endpoint_override,
295 userinfo_endpoint_override,
296 discovery_mode,
297 pkce_mode,
298 response_mode,
299 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
300 forward_login_hint,
301 on_backchannel_logout,
302 registration_token_required
303 FROM upstream_oauth_providers
304 WHERE upstream_oauth_provider_id = $1
305 "#,
306 Uuid::from(id),
307 )
308 .traced()
309 .fetch_optional(&mut *self.conn)
310 .await?;
311
312 let res = res
313 .map(UpstreamOAuthProvider::try_from)
314 .transpose()
315 .map_err(DatabaseError::from)?;
316
317 Ok(res)
318 }
319
320 #[tracing::instrument(
321 name = "db.upstream_oauth_provider.add",
322 skip_all,
323 fields(
324 db.query.text,
325 upstream_oauth_provider.id,
326 upstream_oauth_provider.issuer = params.issuer,
327 upstream_oauth_provider.client_id = %params.client_id,
328 ),
329 err,
330 )]
331 async fn add(
332 &mut self,
333 rng: &mut (dyn RngCore + Send),
334 clock: &dyn Clock,
335 params: UpstreamOAuthProviderParams,
336 ) -> Result<UpstreamOAuthProvider, Self::Error> {
337 let created_at = clock.now();
338 let id = Ulid::from_datetime_with_rng(created_at, rng);
339 tracing::Span::current().record("upstream_oauth_provider.id", tracing::field::display(id));
340
341 sqlx::query!(
342 r#"
343 INSERT INTO upstream_oauth_providers (
344 upstream_oauth_provider_id,
345 issuer,
346 human_name,
347 brand_name,
348 scope,
349 token_endpoint_auth_method,
350 token_endpoint_signing_alg,
351 id_token_signed_response_alg,
352 fetch_userinfo,
353 userinfo_signed_response_alg,
354 client_id,
355 encrypted_client_secret,
356 claims_imports,
357 authorization_endpoint_override,
358 token_endpoint_override,
359 userinfo_endpoint_override,
360 jwks_uri_override,
361 discovery_mode,
362 pkce_mode,
363 response_mode,
364 forward_login_hint,
365 on_backchannel_logout,
366 registration_token_required,
367 created_at
368 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
369 $12, $13, $14, $15, $16, $17, $18, $19, $20,
370 $21, $22, $23, $24)
371 "#,
372 Uuid::from(id),
373 params.issuer.as_deref(),
374 params.human_name.as_deref(),
375 params.brand_name.as_deref(),
376 params.scope.to_string(),
377 params.token_endpoint_auth_method.to_string(),
378 params
379 .token_endpoint_signing_alg
380 .as_ref()
381 .map(ToString::to_string),
382 params.id_token_signed_response_alg.to_string(),
383 params.fetch_userinfo,
384 params
385 .userinfo_signed_response_alg
386 .as_ref()
387 .map(ToString::to_string),
388 ¶ms.client_id,
389 params.encrypted_client_secret.as_deref(),
390 Json(¶ms.claims_imports) as _,
391 params
392 .authorization_endpoint_override
393 .as_ref()
394 .map(ToString::to_string),
395 params
396 .token_endpoint_override
397 .as_ref()
398 .map(ToString::to_string),
399 params
400 .userinfo_endpoint_override
401 .as_ref()
402 .map(ToString::to_string),
403 params.jwks_uri_override.as_ref().map(ToString::to_string),
404 params.discovery_mode.as_str(),
405 params.pkce_mode.as_str(),
406 params.response_mode.as_ref().map(ToString::to_string),
407 params.forward_login_hint,
408 params.on_backchannel_logout.as_str(),
409 params.registration_token_required,
410 created_at,
411 )
412 .traced()
413 .execute(&mut *self.conn)
414 .await?;
415
416 Ok(UpstreamOAuthProvider {
417 id,
418 issuer: params.issuer,
419 human_name: params.human_name,
420 brand_name: params.brand_name,
421 scope: params.scope,
422 client_id: params.client_id,
423 encrypted_client_secret: params.encrypted_client_secret,
424 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
425 token_endpoint_auth_method: params.token_endpoint_auth_method,
426 id_token_signed_response_alg: params.id_token_signed_response_alg,
427 fetch_userinfo: params.fetch_userinfo,
428 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
429 created_at,
430 disabled_at: None,
431 claims_imports: params.claims_imports,
432 authorization_endpoint_override: params.authorization_endpoint_override,
433 token_endpoint_override: params.token_endpoint_override,
434 userinfo_endpoint_override: params.userinfo_endpoint_override,
435 jwks_uri_override: params.jwks_uri_override,
436 discovery_mode: params.discovery_mode,
437 pkce_mode: params.pkce_mode,
438 response_mode: params.response_mode,
439 additional_authorization_parameters: params.additional_authorization_parameters,
440 on_backchannel_logout: params.on_backchannel_logout,
441 forward_login_hint: params.forward_login_hint,
442 registration_token_required: params.registration_token_required,
443 })
444 }
445
446 #[tracing::instrument(
447 name = "db.upstream_oauth_provider.delete_by_id",
448 skip_all,
449 fields(
450 db.query.text,
451 upstream_oauth_provider.id = %id,
452 ),
453 err,
454 )]
455 async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> {
456 {
459 let span = info_span!(
460 "db.oauth2_client.delete_by_id.authorization_sessions",
461 upstream_oauth_provider.id = %id,
462 { DB_QUERY_TEXT } = tracing::field::Empty,
463 );
464 sqlx::query!(
465 r#"
466 DELETE FROM upstream_oauth_authorization_sessions
467 WHERE upstream_oauth_provider_id = $1
468 "#,
469 Uuid::from(id),
470 )
471 .record(&span)
472 .execute(&mut *self.conn)
473 .instrument(span)
474 .await?;
475 }
476
477 {
480 let span = info_span!(
481 "db.oauth2_client.delete_by_id.links",
482 upstream_oauth_provider.id = %id,
483 { DB_QUERY_TEXT } = tracing::field::Empty,
484 );
485 sqlx::query!(
486 r#"
487 DELETE FROM upstream_oauth_links
488 WHERE upstream_oauth_provider_id = $1
489 "#,
490 Uuid::from(id),
491 )
492 .record(&span)
493 .execute(&mut *self.conn)
494 .instrument(span)
495 .await?;
496 }
497
498 let res = sqlx::query!(
499 r#"
500 DELETE FROM upstream_oauth_providers
501 WHERE upstream_oauth_provider_id = $1
502 "#,
503 Uuid::from(id),
504 )
505 .traced()
506 .execute(&mut *self.conn)
507 .await?;
508
509 DatabaseError::ensure_affected_rows(&res, 1)
510 }
511
512 #[tracing::instrument(
513 name = "db.upstream_oauth_provider.add",
514 skip_all,
515 fields(
516 db.query.text,
517 upstream_oauth_provider.id = %id,
518 upstream_oauth_provider.issuer = params.issuer,
519 upstream_oauth_provider.client_id = %params.client_id,
520 ),
521 err,
522 )]
523 async fn upsert(
524 &mut self,
525 clock: &dyn Clock,
526 id: Ulid,
527 params: UpstreamOAuthProviderParams,
528 ) -> Result<UpstreamOAuthProvider, Self::Error> {
529 let created_at = clock.now();
530
531 let created_at = sqlx::query_scalar!(
532 r#"
533 INSERT INTO upstream_oauth_providers (
534 upstream_oauth_provider_id,
535 issuer,
536 human_name,
537 brand_name,
538 scope,
539 token_endpoint_auth_method,
540 token_endpoint_signing_alg,
541 id_token_signed_response_alg,
542 fetch_userinfo,
543 userinfo_signed_response_alg,
544 client_id,
545 encrypted_client_secret,
546 claims_imports,
547 authorization_endpoint_override,
548 token_endpoint_override,
549 userinfo_endpoint_override,
550 jwks_uri_override,
551 discovery_mode,
552 pkce_mode,
553 response_mode,
554 additional_parameters,
555 forward_login_hint,
556 ui_order,
557 on_backchannel_logout,
558 registration_token_required,
559 created_at
560 ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
561 $11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
562 $21, $22, $23, $24, $25, $26)
563 ON CONFLICT (upstream_oauth_provider_id)
564 DO UPDATE
565 SET
566 issuer = EXCLUDED.issuer,
567 human_name = EXCLUDED.human_name,
568 brand_name = EXCLUDED.brand_name,
569 scope = EXCLUDED.scope,
570 token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
571 token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
572 id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,
573 fetch_userinfo = EXCLUDED.fetch_userinfo,
574 userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,
575 disabled_at = NULL,
576 client_id = EXCLUDED.client_id,
577 encrypted_client_secret = EXCLUDED.encrypted_client_secret,
578 claims_imports = EXCLUDED.claims_imports,
579 authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
580 token_endpoint_override = EXCLUDED.token_endpoint_override,
581 userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
582 jwks_uri_override = EXCLUDED.jwks_uri_override,
583 discovery_mode = EXCLUDED.discovery_mode,
584 pkce_mode = EXCLUDED.pkce_mode,
585 response_mode = EXCLUDED.response_mode,
586 additional_parameters = EXCLUDED.additional_parameters,
587 forward_login_hint = EXCLUDED.forward_login_hint,
588 ui_order = EXCLUDED.ui_order,
589 on_backchannel_logout = EXCLUDED.on_backchannel_logout,
590 registration_token_required = EXCLUDED.registration_token_required
591 RETURNING created_at
592 "#,
593 Uuid::from(id),
594 params.issuer.as_deref(),
595 params.human_name.as_deref(),
596 params.brand_name.as_deref(),
597 params.scope.to_string(),
598 params.token_endpoint_auth_method.to_string(),
599 params
600 .token_endpoint_signing_alg
601 .as_ref()
602 .map(ToString::to_string),
603 params.id_token_signed_response_alg.to_string(),
604 params.fetch_userinfo,
605 params
606 .userinfo_signed_response_alg
607 .as_ref()
608 .map(ToString::to_string),
609 ¶ms.client_id,
610 params.encrypted_client_secret.as_deref(),
611 Json(¶ms.claims_imports) as _,
612 params
613 .authorization_endpoint_override
614 .as_ref()
615 .map(ToString::to_string),
616 params
617 .token_endpoint_override
618 .as_ref()
619 .map(ToString::to_string),
620 params
621 .userinfo_endpoint_override
622 .as_ref()
623 .map(ToString::to_string),
624 params.jwks_uri_override.as_ref().map(ToString::to_string),
625 params.discovery_mode.as_str(),
626 params.pkce_mode.as_str(),
627 params.response_mode.as_ref().map(ToString::to_string),
628 Json(¶ms.additional_authorization_parameters) as _,
629 params.forward_login_hint,
630 params.ui_order,
631 params.on_backchannel_logout.as_str(),
632 params.registration_token_required,
633 created_at,
634 )
635 .traced()
636 .fetch_one(&mut *self.conn)
637 .await?;
638
639 Ok(UpstreamOAuthProvider {
640 id,
641 issuer: params.issuer,
642 human_name: params.human_name,
643 brand_name: params.brand_name,
644 scope: params.scope,
645 client_id: params.client_id,
646 encrypted_client_secret: params.encrypted_client_secret,
647 token_endpoint_signing_alg: params.token_endpoint_signing_alg,
648 token_endpoint_auth_method: params.token_endpoint_auth_method,
649 id_token_signed_response_alg: params.id_token_signed_response_alg,
650 fetch_userinfo: params.fetch_userinfo,
651 userinfo_signed_response_alg: params.userinfo_signed_response_alg,
652 created_at,
653 disabled_at: None,
654 claims_imports: params.claims_imports,
655 authorization_endpoint_override: params.authorization_endpoint_override,
656 token_endpoint_override: params.token_endpoint_override,
657 userinfo_endpoint_override: params.userinfo_endpoint_override,
658 jwks_uri_override: params.jwks_uri_override,
659 discovery_mode: params.discovery_mode,
660 pkce_mode: params.pkce_mode,
661 response_mode: params.response_mode,
662 additional_authorization_parameters: params.additional_authorization_parameters,
663 forward_login_hint: params.forward_login_hint,
664 on_backchannel_logout: params.on_backchannel_logout,
665 registration_token_required: params.registration_token_required,
666 })
667 }
668
669 #[tracing::instrument(
670 name = "db.upstream_oauth_provider.disable",
671 skip_all,
672 fields(
673 db.query.text,
674 %upstream_oauth_provider.id,
675 ),
676 err,
677 )]
678 async fn disable(
679 &mut self,
680 clock: &dyn Clock,
681 mut upstream_oauth_provider: UpstreamOAuthProvider,
682 ) -> Result<UpstreamOAuthProvider, Self::Error> {
683 let disabled_at = clock.now();
684 let res = sqlx::query!(
685 r#"
686 UPDATE upstream_oauth_providers
687 SET disabled_at = $2
688 WHERE upstream_oauth_provider_id = $1
689 "#,
690 Uuid::from(upstream_oauth_provider.id),
691 disabled_at,
692 )
693 .traced()
694 .execute(&mut *self.conn)
695 .await?;
696
697 DatabaseError::ensure_affected_rows(&res, 1)?;
698
699 upstream_oauth_provider.disabled_at = Some(disabled_at);
700
701 Ok(upstream_oauth_provider)
702 }
703
704 #[tracing::instrument(
705 name = "db.upstream_oauth_provider.list",
706 skip_all,
707 fields(
708 db.query.text,
709 ),
710 err,
711 )]
712 async fn list(
713 &mut self,
714 filter: UpstreamOAuthProviderFilter<'_>,
715 pagination: Pagination,
716 ) -> Result<Page<UpstreamOAuthProvider>, Self::Error> {
717 let (sql, arguments) = Query::select()
718 .expr_as(
719 Expr::col((
720 UpstreamOAuthProviders::Table,
721 UpstreamOAuthProviders::UpstreamOAuthProviderId,
722 )),
723 ProviderLookupIden::UpstreamOauthProviderId,
724 )
725 .expr_as(
726 Expr::col((
727 UpstreamOAuthProviders::Table,
728 UpstreamOAuthProviders::Issuer,
729 )),
730 ProviderLookupIden::Issuer,
731 )
732 .expr_as(
733 Expr::col((
734 UpstreamOAuthProviders::Table,
735 UpstreamOAuthProviders::HumanName,
736 )),
737 ProviderLookupIden::HumanName,
738 )
739 .expr_as(
740 Expr::col((
741 UpstreamOAuthProviders::Table,
742 UpstreamOAuthProviders::BrandName,
743 )),
744 ProviderLookupIden::BrandName,
745 )
746 .expr_as(
747 Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)),
748 ProviderLookupIden::Scope,
749 )
750 .expr_as(
751 Expr::col((
752 UpstreamOAuthProviders::Table,
753 UpstreamOAuthProviders::ClientId,
754 )),
755 ProviderLookupIden::ClientId,
756 )
757 .expr_as(
758 Expr::col((
759 UpstreamOAuthProviders::Table,
760 UpstreamOAuthProviders::EncryptedClientSecret,
761 )),
762 ProviderLookupIden::EncryptedClientSecret,
763 )
764 .expr_as(
765 Expr::col((
766 UpstreamOAuthProviders::Table,
767 UpstreamOAuthProviders::TokenEndpointSigningAlg,
768 )),
769 ProviderLookupIden::TokenEndpointSigningAlg,
770 )
771 .expr_as(
772 Expr::col((
773 UpstreamOAuthProviders::Table,
774 UpstreamOAuthProviders::TokenEndpointAuthMethod,
775 )),
776 ProviderLookupIden::TokenEndpointAuthMethod,
777 )
778 .expr_as(
779 Expr::col((
780 UpstreamOAuthProviders::Table,
781 UpstreamOAuthProviders::IdTokenSignedResponseAlg,
782 )),
783 ProviderLookupIden::IdTokenSignedResponseAlg,
784 )
785 .expr_as(
786 Expr::col((
787 UpstreamOAuthProviders::Table,
788 UpstreamOAuthProviders::FetchUserinfo,
789 )),
790 ProviderLookupIden::FetchUserinfo,
791 )
792 .expr_as(
793 Expr::col((
794 UpstreamOAuthProviders::Table,
795 UpstreamOAuthProviders::UserinfoSignedResponseAlg,
796 )),
797 ProviderLookupIden::UserinfoSignedResponseAlg,
798 )
799 .expr_as(
800 Expr::col((
801 UpstreamOAuthProviders::Table,
802 UpstreamOAuthProviders::CreatedAt,
803 )),
804 ProviderLookupIden::CreatedAt,
805 )
806 .expr_as(
807 Expr::col((
808 UpstreamOAuthProviders::Table,
809 UpstreamOAuthProviders::DisabledAt,
810 )),
811 ProviderLookupIden::DisabledAt,
812 )
813 .expr_as(
814 Expr::col((
815 UpstreamOAuthProviders::Table,
816 UpstreamOAuthProviders::ClaimsImports,
817 )),
818 ProviderLookupIden::ClaimsImports,
819 )
820 .expr_as(
821 Expr::col((
822 UpstreamOAuthProviders::Table,
823 UpstreamOAuthProviders::JwksUriOverride,
824 )),
825 ProviderLookupIden::JwksUriOverride,
826 )
827 .expr_as(
828 Expr::col((
829 UpstreamOAuthProviders::Table,
830 UpstreamOAuthProviders::TokenEndpointOverride,
831 )),
832 ProviderLookupIden::TokenEndpointOverride,
833 )
834 .expr_as(
835 Expr::col((
836 UpstreamOAuthProviders::Table,
837 UpstreamOAuthProviders::AuthorizationEndpointOverride,
838 )),
839 ProviderLookupIden::AuthorizationEndpointOverride,
840 )
841 .expr_as(
842 Expr::col((
843 UpstreamOAuthProviders::Table,
844 UpstreamOAuthProviders::UserinfoEndpointOverride,
845 )),
846 ProviderLookupIden::UserinfoEndpointOverride,
847 )
848 .expr_as(
849 Expr::col((
850 UpstreamOAuthProviders::Table,
851 UpstreamOAuthProviders::DiscoveryMode,
852 )),
853 ProviderLookupIden::DiscoveryMode,
854 )
855 .expr_as(
856 Expr::col((
857 UpstreamOAuthProviders::Table,
858 UpstreamOAuthProviders::PkceMode,
859 )),
860 ProviderLookupIden::PkceMode,
861 )
862 .expr_as(
863 Expr::col((
864 UpstreamOAuthProviders::Table,
865 UpstreamOAuthProviders::ResponseMode,
866 )),
867 ProviderLookupIden::ResponseMode,
868 )
869 .expr_as(
870 Expr::col((
871 UpstreamOAuthProviders::Table,
872 UpstreamOAuthProviders::AdditionalParameters,
873 )),
874 ProviderLookupIden::AdditionalParameters,
875 )
876 .expr_as(
877 Expr::col((
878 UpstreamOAuthProviders::Table,
879 UpstreamOAuthProviders::ForwardLoginHint,
880 )),
881 ProviderLookupIden::ForwardLoginHint,
882 )
883 .expr_as(
884 Expr::col((
885 UpstreamOAuthProviders::Table,
886 UpstreamOAuthProviders::OnBackchannelLogout,
887 )),
888 ProviderLookupIden::OnBackchannelLogout,
889 )
890 .expr_as(
891 Expr::col((
892 UpstreamOAuthProviders::Table,
893 UpstreamOAuthProviders::RegistrationTokenRequired,
894 )),
895 ProviderLookupIden::RegistrationTokenRequired,
896 )
897 .from(UpstreamOAuthProviders::Table)
898 .apply_filter(filter)
899 .generate_pagination(
900 (
901 UpstreamOAuthProviders::Table,
902 UpstreamOAuthProviders::UpstreamOAuthProviderId,
903 ),
904 pagination,
905 )
906 .build_sqlx(PostgresQueryBuilder);
907
908 let edges: Vec<ProviderLookup> = sqlx::query_as_with(&sql, arguments)
909 .traced()
910 .fetch_all(&mut *self.conn)
911 .await?;
912
913 let page = pagination
914 .process(edges)
915 .try_map(UpstreamOAuthProvider::try_from)?;
916
917 return Ok(page);
918 }
919
920 #[tracing::instrument(
921 name = "db.upstream_oauth_provider.count",
922 skip_all,
923 fields(
924 db.query.text,
925 ),
926 err,
927 )]
928 async fn count(
929 &mut self,
930 filter: UpstreamOAuthProviderFilter<'_>,
931 ) -> Result<usize, Self::Error> {
932 let (sql, arguments) = Query::select()
933 .expr(
934 Expr::col((
935 UpstreamOAuthProviders::Table,
936 UpstreamOAuthProviders::UpstreamOAuthProviderId,
937 ))
938 .count(),
939 )
940 .from(UpstreamOAuthProviders::Table)
941 .apply_filter(filter)
942 .build_sqlx(PostgresQueryBuilder);
943
944 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
945 .traced()
946 .fetch_one(&mut *self.conn)
947 .await?;
948
949 count
950 .try_into()
951 .map_err(DatabaseError::to_invalid_operation)
952 }
953
954 #[tracing::instrument(
955 name = "db.upstream_oauth_provider.all_enabled",
956 skip_all,
957 fields(
958 db.query.text,
959 ),
960 err,
961 )]
962 async fn all_enabled(&mut self) -> Result<Vec<UpstreamOAuthProvider>, Self::Error> {
963 let res = sqlx::query_as!(
964 ProviderLookup,
965 r#"
966 SELECT
967 upstream_oauth_provider_id,
968 issuer,
969 human_name,
970 brand_name,
971 scope,
972 client_id,
973 encrypted_client_secret,
974 token_endpoint_signing_alg,
975 token_endpoint_auth_method,
976 id_token_signed_response_alg,
977 fetch_userinfo,
978 userinfo_signed_response_alg,
979 created_at,
980 disabled_at,
981 claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
982 jwks_uri_override,
983 authorization_endpoint_override,
984 token_endpoint_override,
985 userinfo_endpoint_override,
986 discovery_mode,
987 pkce_mode,
988 response_mode,
989 additional_parameters as "additional_parameters: Json<Vec<(String, String)>>",
990 forward_login_hint,
991 on_backchannel_logout,
992 registration_token_required
993
994 FROM upstream_oauth_providers
995 WHERE disabled_at IS NULL
996 ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
997 "#,
998 )
999 .traced()
1000 .fetch_all(&mut *self.conn)
1001 .await?;
1002
1003 let res: Result<Vec<_>, _> = res.into_iter().map(TryInto::try_into).collect();
1004 Ok(res?)
1005 }
1006}