Skip to main content

mas_storage_pg/upstream_oauth2/
mod.rs

1// Copyright 2024, 2025 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the repositories
8//! related to the upstream OAuth 2.0 providers
9
10mod link;
11mod provider;
12mod session;
13
14pub use self::{
15    link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16    session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21    use chrono::Duration;
22    use mas_data_model::{
23        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderOnBackchannelLogout,
24        UpstreamOAuthProviderTokenAuthMethod, clock::MockClock,
25    };
26    use mas_iana::jose::JsonWebSignatureAlg;
27    use mas_storage::{
28        Pagination, RepositoryAccess,
29        upstream_oauth2::{
30            UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31            UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32            UpstreamOAuthSessionFilter, UpstreamOAuthSessionRepository,
33        },
34        user::UserRepository,
35    };
36    use oauth2_types::scope::{OPENID, Scope};
37    use rand::SeedableRng;
38    use sqlx::PgPool;
39
40    use crate::PgRepository;
41
42    #[sqlx::test(migrator = "crate::MIGRATOR")]
43    async fn test_repository(pool: PgPool) {
44        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45        let clock = MockClock::default();
46        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48        // The provider list should be empty at the start
49        let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50        assert!(all_providers.is_empty());
51
52        // Let's add a provider
53        let provider = repo
54            .upstream_oauth_provider()
55            .add(
56                &mut rng,
57                &clock,
58                UpstreamOAuthProviderParams {
59                    issuer: Some("https://example.com/".to_owned()),
60                    human_name: None,
61                    brand_name: None,
62                    scope: Scope::from_iter([OPENID]),
63                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65                    fetch_userinfo: false,
66                    userinfo_signed_response_alg: None,
67                    token_endpoint_signing_alg: None,
68                    client_id: "client-id".to_owned(),
69                    encrypted_client_secret: None,
70                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71                    token_endpoint_override: None,
72                    authorization_endpoint_override: None,
73                    userinfo_endpoint_override: None,
74                    jwks_uri_override: None,
75                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77                    response_mode: None,
78                    additional_authorization_parameters: Vec::new(),
79                    forward_login_hint: false,
80                    ui_order: 0,
81                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
82                    registration_token_required: false,
83                },
84            )
85            .await
86            .unwrap();
87
88        // Look it up in the database
89        let provider = repo
90            .upstream_oauth_provider()
91            .lookup(provider.id)
92            .await
93            .unwrap()
94            .expect("provider to be found in the database");
95        assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
96        assert_eq!(provider.client_id, "client-id");
97
98        // It should be in the list of all providers
99        let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
100        assert_eq!(providers.len(), 1);
101        assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
102        assert_eq!(providers[0].client_id, "client-id");
103
104        // Start a session
105        let session = repo
106            .upstream_oauth_session()
107            .add(
108                &mut rng,
109                &clock,
110                &provider,
111                "some-state".to_owned(),
112                None,
113                Some("some-nonce".to_owned()),
114            )
115            .await
116            .unwrap();
117
118        // Look it up in the database
119        let session = repo
120            .upstream_oauth_session()
121            .lookup(session.id)
122            .await
123            .unwrap()
124            .expect("session to be found in the database");
125        assert_eq!(session.provider_id, provider.id);
126        assert_eq!(session.link_id(), None);
127        assert!(session.is_pending());
128        assert!(!session.is_completed());
129        assert!(!session.is_consumed());
130
131        // Create a link
132        let link = repo
133            .upstream_oauth_link()
134            .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
135            .await
136            .unwrap();
137
138        // We can look it up by its ID
139        repo.upstream_oauth_link()
140            .lookup(link.id)
141            .await
142            .unwrap()
143            .expect("link to be found in database");
144
145        // or by its subject
146        let link = repo
147            .upstream_oauth_link()
148            .find_by_subject(&provider, "a-subject")
149            .await
150            .unwrap()
151            .expect("link to be found in database");
152        assert_eq!(link.subject, "a-subject");
153        assert_eq!(link.provider_id, provider.id);
154
155        let session = repo
156            .upstream_oauth_session()
157            .complete_with_link(&clock, session, &link, None, None, None, None)
158            .await
159            .unwrap();
160        // Reload the session
161        let session = repo
162            .upstream_oauth_session()
163            .lookup(session.id)
164            .await
165            .unwrap()
166            .expect("session to be found in the database");
167        assert!(session.is_completed());
168        assert!(!session.is_consumed());
169        assert_eq!(session.link_id(), Some(link.id));
170
171        // We need to create a user and start a browser session to consume the session
172        let user = repo
173            .user()
174            .add(&mut rng, &clock, "john".to_owned())
175            .await
176            .unwrap();
177        let browser_session = repo
178            .browser_session()
179            .add(&mut rng, &clock, &user, None)
180            .await
181            .unwrap();
182
183        let session = repo
184            .upstream_oauth_session()
185            .consume(&clock, session, &browser_session)
186            .await
187            .unwrap();
188
189        // Reload the session
190        let session = repo
191            .upstream_oauth_session()
192            .lookup(session.id)
193            .await
194            .unwrap()
195            .expect("session to be found in the database");
196        assert!(session.is_consumed());
197
198        repo.upstream_oauth_link()
199            .associate_to_user(&link, &user)
200            .await
201            .unwrap();
202
203        // XXX: we should also try other combinations of the filter
204        let filter = UpstreamOAuthLinkFilter::new()
205            .for_user(&user)
206            .for_provider(&provider)
207            .for_subject("a-subject")
208            .enabled_providers_only();
209
210        let links = repo
211            .upstream_oauth_link()
212            .list(filter, Pagination::first(10))
213            .await
214            .unwrap();
215        assert!(!links.has_previous_page);
216        assert!(!links.has_next_page);
217        assert_eq!(links.edges.len(), 1);
218        assert_eq!(links.edges[0].node.id, link.id);
219        assert_eq!(links.edges[0].node.user_id, Some(user.id));
220
221        assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
222
223        // There should be exactly one enabled provider
224        assert_eq!(
225            repo.upstream_oauth_provider()
226                .count(UpstreamOAuthProviderFilter::new())
227                .await
228                .unwrap(),
229            1
230        );
231        assert_eq!(
232            repo.upstream_oauth_provider()
233                .count(UpstreamOAuthProviderFilter::new().enabled_only())
234                .await
235                .unwrap(),
236            1
237        );
238        assert_eq!(
239            repo.upstream_oauth_provider()
240                .count(UpstreamOAuthProviderFilter::new().disabled_only())
241                .await
242                .unwrap(),
243            0
244        );
245
246        // Disable the provider
247        repo.upstream_oauth_provider()
248            .disable(&clock, provider.clone())
249            .await
250            .unwrap();
251
252        // There should be exactly one disabled provider
253        assert_eq!(
254            repo.upstream_oauth_provider()
255                .count(UpstreamOAuthProviderFilter::new())
256                .await
257                .unwrap(),
258            1
259        );
260        assert_eq!(
261            repo.upstream_oauth_provider()
262                .count(UpstreamOAuthProviderFilter::new().enabled_only())
263                .await
264                .unwrap(),
265            0
266        );
267        assert_eq!(
268            repo.upstream_oauth_provider()
269                .count(UpstreamOAuthProviderFilter::new().disabled_only())
270                .await
271                .unwrap(),
272            1
273        );
274
275        // Test listing and counting sessions
276        let session_filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
277
278        // Count the sessions for the provider
279        let session_count = repo
280            .upstream_oauth_session()
281            .count(session_filter)
282            .await
283            .unwrap();
284        assert_eq!(session_count, 1);
285
286        // List the sessions for the provider
287        let session_page = repo
288            .upstream_oauth_session()
289            .list(session_filter, Pagination::first(10))
290            .await
291            .unwrap();
292
293        assert_eq!(session_page.edges.len(), 1);
294        assert_eq!(session_page.edges[0].node.id, session.id);
295        assert!(!session_page.has_next_page);
296        assert!(!session_page.has_previous_page);
297
298        // Try deleting the provider
299        repo.upstream_oauth_provider()
300            .delete(provider)
301            .await
302            .unwrap();
303        assert_eq!(
304            repo.upstream_oauth_provider()
305                .count(UpstreamOAuthProviderFilter::new())
306                .await
307                .unwrap(),
308            0
309        );
310    }
311
312    /// Test that the pagination works as expected in the upstream OAuth
313    /// provider repository
314    #[sqlx::test(migrator = "crate::MIGRATOR")]
315    async fn test_provider_repository_pagination(pool: PgPool) {
316        let scope = Scope::from_iter([OPENID]);
317
318        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
319        let clock = MockClock::default();
320        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
321
322        let filter = UpstreamOAuthProviderFilter::new();
323
324        // Count the number of providers before we start
325        assert_eq!(
326            repo.upstream_oauth_provider().count(filter).await.unwrap(),
327            0
328        );
329
330        let mut ids = Vec::with_capacity(20);
331        // Create 20 providers
332        for idx in 0..20 {
333            let client_id = format!("client-{idx}");
334            let provider = repo
335                .upstream_oauth_provider()
336                .add(
337                    &mut rng,
338                    &clock,
339                    UpstreamOAuthProviderParams {
340                        issuer: None,
341                        human_name: None,
342                        brand_name: None,
343                        scope: scope.clone(),
344                        token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
345                        fetch_userinfo: false,
346                        userinfo_signed_response_alg: None,
347                        token_endpoint_signing_alg: None,
348                        id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
349                        client_id,
350                        encrypted_client_secret: None,
351                        claims_imports: UpstreamOAuthProviderClaimsImports::default(),
352                        token_endpoint_override: None,
353                        authorization_endpoint_override: None,
354                        userinfo_endpoint_override: None,
355                        jwks_uri_override: None,
356                        discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
357                        pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
358                        response_mode: None,
359                        additional_authorization_parameters: Vec::new(),
360                        forward_login_hint: false,
361                        ui_order: 0,
362                        on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
363                        registration_token_required: false,
364                    },
365                )
366                .await
367                .unwrap();
368            ids.push(provider.id);
369            clock.advance(Duration::microseconds(10 * 1000 * 1000));
370        }
371
372        // Now we have 20 providers
373        assert_eq!(
374            repo.upstream_oauth_provider().count(filter).await.unwrap(),
375            20
376        );
377
378        // Lookup the first 10 items
379        let page = repo
380            .upstream_oauth_provider()
381            .list(filter, Pagination::first(10))
382            .await
383            .unwrap();
384
385        // It returned the first 10 items
386        assert!(page.has_next_page);
387        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
388        assert_eq!(&edge_ids, &ids[..10]);
389
390        // Getting the same page with the "enabled only" filter should return the same
391        // results
392        let other_page = repo
393            .upstream_oauth_provider()
394            .list(filter.enabled_only(), Pagination::first(10))
395            .await
396            .unwrap();
397
398        assert_eq!(page, other_page);
399
400        // Lookup the next 10 items
401        let page = repo
402            .upstream_oauth_provider()
403            .list(filter, Pagination::first(10).after(ids[9]))
404            .await
405            .unwrap();
406
407        // It returned the next 10 items
408        assert!(!page.has_next_page);
409        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
410        assert_eq!(&edge_ids, &ids[10..]);
411
412        // Lookup the last 10 items
413        let page = repo
414            .upstream_oauth_provider()
415            .list(filter, Pagination::last(10))
416            .await
417            .unwrap();
418
419        // It returned the last 10 items
420        assert!(page.has_previous_page);
421        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
422        assert_eq!(&edge_ids, &ids[10..]);
423
424        // Lookup the previous 10 items
425        let page = repo
426            .upstream_oauth_provider()
427            .list(filter, Pagination::last(10).before(ids[10]))
428            .await
429            .unwrap();
430
431        // It returned the previous 10 items
432        assert!(!page.has_previous_page);
433        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
434        assert_eq!(&edge_ids, &ids[..10]);
435
436        // Lookup 10 items between two IDs
437        let page = repo
438            .upstream_oauth_provider()
439            .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
440            .await
441            .unwrap();
442
443        // It returned the items in between
444        assert!(!page.has_next_page);
445        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.node.id).collect();
446        assert_eq!(&edge_ids, &ids[6..8]);
447
448        // There should not be any disabled providers
449        assert!(
450            repo.upstream_oauth_provider()
451                .list(
452                    UpstreamOAuthProviderFilter::new().disabled_only(),
453                    Pagination::first(1)
454                )
455                .await
456                .unwrap()
457                .edges
458                .is_empty()
459        );
460    }
461
462    /// Test that the pagination works as expected in the upstream OAuth
463    /// session repository
464    #[sqlx::test(migrator = "crate::MIGRATOR")]
465    async fn test_session_repository_pagination(pool: PgPool) {
466        let scope = Scope::from_iter([OPENID]);
467
468        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
469        let clock = MockClock::default();
470        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
471
472        // Create a provider
473        let provider = repo
474            .upstream_oauth_provider()
475            .add(
476                &mut rng,
477                &clock,
478                UpstreamOAuthProviderParams {
479                    issuer: Some("https://example.com/".to_owned()),
480                    human_name: None,
481                    brand_name: None,
482                    scope,
483                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
484                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
485                    fetch_userinfo: false,
486                    userinfo_signed_response_alg: None,
487                    token_endpoint_signing_alg: None,
488                    client_id: "client-id".to_owned(),
489                    encrypted_client_secret: None,
490                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
491                    token_endpoint_override: None,
492                    authorization_endpoint_override: None,
493                    userinfo_endpoint_override: None,
494                    jwks_uri_override: None,
495                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
496                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
497                    response_mode: None,
498                    additional_authorization_parameters: Vec::new(),
499                    forward_login_hint: false,
500                    ui_order: 0,
501                    on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
502                    registration_token_required: false,
503                },
504            )
505            .await
506            .unwrap();
507
508        let filter = UpstreamOAuthSessionFilter::new().for_provider(&provider);
509
510        // Count the number of sessions before we start
511        assert_eq!(
512            repo.upstream_oauth_session().count(filter).await.unwrap(),
513            0
514        );
515
516        let mut links = Vec::with_capacity(3);
517        for subject in ["alice", "bob", "charlie"] {
518            let link = repo
519                .upstream_oauth_link()
520                .add(&mut rng, &clock, &provider, subject.to_owned(), None)
521                .await
522                .unwrap();
523            links.push(link);
524        }
525
526        let mut ids = Vec::with_capacity(20);
527        let sids = ["one", "two"].into_iter().cycle();
528        // Create 20 sessions
529        for (idx, (link, sid)) in links.iter().cycle().zip(sids).enumerate().take(20) {
530            let state = format!("state-{idx}");
531            let session = repo
532                .upstream_oauth_session()
533                .add(&mut rng, &clock, &provider, state, None, None)
534                .await
535                .unwrap();
536            let id_token_claims = serde_json::json!({
537                "sub": link.subject,
538                "sid": sid,
539                "aud": provider.client_id,
540                "iss": "https://example.com/",
541            });
542            let session = repo
543                .upstream_oauth_session()
544                .complete_with_link(
545                    &clock,
546                    session,
547                    link,
548                    None,
549                    Some(id_token_claims),
550                    None,
551                    None,
552                )
553                .await
554                .unwrap();
555            ids.push(session.id);
556            clock.advance(Duration::microseconds(10 * 1000 * 1000));
557        }
558
559        // Now we have 20 sessions
560        assert_eq!(
561            repo.upstream_oauth_session().count(filter).await.unwrap(),
562            20
563        );
564
565        // Lookup the first 10 items
566        let page = repo
567            .upstream_oauth_session()
568            .list(filter, Pagination::first(10))
569            .await
570            .unwrap();
571
572        // It returned the first 10 items
573        assert!(page.has_next_page);
574        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
575        assert_eq!(&edge_ids, &ids[..10]);
576
577        // Lookup the next 10 items
578        let page = repo
579            .upstream_oauth_session()
580            .list(filter, Pagination::first(10).after(ids[9]))
581            .await
582            .unwrap();
583
584        // It returned the next 10 items
585        assert!(!page.has_next_page);
586        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
587        assert_eq!(&edge_ids, &ids[10..]);
588
589        // Lookup the last 10 items
590        let page = repo
591            .upstream_oauth_session()
592            .list(filter, Pagination::last(10))
593            .await
594            .unwrap();
595
596        // It returned the last 10 items
597        assert!(page.has_previous_page);
598        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
599        assert_eq!(&edge_ids, &ids[10..]);
600
601        // Lookup the previous 10 items
602        let page = repo
603            .upstream_oauth_session()
604            .list(filter, Pagination::last(10).before(ids[10]))
605            .await
606            .unwrap();
607
608        // It returned the previous 10 items
609        assert!(!page.has_previous_page);
610        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
611        assert_eq!(&edge_ids, &ids[..10]);
612
613        // Lookup 5 items between two IDs
614        let page = repo
615            .upstream_oauth_session()
616            .list(filter, Pagination::first(10).after(ids[5]).before(ids[11]))
617            .await
618            .unwrap();
619
620        // It returned the items in between
621        assert!(!page.has_next_page);
622        let edge_ids: Vec<_> = page.edges.iter().map(|s| s.node.id).collect();
623        assert_eq!(&edge_ids, &ids[6..11]);
624
625        // Check the sub/sid filters
626        assert_eq!(
627            repo.upstream_oauth_session()
628                .count(filter.with_sub_claim("alice").with_sid_claim("one"))
629                .await
630                .unwrap(),
631            4
632        );
633        assert_eq!(
634            repo.upstream_oauth_session()
635                .count(filter.with_sub_claim("bob").with_sid_claim("two"))
636                .await
637                .unwrap(),
638            4
639        );
640
641        let page = repo
642            .upstream_oauth_session()
643            .list(
644                filter.with_sub_claim("alice").with_sid_claim("one"),
645                Pagination::first(10),
646            )
647            .await
648            .unwrap();
649        assert_eq!(page.edges.len(), 4);
650        for edge in page.edges {
651            assert_eq!(
652                edge.node
653                    .id_token_claims()
654                    .unwrap()
655                    .get("sub")
656                    .unwrap()
657                    .as_str(),
658                Some("alice")
659            );
660            assert_eq!(
661                edge.node
662                    .id_token_claims()
663                    .unwrap()
664                    .get("sid")
665                    .unwrap()
666                    .as_str(),
667                Some("one")
668            );
669        }
670    }
671}