mas_storage_pg/upstream_oauth2/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing the PostgreSQL implementation of the repositories
8//! related to the upstream OAuth 2.0 providers
9
10mod link;
11mod provider;
12mod session;
13
14pub use self::{
15    link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16    session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21    use chrono::Duration;
22    use mas_data_model::{
23        UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod,
24    };
25    use mas_iana::jose::JsonWebSignatureAlg;
26    use mas_storage::{
27        Pagination, RepositoryAccess,
28        clock::MockClock,
29        upstream_oauth2::{
30            UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31            UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32            UpstreamOAuthSessionRepository,
33        },
34        user::UserRepository,
35    };
36    use oauth2_types::scope::{OPENID, Scope};
37    use rand::SeedableRng;
38    use sqlx::PgPool;
39
40    use crate::PgRepository;
41
42    #[sqlx::test(migrator = "crate::MIGRATOR")]
43    async fn test_repository(pool: PgPool) {
44        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45        let clock = MockClock::default();
46        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48        // The provider list should be empty at the start
49        let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50        assert!(all_providers.is_empty());
51
52        // Let's add a provider
53        let provider = repo
54            .upstream_oauth_provider()
55            .add(
56                &mut rng,
57                &clock,
58                UpstreamOAuthProviderParams {
59                    issuer: Some("https://example.com/".to_owned()),
60                    human_name: None,
61                    brand_name: None,
62                    scope: Scope::from_iter([OPENID]),
63                    token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64                    id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65                    fetch_userinfo: false,
66                    userinfo_signed_response_alg: None,
67                    token_endpoint_signing_alg: None,
68                    client_id: "client-id".to_owned(),
69                    encrypted_client_secret: None,
70                    claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71                    token_endpoint_override: None,
72                    authorization_endpoint_override: None,
73                    userinfo_endpoint_override: None,
74                    jwks_uri_override: None,
75                    discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76                    pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77                    response_mode: None,
78                    additional_authorization_parameters: Vec::new(),
79                    forward_login_hint: false,
80                    ui_order: 0,
81                },
82            )
83            .await
84            .unwrap();
85
86        // Look it up in the database
87        let provider = repo
88            .upstream_oauth_provider()
89            .lookup(provider.id)
90            .await
91            .unwrap()
92            .expect("provider to be found in the database");
93        assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
94        assert_eq!(provider.client_id, "client-id");
95
96        // It should be in the list of all providers
97        let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
98        assert_eq!(providers.len(), 1);
99        assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
100        assert_eq!(providers[0].client_id, "client-id");
101
102        // Start a session
103        let session = repo
104            .upstream_oauth_session()
105            .add(
106                &mut rng,
107                &clock,
108                &provider,
109                "some-state".to_owned(),
110                None,
111                Some("some-nonce".to_owned()),
112            )
113            .await
114            .unwrap();
115
116        // Look it up in the database
117        let session = repo
118            .upstream_oauth_session()
119            .lookup(session.id)
120            .await
121            .unwrap()
122            .expect("session to be found in the database");
123        assert_eq!(session.provider_id, provider.id);
124        assert_eq!(session.link_id(), None);
125        assert!(session.is_pending());
126        assert!(!session.is_completed());
127        assert!(!session.is_consumed());
128
129        // Create a link
130        let link = repo
131            .upstream_oauth_link()
132            .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
133            .await
134            .unwrap();
135
136        // We can look it up by its ID
137        repo.upstream_oauth_link()
138            .lookup(link.id)
139            .await
140            .unwrap()
141            .expect("link to be found in database");
142
143        // or by its subject
144        let link = repo
145            .upstream_oauth_link()
146            .find_by_subject(&provider, "a-subject")
147            .await
148            .unwrap()
149            .expect("link to be found in database");
150        assert_eq!(link.subject, "a-subject");
151        assert_eq!(link.provider_id, provider.id);
152
153        let session = repo
154            .upstream_oauth_session()
155            .complete_with_link(&clock, session, &link, None, None, None)
156            .await
157            .unwrap();
158        // Reload the session
159        let session = repo
160            .upstream_oauth_session()
161            .lookup(session.id)
162            .await
163            .unwrap()
164            .expect("session to be found in the database");
165        assert!(session.is_completed());
166        assert!(!session.is_consumed());
167        assert_eq!(session.link_id(), Some(link.id));
168
169        let session = repo
170            .upstream_oauth_session()
171            .consume(&clock, session)
172            .await
173            .unwrap();
174        // Reload the session
175        let session = repo
176            .upstream_oauth_session()
177            .lookup(session.id)
178            .await
179            .unwrap()
180            .expect("session to be found in the database");
181        assert!(session.is_consumed());
182
183        let user = repo
184            .user()
185            .add(&mut rng, &clock, "john".to_owned())
186            .await
187            .unwrap();
188        repo.upstream_oauth_link()
189            .associate_to_user(&link, &user)
190            .await
191            .unwrap();
192
193        // XXX: we should also try other combinations of the filter
194        let filter = UpstreamOAuthLinkFilter::new()
195            .for_user(&user)
196            .for_provider(&provider)
197            .for_subject("a-subject")
198            .enabled_providers_only();
199
200        let links = repo
201            .upstream_oauth_link()
202            .list(filter, Pagination::first(10))
203            .await
204            .unwrap();
205        assert!(!links.has_previous_page);
206        assert!(!links.has_next_page);
207        assert_eq!(links.edges.len(), 1);
208        assert_eq!(links.edges[0].id, link.id);
209        assert_eq!(links.edges[0].user_id, Some(user.id));
210
211        assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
212
213        // There should be exactly one enabled provider
214        assert_eq!(
215            repo.upstream_oauth_provider()
216                .count(UpstreamOAuthProviderFilter::new())
217                .await
218                .unwrap(),
219            1
220        );
221        assert_eq!(
222            repo.upstream_oauth_provider()
223                .count(UpstreamOAuthProviderFilter::new().enabled_only())
224                .await
225                .unwrap(),
226            1
227        );
228        assert_eq!(
229            repo.upstream_oauth_provider()
230                .count(UpstreamOAuthProviderFilter::new().disabled_only())
231                .await
232                .unwrap(),
233            0
234        );
235
236        // Disable the provider
237        repo.upstream_oauth_provider()
238            .disable(&clock, provider.clone())
239            .await
240            .unwrap();
241
242        // There should be exactly one disabled provider
243        assert_eq!(
244            repo.upstream_oauth_provider()
245                .count(UpstreamOAuthProviderFilter::new())
246                .await
247                .unwrap(),
248            1
249        );
250        assert_eq!(
251            repo.upstream_oauth_provider()
252                .count(UpstreamOAuthProviderFilter::new().enabled_only())
253                .await
254                .unwrap(),
255            0
256        );
257        assert_eq!(
258            repo.upstream_oauth_provider()
259                .count(UpstreamOAuthProviderFilter::new().disabled_only())
260                .await
261                .unwrap(),
262            1
263        );
264
265        // Try deleting the provider
266        repo.upstream_oauth_provider()
267            .delete(provider)
268            .await
269            .unwrap();
270        assert_eq!(
271            repo.upstream_oauth_provider()
272                .count(UpstreamOAuthProviderFilter::new())
273                .await
274                .unwrap(),
275            0
276        );
277    }
278
279    /// Test that the pagination works as expected in the upstream OAuth
280    /// provider repository
281    #[sqlx::test(migrator = "crate::MIGRATOR")]
282    async fn test_provider_repository_pagination(pool: PgPool) {
283        let scope = Scope::from_iter([OPENID]);
284
285        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
286        let clock = MockClock::default();
287        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
288
289        let filter = UpstreamOAuthProviderFilter::new();
290
291        // Count the number of providers before we start
292        assert_eq!(
293            repo.upstream_oauth_provider().count(filter).await.unwrap(),
294            0
295        );
296
297        let mut ids = Vec::with_capacity(20);
298        // Create 20 providers
299        for idx in 0..20 {
300            let client_id = format!("client-{idx}");
301            let provider = repo
302                .upstream_oauth_provider()
303                .add(
304                    &mut rng,
305                    &clock,
306                    UpstreamOAuthProviderParams {
307                        issuer: None,
308                        human_name: None,
309                        brand_name: None,
310                        scope: scope.clone(),
311                        token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
312                        fetch_userinfo: false,
313                        userinfo_signed_response_alg: None,
314                        token_endpoint_signing_alg: None,
315                        id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
316                        client_id,
317                        encrypted_client_secret: None,
318                        claims_imports: UpstreamOAuthProviderClaimsImports::default(),
319                        token_endpoint_override: None,
320                        authorization_endpoint_override: None,
321                        userinfo_endpoint_override: None,
322                        jwks_uri_override: None,
323                        discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
324                        pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
325                        response_mode: None,
326                        additional_authorization_parameters: Vec::new(),
327                        forward_login_hint: false,
328                        ui_order: 0,
329                    },
330                )
331                .await
332                .unwrap();
333            ids.push(provider.id);
334            clock.advance(Duration::microseconds(10 * 1000 * 1000));
335        }
336
337        // Now we have 20 providers
338        assert_eq!(
339            repo.upstream_oauth_provider().count(filter).await.unwrap(),
340            20
341        );
342
343        // Lookup the first 10 items
344        let page = repo
345            .upstream_oauth_provider()
346            .list(filter, Pagination::first(10))
347            .await
348            .unwrap();
349
350        // It returned the first 10 items
351        assert!(page.has_next_page);
352        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
353        assert_eq!(&edge_ids, &ids[..10]);
354
355        // Getting the same page with the "enabled only" filter should return the same
356        // results
357        let other_page = repo
358            .upstream_oauth_provider()
359            .list(filter.enabled_only(), Pagination::first(10))
360            .await
361            .unwrap();
362
363        assert_eq!(page, other_page);
364
365        // Lookup the next 10 items
366        let page = repo
367            .upstream_oauth_provider()
368            .list(filter, Pagination::first(10).after(ids[9]))
369            .await
370            .unwrap();
371
372        // It returned the next 10 items
373        assert!(!page.has_next_page);
374        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
375        assert_eq!(&edge_ids, &ids[10..]);
376
377        // Lookup the last 10 items
378        let page = repo
379            .upstream_oauth_provider()
380            .list(filter, Pagination::last(10))
381            .await
382            .unwrap();
383
384        // It returned the last 10 items
385        assert!(page.has_previous_page);
386        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
387        assert_eq!(&edge_ids, &ids[10..]);
388
389        // Lookup the previous 10 items
390        let page = repo
391            .upstream_oauth_provider()
392            .list(filter, Pagination::last(10).before(ids[10]))
393            .await
394            .unwrap();
395
396        // It returned the previous 10 items
397        assert!(!page.has_previous_page);
398        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
399        assert_eq!(&edge_ids, &ids[..10]);
400
401        // Lookup 10 items between two IDs
402        let page = repo
403            .upstream_oauth_provider()
404            .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
405            .await
406            .unwrap();
407
408        // It returned the items in between
409        assert!(!page.has_next_page);
410        let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
411        assert_eq!(&edge_ids, &ids[6..8]);
412
413        // There should not be any disabled providers
414        assert!(
415            repo.upstream_oauth_provider()
416                .list(
417                    UpstreamOAuthProviderFilter::new().disabled_only(),
418                    Pagination::first(1)
419                )
420                .await
421                .unwrap()
422                .edges
423                .is_empty()
424        );
425    }
426}