mas_storage_pg/upstream_oauth2/
mod.rs1mod link;
11mod provider;
12mod session;
13
14pub use self::{
15 link::PgUpstreamOAuthLinkRepository, provider::PgUpstreamOAuthProviderRepository,
16 session::PgUpstreamOAuthSessionRepository,
17};
18
19#[cfg(test)]
20mod tests {
21 use chrono::Duration;
22 use mas_data_model::{
23 UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderTokenAuthMethod,
24 };
25 use mas_iana::jose::JsonWebSignatureAlg;
26 use mas_storage::{
27 Pagination, RepositoryAccess,
28 clock::MockClock,
29 upstream_oauth2::{
30 UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
31 UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository,
32 UpstreamOAuthSessionRepository,
33 },
34 user::UserRepository,
35 };
36 use oauth2_types::scope::{OPENID, Scope};
37 use rand::SeedableRng;
38 use sqlx::PgPool;
39
40 use crate::PgRepository;
41
42 #[sqlx::test(migrator = "crate::MIGRATOR")]
43 async fn test_repository(pool: PgPool) {
44 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
45 let clock = MockClock::default();
46 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
47
48 let all_providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
50 assert!(all_providers.is_empty());
51
52 let provider = repo
54 .upstream_oauth_provider()
55 .add(
56 &mut rng,
57 &clock,
58 UpstreamOAuthProviderParams {
59 issuer: Some("https://example.com/".to_owned()),
60 human_name: None,
61 brand_name: None,
62 scope: Scope::from_iter([OPENID]),
63 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
64 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
65 fetch_userinfo: false,
66 userinfo_signed_response_alg: None,
67 token_endpoint_signing_alg: None,
68 client_id: "client-id".to_owned(),
69 encrypted_client_secret: None,
70 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
71 token_endpoint_override: None,
72 authorization_endpoint_override: None,
73 userinfo_endpoint_override: None,
74 jwks_uri_override: None,
75 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
76 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
77 response_mode: None,
78 additional_authorization_parameters: Vec::new(),
79 forward_login_hint: false,
80 ui_order: 0,
81 },
82 )
83 .await
84 .unwrap();
85
86 let provider = repo
88 .upstream_oauth_provider()
89 .lookup(provider.id)
90 .await
91 .unwrap()
92 .expect("provider to be found in the database");
93 assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
94 assert_eq!(provider.client_id, "client-id");
95
96 let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
98 assert_eq!(providers.len(), 1);
99 assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
100 assert_eq!(providers[0].client_id, "client-id");
101
102 let session = repo
104 .upstream_oauth_session()
105 .add(
106 &mut rng,
107 &clock,
108 &provider,
109 "some-state".to_owned(),
110 None,
111 Some("some-nonce".to_owned()),
112 )
113 .await
114 .unwrap();
115
116 let session = repo
118 .upstream_oauth_session()
119 .lookup(session.id)
120 .await
121 .unwrap()
122 .expect("session to be found in the database");
123 assert_eq!(session.provider_id, provider.id);
124 assert_eq!(session.link_id(), None);
125 assert!(session.is_pending());
126 assert!(!session.is_completed());
127 assert!(!session.is_consumed());
128
129 let link = repo
131 .upstream_oauth_link()
132 .add(&mut rng, &clock, &provider, "a-subject".to_owned(), None)
133 .await
134 .unwrap();
135
136 repo.upstream_oauth_link()
138 .lookup(link.id)
139 .await
140 .unwrap()
141 .expect("link to be found in database");
142
143 let link = repo
145 .upstream_oauth_link()
146 .find_by_subject(&provider, "a-subject")
147 .await
148 .unwrap()
149 .expect("link to be found in database");
150 assert_eq!(link.subject, "a-subject");
151 assert_eq!(link.provider_id, provider.id);
152
153 let session = repo
154 .upstream_oauth_session()
155 .complete_with_link(&clock, session, &link, None, None, None)
156 .await
157 .unwrap();
158 let session = repo
160 .upstream_oauth_session()
161 .lookup(session.id)
162 .await
163 .unwrap()
164 .expect("session to be found in the database");
165 assert!(session.is_completed());
166 assert!(!session.is_consumed());
167 assert_eq!(session.link_id(), Some(link.id));
168
169 let session = repo
170 .upstream_oauth_session()
171 .consume(&clock, session)
172 .await
173 .unwrap();
174 let session = repo
176 .upstream_oauth_session()
177 .lookup(session.id)
178 .await
179 .unwrap()
180 .expect("session to be found in the database");
181 assert!(session.is_consumed());
182
183 let user = repo
184 .user()
185 .add(&mut rng, &clock, "john".to_owned())
186 .await
187 .unwrap();
188 repo.upstream_oauth_link()
189 .associate_to_user(&link, &user)
190 .await
191 .unwrap();
192
193 let filter = UpstreamOAuthLinkFilter::new()
195 .for_user(&user)
196 .for_provider(&provider)
197 .for_subject("a-subject")
198 .enabled_providers_only();
199
200 let links = repo
201 .upstream_oauth_link()
202 .list(filter, Pagination::first(10))
203 .await
204 .unwrap();
205 assert!(!links.has_previous_page);
206 assert!(!links.has_next_page);
207 assert_eq!(links.edges.len(), 1);
208 assert_eq!(links.edges[0].id, link.id);
209 assert_eq!(links.edges[0].user_id, Some(user.id));
210
211 assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
212
213 assert_eq!(
215 repo.upstream_oauth_provider()
216 .count(UpstreamOAuthProviderFilter::new())
217 .await
218 .unwrap(),
219 1
220 );
221 assert_eq!(
222 repo.upstream_oauth_provider()
223 .count(UpstreamOAuthProviderFilter::new().enabled_only())
224 .await
225 .unwrap(),
226 1
227 );
228 assert_eq!(
229 repo.upstream_oauth_provider()
230 .count(UpstreamOAuthProviderFilter::new().disabled_only())
231 .await
232 .unwrap(),
233 0
234 );
235
236 repo.upstream_oauth_provider()
238 .disable(&clock, provider.clone())
239 .await
240 .unwrap();
241
242 assert_eq!(
244 repo.upstream_oauth_provider()
245 .count(UpstreamOAuthProviderFilter::new())
246 .await
247 .unwrap(),
248 1
249 );
250 assert_eq!(
251 repo.upstream_oauth_provider()
252 .count(UpstreamOAuthProviderFilter::new().enabled_only())
253 .await
254 .unwrap(),
255 0
256 );
257 assert_eq!(
258 repo.upstream_oauth_provider()
259 .count(UpstreamOAuthProviderFilter::new().disabled_only())
260 .await
261 .unwrap(),
262 1
263 );
264
265 repo.upstream_oauth_provider()
267 .delete(provider)
268 .await
269 .unwrap();
270 assert_eq!(
271 repo.upstream_oauth_provider()
272 .count(UpstreamOAuthProviderFilter::new())
273 .await
274 .unwrap(),
275 0
276 );
277 }
278
279 #[sqlx::test(migrator = "crate::MIGRATOR")]
282 async fn test_provider_repository_pagination(pool: PgPool) {
283 let scope = Scope::from_iter([OPENID]);
284
285 let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
286 let clock = MockClock::default();
287 let mut repo = PgRepository::from_pool(&pool).await.unwrap();
288
289 let filter = UpstreamOAuthProviderFilter::new();
290
291 assert_eq!(
293 repo.upstream_oauth_provider().count(filter).await.unwrap(),
294 0
295 );
296
297 let mut ids = Vec::with_capacity(20);
298 for idx in 0..20 {
300 let client_id = format!("client-{idx}");
301 let provider = repo
302 .upstream_oauth_provider()
303 .add(
304 &mut rng,
305 &clock,
306 UpstreamOAuthProviderParams {
307 issuer: None,
308 human_name: None,
309 brand_name: None,
310 scope: scope.clone(),
311 token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
312 fetch_userinfo: false,
313 userinfo_signed_response_alg: None,
314 token_endpoint_signing_alg: None,
315 id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
316 client_id,
317 encrypted_client_secret: None,
318 claims_imports: UpstreamOAuthProviderClaimsImports::default(),
319 token_endpoint_override: None,
320 authorization_endpoint_override: None,
321 userinfo_endpoint_override: None,
322 jwks_uri_override: None,
323 discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
324 pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
325 response_mode: None,
326 additional_authorization_parameters: Vec::new(),
327 forward_login_hint: false,
328 ui_order: 0,
329 },
330 )
331 .await
332 .unwrap();
333 ids.push(provider.id);
334 clock.advance(Duration::microseconds(10 * 1000 * 1000));
335 }
336
337 assert_eq!(
339 repo.upstream_oauth_provider().count(filter).await.unwrap(),
340 20
341 );
342
343 let page = repo
345 .upstream_oauth_provider()
346 .list(filter, Pagination::first(10))
347 .await
348 .unwrap();
349
350 assert!(page.has_next_page);
352 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
353 assert_eq!(&edge_ids, &ids[..10]);
354
355 let other_page = repo
358 .upstream_oauth_provider()
359 .list(filter.enabled_only(), Pagination::first(10))
360 .await
361 .unwrap();
362
363 assert_eq!(page, other_page);
364
365 let page = repo
367 .upstream_oauth_provider()
368 .list(filter, Pagination::first(10).after(ids[9]))
369 .await
370 .unwrap();
371
372 assert!(!page.has_next_page);
374 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
375 assert_eq!(&edge_ids, &ids[10..]);
376
377 let page = repo
379 .upstream_oauth_provider()
380 .list(filter, Pagination::last(10))
381 .await
382 .unwrap();
383
384 assert!(page.has_previous_page);
386 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
387 assert_eq!(&edge_ids, &ids[10..]);
388
389 let page = repo
391 .upstream_oauth_provider()
392 .list(filter, Pagination::last(10).before(ids[10]))
393 .await
394 .unwrap();
395
396 assert!(!page.has_previous_page);
398 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
399 assert_eq!(&edge_ids, &ids[..10]);
400
401 let page = repo
403 .upstream_oauth_provider()
404 .list(filter, Pagination::first(10).after(ids[5]).before(ids[8]))
405 .await
406 .unwrap();
407
408 assert!(!page.has_next_page);
410 let edge_ids: Vec<_> = page.edges.iter().map(|p| p.id).collect();
411 assert_eq!(&edge_ids, &ids[6..8]);
412
413 assert!(
415 repo.upstream_oauth_provider()
416 .list(
417 UpstreamOAuthProviderFilter::new().disabled_only(),
418 Pagination::first(1)
419 )
420 .await
421 .unwrap()
422 .edges
423 .is_empty()
424 );
425 }
426}