1mod link;
11mod provider;
12mod session;
13
14pub use self::{
15 link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16 session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21 use chrono::Duration;
22 use mas_data_model::{
23 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24 UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
25 };
26 use mas_iana::jose::JsonWebSignatureAlg;
27 use mas_storage::{
28 Pagination, RepositoryAccess,
29 upstream_oauth2::{
30 UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31 UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32 UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
33 },
34 user::UserRepository,
35 };
36 use oauth2_types::scope::{OPENID, Scope};
37 use rand::SeedableRng;
38 use sqlx::PgPool;
39
40 use crate::PgRepository;
41
42 #[sqlx::test(migrator = "crate::MIGRATOR")]
43 async fn test_repository(pool: PgPool) {
44 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45 let clock = MockClock::default();
46 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48 let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50 assert!(all_providers.is_empty());
51
52 let provider = repo
54 .upstream_oauth_provider()
55 .add(
56 &mut rng,
57 &clock,
58 UpstreamOAuthProviderParams {
59 issuer: Some("https://example.com/".to_owned()),
60 human_name: None,
61 brand_name: None,
62 scope: Scope::from_iter([OPENID]),
63 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65 fetch_userinfo: false,
66 userinfo_signed_response_alg: None,
67 token_endpoint_signing_alg: None,
68 client_id: "client-id".to_owned(),
69 encrypted_client_secret: None,
70 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71 token_endpoint_override: None,
72 authorization_endpoint_override: None,
73 userinfo_endpoint_override: None,
74 jwks_uri_override: None,
75 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77 response_mode: None,
78 additional_authorization_parameters: Vec::new(),
79 forward_login_hint: false,
80 ui_order: 0,
81 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
82 registration_token_required: false,
83 },
84 )
85 .await
86 .unwrap();
87
88 let provider = repo
90 .upstream_oauth_provider()
91 .lookup(provider.id)
92 .await
93 .unwrap()
94 .expect("provider to be found in the database");
95 assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
96 assert_eq!(provider.client_id, "client-id");
97
98 let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
100 assert_eq!(providers.len(), 1);
101 assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
102 assert_eq!(providers[0].client_id, "client-id");
103
104 let session = repo
106 .upstream_oauth_session()
107 .add(
108 &mut rng,
109 &clock,
110 &provider,
111 "some-state".to_owned(),
112 None,
113 Some("some-nonce".to_owned()),
114 )
115 .await
116 .unwrap();
117
118 let session = repo
120 .upstream_oauth_session()
121 .lookup(session.id)
122 .await
123 .unwrap()
124 .expect("session to be found in the database");
125 assert_eq!(session.provider_id, provider.id);
126 assert_eq!(session.link_id(), None);
127 assert!(session.is_pending());
128 assert!(!session.is_completed());
129 assert!(!session.is_consumed());
130
131 let link = repo
133 .upstream_oauth_link()
134 .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
135 .await
136 .unwrap();
137
138 repo.upstream_oauth_link()
140 .lookup(link.id)
141 .await
142 .unwrap()
143 .expect("link to be found in database");
144
145 let link = repo
147 .upstream_oauth_link()
148 .find_by_subject(&provider, "a-subject")
149 .await
150 .unwrap()
151 .expect("link to be found in database");
152 assert_eq!(link.subject, "a-subject");
153 assert_eq!(link.provider_id, provider.id);
154
155 let session = repo
156 .upstream_oauth_session()
157 .complete_with_link(&clock, session, &link, None, None, None, None)
158 .await
159 .unwrap();
160 let session = repo
162 .upstream_oauth_session()
163 .lookup(session.id)
164 .await
165 .unwrap()
166 .expect("session to be found in the database");
167 assert!(session.is_completed());
168 assert!(!session.is_consumed());
169 assert_eq!(session.link_id(), Some(link.id));
170
171 let user = repo
173 .user()
174 .add(&mut rng, &clock, "john".to_owned())
175 .await
176 .unwrap();
177 let browser_session = repo
178 .browser_session()
179 .add(&mut rng, &clock, &user, None)
180 .await
181 .unwrap();
182
183 let session = repo
184 .upstream_oauth_session()
185 .consume(&clock, session, &browser_session)
186 .await
187 .unwrap();
188
189 let session = repo
191 .upstream_oauth_session()
192 .lookup(session.id)
193 .await
194 .unwrap()
195 .expect("session to be found in the database");
196 assert!(session.is_consumed());
197
198 repo.upstream_oauth_link()
199 .associate_to_user(&link, &user)
200 .await
201 .unwrap();
202
203 let filter = UpstreamOAuthLinkFilter::new()
205 .for_user(&user)
206 .for_provider(&provider)
207 .for_subject("a-subject")
208 .enabled_providers_only();
209
210 let links = repo
211 .upstream_oauth_link()
212 .list(filter, Pagination::first(10))
213 .await
214 .unwrap();
215 assert!(!links.has_previous_page);
216 assert!(!links.has_next_page);
217 assert_eq!(links.edges.len(), 1);
218 assert_eq!(links.edges[0].node.id, link.id);
219 assert_eq!(links.edges[0].node.user_id, Some(user.id));
220
221 assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
222
223 assert_eq!(
225 repo.upstream_oauth_provider()
226 .count(UpstreamOAuthProviderFilter::new())
227 .await
228 .unwrap(),
229 1
230 );
231 assert_eq!(
232 repo.upstream_oauth_provider()
233 .count(UpstreamOAuthProviderFilter::new().enabled_only())
234 .await
235 .unwrap(),
236 1
237 );
238 assert_eq!(
239 repo.upstream_oauth_provider()
240 .count(UpstreamOAuthProviderFilter::new().disabled_only())
241 .await
242 .unwrap(),
243 0
244 );
245
246 repo.upstream_oauth_provider()
248 .disable(&clock, provider.clone())
249 .await
250 .unwrap();
251
252 assert_eq!(
254 repo.upstream_oauth_provider()
255 .count(UpstreamOAuthProviderFilter::new())
256 .await
257 .unwrap(),
258 1
259 );
260 assert_eq!(
261 repo.upstream_oauth_provider()
262 .count(UpstreamOAuthProviderFilter::new().enabled_only())
263 .await
264 .unwrap(),
265 0
266 );
267 assert_eq!(
268 repo.upstream_oauth_provider()
269 .count(UpstreamOAuthProviderFilter::new().disabled_only())
270 .await
271 .unwrap(),
272 1
273 );
274
275 let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
277
278 let session_count = repo
280 .upstream_oauth_session()
281 .count(session_filter)
282 .await
283 .unwrap();
284 assert_eq!(session_count, 1);
285
286 let session_page = repo
288 .upstream_oauth_session()
289 .list(session_filter, Pagination::first(10))
290 .await
291 .unwrap();
292
293 assert_eq!(session_page.edges.len(), 1);
294 assert_eq!(session_page.edges[0].node.id, session.id);
295 assert!(!session_page.has_next_page);
296 assert!(!session_page.has_previous_page);
297
298 repo.upstream_oauth_provider()
300 .delete(provider)
301 .await
302 .unwrap();
303 assert_eq!(
304 repo.upstream_oauth_provider()
305 .count(UpstreamOAuthProviderFilter::new())
306 .await
307 .unwrap(),
308 0
309 );
310 }
311
312 #[sqlx::test(migrator = "crate::MIGRATOR")]
315 async fn test_provider_repository_pagination(pool: PgPool) {
316 let scope = Scope::from_iter([OPENID]);
317
318 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
319 let clock = MockClock::default();
320 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
321
322 let filter = UpstreamOAuthProviderFilter::new();
323
324 assert_eq!(
326 repo.upstream_oauth_provider().count(filter).await.unwrap(),
327 0
328 );
329
330 let mut ids = Vec::with_capacity(20);
331 for idx in 0..20 {
333 let client_id = format!("client-{idx}");
334 let provider = repo
335 .upstream_oauth_provider()
336 .add(
337 &mut rng,
338 &clock,
339 UpstreamOAuthProviderParams {
340 issuer: None,
341 human_name: None,
342 brand_name: None,
343 scope: scope.clone(),
344 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
345 fetch_userinfo: false,
346 userinfo_signed_response_alg: None,
347 token_endpoint_signing_alg: None,
348 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
349 client_id,
350 encrypted_client_secret: None,
351 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
352 token_endpoint_override: None,
353 authorization_endpoint_override: None,
354 userinfo_endpoint_override: None,
355 jwks_uri_override: None,
356 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
357 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
358 response_mode: None,
359 additional_authorization_parameters: Vec::new(),
360 forward_login_hint: false,
361 ui_order: 0,
362 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
363 registration_token_required: false,
364 },
365 )
366 .await
367 .unwrap();
368 ids.push(provider.id);
369 clock.advance(Duration::microseconds(10 * 1000 * 1000));
370 }
371
372 assert_eq!(
374 repo.upstream_oauth_provider().count(filter).await.unwrap(),
375 20
376 );
377
378 let page = repo
380 .upstream_oauth_provider()
381 .list(filter, Pagination::first(10))
382 .await
383 .unwrap();
384
385 assert!(page.has_next_page);
387 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
388 assert_eq!(&edge_ids, &ids[..10]);
389
390 let other_page = repo
393 .upstream_oauth_provider()
394 .list(filter.enabled_only(), Pagination::first(10))
395 .await
396 .unwrap();
397
398 assert_eq!(page, other_page);
399
400 let page = repo
402 .upstream_oauth_provider()
403 .list(filter, Pagination::first(10).after(ids[9]))
404 .await
405 .unwrap();
406
407 assert!(!page.has_next_page);
409 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
410 assert_eq!(&edge_ids, &ids[10..]);
411
412 let page = repo
414 .upstream_oauth_provider()
415 .list(filter, Pagination::last(10))
416 .await
417 .unwrap();
418
419 assert!(page.has_previous_page);
421 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
422 assert_eq!(&edge_ids, &ids[10..]);
423
424 let page = repo
426 .upstream_oauth_provider()
427 .list(filter, Pagination::last(10).before(ids[10]))
428 .await
429 .unwrap();
430
431 assert!(!page.has_previous_page);
433 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
434 assert_eq!(&edge_ids, &ids[..10]);
435
436 let page = repo
438 .upstream_oauth_provider()
439 .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
440 .await
441 .unwrap();
442
443 assert!(!page.has_next_page);
445 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
446 assert_eq!(&edge_ids, &ids[6..8]);
447
448 assert!(
450 repo.upstream_oauth_provider()
451 .list(
452 UpstreamOAuthProviderFilter::new().disabled_only(),
453 Pagination::first(1)
454 )
455 .await
456 .unwrap()
457 .edges
458 .is_empty()
459 );
460 }
461
462 #[sqlx::test(migrator = "crate::MIGRATOR")]
465 async fn test_session_repository_pagination(pool: PgPool) {
466 let scope = Scope::from_iter([OPENID]);
467
468 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
469 let clock = MockClock::default();
470 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
471
472 let provider = repo
474 .upstream_oauth_provider()
475 .add(
476 &mut rng,
477 &clock,
478 UpstreamOAuthProviderParams {
479 issuer: Some("https://example.com/".to_owned()),
480 human_name: None,
481 brand_name: None,
482 scope,
483 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
484 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
485 fetch_userinfo: false,
486 userinfo_signed_response_alg: None,
487 token_endpoint_signing_alg: None,
488 client_id: "client-id".to_owned(),
489 encrypted_client_secret: None,
490 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
491 token_endpoint_override: None,
492 authorization_endpoint_override: None,
493 userinfo_endpoint_override: None,
494 jwks_uri_override: None,
495 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
496 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
497 response_mode: None,
498 additional_authorization_parameters: Vec::new(),
499 forward_login_hint: false,
500 ui_order: 0,
501 on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
502 registration_token_required: false,
503 },
504 )
505 .await
506 .unwrap();
507
508 let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
509
510 assert_eq!(
512 repo.upstream_oauth_session().count(filter).await.unwrap(),
513 0
514 );
515
516 let mut links = Vec::with_capacity(3);
517 for subject in ["alice", "bob", "charlie"] {
518 let link = repo
519 .upstream_oauth_link()
520 .add(&mut rng, &clock, &provider, subject.to_owned(), None)
521 .await
522 .unwrap();
523 links.push(link);
524 }
525
526 let mut ids = Vec::with_capacity(20);
527 let sids = ["one", "two"].into_iter().cycle();
528 for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
530 let state = format!("state-{idx}");
531 let session = repo
532 .upstream_oauth_session()
533 .add(&mut rng, &clock, &provider, state, None, None)
534 .await
535 .unwrap();
536 let id_token_claims = serde_json::json!({
537 "sub": link.subject,
538 "sid": sid,
539 "aud": provider.client_id,
540 "iss": "https://example.com/",
541 });
542 let session = repo
543 .upstream_oauth_session()
544 .complete_with_link(
545 &clock,
546 session,
547 link,
548 None,
549 Some(id_token_claims),
550 None,
551 None,
552 )
553 .await
554 .unwrap();
555 ids.push(session.id);
556 clock.advance(Duration::microseconds(10 * 1000 * 1000));
557 }
558
559 assert_eq!(
561 repo.upstream_oauth_session().count(filter).await.unwrap(),
562 20
563 );
564
565 let page = repo
567 .upstream_oauth_session()
568 .list(filter, Pagination::first(10))
569 .await
570 .unwrap();
571
572 assert!(page.has_next_page);
574 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
575 assert_eq!(&edge_ids, &ids[..10]);
576
577 let page = repo
579 .upstream_oauth_session()
580 .list(filter, Pagination::first(10).after(ids[9]))
581 .await
582 .unwrap();
583
584 assert!(!page.has_next_page);
586 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
587 assert_eq!(&edge_ids, &ids[10..]);
588
589 let page = repo
591 .upstream_oauth_session()
592 .list(filter, Pagination::last(10))
593 .await
594 .unwrap();
595
596 assert!(page.has_previous_page);
598 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
599 assert_eq!(&edge_ids, &ids[10..]);
600
601 let page = repo
603 .upstream_oauth_session()
604 .list(filter, Pagination::last(10).before(ids[10]))
605 .await
606 .unwrap();
607
608 assert!(!page.has_previous_page);
610 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
611 assert_eq!(&edge_ids, &ids[..10]);
612
613 let page = repo
615 .upstream_oauth_session()
616 .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
617 .await
618 .unwrap();
619
620 assert!(!page.has_next_page);
622 let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
623 assert_eq!(&edge_ids, &ids[6..11]);
624
625 assert_eq!(
627 repo.upstream_oauth_session()
628 .count(filter.with_sub_claim("alice").with_sid_claim("one"))
629 .await
630 .unwrap(),
631 4
632 );
633 assert_eq!(
634 repo.upstream_oauth_session()
635 .count(filter.with_sub_claim("bob").with_sid_claim("two"))
636 .await
637 .unwrap(),
638 4
639 );
640
641 let page = repo
642 .upstream_oauth_session()
643 .list(
644 filter.with_sub_claim("alice").with_sid_claim("one"),
645 Pagination::first(10),
646 )
647 .await
648 .unwrap();
649 assert_eq!(page.edges.len(), 4);
650 for edge in page.edges {
651 assert_eq!(
652 edge.node
653 .id_token_claims()
654 .unwrap()
655 .get("sub")
656 .unwrap()
657 .as_str(),
658 Some("alice")
659 );
660 assert_eq!(
661 edge.node
662 .id_token_claims()
663 .unwrap()
664 .get("sid")
665 .unwrap()
666 .as_str(),
667 Some("one")
668 );
669 }
670 }
671}